Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -12,7 +12,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
build_anonymous_messages,
|
||||
)
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
filter_message_content,
|
||||
is_bot_message,
|
||||
@@ -42,8 +42,8 @@ class ExpressionLearner:
|
||||
)
|
||||
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
|
||||
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.bw_learner.learner_utils import weighted_sample
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
@@ -50,8 +50,9 @@ class ExpressionSelector:
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
# 统一通过 chat_manager 生成 stream_id,避免各处自行实现哈希逻辑
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
return SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -127,8 +128,7 @@ class ExpressionSelector:
|
||||
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
|
||||
# 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程)
|
||||
fallback_num = min(3, max_num) if max_num > 0 else 3
|
||||
fallback_selected = self._random_expressions(chat_id, fallback_num)
|
||||
if fallback_selected:
|
||||
if fallback_selected := self._random_expressions(chat_id, fallback_num):
|
||||
self.update_expressions_last_active_time(fallback_selected)
|
||||
selected_ids = [expr["id"] for expr in fallback_selected]
|
||||
logger.info(
|
||||
@@ -199,12 +199,7 @@ class ExpressionSelector:
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
if style_exprs:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
return selected_style
|
||||
return weighted_sample(style_exprs, total_num) if style_exprs else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
parse_chat_id_list,
|
||||
@@ -99,9 +99,9 @@ class JargonMiner:
|
||||
)
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
chat_manager = _chat_manager
|
||||
stream_name = chat_manager.get_session_name(self.chat_id)
|
||||
self.stream_name = stream_name or self.chat_id
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
@@ -18,8 +18,8 @@ class MessageRecorder:
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
|
||||
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次提取时间
|
||||
self.last_extraction_time: float = time.time()
|
||||
|
||||
@@ -5,20 +5,19 @@ from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger("reflect_tracker")
|
||||
|
||||
|
||||
class ReflectTracker:
|
||||
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
||||
def __init__(self, chat_stream: BotChatSession, expression: Expression, created_time: float):
|
||||
self.chat_stream = chat_stream
|
||||
self.expression = expression
|
||||
self.created_time = created_time
|
||||
@@ -42,7 +41,7 @@ class ReflectTracker:
|
||||
|
||||
# Fetch messages since creation
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
timestamp_start=self.created_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
@@ -90,10 +89,7 @@ class ReflectTracker:
|
||||
from json_repair import repair_json
|
||||
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if not matches:
|
||||
# Try to parse raw response if no code block
|
||||
matches = [response]
|
||||
matches = re.findall(json_pattern, response, re.DOTALL) or [response]
|
||||
|
||||
json_obj = json.loads(repair_json(matches[0]))
|
||||
|
||||
@@ -122,10 +118,7 @@ class ReflectTracker:
|
||||
self.expression.style = corrected_style
|
||||
|
||||
# 如果拒绝但未更新,标记为 rejected=1
|
||||
if not has_update:
|
||||
self.expression.rejected = True
|
||||
else:
|
||||
self.expression.rejected = False
|
||||
self.expression.rejected = not has_update
|
||||
|
||||
self.expression.save()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user