fix:使用chat_manager而不是计算cha_id
This commit is contained in:
@@ -22,7 +22,7 @@ sys.path.insert(0, project_root)
|
|||||||
# Import after setting up path (required for project imports)
|
# Import after setting up path (required for project imports)
|
||||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||||
from src.config.config import global_config # noqa: E402
|
from src.config.config import global_config # noqa: E402
|
||||||
import hashlib # noqa: E402
|
from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
class TeeOutput:
|
class TeeOutput:
|
||||||
@@ -57,7 +57,7 @@ class TeeOutput:
|
|||||||
|
|
||||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
解析'platform:id:type'为chat_id(与ExpressionSelector中的逻辑一致)
|
解析'platform:id:type'为chat_id,直接复用 ChatManager 的逻辑
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
parts = stream_config_str.split(":")
|
parts = stream_config_str.split(":")
|
||||||
@@ -67,12 +67,7 @@ def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
|||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
is_group = stream_type == "group"
|
is_group = stream_type == "group"
|
||||||
if is_group:
|
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||||
components = [platform, str(id_str)]
|
|
||||||
else:
|
|
||||||
components = [platform, str(id_str), "private"]
|
|
||||||
key = "_".join(components)
|
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
@@ -11,6 +10,7 @@ from src.common.logger import get_logger
|
|||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.bw_learner.learner_utils import weighted_sample
|
from src.bw_learner.learner_utils import weighted_sample
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
"""解析'platform:id:type'为chat_id,直接使用 ChatManager 提供的接口"""
|
||||||
try:
|
try:
|
||||||
parts = stream_config_str.split(":")
|
parts = stream_config_str.split(":")
|
||||||
if len(parts) != 3:
|
if len(parts) != 3:
|
||||||
@@ -76,12 +76,8 @@ class ExpressionSelector:
|
|||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
is_group = stream_type == "group"
|
is_group = stream_type == "group"
|
||||||
if is_group:
|
# 统一通过 chat_manager 生成 stream_id,避免各处自行实现哈希逻辑
|
||||||
components = [platform, str(id_str)]
|
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||||
else:
|
|
||||||
components = [platform, str(id_str), "private"]
|
|
||||||
key = "_".join(components)
|
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -554,16 +554,10 @@ class PrivateReplyer:
|
|||||||
# 判断是否为群聊
|
# 判断是否为群聊
|
||||||
is_group = stream_type == "group"
|
is_group = stream_type == "group"
|
||||||
|
|
||||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||||
import hashlib
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
if is_group:
|
|
||||||
components = [platform, str(id_str)]
|
|
||||||
else:
|
|
||||||
components = [platform, str(id_str), "private"]
|
|
||||||
key = "_".join(components)
|
|
||||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||||
return chat_id, prompt_content
|
return chat_id, prompt_content
|
||||||
|
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
|
|||||||
Reference in New Issue
Block a user