Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -21,7 +21,7 @@ from src.plugin_system.apis import message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
logger = get_logger("chat_history_summarizer")
|
||||
@@ -100,7 +100,7 @@ class ChatHistorySummarizer:
|
||||
def _get_chat_display_name(self) -> str:
|
||||
"""获取聊天显示名称"""
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
chat_name = _chat_manager.get_session_name(self.chat_id)
|
||||
if chat_name:
|
||||
return chat_name
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
@@ -12,7 +13,7 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ThinkingQuestion
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
@@ -133,10 +134,10 @@ async def _react_agent_solve_question(
|
||||
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||
"""
|
||||
start_time = time.time()
|
||||
collected_info = initial_info if initial_info else ""
|
||||
collected_info = initial_info or ""
|
||||
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
chat_name = _chat_manager.get_session_name(chat_id) or chat_id
|
||||
except Exception:
|
||||
chat_name = chat_id
|
||||
react_log_prefix = f"[{chat_name}] "
|
||||
@@ -235,7 +236,7 @@ async def _react_agent_solve_question(
|
||||
# head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个
|
||||
if first_head_prompt is None:
|
||||
# 第一次构建,使用初始的collected_info(即initial_info)
|
||||
initial_collected_info = initial_info if initial_info else ""
|
||||
initial_collected_info = initial_info or ""
|
||||
# 根据配置选择使用哪个 prompt
|
||||
prompt_name = (
|
||||
"memory_retrieval_react_prompt_head_lpmm"
|
||||
@@ -362,7 +363,7 @@ async def _react_agent_solve_question(
|
||||
return information
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
# 如果JSON解析失败,尝试在文本中查找JSON对象
|
||||
try:
|
||||
with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError):
|
||||
# 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取)
|
||||
first_brace = text.find("{")
|
||||
if first_brace != -1:
|
||||
@@ -384,8 +385,6 @@ async def _react_agent_solve_question(
|
||||
if isinstance(data, dict) and "return_information" in data:
|
||||
information = data.get("information", "")
|
||||
return information
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@@ -679,7 +678,7 @@ async def _react_agent_solve_question(
|
||||
evaluation_prompt_template.add_context("bot_name", bot_name)
|
||||
evaluation_prompt_template.add_context("time_now", time_now)
|
||||
evaluation_prompt_template.add_context("chat_history", chat_history)
|
||||
evaluation_prompt_template.add_context("collected_info", collected_info if collected_info else "暂无信息")
|
||||
evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息")
|
||||
evaluation_prompt_template.add_context("current_iteration", str(current_iteration))
|
||||
evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations))
|
||||
evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
|
||||
@@ -800,8 +799,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
|
||||
if not records:
|
||||
return ""
|
||||
|
||||
history_lines = []
|
||||
history_lines.append("最近已查询的问题和结果:")
|
||||
history_lines = ["最近已查询的问题和结果:"]
|
||||
|
||||
for record in records:
|
||||
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
||||
@@ -813,8 +811,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
|
||||
if len(record.answer) > 100:
|
||||
answer_preview += "..."
|
||||
|
||||
history_lines.append(f"- 问题:{record.question}")
|
||||
history_lines.append(f" 状态:{status}")
|
||||
history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
|
||||
if answer_preview:
|
||||
history_lines.append(f" 答案:{answer_preview}")
|
||||
history_lines.append("") # 空行分隔
|
||||
@@ -855,12 +852,11 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
|
||||
if not records:
|
||||
return []
|
||||
|
||||
found_answers = []
|
||||
for record in records:
|
||||
if record.answer:
|
||||
found_answers.append(f"问题:{record.question}\n答案:{record.answer}")
|
||||
|
||||
return found_answers
|
||||
return [
|
||||
f"问题:{record.question}\n答案:{record.answer}"
|
||||
for record in records
|
||||
if record.answer
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
||||
@@ -892,8 +888,7 @@ def _store_thinking_back(
|
||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
if record:
|
||||
if record := session.exec(statement).first():
|
||||
record.context = context
|
||||
record.found_answer = found_answer
|
||||
record.answer = answer
|
||||
@@ -957,10 +952,7 @@ async def _process_memory_retrieval(
|
||||
if is_timeout:
|
||||
logger.info("ReAct Agent超时,不返回结果")
|
||||
|
||||
if found_answer and answer:
|
||||
return answer
|
||||
|
||||
return None
|
||||
return answer if found_answer and answer else None
|
||||
|
||||
|
||||
async def build_memory_retrieval_prompt(
|
||||
@@ -1013,8 +1005,7 @@ async def build_memory_retrieval_prompt(
|
||||
cleaned_concepts = []
|
||||
for word in unknown_words:
|
||||
if isinstance(word, str):
|
||||
cleaned = word.strip()
|
||||
if cleaned:
|
||||
if cleaned := word.strip():
|
||||
cleaned_concepts.append(cleaned)
|
||||
if cleaned_concepts:
|
||||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
|
||||
@@ -30,9 +30,8 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
|
||||
return chat_ids
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
for blacklist_item in blacklist:
|
||||
if not isinstance(blacklist_item, str):
|
||||
continue
|
||||
@@ -51,7 +50,10 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 转换为chat_id
|
||||
chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
if is_group:
|
||||
chat_id = SessionUtils.calculate_session_id(platform, group_id=str(id_str))
|
||||
else:
|
||||
chat_id = SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
if chat_id:
|
||||
chat_ids.add(chat_id)
|
||||
else:
|
||||
@@ -225,9 +227,9 @@ async def search_chat_history(
|
||||
if keyword:
|
||||
keyword_matched = False
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
keywords_list = parse_keywords_string(keyword) or (
|
||||
[keyword.strip()] if keyword.strip() else []
|
||||
)
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
Reference in New Issue
Block a user