feat:修复一些bug
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
@@ -28,6 +33,23 @@ from src.maisaka.message_adapter import (
|
||||
logger = get_logger("maisaka_replyer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExpressionRecord:
|
||||
"""表达方式的轻量记录。"""
|
||||
|
||||
expression_id: Optional[int]
|
||||
situation: str
|
||||
style: str
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
@@ -182,6 +204,89 @@ class MaisakaReplyGenerator:
|
||||
user_prompt = "\n\n".join(user_sections)
|
||||
return f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[SessionMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
) -> MaisakaReplyContext:
|
||||
"""在 replyer 内部构建表达习惯和黑话解释。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("Failed to build Maisaka reply context: session_id is missing")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
expression_habits, selected_expression_ids = self._build_expression_habits(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=expression_habits,
|
||||
selected_expression_ids=selected_expression_ids,
|
||||
)
|
||||
|
||||
def _build_expression_habits(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_history: List[SessionMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> tuple[str, List[int]]:
|
||||
"""查询并格式化适合当前会话的表达习惯。"""
|
||||
del chat_history
|
||||
del reply_message
|
||||
del reply_reason
|
||||
|
||||
expression_records = self._load_expression_records(session_id)
|
||||
if not expression_records:
|
||||
return "", []
|
||||
|
||||
lines: List[str] = []
|
||||
selected_ids: List[int] = []
|
||||
for expression in expression_records:
|
||||
if expression.expression_id is not None:
|
||||
selected_ids.append(expression.expression_id)
|
||||
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
|
||||
|
||||
block = "【表达习惯参考】\n" + "\n".join(lines)
|
||||
logger.info(
|
||||
f"Built Maisaka expression habits: session_id={session_id} "
|
||||
f"count={len(selected_ids)} ids={selected_ids!r}"
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
)
|
||||
for expression in expressions
|
||||
]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
@@ -212,8 +317,6 @@ class MaisakaReplyGenerator:
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
result.selected_expression_ids = list(selected_expression_ids or [])
|
||||
|
||||
if chat_history is None:
|
||||
result.error_message = "chat_history is empty"
|
||||
return False, result
|
||||
@@ -221,8 +324,7 @@ class MaisakaReplyGenerator:
|
||||
logger.info(
|
||||
f"Maisaka replyer start: stream_id={stream_id} reply_reason={reply_reason!r} "
|
||||
f"history_size={len(chat_history)} target_message_id="
|
||||
f"{reply_message.message_id if reply_message else None} "
|
||||
f"expression_count={len(result.selected_expression_ids)}"
|
||||
f"{reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
@@ -232,11 +334,52 @@ class MaisakaReplyGenerator:
|
||||
and get_message_kind(message) != "perception"
|
||||
and get_message_source(message) != "user_reference"
|
||||
]
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=expression_habits,
|
||||
|
||||
logger.debug(f"Maisaka replyer: filtered_history size={len(filtered_history)}")
|
||||
|
||||
# Validate that express_model is properly initialized
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka replyer: express_model is None!")
|
||||
result.error_message = "express_model is not initialized"
|
||||
return False, result
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka replyer: _build_reply_context failed: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"_build_reply_context failed: {exc}"
|
||||
return False, result
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka reply context built: stream_id={stream_id} "
|
||||
f"selected_expression_ids={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka replyer: _build_prompt failed: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"_build_prompt failed: {exc}"
|
||||
return False, result
|
||||
|
||||
result.completion.request_prompt = prompt
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
|
||||
Reference in New Issue
Block a user