diff --git a/src/bw_learner/expression_reflect_tracker.py b/src/bw_learner/expression_reflect_tracker.py index 1ab56d3f..e4173c6a 100644 --- a/src/bw_learner/expression_reflect_tracker.py +++ b/src/bw_learner/expression_reflect_tracker.py @@ -1,16 +1,23 @@ +import json +import re +import time from typing import TYPE_CHECKING, Optional -import time +from json_repair import repair_json -from src.common.logger import get_logger +from src.chat.utils.chat_message_builder import ( + build_readable_messages, + get_raw_msg_by_timestamp_with_chat, +) from src.common.database.database import get_db_session -from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.config.config import model_config +from src.llm_models.utils_model import LLMRequest +from src.prompt.prompt_manager import prompt_manager if TYPE_CHECKING: from src.common.data_models.expression_data_model import MaiExpression -# TODO: 这个LLMRequest实例被更优雅的方式替换掉 judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker") logger = get_logger("reflect_tracker") @@ -57,4 +64,119 @@ class ReflectTracker: self._reset_tracker() return True - # TODO: 完成追踪检查逻辑 \ No newline at end of file + # 获取消息列表 + msg_list = get_raw_msg_by_timestamp_with_chat( + chat_id=self.session_id, + timestamp_start=self.tracking_start_time, + timestamp_end=time.time(), + ) + + current_msg_count = len(msg_list) + + # 检查消息数量是否超限 + if current_msg_count > self.max_msg_count: + logger.info(f"ReflectTracker for expr {self.expression.item_id} timed out (message count).") + self._reset_tracker() + return True + + # 如果没有新消息,跳过本次检查 + if current_msg_count <= self.last_check_msg_count: + return False + + self.last_check_msg_count = current_msg_count + + # 构建上下文 + context_block = build_readable_messages( + msg_list, + replace_bot_name=True, + timestamp_mode="relative", + read_mark=0.0, + show_actions=False, + ) + + # LLM 判断 + try: + prompt_template = prompt_manager.get_prompt("reflect_judge") + prompt_template.add_context("situation", str(self.expression.situation)) + prompt_template.add_context("style", str(self.expression.style)) + prompt_template.add_context("context_block", context_block) + prompt = await prompt_manager.render_prompt(prompt_template) + + logger.info(f"ReflectTracker LLM Prompt: {prompt}") + + response, _ = await judge_model.generate_response_async(prompt, temperature=0.1) + + logger.info(f"ReflectTracker LLM Response: {response}") + + # 解析 JSON 响应 + json_pattern = r"```json\s*(.*?)\s*```" + matches = re.findall(json_pattern, response, re.DOTALL) + if not matches: + matches = [response] + + json_obj = json.loads(repair_json(matches[0])) + judgment = json_obj.get("judgment") + + if judgment == "Approve": + self._update_expression(checked=True, rejected=False, modified_by="ai") + logger.info(f"Expression {self.expression.item_id} approved by operator.") + self._reset_tracker() + return True + + elif judgment == "Reject": + corrected_situation = json_obj.get("corrected_situation") + corrected_style = json_obj.get("corrected_style") + has_update = bool(corrected_situation or corrected_style) + + update_kwargs = {"checked": True, "modified_by": "ai"} + if corrected_situation: + update_kwargs["situation"] = corrected_situation + if corrected_style: + update_kwargs["style"] = corrected_style + if not has_update: + update_kwargs["rejected"] = True + else: + update_kwargs["rejected"] = False + + self._update_expression(**update_kwargs) + + if has_update: + logger.info( + f"Expression {self.expression.item_id} rejected and updated. " + f"New situation: {corrected_situation}, New style: {corrected_style}" + ) + else: + logger.info( + f"Expression {self.expression.item_id} rejected but no correction provided, marked as rejected." + ) + self._reset_tracker() + return True + + elif judgment == "Ignore": + logger.info(f"ReflectTracker for expr {self.expression.item_id} judged as Ignore.") + return False + + except Exception as e: + logger.error(f"Error in ReflectTracker check: {e}") + return False + + return False + + def _update_expression(self, **kwargs): + """更新表达并持久化到数据库""" + if not self.expression: + return + + # 更新内存中的表达对象 + for key, value in kwargs.items(): + if hasattr(self.expression, key): + setattr(self.expression, key, value) + + # 持久化到数据库 + try: + with get_db_session() as session: + db_expr = self.expression.to_db_instance() + session.merge(db_expr) + session.commit() + except Exception as e: + logger.error(f"Failed to persist expression update: {e}") \ No newline at end of file