From efaff7ac605c0587c436f4d062f2580c4676af24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=B4=E7=8C=AB?= Date: Tue, 3 Mar 2026 02:18:04 +0900 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E8=BF=BD=E8=B8=AA=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E9=80=BB=E8=BE=91=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E6=95=B0=E9=87=8F=E5=92=8CLLM=E5=88=A4=E6=96=AD?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E8=A1=A8=E8=BE=BE=E5=B9=B6=E6=8C=81?= =?UTF-8?q?=E4=B9=85=E5=8C=96=E5=88=B0=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bw_learner/expression_reflect_tracker.py | 132 ++++++++++++++++++- 1 file changed, 127 insertions(+), 5 deletions(-) diff --git a/src/bw_learner/expression_reflect_tracker.py b/src/bw_learner/expression_reflect_tracker.py index 1ab56d3f..e4173c6a 100644 --- a/src/bw_learner/expression_reflect_tracker.py +++ b/src/bw_learner/expression_reflect_tracker.py @@ -1,16 +1,23 @@ +import json +import re +import time from typing import TYPE_CHECKING, Optional -import time +from json_repair import repair_json -from src.common.logger import get_logger +from src.chat.utils.chat_message_builder import ( + build_readable_messages, + get_raw_msg_by_timestamp_with_chat, +) from src.common.database.database import get_db_session -from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.prompt.prompt_manager import prompt_manager if TYPE_CHECKING: from src.common.data_models.expression_data_model import MaiExpression -# TODO: 这个LLMRequest实例被更优雅的方式替换掉 judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker") logger = get_logger("reflect_tracker") @@ -57,4 +64,119 @@ class ReflectTracker: self._reset_tracker() return True - # TODO: 完成追踪检查逻辑 \ No newline at end of file + # 获取消息列表 + msg_list = get_raw_msg_by_timestamp_with_chat( + chat_id=self.session_id, + timestamp_start=self.tracking_start_time, + timestamp_end=time.time(), + ) + + current_msg_count = len(msg_list) + + # 检查消息数量是否超限 + if current_msg_count > self.max_msg_count: + logger.info(f"ReflectTracker for expr {self.expression.item_id} timed out (message count).") + self._reset_tracker() + return True + + # 如果没有新消息,跳过本次检查 + if current_msg_count <= self.last_check_msg_count: + return False + + self.last_check_msg_count = current_msg_count + + # 构建上下文 + context_block = build_readable_messages( + msg_list, + replace_bot_name=True, + timestamp_mode="relative", + read_mark=0.0, + show_actions=False, + ) + + # LLM 判断 + try: + prompt_template = prompt_manager.get_prompt("reflect_judge") + prompt_template.add_context("situation", str(self.expression.situation)) + prompt_template.add_context("style", str(self.expression.style)) + prompt_template.add_context("context_block", context_block) + prompt = await prompt_manager.render_prompt(prompt_template) + + logger.info(f"ReflectTracker LLM Prompt: {prompt}") + + response, _ = await judge_model.generate_response_async(prompt, temperature=0.1) + + logger.info(f"ReflectTracker LLM Response: {response}") + + # 解析 JSON 响应 + json_pattern = r"```json\s*(.*?)\s*```" + matches = re.findall(json_pattern, response, re.DOTALL) + if not matches: + matches = [response] + + json_obj = json.loads(repair_json(matches[0])) + judgment = json_obj.get("judgment") + + if judgment == "Approve": + self._update_expression(checked=True, rejected=False, modified_by="ai") + logger.info(f"Expression {self.expression.item_id} approved by operator.") + self._reset_tracker() + return True + + elif judgment == "Reject": + corrected_situation = json_obj.get("corrected_situation") + corrected_style = json_obj.get("corrected_style") + has_update = bool(corrected_situation or corrected_style) + + update_kwargs = {"checked": True, "modified_by": "ai"} + if corrected_situation: + update_kwargs["situation"] = corrected_situation + if corrected_style: + update_kwargs["style"] = corrected_style + if not has_update: + update_kwargs["rejected"] = True + else: + update_kwargs["rejected"] = False + + self._update_expression(**update_kwargs) + + if has_update: + logger.info( + f"Expression {self.expression.item_id} rejected and updated. " + f"New situation: {corrected_situation}, New style: {corrected_style}" + ) + else: + logger.info( + f"Expression {self.expression.item_id} rejected but no correction provided, marked as rejected." + ) + self._reset_tracker() + return True + + elif judgment == "Ignore": + logger.info(f"ReflectTracker for expr {self.expression.item_id} judged as Ignore.") + return False + + except Exception as e: + logger.error(f"Error in ReflectTracker check: {e}") + return False + + return False + + def _update_expression(self, **kwargs): + """更新表达并持久化到数据库""" + if not self.expression: + return + + # 更新内存中的表达对象 + for key, value in kwargs.items(): + if hasattr(self.expression, key): + setattr(self.expression, key, value) + + # 持久化到数据库 + try: + with get_db_session() as session: + db_expr = self.expression.to_db_instance() + session.merge(db_expr) + session.commit() + except Exception as e: + logger.error(f"Failed to persist expression update: {e}") \ No newline at end of file