fix:优化聊天流信息的展示和检索,优化chat_prompt无效的问题,优化部分群定义问题
This commit is contained in:
@@ -11,8 +11,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.maisaka.context_messages import LLMContextMessage
|
||||
@@ -65,14 +64,9 @@ class MaisakaExpressionSelector:
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
target_session_ids = ChatConfigUtils.get_target_session_ids(target_item)
|
||||
group_session_ids.update(target_session_ids)
|
||||
if ChatConfigUtils.target_matches_session(target_item, session_id):
|
||||
contains_current_session = True
|
||||
|
||||
if contains_global_share_marker:
|
||||
|
||||
@@ -30,7 +30,7 @@ from src.common.data_models.message_component_data_model import (
|
||||
VoiceComponent,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.common.utils.utils_config import ChatConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.model_configs import ModelInfo
|
||||
from src.core.types import ActionInfo
|
||||
@@ -211,46 +211,7 @@ class BaseMaisakaReplyGenerator:
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外 prompt。"""
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if config_is_group != is_group_chat:
|
||||
continue
|
||||
if config_chat_id == chat_id:
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
return ChatConfigUtils.get_chat_prompt_for_chat(chat_id, is_group_chat)
|
||||
|
||||
def _build_group_chat_attention_block(self, session_id: str) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
@@ -11,29 +11,7 @@ class TempMethodsExpression:
|
||||
|
||||
@staticmethod
|
||||
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return None
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue
|
||||
stream_id = TempMethodsExpression._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
(config_item.rule_type == "group"),
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id != chat_stream_id:
|
||||
continue
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item
|
||||
|
||||
return None
|
||||
return ExpressionConfigUtils._find_expression_config_item(chat_stream_id)
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
@@ -46,15 +24,7 @@ class TempMethodsExpression:
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
return ExpressionConfigUtils.get_expression_config_for_chat(chat_stream_id)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
|
||||
@@ -150,9 +150,11 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
|
||||
file_ext = self.file_name.split(".")[-1].lower()
|
||||
if file_ext != self.image_format:
|
||||
logger.warning(
|
||||
f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
|
||||
)
|
||||
log_message = f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
|
||||
if file_ext == "tmp":
|
||||
logger.debug(log_message)
|
||||
else:
|
||||
logger.warning(log_message)
|
||||
self._rename_file_to_match_format()
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import time
|
||||
|
||||
@@ -18,16 +18,8 @@ class ExpressionConfigUtils:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue
|
||||
stream_id = ExpressionConfigUtils._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
(config_item.rule_type == "group"),
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id != session_id:
|
||||
continue
|
||||
return config_item
|
||||
if ChatConfigUtils.target_matches_session(config_item, session_id):
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
@@ -84,6 +76,180 @@ class ExpressionConfigUtils:
|
||||
|
||||
|
||||
class ChatConfigUtils:
|
||||
@staticmethod
|
||||
def _iter_matching_chat_prompts(session_id: str, is_group_chat: Optional[bool]) -> Iterator[str]:
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_stream = chat_manager.get_session_by_session_id(session_id)
|
||||
session_utils = SessionUtils
|
||||
except Exception as e:
|
||||
logger.debug(f"解析额外 Prompt 聊天流失败: session_id={session_id} error={e}")
|
||||
chat_stream = None
|
||||
session_utils = None
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
continue
|
||||
|
||||
if chat_stream is not None:
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
if chat_stream_platform == platform and chat_stream_target_id == item_id:
|
||||
yield prompt_content
|
||||
continue
|
||||
|
||||
if session_utils is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
if rule_type == "group":
|
||||
config_chat_id = session_utils.calculate_session_id(platform, group_id=item_id)
|
||||
else:
|
||||
config_chat_id = session_utils.calculate_session_id(platform, user_id=item_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"生成额外 Prompt 聊天流 ID 失败: platform={platform} item_id={item_id} error={e}")
|
||||
continue
|
||||
|
||||
if config_chat_id == session_id:
|
||||
yield prompt_content
|
||||
|
||||
@staticmethod
|
||||
def get_chat_prompt_for_chat(session_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外 Prompt,允许同一聊天流配置多条。"""
|
||||
if not session_id or not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
prompt_contents = list(ChatConfigUtils._iter_matching_chat_prompts(session_id, is_group_chat))
|
||||
if not prompt_contents:
|
||||
return ""
|
||||
|
||||
logger.debug(f"匹配到 {len(prompt_contents)} 条聊天额外 Prompt: session_id={session_id}")
|
||||
return "\n".join(prompt_contents)
|
||||
|
||||
@staticmethod
|
||||
def _target_values(target_item) -> tuple[str, str, str]:
|
||||
platform = str(target_item.platform or "").strip()
|
||||
item_id = str(target_item.item_id or "").strip()
|
||||
rule_type = str(target_item.rule_type or "").strip()
|
||||
return platform, item_id, rule_type
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_stream(session_id: str):
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
return chat_manager.get_session_by_session_id(session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取聊天流失败: session_id={session_id} error={e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
if is_group:
|
||||
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def target_matches_session(target_item, session_id: str, is_group_chat: Optional[bool] = None) -> bool:
|
||||
"""判断 platform/item_id/rule_type 配置目标是否命中当前聊天流。"""
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
|
||||
if not platform or not item_id:
|
||||
return False
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
return False
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
return False
|
||||
|
||||
chat_stream = ChatConfigUtils._get_chat_stream(session_id)
|
||||
if chat_stream is not None:
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
return chat_stream_platform == platform and chat_stream_target_id == item_id
|
||||
|
||||
return ChatConfigUtils._get_stream_id(platform, item_id, config_is_group) == session_id
|
||||
|
||||
@staticmethod
|
||||
def get_target_session_ids(target_item) -> set[str]:
|
||||
"""获取配置目标对应的已知聊天流 ID,并保留无路由 ID 作为兼容回退。"""
|
||||
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
|
||||
if not platform or not item_id:
|
||||
return set()
|
||||
|
||||
if rule_type == "group":
|
||||
is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
return set()
|
||||
|
||||
session_ids: set[str] = set()
|
||||
if fallback_session_id := ChatConfigUtils._get_stream_id(platform, item_id, is_group):
|
||||
session_ids.add(fallback_session_id)
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
for session_id, chat_stream in chat_manager.sessions.items():
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
if chat_stream_platform == platform and chat_stream_target_id == item_id:
|
||||
session_ids.add(session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"解析配置目标已知聊天流失败: platform={platform} item_id={item_id} error={e}")
|
||||
|
||||
return session_ids
|
||||
|
||||
@staticmethod
|
||||
def _resolve_is_group_chat(session_id: Optional[str]) -> Optional[bool]:
|
||||
if not session_id:
|
||||
@@ -117,16 +283,10 @@ class ChatConfigUtils:
|
||||
|
||||
# 优先匹配会话相关的规则
|
||||
if session_id:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
for rule in global_config.chat.talk_value_rules:
|
||||
if not rule.platform and not rule.item_id:
|
||||
continue # 一起留空表示全局
|
||||
if rule.rule_type == "group":
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, group_id=str(rule.item_id))
|
||||
else:
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
||||
if rule_session_id != session_id:
|
||||
if not ChatConfigUtils.target_matches_session(rule, session_id, is_group_chat):
|
||||
continue # 不匹配的会话 ID,跳过
|
||||
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||
if not parsed_range:
|
||||
|
||||
26
src/main.py
26
src/main.py
@@ -45,6 +45,7 @@ class MainSystem:
|
||||
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
self.webui_task: asyncio.Task[None] | None = None
|
||||
self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器
|
||||
|
||||
def _setup_webui_server(self) -> None:
|
||||
@@ -69,16 +70,23 @@ class MainSystem:
|
||||
enable_stage_status_board()
|
||||
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
|
||||
|
||||
await self._auto_update_webui_dashboard()
|
||||
|
||||
# 设置独立的 WebUI 服务器
|
||||
self._setup_webui_server()
|
||||
|
||||
# 其他初始化任务
|
||||
await asyncio.gather(self._init_components())
|
||||
self.webui_task = asyncio.create_task(self._run_webui_startup_sequence(), name="webui_startup")
|
||||
try:
|
||||
await self._init_components()
|
||||
except Exception:
|
||||
self.webui_task.cancel()
|
||||
await asyncio.gather(self.webui_task, return_exceptions=True)
|
||||
raise
|
||||
|
||||
logger.info(t("startup.initialization_completed_banner", nickname=global_config.bot.nickname))
|
||||
|
||||
async def _run_webui_startup_sequence(self) -> None:
|
||||
"""按顺序检查 WebUI 更新并启动 WebUI,同时允许主初始化并行执行。"""
|
||||
await self._auto_update_webui_dashboard()
|
||||
self._setup_webui_server()
|
||||
if self.webui_server:
|
||||
await self.webui_server.start()
|
||||
|
||||
async def _auto_update_webui_dashboard(self) -> None:
|
||||
"""启动时自动检查并更新 WebUI dashboard。"""
|
||||
if not global_config.webui.enabled:
|
||||
@@ -166,7 +174,9 @@ class MainSystem:
|
||||
]
|
||||
|
||||
# 如果 WebUI 服务器已初始化,添加到任务列表
|
||||
if self.webui_server:
|
||||
if self.webui_task:
|
||||
tasks.append(self.webui_task)
|
||||
elif self.webui_server:
|
||||
tasks.append(self.webui_server.start())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@@ -10,7 +10,7 @@ from rich.console import RenderableType
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.common.utils.utils_config import ChatConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
@@ -384,49 +384,7 @@ class MaisakaChatLoopService:
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外提示。"""
|
||||
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
continue
|
||||
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到 Maisaka 聊天额外提示,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
return ChatConfigUtils.get_chat_prompt_for_chat(chat_id, is_group_chat)
|
||||
|
||||
def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None:
|
||||
"""设置额外工具定义。
|
||||
|
||||
197
src/webui/routers/reasoning_process.py
Normal file
197
src/webui/routers/reasoning_process.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""推理过程日志浏览接口。"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/reasoning-process", tags=["reasoning-process"], dependencies=[Depends(require_auth)])
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
PROMPT_LOG_ROOT = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
|
||||
ALLOWED_SUFFIXES = {".txt", ".html"}
|
||||
|
||||
|
||||
class ReasoningPromptFile(BaseModel):
|
||||
"""推理过程日志条目。"""
|
||||
|
||||
stage: str
|
||||
session_id: str
|
||||
stem: str
|
||||
timestamp: int | None = None
|
||||
text_path: str | None = None
|
||||
html_path: str | None = None
|
||||
size: int = 0
|
||||
modified_at: float = 0
|
||||
|
||||
|
||||
class ReasoningPromptListResponse(BaseModel):
|
||||
"""推理过程日志列表响应。"""
|
||||
|
||||
items: list[ReasoningPromptFile]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
stages: list[str] = Field(default_factory=list)
|
||||
sessions: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ReasoningPromptContentResponse(BaseModel):
|
||||
"""推理过程文本内容响应。"""
|
||||
|
||||
path: str
|
||||
content: str
|
||||
size: int
|
||||
modified_at: float
|
||||
|
||||
|
||||
def _to_safe_relative_path(relative_path: str) -> Path:
|
||||
safe_path = Path(relative_path)
|
||||
if safe_path.is_absolute() or ".." in safe_path.parts:
|
||||
raise HTTPException(status_code=400, detail="路径不合法")
|
||||
return safe_path
|
||||
|
||||
|
||||
def _resolve_prompt_log_path(relative_path: str, allowed_suffixes: set[str]) -> Path:
|
||||
safe_path = _to_safe_relative_path(relative_path)
|
||||
resolved_path = (PROMPT_LOG_ROOT / safe_path).resolve()
|
||||
|
||||
try:
|
||||
resolved_path.relative_to(PROMPT_LOG_ROOT)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="路径不合法") from exc
|
||||
|
||||
if resolved_path.suffix.lower() not in allowed_suffixes:
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||
if not resolved_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def _relative_posix_path(path: Path) -> str:
|
||||
return path.relative_to(PROMPT_LOG_ROOT).as_posix()
|
||||
|
||||
|
||||
def _collect_prompt_files() -> tuple[list[ReasoningPromptFile], list[str], list[str]]:
|
||||
if not PROMPT_LOG_ROOT.is_dir():
|
||||
return [], [], []
|
||||
|
||||
records: dict[tuple[str, str, str], dict[str, object]] = {}
|
||||
stages: set[str] = set()
|
||||
sessions: set[str] = set()
|
||||
|
||||
for file_path in PROMPT_LOG_ROOT.rglob("*"):
|
||||
if not file_path.is_file() or file_path.suffix.lower() not in ALLOWED_SUFFIXES:
|
||||
continue
|
||||
|
||||
try:
|
||||
relative_path = file_path.relative_to(PROMPT_LOG_ROOT)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
parts = relative_path.parts
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
|
||||
stage, session_id = parts[0], parts[1]
|
||||
stem = file_path.stem
|
||||
key = (stage, session_id, stem)
|
||||
stat = file_path.stat()
|
||||
|
||||
stages.add(stage)
|
||||
sessions.add(session_id)
|
||||
record = records.setdefault(
|
||||
key,
|
||||
{
|
||||
"stage": stage,
|
||||
"session_id": session_id,
|
||||
"stem": stem,
|
||||
"timestamp": int(stem) if stem.isdigit() else None,
|
||||
"text_path": None,
|
||||
"html_path": None,
|
||||
"size": 0,
|
||||
"modified_at": 0.0,
|
||||
},
|
||||
)
|
||||
record["size"] = int(record["size"]) + stat.st_size
|
||||
record["modified_at"] = max(float(record["modified_at"]), stat.st_mtime)
|
||||
|
||||
if file_path.suffix.lower() == ".txt":
|
||||
record["text_path"] = _relative_posix_path(file_path)
|
||||
elif file_path.suffix.lower() == ".html":
|
||||
record["html_path"] = _relative_posix_path(file_path)
|
||||
|
||||
items = [ReasoningPromptFile(**record) for record in records.values()]
|
||||
items.sort(key=lambda item: (item.modified_at, item.timestamp or 0), reverse=True)
|
||||
return items, sorted(stages), sorted(sessions)
|
||||
|
||||
|
||||
@router.get("/files", response_model=ReasoningPromptListResponse)
|
||||
async def list_reasoning_prompt_files(
|
||||
stage: str = Query("all"),
|
||||
session: str = Query("all"),
|
||||
search: str = Query(""),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=10, le=200),
|
||||
):
|
||||
"""列出 logs/maisaka_prompt 下的推理过程日志。"""
|
||||
|
||||
items, stages, sessions = _collect_prompt_files()
|
||||
normalized_search = search.strip().lower()
|
||||
|
||||
if stage != "all":
|
||||
items = [item for item in items if item.stage == stage]
|
||||
if session != "all":
|
||||
items = [item for item in items if item.session_id == session]
|
||||
if normalized_search:
|
||||
items = [
|
||||
item
|
||||
for item in items
|
||||
if normalized_search in item.stage.lower()
|
||||
or normalized_search in item.session_id.lower()
|
||||
or normalized_search in item.stem.lower()
|
||||
]
|
||||
|
||||
total = len(items)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return ReasoningPromptListResponse(
|
||||
items=items[start:end],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
stages=stages,
|
||||
sessions=sessions,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/file", response_model=ReasoningPromptContentResponse)
|
||||
async def get_reasoning_prompt_file(path: str = Query(...)):
|
||||
"""读取推理过程 txt 日志内容。"""
|
||||
|
||||
file_path = _resolve_prompt_log_path(path, {".txt"})
|
||||
stat = file_path.stat()
|
||||
|
||||
return ReasoningPromptContentResponse(
|
||||
path=_relative_posix_path(file_path),
|
||||
content=file_path.read_text(encoding="utf-8", errors="replace"),
|
||||
size=stat.st_size,
|
||||
modified_at=stat.st_mtime,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/html")
|
||||
async def get_reasoning_prompt_html(path: str = Query(...)):
|
||||
"""预览推理过程 html 日志内容。"""
|
||||
|
||||
file_path = _resolve_prompt_log_path(path, {".html"})
|
||||
return FileResponse(
|
||||
file_path,
|
||||
media_type="text/html; charset=utf-8",
|
||||
headers={"X-Robots-Tag": "noindex, nofollow"},
|
||||
)
|
||||
@@ -20,6 +20,7 @@ from src.webui.routers.memory import router as memory_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.reasoning_process import router as reasoning_process_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
@@ -46,6 +47,7 @@ router.include_router(emoji_router)
|
||||
router.include_router(plugin_router)
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
router.include_router(reasoning_process_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册长期记忆管理路由
|
||||
|
||||
Reference in New Issue
Block a user