fix:优化聊天流信息的展示和检索,优化chat_prompt无效的问题,优化部分群定义问题
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import time
|
||||
|
||||
@@ -18,16 +18,8 @@ class ExpressionConfigUtils:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue
|
||||
stream_id = ExpressionConfigUtils._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
(config_item.rule_type == "group"),
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id != session_id:
|
||||
continue
|
||||
return config_item
|
||||
if ChatConfigUtils.target_matches_session(config_item, session_id):
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
@@ -84,6 +76,180 @@ class ExpressionConfigUtils:
|
||||
|
||||
|
||||
class ChatConfigUtils:
|
||||
@staticmethod
|
||||
def _iter_matching_chat_prompts(session_id: str, is_group_chat: Optional[bool]) -> Iterator[str]:
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_stream = chat_manager.get_session_by_session_id(session_id)
|
||||
session_utils = SessionUtils
|
||||
except Exception as e:
|
||||
logger.debug(f"解析额外 Prompt 聊天流失败: session_id={session_id} error={e}")
|
||||
chat_stream = None
|
||||
session_utils = None
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
continue
|
||||
|
||||
if chat_stream is not None:
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
if chat_stream_platform == platform and chat_stream_target_id == item_id:
|
||||
yield prompt_content
|
||||
continue
|
||||
|
||||
if session_utils is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
if rule_type == "group":
|
||||
config_chat_id = session_utils.calculate_session_id(platform, group_id=item_id)
|
||||
else:
|
||||
config_chat_id = session_utils.calculate_session_id(platform, user_id=item_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"生成额外 Prompt 聊天流 ID 失败: platform={platform} item_id={item_id} error={e}")
|
||||
continue
|
||||
|
||||
if config_chat_id == session_id:
|
||||
yield prompt_content
|
||||
|
||||
@staticmethod
|
||||
def get_chat_prompt_for_chat(session_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外 Prompt,允许同一聊天流配置多条。"""
|
||||
if not session_id or not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
prompt_contents = list(ChatConfigUtils._iter_matching_chat_prompts(session_id, is_group_chat))
|
||||
if not prompt_contents:
|
||||
return ""
|
||||
|
||||
logger.debug(f"匹配到 {len(prompt_contents)} 条聊天额外 Prompt: session_id={session_id}")
|
||||
return "\n".join(prompt_contents)
|
||||
|
||||
@staticmethod
|
||||
def _target_values(target_item) -> tuple[str, str, str]:
|
||||
platform = str(target_item.platform or "").strip()
|
||||
item_id = str(target_item.item_id or "").strip()
|
||||
rule_type = str(target_item.rule_type or "").strip()
|
||||
return platform, item_id, rule_type
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_stream(session_id: str):
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
return chat_manager.get_session_by_session_id(session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取聊天流失败: session_id={session_id} error={e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
if is_group:
|
||||
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def target_matches_session(target_item, session_id: str, is_group_chat: Optional[bool] = None) -> bool:
|
||||
"""判断 platform/item_id/rule_type 配置目标是否命中当前聊天流。"""
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
|
||||
if not platform or not item_id:
|
||||
return False
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
return False
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
return False
|
||||
|
||||
chat_stream = ChatConfigUtils._get_chat_stream(session_id)
|
||||
if chat_stream is not None:
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
return chat_stream_platform == platform and chat_stream_target_id == item_id
|
||||
|
||||
return ChatConfigUtils._get_stream_id(platform, item_id, config_is_group) == session_id
|
||||
|
||||
@staticmethod
|
||||
def get_target_session_ids(target_item) -> set[str]:
|
||||
"""获取配置目标对应的已知聊天流 ID,并保留无路由 ID 作为兼容回退。"""
|
||||
platform, item_id, rule_type = ChatConfigUtils._target_values(target_item)
|
||||
if not platform or not item_id:
|
||||
return set()
|
||||
|
||||
if rule_type == "group":
|
||||
is_group = True
|
||||
target_attr = "group_id"
|
||||
elif rule_type == "private":
|
||||
is_group = False
|
||||
target_attr = "user_id"
|
||||
else:
|
||||
return set()
|
||||
|
||||
session_ids: set[str] = set()
|
||||
if fallback_session_id := ChatConfigUtils._get_stream_id(platform, item_id, is_group):
|
||||
session_ids.add(fallback_session_id)
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
for session_id, chat_stream in chat_manager.sessions.items():
|
||||
chat_stream_platform = str(chat_stream.platform or "").strip()
|
||||
chat_stream_target_id = str(getattr(chat_stream, target_attr) or "").strip()
|
||||
if chat_stream_platform == platform and chat_stream_target_id == item_id:
|
||||
session_ids.add(session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"解析配置目标已知聊天流失败: platform={platform} item_id={item_id} error={e}")
|
||||
|
||||
return session_ids
|
||||
|
||||
@staticmethod
|
||||
def _resolve_is_group_chat(session_id: Optional[str]) -> Optional[bool]:
|
||||
if not session_id:
|
||||
@@ -117,16 +283,10 @@ class ChatConfigUtils:
|
||||
|
||||
# 优先匹配会话相关的规则
|
||||
if session_id:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
for rule in global_config.chat.talk_value_rules:
|
||||
if not rule.platform and not rule.item_id:
|
||||
continue # 一起留空表示全局
|
||||
if rule.rule_type == "group":
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, group_id=str(rule.item_id))
|
||||
else:
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
||||
if rule_session_id != session_id:
|
||||
if not ChatConfigUtils.target_matches_session(rule, session_id, is_group_chat):
|
||||
continue # 不匹配的会话 ID,跳过
|
||||
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||
if not parsed_range:
|
||||
|
||||
Reference in New Issue
Block a user