feat:修复一些bug

This commit is contained in:
SengokuCola
2026-03-29 18:28:56 +08:00
parent 82bbf0fd52
commit 96844a9bf5
8 changed files with 898 additions and 444 deletions

View File

@@ -1,11 +1,16 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import random
import time
from sqlmodel import select
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
@@ -28,6 +33,23 @@ from src.maisaka.message_adapter import (
logger = get_logger("maisaka_replyer")
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
"""表达方式的轻量记录。"""
expression_id: Optional[int]
situation: str
style: str
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
@@ -182,6 +204,89 @@ class MaisakaReplyGenerator:
user_prompt = "\n\n".join(user_sections)
return f"System: {system_prompt}\n\nUser: {user_prompt}"
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
) -> MaisakaReplyContext:
"""在 replyer 内部构建表达习惯和黑话解释。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("Failed to build Maisaka reply context: session_id is missing")
return MaisakaReplyContext()
expression_habits, selected_expression_ids = self._build_expression_habits(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
return MaisakaReplyContext(
expression_habits=expression_habits,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
session_id: str,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达习惯。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records(session_id)
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"Built Maisaka expression habits: session_id={session_id} "
f"count={len(selected_ids)} ids={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
async def generate_reply_with_context(
self,
extra_info: str = "",
@@ -212,8 +317,6 @@ class MaisakaReplyGenerator:
del unknown_words
result = ReplyGenerationResult()
result.selected_expression_ids = list(selected_expression_ids or [])
if chat_history is None:
result.error_message = "chat_history is empty"
return False, result
@@ -221,8 +324,7 @@ class MaisakaReplyGenerator:
logger.info(
f"Maisaka replyer start: stream_id={stream_id} reply_reason={reply_reason!r} "
f"history_size={len(chat_history)} target_message_id="
f"{reply_message.message_id if reply_message else None} "
f"expression_count={len(result.selected_expression_ids)}"
f"{reply_message.message_id if reply_message else None}"
)
filtered_history = [
@@ -232,11 +334,52 @@ class MaisakaReplyGenerator:
and get_message_kind(message) != "perception"
and get_message_source(message) != "user_reference"
]
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=expression_habits,
logger.debug(f"Maisaka replyer: filtered_history size={len(filtered_history)}")
# Validate that express_model is properly initialized
if self.express_model is None:
logger.error("Maisaka replyer: express_model is None!")
result.error_message = "express_model is not initialized"
return False, result
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_reply_context failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_reply_context failed: {exc}"
return False, result
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka reply context built: stream_id={stream_id} "
f"selected_expression_ids={result.selected_expression_ids!r}"
)
try:
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_prompt failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_prompt failed: {exc}"
return False, result
result.completion.request_prompt = prompt
if global_config.debug.show_replyer_prompt:

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float, Text
from sqlmodel import Field, LargeBinary, SQLModel
@@ -17,8 +17,8 @@ class ImageType(str, Enum):
class ModifiedBy(str, Enum):
AI = "ai"
USER = "user"
AI = "AI"
USER = "USER"
class Messages(SQLModel, table=True):
@@ -223,18 +223,40 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容未处理的黑话内容为List[str]
raw_content: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 原始内容未处理的黑话内容为List[str]
meaning: str # 黑话含义
session_id_dict: str = Field(default=r"{}") # 会话ID列表格式为{"session_id": session_count, ...}
meaning: str = Field(sa_column=Column(Text, nullable=False)) # 黑话含义
session_id_dict: str = Field(
default=r"{}", sa_column=Column(Text, nullable=False)
) # 会话ID列表格式为{"session_id": session_count, ...}
count: int = Field(default=0) # 使用次数
is_jargon: Optional[bool] = Field(default=True) # 是否为黑话False表示为白话
is_complete: bool = Field(default=False) # 是否为已经完成全部推断count > 100后不再推断
is_global: bool = Field(default=False) # 是否为全局黑话独立于session_id_dict
last_inference_count: int = Field(default=0) # 上一次进行推断时的count值用于判断是否需要重新推断
inference_with_context: Optional[str] = Field(default=None, nullable=True) # 带上下文的推断结果JSON格式
inference_with_content_only: Optional[str] = Field(default=None, nullable=True) # 只基于词条的推断结果JSON格式
inference_with_context: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 带上下文的推断结果JSON格式
inference_with_content_only: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 只基于词条的推断结果JSON格式
class MaiKnowledge(SQLModel, table=True):
"""存储 Maisaka 的用户画像知识。"""
__tablename__ = "mai_knowledge" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True)
knowledge_id: str = Field(index=True, max_length=255)
category_id: str = Field(index=True, max_length=32)
content: str
normalized_content: str = Field(index=True)
metadata_json: Optional[str] = Field(default=None, nullable=True)
created_at: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True))
class ChatHistory(SQLModel, table=True):

View File

@@ -8,7 +8,11 @@ from typing import Any, Dict, List, Optional
import json
# 数据目录位于项目根目录下的 mai_knowledge
from sqlmodel import select
from src.common.database.database import DATABASE_URL, get_db_session
from src.common.database.database_model import MaiKnowledge
PROJECT_ROOT = Path(__file__).resolve().parents[2]
KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge"
KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json"
@@ -18,7 +22,7 @@ KNOWLEDGE_CATEGORIES = {
"1": "性别",
"2": "性格",
"3": "饮食口味",
"4": "交友",
"4": "交友",
"5": "情绪/理性倾向",
"6": "兴趣爱好",
"7": "职业/专业",
@@ -31,77 +35,128 @@ KNOWLEDGE_CATEGORIES = {
class KnowledgeStore:
"""
简单的 Maisaka 知识存储。
特性:
- 持久化到 JSON 文件
- 按分类存储用户画像类知识
- 支持基础去重
"""
"""存储 Maisaka 的用户画像知识。"""
def __init__(self) -> None:
"""初始化知识存储。"""
self._knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
self._ensure_data_dir()
self._load()
"""初始化知识存储,并在需要时迁移旧版 JSON 数据"""
self._ensure_legacy_data_dir()
self._migrate_legacy_file_if_needed()
def _ensure_data_dir(self) -> None:
"""确保数据目录存在"""
def _ensure_legacy_data_dir(self) -> None:
"""确保旧版知识目录存在,便于兼容历史数据。"""
KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True)
def _load(self) -> None:
"""从文件加载知识数据。"""
if not KNOWLEDGE_FILE.exists():
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
return
try:
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file:
loaded = json.load(file)
normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
for category_id in KNOWLEDGE_CATEGORIES:
category_items = loaded.get(category_id, [])
if isinstance(category_items, list):
normalized_knowledge[category_id] = [
item for item in category_items if isinstance(item, dict)
]
self._knowledge = normalized_knowledge
except Exception:
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
def _save(self) -> None:
"""保存知识数据到文件。"""
with open(KNOWLEDGE_FILE, "w", encoding="utf-8") as file:
json.dump(self._knowledge, file, ensure_ascii=False, indent=2)
@staticmethod
def _normalize_content(content: str) -> str:
"""标准化知识内容,便于去重。"""
return " ".join(str(content).strip().split())
@staticmethod
def _serialize_metadata(metadata: Optional[Dict[str, Any]]) -> Optional[str]:
"""将元数据序列化为 JSON 文本。"""
if not metadata:
return None
return json.dumps(metadata, ensure_ascii=False, sort_keys=True)
@staticmethod
def _deserialize_metadata(raw_text: Optional[str]) -> Dict[str, Any]:
"""将 JSON 文本反序列化为元数据字典。"""
if not raw_text:
return {}
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
@staticmethod
def _parse_created_at(raw_value: Any) -> datetime:
"""解析旧版数据中的创建时间。"""
if isinstance(raw_value, datetime):
return raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text:
try:
return datetime.fromisoformat(raw_text)
except ValueError:
pass
return datetime.now()
@classmethod
def _build_item_dict(cls, record: MaiKnowledge) -> Dict[str, Any]:
"""将数据库记录转换为兼容旧接口的字典。"""
return {
"id": record.knowledge_id,
"content": record.content,
"metadata": cls._deserialize_metadata(record.metadata_json),
"created_at": record.created_at.isoformat(),
}
def _load_legacy_knowledge_file(self) -> Dict[str, List[Dict[str, Any]]]:
"""读取旧版 JSON 知识文件。"""
if not KNOWLEDGE_FILE.exists():
return {}
try:
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file:
loaded = json.load(file)
except Exception:
return {}
if not isinstance(loaded, dict):
return {}
normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {}
for category_id in KNOWLEDGE_CATEGORIES:
category_items = loaded.get(category_id, [])
if isinstance(category_items, list):
normalized_knowledge[category_id] = [
item for item in category_items if isinstance(item, dict)
]
return normalized_knowledge
def _migrate_legacy_file_if_needed(self) -> None:
"""在数据库为空时,将旧版 JSON 中的知识导入数据库。"""
legacy_knowledge = self._load_legacy_knowledge_file()
if not legacy_knowledge:
return
with get_db_session(auto_commit=False) as session:
existing_record = session.exec(select(MaiKnowledge.id).limit(1)).first()
if existing_record is not None:
return
for category_id, items in legacy_knowledge.items():
if category_id not in KNOWLEDGE_CATEGORIES:
continue
for item in items:
content = self._normalize_content(str(item.get("content", "")))
if not content:
continue
metadata = item.get("metadata")
session.add(
MaiKnowledge(
knowledge_id=str(item.get("id") or f"know_{category_id}_{datetime.now().timestamp()}"),
category_id=category_id,
content=content,
normalized_content=content,
metadata_json=self._serialize_metadata(metadata if isinstance(metadata, dict) else None),
created_at=self._parse_created_at(item.get("created_at")),
)
)
session.commit()
def add_knowledge(
self,
category_id: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""
添加一条知识信息。
Args:
category_id: 分类编号
content: 知识内容
metadata: 附加元数据
Returns:
是否新增成功;若命中去重则返回 False
"""
"""添加一条知识信息。"""
if category_id not in KNOWLEDGE_CATEGORIES:
return False
@@ -109,29 +164,59 @@ class KnowledgeStore:
if not normalized_content:
return False
existing_items = self._knowledge.get(category_id, [])
for item in existing_items:
existing_content = self._normalize_content(str(item.get("content", "")))
if existing_content == normalized_content:
with get_db_session(auto_commit=False) as session:
existing_record = session.exec(
select(MaiKnowledge).where(
MaiKnowledge.category_id == category_id,
MaiKnowledge.normalized_content == normalized_content,
)
).first()
if existing_record is not None:
return False
knowledge_item = {
"id": f"know_{category_id}_{datetime.now().timestamp()}",
"content": normalized_content,
"metadata": metadata or {},
"created_at": datetime.now().isoformat(),
}
self._knowledge[category_id].append(knowledge_item)
self._save()
session.add(
MaiKnowledge(
knowledge_id=f"know_{category_id}_{datetime.now().timestamp()}",
category_id=category_id,
content=normalized_content,
normalized_content=normalized_content,
metadata_json=self._serialize_metadata(metadata),
created_at=datetime.now(),
)
)
session.commit()
return True
def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]:
"""获取某个分类下的所有知识。"""
return self._knowledge.get(category_id, [])
if category_id not in KNOWLEDGE_CATEGORIES:
return []
with get_db_session() as session:
records = session.exec(
select(MaiKnowledge)
.where(MaiKnowledge.category_id == category_id)
.order_by(MaiKnowledge.created_at.asc(), MaiKnowledge.id.asc())
).all()
return [self._build_item_dict(record) for record in records]
def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]:
"""获取全部知识。"""
return self._knowledge
all_knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
with get_db_session() as session:
records = session.exec(
select(MaiKnowledge).order_by(
MaiKnowledge.category_id.asc(),
MaiKnowledge.created_at.asc(),
MaiKnowledge.id.asc(),
)
).all()
for record in records:
all_knowledge.setdefault(record.category_id, []).append(self._build_item_dict(record))
return all_knowledge
def get_category_name(self, category_id: str) -> str:
"""获取分类名称。"""
@@ -139,24 +224,23 @@ class KnowledgeStore:
def get_categories_summary(self) -> str:
"""获取分类摘要,供模型判断是否需要检索。"""
counts: Dict[str, int] = {category_id: 0 for category_id in KNOWLEDGE_CATEGORIES}
with get_db_session() as session:
records = session.exec(select(MaiKnowledge.category_id)).all()
for category_id in records:
if category_id in counts:
counts[category_id] += 1
lines: List[str] = []
for category_id, category_name in KNOWLEDGE_CATEGORIES.items():
count = len(self._knowledge.get(category_id, []))
count = counts.get(category_id, 0)
count_text = f"{count}" if count > 0 else "无数据"
lines.append(f"{category_id}. {category_name} ({count_text})")
return "\n".join(lines)
def get_formatted_knowledge(self, category_ids: List[str], limit_per_category: int = 5) -> str:
"""
获取指定分类的格式化知识内容。
Args:
category_ids: 分类编号列表
limit_per_category: 每个分类最多返回多少条
Returns:
格式化后的知识内容
"""
"""获取指定分类的格式化知识内容。"""
parts: List[str] = []
for category_id in category_ids:
items = self.get_category_knowledge(category_id)
@@ -176,13 +260,18 @@ class KnowledgeStore:
def get_stats(self) -> Dict[str, Any]:
"""获取知识数据统计。"""
total_items = sum(len(items) for items in self._knowledge.values())
with get_db_session() as session:
total_items = len(session.exec(select(MaiKnowledge.id)).all())
return {
"total_categories": len(KNOWLEDGE_CATEGORIES),
"total_items": total_items,
"data_file": str(KNOWLEDGE_FILE),
"data_exists": KNOWLEDGE_FILE.exists(),
"data_size_kb": KNOWLEDGE_FILE.stat().st_size / 1024 if KNOWLEDGE_FILE.exists() else 0,
"data_file": DATABASE_URL,
"data_exists": True,
"data_size_kb": 0,
"legacy_data_file": str(KNOWLEDGE_FILE),
"legacy_data_exists": KNOWLEDGE_FILE.exists(),
"storage_type": "database",
}

View File

@@ -30,6 +30,7 @@ from .builtin_tools import get_builtin_tools
from .message_adapter import (
build_message,
format_speaker_content,
get_message_role,
to_llm_message,
)
@@ -303,6 +304,7 @@ class MaisakaChatLoopService:
async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
def message_factory(_client: BaseClient) -> List[Message]:
messages: List[Message] = []
@@ -310,7 +312,7 @@ class MaisakaChatLoopService:
system_msg.add_text_content(self._chat_system_prompt)
messages.append(system_msg.build())
for msg in chat_history:
for msg in selected_history:
llm_message = to_llm_message(msg)
if llm_message is not None:
messages.append(llm_message)
@@ -333,6 +335,7 @@ class MaisakaChatLoopService:
Panel(
Group(*ordered_panels),
title="MaiSaka LLM Request - chat_loop_step",
subtitle=selection_reason,
border_style="cyan",
padding=(0, 1),
)
@@ -374,6 +377,38 @@ class MaisakaChatLoopService:
raw_message=raw_message,
)
@staticmethod
def _select_llm_context_messages(chat_history: List[SessionMessage]) -> tuple[List[SessionMessage], str]:
"""选择真正发送给 LLM 的上下文消息。"""
max_context_size = max(1, int(global_config.chat.max_context_size))
counted_roles = {"user", "assistant"}
selected_indices: List[int] = []
counted_message_count = 0
for index in range(len(chat_history) - 1, -1, -1):
message = chat_history[index]
if to_llm_message(message) is None:
continue
selected_indices.append(index)
if get_message_role(message) in counted_roles:
counted_message_count += 1
if counted_message_count >= max_context_size:
break
if not selected_indices:
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant当前 0 条)"
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
return (
selected_history,
(
f"上下文判定:最近 {max_context_size} 条 user/assistant"
f"展示并发送窗口内消息 {len(selected_history)}"
),
)
@staticmethod
def build_chat_context(user_text: str) -> List[SessionMessage]:
return [

View File

@@ -14,6 +14,7 @@ from sqlmodel import select
from src.chat.heart_flow.heartFC_utils import CycleDetail
from src.chat.message_receive.message import SessionMessage
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import get_bot_account
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.data_models.mai_message_data_model import UserInfo
@@ -33,7 +34,6 @@ from .message_adapter import (
get_message_text,
get_message_role,
)
from .reply_context_builder import MaisakaReplyContextBuilder
from .tool_handlers import (
handle_mcp_tool,
handle_unknown_tool,
@@ -50,8 +50,8 @@ class MaisakaReasoningEngine:
def __init__(self, runtime: "MaisakaHeartFlowChatting") -> None:
self._runtime = runtime
self._reply_context_builder = MaisakaReplyContextBuilder(runtime.session_id)
self._last_reasoning_content: str = ""
self._shown_jargons: set[str] = set() # 已在参考消息中展示过的 jargon
async def run_loop(self) -> None:
"""独立消费消息批次,并执行对应的内部思考轮次。"""
@@ -72,11 +72,19 @@ class MaisakaReasoningEngine:
self._runtime._log_cycle_started(cycle_detail, round_index)
try:
# 每次LLM生成前动态添加参考消息到最新位置
self._append_jargon_reference_message()
reference_added = self._append_jargon_reference_message()
planner_started_at = time.time()
response = await self._runtime._chat_loop_service.chat_loop_step(self._runtime._chat_history)
cycle_detail.time_records["planner"] = time.time() - planner_started_at
# LLM调用后移除刚才添加的参考消息一次性使用
if reference_added and self._runtime._chat_history:
# 从末尾往前查找并移除参考消息
for i in range(len(self._runtime._chat_history) - 1, -1, -1):
if get_message_source(self._runtime._chat_history[i]) == "user_reference":
self._runtime._chat_history.pop(i)
break
reasoning_content = response.content or ""
if self._should_replace_reasoning(reasoning_content):
response.content = "让我根据新情况重新思考:"
@@ -218,15 +226,23 @@ class MaisakaReasoningEngine:
self._runtime._chat_history.insert(insert_at, message)
return insert_at
def _append_jargon_reference_message(self) -> None:
"""每次LLM生成前如果命中了黑话词条则添加一条参考信息消息到聊天历史末尾。"""
def _append_jargon_reference_message(self) -> bool:
"""每次LLM生成前如果命中了黑话词条则添加一条参考信息消息到聊天历史末尾。
Returns:
bool: 是否添加了参考消息
"""
content = self._build_user_history_corpus()
if not content:
return
return False
matched_words = self._find_jargon_words_in_text(content)
if not matched_words:
return
return False
# 记录已展示的 jargon
for word in matched_words:
self._shown_jargons.add(word.lower())
reference_text = (
"[参考信息]\n"
@@ -248,6 +264,7 @@ class MaisakaReasoningEngine:
display_text=reference_text,
)
self._runtime._chat_history.append(reference_message)
return True
def _build_user_history_corpus(self) -> str:
"""拼接当前聊天记录内所有用户消息的正文,用于统一匹配黑话。"""
@@ -282,9 +299,15 @@ class MaisakaReasoningEngine:
jargon_content = str(jargon.content or "").strip()
if not jargon_content:
continue
# meaning 为空的不匹配
if not str(jargon.meaning or "").strip():
continue
normalized_content = jargon_content.lower()
if normalized_content in seen_words:
continue
# 跳过已经展示过的 jargon
if normalized_content in self._shown_jargons:
continue
if not self._is_visible_jargon(jargon):
continue
match_position = self._get_jargon_match_position(jargon_content, lowered_content, content)
@@ -573,34 +596,8 @@ class MaisakaReasoningEngine:
return False
logger.info(f"{self._runtime.log_prefix} acquired Maisaka reply generator successfully")
logger.info(
f"{self._runtime.log_prefix} building reply context: "
f"target_msg_id={target_message_id} unknown_words={unknown_words!r}"
)
try:
reply_context = await self._reply_context_builder.build(
chat_history=self._runtime._chat_history,
reply_message=target_message,
reply_reason=latest_thought,
)
except Exception:
logger.exception(
f"{self._runtime.log_prefix} reply context builder crashed: "
f"target_msg_id={target_message_id}"
)
self._runtime._chat_history.append(
self._build_tool_message(tool_call, "Reply context preparation crashed.")
)
return False
logger.info(
f"{self._runtime.log_prefix} reply context built: "
f"target_msg_id={target_message_id} "
f"selected_expression_ids={reply_context.selected_expression_ids!r} "
f"has_jargon_explanation={bool(reply_context.jargon_explanation.strip())}"
)
logger.info(f"{self._runtime.log_prefix} calling generate_reply_with_context: target_msg_id={target_message_id}")
try:
success, reply_result = await replyer.generate_reply_with_context(
reply_reason=latest_thought,
@@ -609,11 +606,13 @@ class MaisakaReasoningEngine:
chat_history=self._runtime._chat_history,
unknown_words=unknown_words,
log_reply=False,
expression_habits=reply_context.expression_habits,
selected_expression_ids=reply_context.selected_expression_ids,
)
except Exception:
logger.exception(f"{self._runtime.log_prefix} reply generator crashed: target_msg_id={target_message_id}")
except Exception as exc:
import traceback
logger.error(
f"{self._runtime.log_prefix} reply generator crashed: target_msg_id={target_message_id} "
f"exc_type={type(exc).__name__} exc_msg={str(exc)}\n{traceback.format_exc()}"
)
self._runtime._chat_history.append(
self._build_tool_message(tool_call, "Visible reply generation crashed.")
)
@@ -686,18 +685,26 @@ class MaisakaReasoningEngine:
tool_reasoning=latest_thought,
)
target_platform = target_message.platform or anchor_message.platform
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
self._runtime._chat_history.append(
build_message(
role="user",
content=format_speaker_content(bot_name, reply_text, datetime.now()),
source="guided_reply",
platform=target_message.platform or anchor_message.platform,
session_id=self._runtime.session_id,
group_info=self._runtime._build_group_info(target_message),
user_info=self._runtime._build_runtime_user_info(),
)
bot_user_info = UserInfo(
user_id=get_bot_account(target_platform) or "maisaka_assistant",
user_nickname=bot_name,
user_cardname=None,
)
history_message = build_message(
role="assistant",
content=reply_text,
source="guided_reply",
platform=target_platform,
session_id=self._runtime.session_id,
group_info=self._runtime._build_group_info(target_message),
user_info=bot_user_info,
)
structured_visible_text = f"{self._build_planner_user_prefix(history_message)}{reply_text}"
history_message.display_message = structured_visible_text
history_message.processed_plain_text = structured_visible_text
self._runtime._chat_history.append(history_message)
return True
async def _handle_send_emoji(self, tool_call: ToolCall, anchor_message: SessionMessage) -> None:

View File

@@ -1,248 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional
import json
import re
from sqlmodel import select
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression, Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from .message_adapter import get_message_role, get_message_source, get_message_text, parse_speaker_content
logger = get_logger("maisaka_reply_context")
@dataclass
class ReplyContextBuildResult:
"""Reply 前置上下文构建结果。"""
expression_habits: str = ""
jargon_explanation: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
expression_id: Optional[int]
situation: str
style: str
@dataclass
class _JargonRecord:
jargon_id: Optional[int]
content: str
count: int
meaning: str
session_id_dict: str
is_global: bool
class MaisakaReplyContextBuilder:
"""为 Maisaka reply 构建表达方式和黑话解释。"""
def __init__(self, session_id: str) -> None:
self._session_id = session_id
async def build(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> ReplyContextBuildResult:
"""构建 reply 前置上下文。"""
expression_habits, selected_expression_ids = self._build_expression_habits(
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
jargon_explanation = self._build_jargon_explanation(
chat_history=chat_history,
reply_message=reply_message,
)
return ReplyContextBuildResult(
expression_habits=expression_habits,
jargon_explanation=jargon_explanation,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达方式。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records()
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"Built Maisaka expression habits: session_id={self._session_id} "
f"count={len(selected_ids)} ids={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self) -> List[_ExpressionRecord]:
"""在 session 内提取表达方式的静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == self._session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
def _build_jargon_explanation(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""查询并格式化黑话解释。"""
if not global_config.expression.enable_jargon_explanation:
return ""
return self._build_context_jargon_explanation(chat_history, reply_message)
def _build_context_jargon_explanation(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""基于当前上下文自动匹配黑话。"""
corpus = self._build_context_corpus(chat_history, reply_message)
if not corpus:
return ""
jargon_records = self._load_jargon_records()
matched_records: List[tuple[int, int, int, _JargonRecord]] = []
seen_contents: set[str] = set()
for jargon in jargon_records:
if not jargon.content or not jargon.meaning:
continue
normalized_content = jargon.content.lower()
if normalized_content in seen_contents:
continue
if not self._is_visible_jargon(jargon):
continue
match_position = self._get_jargon_match_position(jargon.content, corpus)
if match_position is None:
continue
seen_contents.add(normalized_content)
matched_records.append((match_position, -len(jargon.content), -jargon.count, jargon))
matched_records.sort()
lines = [f"- {jargon.content}: {jargon.meaning}" for _, _, _, jargon in matched_records[:8]]
if not lines:
return ""
logger.info(
f"Built Maisaka jargon explanation: session_id={self._session_id} "
f"count={len(lines)}"
)
return "【黑话解释】\n" + "\n".join(lines)
def _load_jargon_records(self) -> List[_JargonRecord]:
"""在 session 内提取黑话的静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Jargon).where(Jargon.is_jargon.is_(True), Jargon.meaning != "") # type: ignore[attr-defined]
query = query.order_by(Jargon.count.desc()) # type: ignore[attr-defined]
jargons = session.exec(query).all()
return [
_JargonRecord(
jargon_id=jargon.id,
content=(jargon.content or "").strip(),
count=int(jargon.count or 0),
meaning=(jargon.meaning or "").strip(),
session_id_dict=jargon.session_id_dict or "{}",
is_global=bool(jargon.is_global),
)
for jargon in jargons
]
def _build_context_corpus(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""将当前聊天记录内所有用户消息拼成待匹配文本。"""
parts: List[str] = []
for message in chat_history:
if get_message_role(message) != "user":
continue
if get_message_source(message) != "user":
continue
text = get_message_text(message).strip()
if not text:
continue
_, body = parse_speaker_content(text)
parts.append(body.strip() or text)
if reply_message is not None and get_message_source(reply_message) == "user":
reply_text = get_message_text(reply_message).strip()
if reply_text:
_, body = parse_speaker_content(reply_text)
normalized_reply_text = body.strip() or reply_text
if normalized_reply_text not in parts:
parts.append(normalized_reply_text)
return "\n".join(parts)
def _is_visible_jargon(self, jargon: _JargonRecord) -> bool:
"""判断当前会话是否可见该黑话。"""
if global_config.expression.all_global_jargon or jargon.is_global:
return True
try:
session_id_dict = json.loads(jargon.session_id_dict or "{}")
except (TypeError, json.JSONDecodeError):
logger.warning(f"Failed to parse jargon.session_id_dict: jargon_id={jargon.jargon_id}")
return False
return self._session_id in session_id_dict
@staticmethod
def _get_jargon_match_position(content: str, corpus: str) -> Optional[int]:
"""返回 jargon 在上下文中的首次命中位置,未命中时返回 `None`。"""
if re.search(r"[\u4e00-\u9fff]", content):
match = re.search(re.escape(content), corpus, flags=re.IGNORECASE)
if match is None:
return None
return match.start()
pattern = rf"\b{re.escape(content)}\b"
match = re.search(pattern, corpus, flags=re.IGNORECASE)
if match is None:
return None
return match.start()