feat:修复一些bug

This commit is contained in:
SengokuCola
2026-03-29 18:28:56 +08:00
parent 82bbf0fd52
commit 96844a9bf5
8 changed files with 898 additions and 444 deletions

View File

@@ -1,11 +1,16 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import random
import time
from sqlmodel import select
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
@@ -28,6 +33,23 @@ from src.maisaka.message_adapter import (
logger = get_logger("maisaka_replyer")
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
"""表达方式的轻量记录。"""
expression_id: Optional[int]
situation: str
style: str
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
@@ -182,6 +204,89 @@ class MaisakaReplyGenerator:
user_prompt = "\n\n".join(user_sections)
return f"System: {system_prompt}\n\nUser: {user_prompt}"
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
) -> MaisakaReplyContext:
"""在 replyer 内部构建表达习惯和黑话解释。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("Failed to build Maisaka reply context: session_id is missing")
return MaisakaReplyContext()
expression_habits, selected_expression_ids = self._build_expression_habits(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
return MaisakaReplyContext(
expression_habits=expression_habits,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
session_id: str,
chat_history: List[SessionMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达习惯。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records(session_id)
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"Built Maisaka expression habits: session_id={session_id} "
f"count={len(selected_ids)} ids={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
async def generate_reply_with_context(
self,
extra_info: str = "",
@@ -212,8 +317,6 @@ class MaisakaReplyGenerator:
del unknown_words
result = ReplyGenerationResult()
result.selected_expression_ids = list(selected_expression_ids or [])
if chat_history is None:
result.error_message = "chat_history is empty"
return False, result
@@ -221,8 +324,7 @@ class MaisakaReplyGenerator:
logger.info(
f"Maisaka replyer start: stream_id={stream_id} reply_reason={reply_reason!r} "
f"history_size={len(chat_history)} target_message_id="
f"{reply_message.message_id if reply_message else None} "
f"expression_count={len(result.selected_expression_ids)}"
f"{reply_message.message_id if reply_message else None}"
)
filtered_history = [
@@ -232,11 +334,52 @@ class MaisakaReplyGenerator:
and get_message_kind(message) != "perception"
and get_message_source(message) != "user_reference"
]
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=expression_habits,
logger.debug(f"Maisaka replyer: filtered_history size={len(filtered_history)}")
# Validate that express_model is properly initialized
if self.express_model is None:
logger.error("Maisaka replyer: express_model is None!")
result.error_message = "express_model is not initialized"
return False, result
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_reply_context failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_reply_context failed: {exc}"
return False, result
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka reply context built: stream_id={stream_id} "
f"selected_expression_ids={result.selected_expression_ids!r}"
)
try:
prompt = self._build_prompt(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka replyer: _build_prompt failed: {exc}\n{traceback.format_exc()}")
result.error_message = f"_build_prompt failed: {exc}"
return False, result
result.completion.request_prompt = prompt
if global_config.debug.show_replyer_prompt: