feat:修复一些bug

This commit is contained in:
SengokuCola
2026-03-29 18:28:56 +08:00
parent 82bbf0fd52
commit 96844a9bf5
8 changed files with 898 additions and 444 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Optional
import json
import sqlite3
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, delete
ROOT_PATH = Path(__file__).resolve().parent.parent
@@ -92,11 +93,38 @@ def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[st
return {str(row["name"]) for row in rows}
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
"""获取表字段是否允许 NULL 的映射。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
"""读取整张表的数据。"""
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
def normalize_optional_text(raw_value: Any) -> Optional[str]:
"""标准化可空文本字段。"""
if raw_value is None:
return None
return str(raw_value)
def ensure_nullable_compatibility(
table_name: str,
column_name: str,
row_id: Any,
value: Any,
nullable_map: dict[str, bool],
) -> None:
"""检查待迁移值是否与目标表可空约束兼容。"""
if value is None and not nullable_map.get(column_name, True):
raise ValueError(
f"目标表 {table_name}.{column_name} 不允许 NULL但源记录 id={row_id} 的该字段为 NULL。"
)
def normalize_string_list(raw_value: Any) -> list[str]:
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
if raw_value is None:
@@ -126,14 +154,55 @@ def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
"""标准化审核来源字段。"""
if raw_value is None:
return None
value = str(raw_value).strip().lower()
if value == ModifiedBy.AI.value:
normalized_raw_value = raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text.startswith('"') and raw_text.endswith('"'):
try:
normalized_raw_value = json.loads(raw_text)
except json.JSONDecodeError:
normalized_raw_value = raw_text
else:
normalized_raw_value = raw_text
value = str(normalized_raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
return ModifiedBy.AI
if value == ModifiedBy.USER.value:
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
return ModifiedBy.USER
return None
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
"""解析可空布尔值,兼容整数和字符串。"""
if raw_value is None:
return None
if isinstance(raw_value, bool):
return raw_value
if isinstance(raw_value, int):
return bool(raw_value)
if isinstance(raw_value, float):
return bool(int(raw_value))
value = str(raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {"1", "true", "t", "yes", "y"}:
return True
if value in {"0", "false", "f", "no", "n"}:
return False
raise ValueError(f"无法解析布尔值:{raw_value}")
def parse_bool(raw_value: Any, default: bool = False) -> bool:
"""解析非空布尔值。"""
parsed = parse_optional_bool(raw_value)
return default if parsed is None else parsed
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
"""将旧库中的 Unix 时间戳转换为 datetime。"""
if raw_value is None or raw_value == "":
@@ -212,6 +281,10 @@ def migrate_expressions(
) -> int:
"""迁移 expression 数据。"""
migrated_count = 0
modified_by_ai_count = 0
modified_by_user_count = 0
modified_by_null_count = 0
unknown_modified_by_values: dict[str, int] = {}
for row in old_rows:
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
last_active_time = timestamp_to_datetime(
@@ -219,22 +292,72 @@ def migrate_expressions(
True,
)
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
modified_by = normalize_modified_by(raw_modified_by)
if modified_by == ModifiedBy.AI:
modified_by_ai_count += 1
elif modified_by == ModifiedBy.USER:
modified_by_user_count += 1
else:
modified_by_null_count += 1
if raw_modified_by not in (None, "", "null", "NULL", "None"):
unknown_key = str(raw_modified_by)
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
expression = Expression(
id=int(row["id"]) if row["id"] is not None else None,
situation=str(row["situation"]).strip(),
style=str(row["style"]).strip(),
content_list=json.dumps(content_list, ensure_ascii=False),
count=int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
last_active_time=last_active_time or datetime.now(),
create_time=create_time or datetime.now(),
session_id=str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
checked=bool(row["checked"]) if "checked" in expression_columns and row["checked"] is not None else False,
rejected=bool(row["rejected"]) if "rejected" in expression_columns and row["rejected"] is not None else False,
modified_by=normalize_modified_by(row["modified_by"] if "modified_by" in expression_columns else None),
target_session.execute(
text(
"""
INSERT INTO expressions (
id,
situation,
style,
content_list,
count,
last_active_time,
create_time,
session_id,
checked,
rejected,
modified_by
) VALUES (
:id,
:situation,
:style,
:content_list,
:count,
:last_active_time,
:create_time,
:session_id,
:checked,
:rejected,
:modified_by
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"situation": str(row["situation"]).strip(),
"style": str(row["style"]).strip(),
"content_list": json.dumps(content_list, ensure_ascii=False),
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
"last_active_time": last_active_time or datetime.now(),
"create_time": create_time or datetime.now(),
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
"modified_by": modified_by.name if modified_by is not None else None,
},
)
target_session.add(expression)
migrated_count += 1
print(
"Expression modified_by 迁移统计:"
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
)
if unknown_modified_by_values:
preview_items = list(unknown_modified_by_values.items())[:10]
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
return migrated_count
@@ -242,12 +365,17 @@ def migrate_jargons(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
jargon_columns: set[str],
jargon_nullable_map: dict[str, bool],
) -> int:
"""迁移 jargon 数据。"""
migrated_count = 0
coerced_meaning_null_count = 0
for row in old_rows:
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
raw_content_list = normalize_string_list(row["raw_content"] if "raw_content" in jargon_columns else None)
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
raw_content_list = normalize_string_list(raw_content_value)
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
inference_content_key = (
"inference_content_only"
if "inference_content_only" in jargon_columns
@@ -256,35 +384,81 @@ def migrate_jargons(
else None
)
jargon = Jargon(
id=int(row["id"]) if row["id"] is not None else None,
content=str(row["content"]).strip(),
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
meaning=str(row["meaning"]).strip() if "meaning" in jargon_columns and row["meaning"] is not None else "",
session_id_dict=build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
count=count,
is_jargon=bool(row["is_jargon"]) if "is_jargon" in jargon_columns and row["is_jargon"] is not None else None,
is_complete=bool(row["is_complete"]) if "is_complete" in jargon_columns and row["is_complete"] is not None else False,
is_global=bool(row["is_global"]) if "is_global" in jargon_columns and row["is_global"] is not None else False,
last_inference_count=(
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
inference_with_context=(
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
inference_with_content_only=(
str(row[inference_content_key]) if inference_content_key and row[inference_content_key] is not None else None
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
meaning_value = ""
coerced_meaning_null_count += 1
# 显式执行 SQL避免 ORM 在 None 场景下回填模型默认值。
target_session.execute(
text(
"""
INSERT INTO jargons (
id,
content,
raw_content,
meaning,
session_id_dict,
count,
is_jargon,
is_complete,
is_global,
last_inference_count,
inference_with_context,
inference_with_content_only
) VALUES (
:id,
:content,
:raw_content,
:meaning,
:session_id_dict,
:count,
:is_jargon,
:is_complete,
:is_global,
:last_inference_count,
:inference_with_context,
:inference_with_content_only
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"content": str(row["content"]).strip(),
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
"meaning": meaning_value,
"session_id_dict": build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
"count": count,
"is_jargon": is_jargon_value,
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
"last_inference_count": (
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
"inference_with_context": (
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
"inference_with_content_only": (
str(row[inference_content_key])
if inference_content_key and row[inference_content_key] is not None
else None
),
},
)
target_session.add(jargon)
migrated_count += 1
if coerced_meaning_null_count > 0:
print(
f"警告:目标表 jargons.meaning 不允许 NULL已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
)
return migrated_count
@@ -337,13 +511,19 @@ def main() -> None:
target_engine = create_target_engine(target_db_path)
SQLModel.metadata.create_all(target_engine)
target_sqlite_connection = connect_sqlite(target_db_path)
try:
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
finally:
target_sqlite_connection.close()
with Session(target_engine) as target_session:
if clear_target:
clear_target_tables(target_session)
target_session.commit()
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
target_session.commit()
print("迁移完成。")

View File

@@ -62,6 +62,24 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.496335"
},
{
"id": "know_1_1774773435.68612",
"content": "用户名为小千",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.686120"
},
{
"id": "know_1_1774773676.69252",
"content": "用户自称猫娘(二次元人设)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.692520"
}
],
"2": [
@@ -156,7 +174,17 @@
"created_at": "2026-03-29T16:13:13.481732"
}
],
"3": [],
"3": [
{
"id": "know_3_1774773676.695521",
"content": "喜欢冰淇淋",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.695521"
}
],
"4": [],
"5": [],
"6": [
@@ -384,6 +412,78 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:25:30.867535"
},
{
"id": "know_6_1774773338.849271",
"content": "熟悉《原神》等二次元游戏及网络梗文化",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.849271"
},
{
"id": "know_6_1774773371.406209",
"content": "关注高分屏字体显示效果",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.406209"
},
{
"id": "know_6_1774773401.48921",
"content": "熟悉电脑显示技术(如高分屏字体选择)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:41.489210"
},
{
"id": "know_6_1774773435.688119",
"content": "关注高分屏显示效果与字体选择(无衬线/衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.688119"
},
{
"id": "know_6_1774773608.256103",
"content": "关注屏幕字体与分辨率(无衬线/有衬线)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:40:08.256103"
},
{
"id": "know_6_1774773645.671546",
"content": "关注屏幕分辨率与字体显示效果(高分屏/无衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:40:45.671546"
},
{
"id": "know_6_1774773676.698035",
"content": "关注字体设计(无衬线体)",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.698035"
},
{
"id": "know_6_1774773740.83822",
"content": "喜欢二次元文化及 VTuber 风格内容",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:42:20.838220"
}
],
"7": [
@@ -458,6 +558,33 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:54.870238"
},
{
"id": "know_7_1774773185.194069",
"content": "使用 NapCat 框架",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:33:05.194069"
},
{
"id": "know_7_1774773338.851275",
"content": "使用 NapCat 框架,具备技术平台认知能力",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.851275"
},
{
"id": "know_7_1774773371.403696",
"content": "熟悉 NapCat 框架",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.403696"
}
],
"8": [
@@ -523,6 +650,24 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:15:54.875743"
},
{
"id": "know_8_1774773435.690121",
"content": "习惯使用表情包表达情绪或进行网络互动",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.690121"
},
{
"id": "know_8_1774773676.701034",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:41:16.701034"
}
],
"9": [],
@@ -634,6 +779,69 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:14:39.503333"
},
{
"id": "know_10_1774773338.853274",
"content": "沟通风格幽默风趣,擅长玩梗与自嘲",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:35:38.853274"
},
{
"id": "know_10_1774773371.408719",
"content": "喜欢用幽默调侃的方式回应他人",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:11.408719"
},
{
"id": "know_10_1774773401.491209",
"content": "沟通风格幽默风趣,擅长玩梗和角色扮演",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:36:41.491209"
},
{
"id": "know_10_1774773435.693121",
"content": "沟通风格幽默、喜欢玩梗和自嘲,擅长接话茬",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:37:15.693121"
},
{
"id": "know_10_1774773532.488374",
"content": "沟通风格幽默,喜欢使用网络梗和表情包活跃气氛",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:38:52.488374"
},
{
"id": "know_10_1774773532.490959",
"content": "在争论中倾向于据理力争,并自嘲或调侃对方阅读理解能力",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:38:52.490959"
},
{
"id": "know_10_1774773569.709356",
"content": "喜欢用幽默、夸张和自嘲的方式活跃气氛",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:39:29.709356"
}
],
"11": [
@@ -656,6 +864,24 @@
"source": "maisaka_learning"
},
"created_at": "2026-03-29T15:50:54.657355"
},
{
"id": "know_12_1774773185.196068",
"content": "备战中考",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:33:05.196068"
},
{
"id": "know_12_1774773740.836223",
"content": "面对压力或冲突时,倾向于通过撒娇、耍赖和寻求盟友支持来应对",
"metadata": {
"session_id": "628336b082552269377e9d0648e26c60",
"source": "maisaka_learning"
},
"created_at": "2026-03-29T16:42:20.836223"
}
]
}

View File

@@ -1,11 +1,16 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import random
import time
from sqlmodel import select
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
@@ -28,6 +33,23 @@ from src.maisaka.message_adapter import (
logger = get_logger("maisaka_replyer")
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
"""表达方式的轻量记录。"""
expression_id: Optional[int]
situation: str
style: str
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
@@ -182,6 +204,89 @@ class MaisakaReplyGenerator:
user_prompt = "\n\n".join(user_sections)
return f"System: {system_prompt}\n\nUser: {user_prompt}"
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
) -> MaisakaReplyContext:
"""在 replyer 内部构建表达习惯和黑话解释。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("Failed to build Maisaka reply context: session_id is missing")
return MaisakaReplyContext()
expression_habits, selected_expression_ids = self._build_expression_habits(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
return MaisakaReplyContext(
expression_habits=expression_habits,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
session_id: str,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达习惯。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records(session_id)
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"Built Maisaka expression habits: session_id={session_id} "
f"count={len(selected_ids)} ids={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
async def generate_reply_with_context(
self,
extra_info: str = "",
@@ -212,8 +317,6 @@ class MaisakaReplyGenerator:
del unknown_words
result = ReplyGenerationResult()
result.selected_expression_ids = list(selected_expression_ids or [])
if chat_history is None:
result.error_message = "chat_history is empty"
return False, result
@@ -221,8 +324,7 @@ class MaisakaReplyGenerator:
logger.info(
f"Maisaka replyer start: stream_id={stream_id} reply_reason={reply_reason!r} "
f"history_size={len(chat_history)} target_message_id="
f"{reply_message.message_id if reply_message else None} "
f"expression_count={len(result.selected_expression_ids)}"
f"{reply_message.message_id if reply_message else None}"
)
filtered_history = [
@@ -232,11 +334,52 @@ class MaisakaReplyGenerator:
and get_message_kind(message) != "perception"
and get_message_source(message) != "user_reference"
]
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=expression_habits,
logger.debug(f"Maisaka replyer: filtered_history size={len(filtered_history)}")
# Validate that express_model is properly initialized
if self.express_model is None:
logger.error("Maisaka replyer: express_model is None!")
result.error_message = "express_model is not initialized"
return False, result
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_reply_context failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_reply_context failed: {exc}"
return False, result
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka reply context built: stream_id={stream_id} "
f"selected_expression_ids={result.selected_expression_ids!r}"
)
try:
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_prompt failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_prompt failed: {exc}"
return False, result
result.completion.request_prompt = prompt
if global_config.debug.show_replyer_prompt:

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float, Text
from sqlmodel import Field, LargeBinary, SQLModel
@@ -17,8 +17,8 @@ class ImageType(str, Enum):
class ModifiedBy(str, Enum):
AI = "ai"
USER = "user"
AI = "AI"
USER = "USER"
class Messages(SQLModel, table=True):
@@ -223,18 +223,40 @@ class Jargon(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
content: str = Field(index=True, max_length=255) # 黑话内容
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容未处理的黑话内容为List[str]
raw_content: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 原始内容未处理的黑话内容为List[str]
meaning: str # 黑话含义
session_id_dict: str = Field(default=r"{}") # 会话ID列表格式为{"session_id": session_count, ...}
meaning: str = Field(sa_column=Column(Text, nullable=False)) # 黑话含义
session_id_dict: str = Field(
default=r"{}", sa_column=Column(Text, nullable=False)
) # 会话ID列表格式为{"session_id": session_count, ...}
count: int = Field(default=0) # 使用次数
is_jargon: Optional[bool] = Field(default=True) # 是否为黑话False表示为白话
is_complete: bool = Field(default=False) # 是否为已经完成全部推断count > 100后不再推断
is_global: bool = Field(default=False) # 是否为全局黑话独立于session_id_dict
last_inference_count: int = Field(default=0) # 上一次进行推断时的count值用于判断是否需要重新推断
inference_with_context: Optional[str] = Field(default=None, nullable=True) # 带上下文的推断结果JSON格式
inference_with_content_only: Optional[str] = Field(default=None, nullable=True) # 只基于词条的推断结果JSON格式
inference_with_context: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 带上下文的推断结果JSON格式
inference_with_content_only: Optional[str] = Field(
default=None, sa_column=Column(Text, nullable=True)
) # 只基于词条的推断结果JSON格式
class MaiKnowledge(SQLModel, table=True):
"""存储 Maisaka 的用户画像知识。"""
__tablename__ = "mai_knowledge" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True)
knowledge_id: str = Field(index=True, max_length=255)
category_id: str = Field(index=True, max_length=32)
content: str
normalized_content: str = Field(index=True)
metadata_json: Optional[str] = Field(default=None, nullable=True)
created_at: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True))
class ChatHistory(SQLModel, table=True):

View File

@@ -8,7 +8,11 @@ from typing import Any, Dict, List, Optional
import json
# 数据目录位于项目根目录下的 mai_knowledge
from sqlmodel import select
from src.common.database.database import DATABASE_URL, get_db_session
from src.common.database.database_model import MaiKnowledge
PROJECT_ROOT = Path(__file__).resolve().parents[2]
KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge"
KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json"
@@ -18,7 +22,7 @@ KNOWLEDGE_CATEGORIES = {
"1": "性别",
"2": "性格",
"3": "饮食口味",
"4": "交友",
"4": "交友",
"5": "情绪/理性倾向",
"6": "兴趣爱好",
"7": "职业/专业",
@@ -31,77 +35,128 @@ KNOWLEDGE_CATEGORIES = {
class KnowledgeStore:
"""
简单的 Maisaka 知识存储。
特性:
- 持久化到 JSON 文件
- 按分类存储用户画像类知识
- 支持基础去重
"""
"""存储 Maisaka 的用户画像知识。"""
def __init__(self) -> None:
"""初始化知识存储。"""
self._knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
self._ensure_data_dir()
self._load()
"""初始化知识存储,并在需要时迁移旧版 JSON 数据"""
self._ensure_legacy_data_dir()
self._migrate_legacy_file_if_needed()
def _ensure_data_dir(self) -> None:
"""确保数据目录存在"""
def _ensure_legacy_data_dir(self) -> None:
"""确保旧版知识目录存在,便于兼容历史数据。"""
KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True)
def _load(self) -> None:
"""从文件加载知识数据。"""
if not KNOWLEDGE_FILE.exists():
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
return
try:
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file:
loaded = json.load(file)
normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
for category_id in KNOWLEDGE_CATEGORIES:
category_items = loaded.get(category_id, [])
if isinstance(category_items, list):
normalized_knowledge[category_id] = [
item for item in category_items if isinstance(item, dict)
]
self._knowledge = normalized_knowledge
except Exception:
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
def _save(self) -> None:
"""保存知识数据到文件。"""
with open(KNOWLEDGE_FILE, "w", encoding="utf-8") as file:
json.dump(self._knowledge, file, ensure_ascii=False, indent=2)
@staticmethod
def _normalize_content(content: str) -> str:
"""标准化知识内容,便于去重。"""
return " ".join(str(content).strip().split())
@staticmethod
def _serialize_metadata(metadata: Optional[Dict[str, Any]]) -> Optional[str]:
"""将元数据序列化为 JSON 文本。"""
if not metadata:
return None
return json.dumps(metadata, ensure_ascii=False, sort_keys=True)
@staticmethod
def _deserialize_metadata(raw_text: Optional[str]) -> Dict[str, Any]:
"""将 JSON 文本反序列化为元数据字典。"""
if not raw_text:
return {}
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
@staticmethod
def _parse_created_at(raw_value: Any) -> datetime:
"""解析旧版数据中的创建时间。"""
if isinstance(raw_value, datetime):
return raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text:
try:
return datetime.fromisoformat(raw_text)
except ValueError:
pass
return datetime.now()
@classmethod
def _build_item_dict(cls, record: MaiKnowledge) -> Dict[str, Any]:
"""将数据库记录转换为兼容旧接口的字典。"""
return {
"id": record.knowledge_id,
"content": record.content,
"metadata": cls._deserialize_metadata(record.metadata_json),
"created_at": record.created_at.isoformat(),
}
def _load_legacy_knowledge_file(self) -> Dict[str, List[Dict[str, Any]]]:
"""读取旧版 JSON 知识文件。"""
if not KNOWLEDGE_FILE.exists():
return {}
try:
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file:
loaded = json.load(file)
except Exception:
return {}
if not isinstance(loaded, dict):
return {}
normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {}
for category_id in KNOWLEDGE_CATEGORIES:
category_items = loaded.get(category_id, [])
if isinstance(category_items, list):
normalized_knowledge[category_id] = [
item for item in category_items if isinstance(item, dict)
]
return normalized_knowledge
def _migrate_legacy_file_if_needed(self) -> None:
"""在数据库为空时,将旧版 JSON 中的知识导入数据库。"""
legacy_knowledge = self._load_legacy_knowledge_file()
if not legacy_knowledge:
return
with get_db_session(auto_commit=False) as session:
existing_record = session.exec(select(MaiKnowledge.id).limit(1)).first()
if existing_record is not None:
return
for category_id, items in legacy_knowledge.items():
if category_id not in KNOWLEDGE_CATEGORIES:
continue
for item in items:
content = self._normalize_content(str(item.get("content", "")))
if not content:
continue
metadata = item.get("metadata")
session.add(
MaiKnowledge(
knowledge_id=str(item.get("id") or f"know_{category_id}_{datetime.now().timestamp()}"),
category_id=category_id,
content=content,
normalized_content=content,
metadata_json=self._serialize_metadata(metadata if isinstance(metadata, dict) else None),
created_at=self._parse_created_at(item.get("created_at")),
)
)
session.commit()
def add_knowledge(
self,
category_id: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""
添加一条知识信息。
Args:
category_id: 分类编号
content: 知识内容
metadata: 附加元数据
Returns:
是否新增成功;若命中去重则返回 False
"""
"""添加一条知识信息。"""
if category_id not in KNOWLEDGE_CATEGORIES:
return False
@@ -109,29 +164,59 @@ class KnowledgeStore:
if not normalized_content:
return False
existing_items = self._knowledge.get(category_id, [])
for item in existing_items:
existing_content = self._normalize_content(str(item.get("content", "")))
if existing_content == normalized_content:
with get_db_session(auto_commit=False) as session:
existing_record = session.exec(
select(MaiKnowledge).where(
MaiKnowledge.category_id == category_id,
MaiKnowledge.normalized_content == normalized_content,
)
).first()
if existing_record is not None:
return False
knowledge_item = {
"id": f"know_{category_id}_{datetime.now().timestamp()}",
"content": normalized_content,
"metadata": metadata or {},
"created_at": datetime.now().isoformat(),
}
self._knowledge[category_id].append(knowledge_item)
self._save()
session.add(
MaiKnowledge(
knowledge_id=f"know_{category_id}_{datetime.now().timestamp()}",
category_id=category_id,
content=normalized_content,
normalized_content=normalized_content,
metadata_json=self._serialize_metadata(metadata),
created_at=datetime.now(),
)
)
session.commit()
return True
def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]:
"""获取某个分类下的所有知识。"""
return self._knowledge.get(category_id, [])
if category_id not in KNOWLEDGE_CATEGORIES:
return []
with get_db_session() as session:
records = session.exec(
select(MaiKnowledge)
.where(MaiKnowledge.category_id == category_id)
.order_by(MaiKnowledge.created_at.asc(), MaiKnowledge.id.asc())
).all()
return [self._build_item_dict(record) for record in records]
def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]:
"""获取全部知识。"""
return self._knowledge
all_knowledge: Dict[str, List[Dict[str, Any]]] = {
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
with get_db_session() as session:
records = session.exec(
select(MaiKnowledge).order_by(
MaiKnowledge.category_id.asc(),
MaiKnowledge.created_at.asc(),
MaiKnowledge.id.asc(),
)
).all()
for record in records:
all_knowledge.setdefault(record.category_id, []).append(self._build_item_dict(record))
return all_knowledge
def get_category_name(self, category_id: str) -> str:
"""获取分类名称。"""
@@ -139,24 +224,23 @@ class KnowledgeStore:
def get_categories_summary(self) -> str:
"""获取分类摘要,供模型判断是否需要检索。"""
counts: Dict[str, int] = {category_id: 0 for category_id in KNOWLEDGE_CATEGORIES}
with get_db_session() as session:
records = session.exec(select(MaiKnowledge.category_id)).all()
for category_id in records:
if category_id in counts:
counts[category_id] += 1
lines: List[str] = []
for category_id, category_name in KNOWLEDGE_CATEGORIES.items():
count = len(self._knowledge.get(category_id, []))
count = counts.get(category_id, 0)
count_text = f"{count}" if count > 0 else "无数据"
lines.append(f"{category_id}. {category_name} ({count_text})")
return "\n".join(lines)
def get_formatted_knowledge(self, category_ids: List[str], limit_per_category: int = 5) -> str:
"""
获取指定分类的格式化知识内容。
Args:
category_ids: 分类编号列表
limit_per_category: 每个分类最多返回多少条
Returns:
格式化后的知识内容
"""
"""获取指定分类的格式化知识内容。"""
parts: List[str] = []
for category_id in category_ids:
items = self.get_category_knowledge(category_id)
@@ -176,13 +260,18 @@ class KnowledgeStore:
def get_stats(self) -> Dict[str, Any]:
"""获取知识数据统计。"""
total_items = sum(len(items) for items in self._knowledge.values())
with get_db_session() as session:
total_items = len(session.exec(select(MaiKnowledge.id)).all())
return {
"total_categories": len(KNOWLEDGE_CATEGORIES),
"total_items": total_items,
"data_file": str(KNOWLEDGE_FILE),
"data_exists": KNOWLEDGE_FILE.exists(),
"data_size_kb": KNOWLEDGE_FILE.stat().st_size / 1024 if KNOWLEDGE_FILE.exists() else 0,
"data_file": DATABASE_URL,
"data_exists": True,
"data_size_kb": 0,
"legacy_data_file": str(KNOWLEDGE_FILE),
"legacy_data_exists": KNOWLEDGE_FILE.exists(),
"storage_type": "database",
}

View File

@@ -30,6 +30,7 @@ from .builtin_tools import get_builtin_tools
from .message_adapter import (
build_message,
format_speaker_content,
get_message_role,
to_llm_message,
)
@@ -303,6 +304,7 @@ class MaisakaChatLoopService:
async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
def message_factory(_client: BaseClient) -> List[Message]:
messages: List[Message] = []
@@ -310,7 +312,7 @@ class MaisakaChatLoopService:
system_msg.add_text_content(self._chat_system_prompt)
messages.append(system_msg.build())
for msg in chat_history:
for msg in selected_history:
llm_message = to_llm_message(msg)
if llm_message is not None:
messages.append(llm_message)
@@ -333,6 +335,7 @@ class MaisakaChatLoopService:
Panel(
Group(*ordered_panels),
title="MaiSaka LLM Request - chat_loop_step",
subtitle=selection_reason,
border_style="cyan",
padding=(0, 1),
)
@@ -374,6 +377,38 @@ class MaisakaChatLoopService:
raw_message=raw_message,
)
@staticmethod
def _select_llm_context_messages(chat_history: List[SessionMessage]) -> tuple[List[SessionMessage], str]:
"""选择真正发送给 LLM 的上下文消息。"""
max_context_size = max(1, int(global_config.chat.max_context_size))
counted_roles = {"user", "assistant"}
selected_indices: List[int] = []
counted_message_count = 0
for index in range(len(chat_history) - 1, -1, -1):
message = chat_history[index]
if to_llm_message(message) is None:
continue
selected_indices.append(index)
if get_message_role(message) in counted_roles:
counted_message_count += 1
if counted_message_count >= max_context_size:
break
if not selected_indices:
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant当前 0 条)"
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
return (
selected_history,
(
f"上下文判定:最近 {max_context_size} 条 user/assistant"
f"展示并发送窗口内消息 {len(selected_history)}"
),
)
@staticmethod
def build_chat_context(user_text: str) -> List[SessionMessage]:
return [

View File

@@ -14,6 +14,7 @@ from sqlmodel import select
from src.chat.heart_flow.heartFC_utils import CycleDetail
from src.chat.message_receive.message import SessionMessage
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import get_bot_account
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.data_models.mai_message_data_model import UserInfo
@@ -33,7 +34,6 @@ from .message_adapter import (
get_message_text,
get_message_role,
)
from .reply_context_builder import MaisakaReplyContextBuilder
from .tool_handlers import (
handle_mcp_tool,
handle_unknown_tool,
@@ -50,8 +50,8 @@ class MaisakaReasoningEngine:
def __init__(self, runtime: "MaisakaHeartFlowChatting") -> None:
self._runtime = runtime
self._reply_context_builder = MaisakaReplyContextBuilder(runtime.session_id)
self._last_reasoning_content: str = ""
self._shown_jargons: set[str] = set() # 已在参考消息中展示过的 jargon
async def run_loop(self) -> None:
"""独立消费消息批次,并执行对应的内部思考轮次。"""
@@ -72,11 +72,19 @@ class MaisakaReasoningEngine:
self._runtime._log_cycle_started(cycle_detail, round_index)
try:
# 每次LLM生成前动态添加参考消息到最新位置
self._append_jargon_reference_message()
reference_added = self._append_jargon_reference_message()
planner_started_at = time.time()
response = await self._runtime._chat_loop_service.chat_loop_step(self._runtime._chat_history)
cycle_detail.time_records["planner"] = time.time() - planner_started_at
# LLM调用后移除刚才添加的参考消息一次性使用
if reference_added and self._runtime._chat_history:
# 从末尾往前查找并移除参考消息
for i in range(len(self._runtime._chat_history) - 1, -1, -1):
if get_message_source(self._runtime._chat_history[i]) == "user_reference":
self._runtime._chat_history.pop(i)
break
reasoning_content = response.content or ""
if self._should_replace_reasoning(reasoning_content):
response.content = "让我根据新情况重新思考:"
@@ -218,15 +226,23 @@ class MaisakaReasoningEngine:
self._runtime._chat_history.insert(insert_at, message)
return insert_at
def _append_jargon_reference_message(self) -> None:
"""每次LLM生成前如果命中了黑话词条则添加一条参考信息消息到聊天历史末尾。"""
def _append_jargon_reference_message(self) -> bool:
"""每次LLM生成前如果命中了黑话词条则添加一条参考信息消息到聊天历史末尾。
Returns:
bool: 是否添加了参考消息
"""
content = self._build_user_history_corpus()
if not content:
return
return False
matched_words = self._find_jargon_words_in_text(content)
if not matched_words:
return
return False
# 记录已展示的 jargon
for word in matched_words:
self._shown_jargons.add(word.lower())
reference_text = (
"[参考信息]\n"
@@ -248,6 +264,7 @@ class MaisakaReasoningEngine:
display_text=reference_text,
)
self._runtime._chat_history.append(reference_message)
return True
def _build_user_history_corpus(self) -> str:
"""拼接当前聊天记录内所有用户消息的正文,用于统一匹配黑话。"""
@@ -282,9 +299,15 @@ class MaisakaReasoningEngine:
jargon_content = str(jargon.content or "").strip()
if not jargon_content:
continue
# meaning 为空的不匹配
if not str(jargon.meaning or "").strip():
continue
normalized_content = jargon_content.lower()
if normalized_content in seen_words:
continue
# 跳过已经展示过的 jargon
if normalized_content in self._shown_jargons:
continue
if not self._is_visible_jargon(jargon):
continue
match_position = self._get_jargon_match_position(jargon_content, lowered_content, content)
@@ -573,34 +596,8 @@ class MaisakaReasoningEngine:
return False
logger.info(f"{self._runtime.log_prefix} acquired Maisaka reply generator successfully")
logger.info(
f"{self._runtime.log_prefix} building reply context: "
f"target_msg_id={target_message_id} unknown_words={unknown_words!r}"
)
try:
reply_context = await self._reply_context_builder.build(
chat_history=self._runtime._chat_history,
reply_message=target_message,
reply_reason=latest_thought,
)
except Exception:
logger.exception(
f"{self._runtime.log_prefix} reply context builder crashed: "
f"target_msg_id={target_message_id}"
)
self._runtime._chat_history.append(
self._build_tool_message(tool_call, "Reply context preparation crashed.")
)
return False
logger.info(
f"{self._runtime.log_prefix} reply context built: "
f"target_msg_id={target_message_id} "
f"selected_expression_ids={reply_context.selected_expression_ids!r} "
f"has_jargon_explanation={bool(reply_context.jargon_explanation.strip())}"
)
logger.info(f"{self._runtime.log_prefix} calling generate_reply_with_context: target_msg_id={target_message_id}")
try:
success, reply_result = await replyer.generate_reply_with_context(
reply_reason=latest_thought,
@@ -609,11 +606,13 @@ class MaisakaReasoningEngine:
chat_history=self._runtime._chat_history,
unknown_words=unknown_words,
log_reply=False,
expression_habits=reply_context.expression_habits,
selected_expression_ids=reply_context.selected_expression_ids,
)
except Exception:
logger.exception(f"{self._runtime.log_prefix} reply generator crashed: target_msg_id={target_message_id}")
except Exception as exc:
import traceback
logger.error(
f"{self._runtime.log_prefix} reply generator crashed: target_msg_id={target_message_id} "
f"exc_type={type(exc).__name__} exc_msg={str(exc)}\n{traceback.format_exc()}"
)
self._runtime._chat_history.append(
self._build_tool_message(tool_call, "Visible reply generation crashed.")
)
@@ -686,18 +685,26 @@ class MaisakaReasoningEngine:
tool_reasoning=latest_thought,
)
target_platform = target_message.platform or anchor_message.platform
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
self._runtime._chat_history.append(
build_message(
role="user",
content=format_speaker_content(bot_name, reply_text, datetime.now()),
source="guided_reply",
platform=target_message.platform or anchor_message.platform,
session_id=self._runtime.session_id,
group_info=self._runtime._build_group_info(target_message),
user_info=self._runtime._build_runtime_user_info(),
)
bot_user_info = UserInfo(
user_id=get_bot_account(target_platform) or "maisaka_assistant",
user_nickname=bot_name,
user_cardname=None,
)
history_message = build_message(
role="assistant",
content=reply_text,
source="guided_reply",
platform=target_platform,
session_id=self._runtime.session_id,
group_info=self._runtime._build_group_info(target_message),
user_info=bot_user_info,
)
structured_visible_text = f"{self._build_planner_user_prefix(history_message)}{reply_text}"
history_message.display_message = structured_visible_text
history_message.processed_plain_text = structured_visible_text
self._runtime._chat_history.append(history_message)
return True
async def _handle_send_emoji(self, tool_call: ToolCall, anchor_message: SessionMessage) -> None:

View File

@@ -1,248 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional
import json
import re
from sqlmodel import select
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression, Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from .message_adapter import get_message_role, get_message_source, get_message_text, parse_speaker_content
logger = get_logger("maisaka_reply_context")
@dataclass
class ReplyContextBuildResult:
"""Reply 前置上下文构建结果。"""
expression_habits: str = ""
jargon_explanation: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
expression_id: Optional[int]
situation: str
style: str
@dataclass
class _JargonRecord:
jargon_id: Optional[int]
content: str
count: int
meaning: str
session_id_dict: str
is_global: bool
class MaisakaReplyContextBuilder:
"""为 Maisaka reply 构建表达方式和黑话解释。"""
def __init__(self, session_id: str) -> None:
self._session_id = session_id
async def build(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> ReplyContextBuildResult:
"""构建 reply 前置上下文。"""
expression_habits, selected_expression_ids = self._build_expression_habits(
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
jargon_explanation = self._build_jargon_explanation(
chat_history=chat_history,
reply_message=reply_message,
)
return ReplyContextBuildResult(
expression_habits=expression_habits,
jargon_explanation=jargon_explanation,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达方式。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records()
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"Built Maisaka expression habits: session_id={self._session_id} "
f"count={len(selected_ids)} ids={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self) -> List[_ExpressionRecord]:
"""在 session 内提取表达方式的静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == self._session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
def _build_jargon_explanation(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""查询并格式化黑话解释。"""
if not global_config.expression.enable_jargon_explanation:
return ""
return self._build_context_jargon_explanation(chat_history, reply_message)
def _build_context_jargon_explanation(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""基于当前上下文自动匹配黑话。"""
corpus = self._build_context_corpus(chat_history, reply_message)
if not corpus:
return ""
jargon_records = self._load_jargon_records()
matched_records: List[tuple[int, int, int, _JargonRecord]] = []
seen_contents: set[str] = set()
for jargon in jargon_records:
if not jargon.content or not jargon.meaning:
continue
normalized_content = jargon.content.lower()
if normalized_content in seen_contents:
continue
if not self._is_visible_jargon(jargon):
continue
match_position = self._get_jargon_match_position(jargon.content, corpus)
if match_position is None:
continue
seen_contents.add(normalized_content)
matched_records.append((match_position, -len(jargon.content), -jargon.count, jargon))
matched_records.sort()
lines = [f"- {jargon.content}: {jargon.meaning}" for _, _, _, jargon in matched_records[:8]]
if not lines:
return ""
logger.info(
f"Built Maisaka jargon explanation: session_id={self._session_id} "
f"count={len(lines)}"
)
return "【黑话解释】\n" + "\n".join(lines)
def _load_jargon_records(self) -> List[_JargonRecord]:
"""在 session 内提取黑话的静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Jargon).where(Jargon.is_jargon.is_(True), Jargon.meaning != "") # type: ignore[attr-defined]
query = query.order_by(Jargon.count.desc()) # type: ignore[attr-defined]
jargons = session.exec(query).all()
return [
_JargonRecord(
jargon_id=jargon.id,
content=(jargon.content or "").strip(),
count=int(jargon.count or 0),
meaning=(jargon.meaning or "").strip(),
session_id_dict=jargon.session_id_dict or "{}",
is_global=bool(jargon.is_global),
)
for jargon in jargons
]
def _build_context_corpus(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
) -> str:
"""将当前聊天记录内所有用户消息拼成待匹配文本。"""
parts: List[str] = []
for message in chat_history:
if get_message_role(message) != "user":
continue
if get_message_source(message) != "user":
continue
text = get_message_text(message).strip()
if not text:
continue
_, body = parse_speaker_content(text)
parts.append(body.strip() or text)
if reply_message is not None and get_message_source(reply_message) == "user":
reply_text = get_message_text(reply_message).strip()
if reply_text:
_, body = parse_speaker_content(reply_text)
normalized_reply_text = body.strip() or reply_text
if normalized_reply_text not in parts:
parts.append(normalized_reply_text)
return "\n".join(parts)
def _is_visible_jargon(self, jargon: _JargonRecord) -> bool:
"""判断当前会话是否可见该黑话。"""
if global_config.expression.all_global_jargon or jargon.is_global:
return True
try:
session_id_dict = json.loads(jargon.session_id_dict or "{}")
except (TypeError, json.JSONDecodeError):
logger.warning(f"Failed to parse jargon.session_id_dict: jargon_id={jargon.jargon_id}")
return False
return self._session_id in session_id_dict
@staticmethod
def _get_jargon_match_position(content: str, corpus: str) -> Optional[int]:
"""返回 jargon 在上下文中的首次命中位置,未命中时返回 `None`。"""
if re.search(r"[\u4e00-\u9fff]", content):
match = re.search(re.escape(content), corpus, flags=re.IGNORECASE)
if match is None:
return None
return match.start()
pattern = rf"\b{re.escape(content)}\b"
match = re.search(pattern, corpus, flags=re.IGNORECASE)
if match is None:
return None
return match.start()