实现追踪检查逻辑,添加消息数量和LLM判断,更新表达并持久化到数据库

This commit is contained in:
晴猫
2026-03-03 02:18:04 +09:00
parent c6545afa2e
commit efaff7ac60

View File

@@ -1,16 +1,23 @@
import json
import re
import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import time from json_repair import repair_json
from src.common.logger import get_logger from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
)
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.expression_data_model import MaiExpression from src.common.data_models.expression_data_model import MaiExpression
# TODO: 这个LLMRequest实例被更优雅的方式替换掉
judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker") judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker")
logger = get_logger("reflect_tracker") logger = get_logger("reflect_tracker")
@@ -57,4 +64,119 @@ class ReflectTracker:
self._reset_tracker() self._reset_tracker()
return True return True
# TODO: 完成追踪检查逻辑 # 获取消息列表
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.session_id,
timestamp_start=self.tracking_start_time,
timestamp_end=time.time(),
)
current_msg_count = len(msg_list)
# 检查消息数量是否超限
if current_msg_count > self.max_msg_count:
logger.info(f"ReflectTracker for expr {self.expression.item_id} timed out (message count).")
self._reset_tracker()
return True
# 如果没有新消息,跳过本次检查
if current_msg_count <= self.last_check_msg_count:
return False
self.last_check_msg_count = current_msg_count
# 构建上下文
context_block = build_readable_messages(
msg_list,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=False,
)
# LLM 判断
try:
prompt_template = prompt_manager.get_prompt("reflect_judge")
prompt_template.add_context("situation", str(self.expression.situation))
prompt_template.add_context("style", str(self.expression.style))
prompt_template.add_context("context_block", context_block)
prompt = await prompt_manager.render_prompt(prompt_template)
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
response, _ = await judge_model.generate_response_async(prompt, temperature=0.1)
logger.info(f"ReflectTracker LLM Response: {response}")
# 解析 JSON 响应
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if not matches:
matches = [response]
json_obj = json.loads(repair_json(matches[0]))
judgment = json_obj.get("judgment")
if judgment == "Approve":
self._update_expression(checked=True, rejected=False, modified_by="ai")
logger.info(f"Expression {self.expression.item_id} approved by operator.")
self._reset_tracker()
return True
elif judgment == "Reject":
corrected_situation = json_obj.get("corrected_situation")
corrected_style = json_obj.get("corrected_style")
has_update = bool(corrected_situation or corrected_style)
update_kwargs = {"checked": True, "modified_by": "ai"}
if corrected_situation:
update_kwargs["situation"] = corrected_situation
if corrected_style:
update_kwargs["style"] = corrected_style
if not has_update:
update_kwargs["rejected"] = True
else:
update_kwargs["rejected"] = False
self._update_expression(**update_kwargs)
if has_update:
logger.info(
f"Expression {self.expression.item_id} rejected and updated. "
f"New situation: {corrected_situation}, New style: {corrected_style}"
)
else:
logger.info(
f"Expression {self.expression.item_id} rejected but no correction provided, marked as rejected."
)
self._reset_tracker()
return True
elif judgment == "Ignore":
logger.info(f"ReflectTracker for expr {self.expression.item_id} judged as Ignore.")
return False
except Exception as e:
logger.error(f"Error in ReflectTracker check: {e}")
return False
return False
def _update_expression(self, **kwargs):
"""更新表达并持久化到数据库"""
if not self.expression:
return
# 更新内存中的表达对象
for key, value in kwargs.items():
if hasattr(self.expression, key):
setattr(self.expression, key, value)
# 持久化到数据库
try:
with get_db_session() as session:
db_expr = self.expression.to_db_instance()
session.merge(db_expr)
session.commit()
except Exception as e:
logger.error(f"Failed to persist expression update: {e}")