Refactor chat stream handling to use BotChatSession

- Updated imports and references from ChatStream to BotChatSession across multiple files.
- Adjusted method signatures and internal logic to accommodate the new session management.
- Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
DrSmoothl
2026-03-07 00:57:33 +08:00
parent 8712fc0d05
commit 2e3dd44ee9
43 changed files with 706 additions and 563 deletions

View File

@@ -12,7 +12,7 @@ from src.chat.utils.chat_message_builder import (
build_anonymous_messages,
)
from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
@@ -42,8 +42,8 @@ class ExpressionLearner:
)
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import weighted_sample
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
@@ -50,8 +50,9 @@ class ExpressionSelector:
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
# 统一通过 chat_manager 生成 stream_id避免各处自行实现哈希逻辑
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
return SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
)
except Exception:
return None
@@ -127,8 +128,7 @@ class ExpressionSelector:
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
# 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程
fallback_num = min(3, max_num) if max_num > 0 else 3
fallback_selected = self._random_expressions(chat_id, fallback_num)
if fallback_selected:
if fallback_selected := self._random_expressions(chat_id, fallback_num):
self.update_expressions_last_active_time(fallback_selected)
selected_ids = [expr["id"] for expr in fallback_selected]
logger.info(
@@ -199,12 +199,7 @@ class ExpressionSelector:
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
return selected_style
return weighted_sample(style_exprs, total_num) if style_exprs else []
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import (
parse_chat_id_list,
@@ -99,9 +99,9 @@ class JargonMiner:
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
self.stream_name = stream_name if stream_name else self.chat_id
chat_manager = _chat_manager
stream_name = chat_manager.get_session_name(self.chat_id)
self.stream_name = stream_name or self.chat_id
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()

View File

@@ -2,7 +2,7 @@ import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression
from src.bw_learner.expression_learner import expression_learner_manager
@@ -18,8 +18,8 @@ class MessageRecorder:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()

View File

@@ -5,20 +5,19 @@ from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
from src.config.config import model_config
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
if TYPE_CHECKING:
pass
logger = get_logger("reflect_tracker")
class ReflectTracker:
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
def __init__(self, chat_stream: BotChatSession, expression: Expression, created_time: float):
self.chat_stream = chat_stream
self.expression = expression
self.created_time = created_time
@@ -42,7 +41,7 @@ class ReflectTracker:
# Fetch messages since creation
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
timestamp_start=self.created_time,
timestamp_end=time.time(),
)
@@ -90,10 +89,7 @@ class ReflectTracker:
from json_repair import repair_json
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if not matches:
# Try to parse raw response if no code block
matches = [response]
matches = re.findall(json_pattern, response, re.DOTALL) or [response]
json_obj = json.loads(repair_json(matches[0]))
@@ -122,10 +118,7 @@ class ReflectTracker:
self.expression.style = corrected_style
# 如果拒绝但未更新,标记为 rejected=1
if not has_update:
self.expression.rejected = True
else:
self.expression.rejected = False
self.expression.rejected = not has_update
self.expression.save()