fix:优化聊天流信息的展示和检索,优化chat_prompt无效的问题,优化部分群定义问题

This commit is contained in:
SengokuCola
2026-05-07 18:06:55 +08:00
parent 93cef02d92
commit b6808d4b73
21 changed files with 1219 additions and 165 deletions

View File

@@ -11,8 +11,7 @@ from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_config import ExpressionConfigUtils
from src.common.utils.utils_session import SessionUtils
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.maisaka.context_messages import LLMContextMessage
@@ -65,14 +64,9 @@ class MaisakaExpressionSelector:
if not platform or not item_id:
continue
rule_type = target_item.rule_type
target_session_id = SessionUtils.calculate_session_id(
platform,
group_id=item_id if rule_type == "group" else None,
user_id=None if rule_type == "group" else item_id,
)
group_session_ids.add(target_session_id)
if target_session_id == session_id:
target_session_ids = ChatConfigUtils.get_target_session_ids(target_item)
group_session_ids.update(target_session_ids)
if ChatConfigUtils.target_matches_session(target_item, session_id):
contains_current_session = True
if contains_global_share_marker:

View File

@@ -30,7 +30,7 @@ from src.common.data_models.message_component_data_model import (
VoiceComponent,
)
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.common.utils.utils_config import ChatConfigUtils
from src.config.config import global_config
from src.config.model_configs import ModelInfo
from src.core.types import ActionInfo
@@ -211,46 +211,7 @@ class BaseMaisakaReplyGenerator:
@staticmethod
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
"""根据聊天流 ID 获取匹配的额外 prompt。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "platform"):
platform = str(chat_prompt_item.platform or "").strip()
item_id = str(chat_prompt_item.item_id or "").strip()
rule_type = str(chat_prompt_item.rule_type or "").strip()
prompt_content = str(chat_prompt_item.prompt or "").strip()
elif isinstance(chat_prompt_item, str):
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4:
continue
platform, item_id, rule_type, prompt_content = parts
platform = platform.strip()
item_id = item_id.strip()
rule_type = rule_type.strip()
prompt_content = prompt_content.strip()
else:
continue
if not platform or not item_id or not prompt_content:
continue
if rule_type == "group":
config_is_group = True
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
elif rule_type == "private":
config_is_group = False
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
else:
continue
if config_is_group != is_group_chat:
continue
if config_chat_id == chat_id:
return prompt_content
return ""
return ChatConfigUtils.get_chat_prompt_for_chat(chat_id, is_group_chat)
def _build_group_chat_attention_block(self, session_id: str) -> str:
"""构建当前聊天场景下的额外注意事项块。"""

View File

@@ -1,7 +1,7 @@
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.utils.utils_config import ExpressionConfigUtils
logger = get_logger("common_utils")
@@ -11,29 +11,7 @@ class TempMethodsExpression:
@staticmethod
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
if not global_config.expression.learning_list:
return None
if chat_stream_id:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue
stream_id = TempMethodsExpression._get_stream_id(
config_item.platform,
str(config_item.item_id),
(config_item.rule_type == "group"),
)
if stream_id is None:
continue
if stream_id != chat_stream_id:
continue
return config_item
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
return config_item
return None
return ExpressionConfigUtils._find_expression_config_item(chat_stream_id)
@staticmethod
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
@@ -46,15 +24,7 @@ class TempMethodsExpression:
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
"""
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
if config_item is None:
return True, True, True
return (
config_item.use_expression,
config_item.enable_learning,
config_item.enable_jargon_learning,
)
return ExpressionConfigUtils.get_expression_config_for_chat(chat_stream_id)
@staticmethod
def _get_stream_id(

View File

@@ -150,9 +150,11 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
file_ext = self.file_name.split(".")[-1].lower()
if file_ext != self.image_format:
logger.warning(
f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
)
log_message = f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
if file_ext == "tmp":
logger.debug(log_message)
else:
logger.warning(log_message)
self._rename_file_to_match_format()
return True

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Iterator, Optional
import time
@@ -18,16 +18,8 @@ class ExpressionConfigUtils:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue
stream_id = ExpressionConfigUtils._get_stream_id(
config_item.platform,
str(config_item.item_id),
(config_item.rule_type == "group"),
)
if stream_id is None:
continue
if stream_id != session_id:
continue
return config_item
if ChatConfigUtils.target_matches_session(config_item, session_id):
return config_item
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
@@ -84,6 +76,180 @@ class ExpressionConfigUtils:
class ChatConfigUtils:
@staticmethod
def _iter_matching_chat_prompts(session_id: str, is_group_chat: Optional[bool]) -> Iterator[str]:
try:
from src.chat.message_receive.chat_manager import chat_manager
from src.common.utils.utils_session import SessionUtils
chat_stream = chat_manager.get_session_by_session_id(session_id)
session_utils = SessionUtils
except Exception as e:
logger.debug(f"解析额外 Prompt 聊天流失败: session_id={session_id} error={e}")
chat_stream = None
session_utils = None
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "platform"):
platform = str(chat_prompt_item.platform or "").strip()
item_id = str(chat_prompt_item.item_id or "").strip()
rule_type = str(chat_prompt_item.rule_type or "").strip()
prompt_content = str(chat_prompt_item.prompt or "").strip()
elif isinstance(chat_prompt_item, str):
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4:
continue
platform, item_id, rule_type, prompt_content = parts
platform = platform.strip()
item_id = item_id.strip()
rule_type = rule_type.strip()
prompt_content = prompt_content.strip()
else:
continue
if not platform or not item_id or not prompt_content:
continue
if rule_type == "group":
config_is_group = True
target_attr = "group_id"
elif rule_type == "private":
config_is_group = False
target_attr = "user_id"
else:
continue
if is_group_chat is not None and config_is_group != is_group_chat:
continue
if chat_stream is not None:
chat_stream_platform = str(chat_stream.platform or "").strip()
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
if chat_stream_platform == platform and chat_stream_target_id == item_id:
yield prompt_content
continue
if session_utils is None:
continue
try:
if rule_type == "group":
config_chat_id = session_utils.calculate_session_id(platform, group_id=item_id)
else:
config_chat_id = session_utils.calculate_session_id(platform, user_id=item_id)
except Exception as e:
logger.debug(f"生成额外 Prompt 聊天流 ID 失败: platform={platform} item_id={item_id} error={e}")
continue
if config_chat_id == session_id:
yield prompt_content
@staticmethod
def get_chat_prompt_for_chat(session_id: str, is_group_chat: Optional[bool]) -> str:
"""根据聊天流 ID 获取匹配的额外 Prompt允许同一聊天流配置多条。"""
if not session_id or not global_config.chat.chat_prompts:
return ""
prompt_contents = list(ChatConfigUtils._iter_matching_chat_prompts(session_id, is_group_chat))
if not prompt_contents:
return ""
logger.debug(f"匹配到 {len(prompt_contents)} 条聊天额外 Prompt: session_id={session_id}")
return "\n".join(prompt_contents)
@staticmethod
def _target_values(target_item) -> tuple[str, str, str]:
platform = str(target_item.platform or "").strip()
item_id = str(target_item.item_id or "").strip()
rule_type = str(target_item.rule_type or "").strip()
return platform, item_id, rule_type
@staticmethod
def _get_chat_stream(session_id: str):
try:
from src.chat.message_receive.chat_manager import chat_manager
return chat_manager.get_session_by_session_id(session_id)
except Exception as e:
logger.debug(f"获取聊天流失败: session_id={session_id} error={e}")
return None
@staticmethod
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
try:
from src.common.utils.utils_session import SessionUtils
if is_group:
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
except Exception as e:
logger.error(f"生成聊天流 ID 失败: {e}")
return None
@staticmethod
def target_matches_session(target_item, session_id: str, is_group_chat: Optional[bool] = None) -> bool:
"""判断 platform/item_id/rule_type 配置目标是否命中当前聊天流。"""
if not session_id:
return False
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
if not platform or not item_id:
return False
if rule_type == "group":
config_is_group = True
target_attr = "group_id"
elif rule_type == "private":
config_is_group = False
target_attr = "user_id"
else:
return False
if is_group_chat is not None and config_is_group != is_group_chat:
return False
chat_stream = ChatConfigUtils._get_chat_stream(session_id)
if chat_stream is not None:
chat_stream_platform = str(chat_stream.platform or "").strip()
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
return chat_stream_platform == platform and chat_stream_target_id == item_id
return ChatConfigUtils._get_stream_id(platform, item_id, config_is_group) == session_id
@staticmethod
def get_target_session_ids(target_item) -> set[str]:
"""获取配置目标对应的已知聊天流 ID并保留无路由 ID 作为兼容回退。"""
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
if not platform or not item_id:
return set()
if rule_type == "group":
is_group = True
target_attr = "group_id"
elif rule_type == "private":
is_group = False
target_attr = "user_id"
else:
return set()
session_ids: set[str] = set()
if fallback_session_id := ChatConfigUtils._get_stream_id(platform, item_id, is_group):
session_ids.add(fallback_session_id)
try:
from src.chat.message_receive.chat_manager import chat_manager
for session_id, chat_stream in chat_manager.sessions.items():
chat_stream_platform = str(chat_stream.platform or "").strip()
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
if chat_stream_platform == platform and chat_stream_target_id == item_id:
session_ids.add(session_id)
except Exception as e:
logger.debug(f"解析配置目标已知聊天流失败: platform={platform} item_id={item_id} error={e}")
return session_ids
@staticmethod
def _resolve_is_group_chat(session_id: Optional[str]) -> Optional[bool]:
if not session_id:
@@ -117,16 +283,10 @@ class ChatConfigUtils:
# 优先匹配会话相关的规则
if session_id:
from src.common.utils.utils_session import SessionUtils
for rule in global_config.chat.talk_value_rules:
if not rule.platform and not rule.item_id:
continue # 一起留空表示全局
if rule.rule_type == "group":
rule_session_id = SessionUtils.calculate_session_id(rule.platform, group_id=str(rule.item_id))
else:
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
if rule_session_id != session_id:
if not ChatConfigUtils.target_matches_session(rule, session_id, is_group_chat):
continue # 不匹配的会话 ID跳过
parsed_range = ChatConfigUtils.parse_range(rule.time)
if not parsed_range:

View File

@@ -45,6 +45,7 @@ class MainSystem:
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
self.webui_task: asyncio.Task[None] | None = None
self.webui_server: WebUIServer | None = None # 独立的 WebUI 服务器
def _setup_webui_server(self) -> None:
@@ -69,16 +70,23 @@ class MainSystem:
enable_stage_status_board()
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
await self._auto_update_webui_dashboard()
# 设置独立的 WebUI 服务器
self._setup_webui_server()
# 其他初始化任务
await asyncio.gather(self._init_components())
self.webui_task = asyncio.create_task(self._run_webui_startup_sequence(), name="webui_startup")
try:
await self._init_components()
except Exception:
self.webui_task.cancel()
await asyncio.gather(self.webui_task, return_exceptions=True)
raise
logger.info(t("startup.initialization_completed_banner", nickname=global_config.bot.nickname))
async def _run_webui_startup_sequence(self) -> None:
"""按顺序检查 WebUI 更新并启动 WebUI同时允许主初始化并行执行。"""
await self._auto_update_webui_dashboard()
self._setup_webui_server()
if self.webui_server:
await self.webui_server.start()
async def _auto_update_webui_dashboard(self) -> None:
"""启动时自动检查并更新 WebUI dashboard。"""
if not global_config.webui.enabled:
@@ -166,7 +174,9 @@ class MainSystem:
]
# 如果 WebUI 服务器已初始化,添加到任务列表
if self.webui_server:
if self.webui_task:
tasks.append(self.webui_task)
elif self.webui_server:
tasks.append(self.webui_server.start())
await asyncio.gather(*tasks)

View File

@@ -10,7 +10,7 @@ from rich.console import RenderableType
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.common.utils.utils_config import ChatConfigUtils
from src.config.config import global_config
from src.core.tooling import ToolAvailabilityContext, ToolRegistry
from src.llm_models.model_client.base_client import BaseClient
@@ -384,49 +384,7 @@ class MaisakaChatLoopService:
@staticmethod
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
"""根据聊天流 ID 获取匹配的额外提示。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "platform"):
platform = str(chat_prompt_item.platform or "").strip()
item_id = str(chat_prompt_item.item_id or "").strip()
rule_type = str(chat_prompt_item.rule_type or "").strip()
prompt_content = str(chat_prompt_item.prompt or "").strip()
elif isinstance(chat_prompt_item, str):
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4:
continue
platform, item_id, rule_type, prompt_content = parts
platform = platform.strip()
item_id = item_id.strip()
rule_type = rule_type.strip()
prompt_content = prompt_content.strip()
else:
continue
if not platform or not item_id or not prompt_content:
continue
if rule_type == "group":
config_is_group = True
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
elif rule_type == "private":
config_is_group = False
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
else:
continue
if is_group_chat is not None and config_is_group != is_group_chat:
continue
if config_chat_id == chat_id:
logger.debug(f"匹配到 Maisaka 聊天额外提示chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
return ""
return ChatConfigUtils.get_chat_prompt_for_chat(chat_id, is_group_chat)
def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None:
"""设置额外工具定义。

View File

@@ -0,0 +1,197 @@
"""推理过程日志浏览接口。"""
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from src.webui.dependencies import require_auth
router = APIRouter(prefix="/reasoning-process", tags=["reasoning-process"], dependencies=[Depends(require_auth)])
PROJECT_ROOT = Path(__file__).resolve().parents[3]
PROMPT_LOG_ROOT = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
ALLOWED_SUFFIXES = {".txt", ".html"}
class ReasoningPromptFile(BaseModel):
"""推理过程日志条目。"""
stage: str
session_id: str
stem: str
timestamp: int | None = None
text_path: str | None = None
html_path: str | None = None
size: int = 0
modified_at: float = 0
class ReasoningPromptListResponse(BaseModel):
"""推理过程日志列表响应。"""
items: list[ReasoningPromptFile]
total: int
page: int
page_size: int
stages: list[str] = Field(default_factory=list)
sessions: list[str] = Field(default_factory=list)
class ReasoningPromptContentResponse(BaseModel):
"""推理过程文本内容响应。"""
path: str
content: str
size: int
modified_at: float
def _to_safe_relative_path(relative_path: str) -> Path:
safe_path = Path(relative_path)
if safe_path.is_absolute() or ".." in safe_path.parts:
raise HTTPException(status_code=400, detail="路径不合法")
return safe_path
def _resolve_prompt_log_path(relative_path: str, allowed_suffixes: set[str]) -> Path:
safe_path = _to_safe_relative_path(relative_path)
resolved_path = (PROMPT_LOG_ROOT / safe_path).resolve()
try:
resolved_path.relative_to(PROMPT_LOG_ROOT)
except ValueError as exc:
raise HTTPException(status_code=400, detail="路径不合法") from exc
if resolved_path.suffix.lower() not in allowed_suffixes:
raise HTTPException(status_code=400, detail="不支持的文件类型")
if not resolved_path.is_file():
raise HTTPException(status_code=404, detail="文件不存在")
return resolved_path
def _relative_posix_path(path: Path) -> str:
return path.relative_to(PROMPT_LOG_ROOT).as_posix()
def _collect_prompt_files() -> tuple[list[ReasoningPromptFile], list[str], list[str]]:
if not PROMPT_LOG_ROOT.is_dir():
return [], [], []
records: dict[tuple[str, str, str], dict[str, object]] = {}
stages: set[str] = set()
sessions: set[str] = set()
for file_path in PROMPT_LOG_ROOT.rglob("*"):
if not file_path.is_file() or file_path.suffix.lower() not in ALLOWED_SUFFIXES:
continue
try:
relative_path = file_path.relative_to(PROMPT_LOG_ROOT)
except ValueError:
continue
parts = relative_path.parts
if len(parts) < 3:
continue
stage, session_id = parts[0], parts[1]
stem = file_path.stem
key = (stage, session_id, stem)
stat = file_path.stat()
stages.add(stage)
sessions.add(session_id)
record = records.setdefault(
key,
{
"stage": stage,
"session_id": session_id,
"stem": stem,
"timestamp": int(stem) if stem.isdigit() else None,
"text_path": None,
"html_path": None,
"size": 0,
"modified_at": 0.0,
},
)
record["size"] = int(record["size"]) + stat.st_size
record["modified_at"] = max(float(record["modified_at"]), stat.st_mtime)
if file_path.suffix.lower() == ".txt":
record["text_path"] = _relative_posix_path(file_path)
elif file_path.suffix.lower() == ".html":
record["html_path"] = _relative_posix_path(file_path)
items = [ReasoningPromptFile(**record) for record in records.values()]
items.sort(key=lambda item: (item.modified_at, item.timestamp or 0), reverse=True)
return items, sorted(stages), sorted(sessions)
@router.get("/files", response_model=ReasoningPromptListResponse)
async def list_reasoning_prompt_files(
stage: str = Query("all"),
session: str = Query("all"),
search: str = Query(""),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=10, le=200),
):
"""列出 logs/maisaka_prompt 下的推理过程日志。"""
items, stages, sessions = _collect_prompt_files()
normalized_search = search.strip().lower()
if stage != "all":
items = [item for item in items if item.stage == stage]
if session != "all":
items = [item for item in items if item.session_id == session]
if normalized_search:
items = [
item
for item in items
if normalized_search in item.stage.lower()
or normalized_search in item.session_id.lower()
or normalized_search in item.stem.lower()
]
total = len(items)
start = (page - 1) * page_size
end = start + page_size
return ReasoningPromptListResponse(
items=items[start:end],
total=total,
page=page,
page_size=page_size,
stages=stages,
sessions=sessions,
)
@router.get("/file", response_model=ReasoningPromptContentResponse)
async def get_reasoning_prompt_file(path: str = Query(...)):
"""读取推理过程 txt 日志内容。"""
file_path = _resolve_prompt_log_path(path, {".txt"})
stat = file_path.stat()
return ReasoningPromptContentResponse(
path=_relative_posix_path(file_path),
content=file_path.read_text(encoding="utf-8", errors="replace"),
size=stat.st_size,
modified_at=stat.st_mtime,
)
@router.get("/html")
async def get_reasoning_prompt_html(path: str = Query(...)):
"""预览推理过程 html 日志内容。"""
file_path = _resolve_prompt_log_path(path, {".html"})
return FileResponse(
file_path,
media_type="text/html; charset=utf-8",
headers={"X-Robots-Tag": "noindex, nofollow"},
)

View File

@@ -20,6 +20,7 @@ from src.webui.routers.memory import router as memory_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.reasoning_process import router as reasoning_process_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
@@ -46,6 +47,7 @@ router.include_router(emoji_router)
router.include_router(plugin_router)
# 注册系统控制路由
router.include_router(system_router)
router.include_router(reasoning_process_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册长期记忆管理路由