HFC基本重构框架和TODO
This commit is contained in:
committed by
SengokuCola
parent
045bd5e183
commit
46cb0278d7
@@ -15,7 +15,7 @@ from src.chat.utils.timer_calculator import Timer
|
|||||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
from src.chat.heart_flow.hfc_utils_old import CycleDetail
|
||||||
from src.bw_learner.expression_learner_old import expression_learner_manager
|
from src.bw_learner.expression_learner_old import expression_learner_manager
|
||||||
from src.bw_learner.message_recorder_old import extract_and_distribute_messages
|
from src.bw_learner.message_recorder_old import extract_and_distribute_messages
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING, Tuple, Dict
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -7,7 +7,7 @@ import traceback
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_config import ExpressionConfigUtils, ChatConfigUtils
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.config.file_watcher import FileChange
|
from src.config.file_watcher import FileChange
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
@@ -15,6 +15,8 @@ from src.bw_learner.expression_reflector import ExpressionReflector
|
|||||||
from src.bw_learner.expression_learner import ExpressionLearner
|
from src.bw_learner.expression_learner import ExpressionLearner
|
||||||
from src.bw_learner.jargon_miner import JargonMiner
|
from src.bw_learner.jargon_miner import JargonMiner
|
||||||
|
|
||||||
|
from .heartFC_utils import CycleDetail, CycleActionInfo, CyclePlanInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.message import SessionMessage
|
from src.chat.message_receive.message import SessionMessage
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ class HeartFChatting:
|
|||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||||
self.log_prefix = f"[{session_name}]"
|
self.log_prefix = f"[{session_name}]"
|
||||||
|
self.session_name = session_name
|
||||||
|
|
||||||
# 系统运行状态
|
# 系统运行状态
|
||||||
self._running: bool = False
|
self._running: bool = False
|
||||||
@@ -57,12 +60,23 @@ class HeartFChatting:
|
|||||||
self._cycle_event = asyncio.Event()
|
self._cycle_event = asyncio.Event()
|
||||||
|
|
||||||
# 表达方式相关内容
|
# 表达方式相关内容
|
||||||
|
self._min_messages_for_extraction = 30 # 最少提取消息数
|
||||||
|
self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
|
||||||
|
self._last_extraction_time: float = 0.0 # 上次提取的时间戳
|
||||||
|
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||||
|
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||||
|
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||||
|
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||||
# 反思器
|
# 反思器
|
||||||
self._reflector: Optional[ExpressionReflector] = None
|
self._reflector: ExpressionReflector = ExpressionReflector(session_id)
|
||||||
# 表达学习器
|
# 表达学习器
|
||||||
self._expression_learner: Optional[ExpressionLearner] = None
|
self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
|
||||||
# 黑话挖掘器
|
# 黑话挖掘器
|
||||||
self._jargon_miner: Optional[JargonMiner] = None
|
self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
|
||||||
|
|
||||||
|
# TODO: ChatSummarizer 聊天总结器重构
|
||||||
|
|
||||||
|
# ====== 公开方法 ======
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动 HeartFChatting 的主循环"""
|
"""启动 HeartFChatting 的主循环"""
|
||||||
@@ -149,10 +163,18 @@ class HeartFChatting:
|
|||||||
await self.stop() # 确保状态正确
|
await self.stop() # 确保状态正确
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
await self.start() # 尝试重新启动
|
await self.start() # 尝试重新启动
|
||||||
|
|
||||||
async def _config_callback(self, file_change: FileChange):
|
|
||||||
|
|
||||||
|
|
||||||
|
async def _config_callback(self, file_change: Optional[FileChange] = None):
|
||||||
|
"""配置文件变更回调函数"""
|
||||||
|
# TODO: 根据配置文件变动重新计算相关参数:
|
||||||
|
"""
|
||||||
|
需要计算的参数:
|
||||||
|
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||||
|
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||||
|
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ====== 心流聊天核心逻辑 ======
|
||||||
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
|
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||||
"""心流聊天的主循环逻辑"""
|
"""心流聊天的主循环逻辑"""
|
||||||
if self._consecutive_no_reply_count >= 5:
|
if self._consecutive_no_reply_count >= 5:
|
||||||
@@ -166,7 +188,9 @@ class HeartFChatting:
|
|||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(0.2)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
talk_value_threshold = random.random() * self._get_talk_value(self.session_id) * self._talk_frequency_adjust
|
talk_value_threshold = (
|
||||||
|
random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||||
|
)
|
||||||
if mentioned_message and global_config.chat.mentioned_bot_reply:
|
if mentioned_message and global_config.chat.mentioned_bot_reply:
|
||||||
await self._judge_and_response(mentioned_message)
|
await self._judge_and_response(mentioned_message)
|
||||||
elif random.random() < talk_value_threshold:
|
elif random.random() < talk_value_threshold:
|
||||||
@@ -175,14 +199,23 @@ class HeartFChatting:
|
|||||||
|
|
||||||
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
|
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||||
"""判定和生成回复"""
|
"""判定和生成回复"""
|
||||||
if self._reflector:
|
await self._trigger_reflector()
|
||||||
await self._reflector.check_and_ask()
|
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
|
||||||
if self._reflector.reflect_tracker.tracking and await self._reflector.reflect_tracker.trigger_tracker():
|
|
||||||
logger.info(f"{self.log_prefix} 追踪检查已解决,结束追踪器")
|
|
||||||
self._reflector.reflect_tracker.reset_tracker() # 结束当前追踪器
|
|
||||||
|
|
||||||
# TODO: 完成反思器之后的逻辑
|
# TODO: 完成反思器之后的逻辑
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
current_cycle_detail = self._start_cycle()
|
||||||
|
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||||
|
|
||||||
|
# TODO: 动作检查逻辑
|
||||||
|
# TODO: Planner逻辑
|
||||||
|
# TODO: 动作执行逻辑
|
||||||
|
|
||||||
|
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||||
|
if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||||
|
return True
|
||||||
|
|
||||||
def _handle_loop_completion(self, task: asyncio.Task):
|
def _handle_loop_completion(self, task: asyncio.Task):
|
||||||
"""当 _hfc_func 任务完成时执行的回调。"""
|
"""当 _hfc_func 任务完成时执行的回调。"""
|
||||||
@@ -195,59 +228,72 @@ class HeartFChatting:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||||
|
|
||||||
def _get_talk_value(self, session_id: Optional[str]) -> float:
|
# ====== 反思器和学习器触发逻辑 ======
|
||||||
result = global_config.chat.talk_value or 0.0
|
async def _trigger_reflector(self):
|
||||||
if not global_config.chat.enable_talk_value_rules or not global_config.chat.talk_value_rules:
|
await self._reflector.check_and_ask()
|
||||||
return result
|
if self._reflector.reflect_tracker.tracking and await self._reflector.reflect_tracker.trigger_tracker():
|
||||||
local_time = time.localtime()
|
logger.info(f"{self.log_prefix} 追踪检查已解决,结束追踪器")
|
||||||
now_min = local_time.tm_hour * 60 + local_time.tm_min
|
self._reflector.reflect_tracker.reset_tracker() # 结束当前追踪器
|
||||||
|
|
||||||
# 优先匹配会话相关的规则
|
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
|
||||||
if session_id:
|
self._expression_learner.add_messages(messages)
|
||||||
for rule in global_config.chat.talk_value_rules:
|
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||||
if not rule.platform and not rule.item_id:
|
return
|
||||||
continue # 一起留空表示全局
|
if self._expression_learner.get_cache_size() < self._min_messages_for_extraction:
|
||||||
if rule.rule_type == "group":
|
return
|
||||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, group_id=str(rule.item_id))
|
extraction_end_time = time.time()
|
||||||
else:
|
logger.info(
|
||||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
|
||||||
if rule_session_id != session_id:
|
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
|
||||||
continue # 不匹配的会话ID,跳过
|
)
|
||||||
parsed_range = self._parse_range(rule.time)
|
self._last_extraction_time = extraction_end_time
|
||||||
if not parsed_range:
|
if self._enable_expression_learning:
|
||||||
continue # 无法解析的时间范围,跳过
|
asyncio.create_task(self._expression_learning())
|
||||||
start_min, end_min = parsed_range
|
|
||||||
in_range: bool = False
|
|
||||||
if start_min <= end_min:
|
|
||||||
in_range = start_min <= now_min <= end_min
|
|
||||||
else: # 跨天的时间范围
|
|
||||||
in_range = now_min >= start_min or now_min <= end_min
|
|
||||||
if in_range:
|
|
||||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
|
||||||
|
|
||||||
# 没有匹配到会话相关的规则,继续匹配全局规则
|
async def _expression_learning(self):
|
||||||
for rule in global_config.chat.talk_value_rules:
|
|
||||||
if rule.platform or rule.item_id:
|
|
||||||
continue # 只匹配全局规则
|
|
||||||
parsed_range = self._parse_range(rule.time)
|
|
||||||
if not parsed_range:
|
|
||||||
continue # 无法解析的时间范围,跳过
|
|
||||||
start_min, end_min = parsed_range
|
|
||||||
in_range: bool = False
|
|
||||||
if start_min <= end_min:
|
|
||||||
in_range = start_min <= now_min <= end_min
|
|
||||||
else: # 跨天的时间范围
|
|
||||||
in_range = now_min >= start_min or now_min <= end_min
|
|
||||||
if in_range:
|
|
||||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
|
||||||
return result # 如果没有任何规则生效,返回默认值
|
|
||||||
|
|
||||||
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
|
|
||||||
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
|
||||||
try:
|
try:
|
||||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
learnt_style = await self._expression_learner.learn()
|
||||||
sh, sm = [int(x) for x in start_str.split(":")]
|
if learnt_style:
|
||||||
eh, em = [int(x) for x in end_str.split(":")]
|
logger.info(f"{self.log_prefix} 表达学习完成")
|
||||||
return sh * 60 + sm, eh * 60 + em
|
else:
|
||||||
except Exception:
|
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
|
||||||
return None
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# ====== 记录循环执行信息相关逻辑 ======
|
||||||
|
def _start_cycle(self) -> CycleDetail:
|
||||||
|
self._cycle_counter += 1
|
||||||
|
current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||||
|
current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||||
|
return current_cycle_detail
|
||||||
|
|
||||||
|
def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
|
||||||
|
cycle_detail.end_time = time.time()
|
||||||
|
timer_strings: List[str] = [
|
||||||
|
f"{name}: {duration:.2f}s"
|
||||||
|
for name, duration in cycle_detail.time_records.items()
|
||||||
|
if not only_long_execution or duration >= 0.1
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
|
||||||
|
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
|
||||||
|
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cycle_detail
|
||||||
|
|
||||||
|
# ====== Action相关逻辑 ======
|
||||||
|
async def _execute_action(self, *args, **kwargs):
|
||||||
|
"""原ExecuteAction"""
|
||||||
|
raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
|
||||||
|
|
||||||
|
async def _execute_other_actions(self, *args, **kwargs):
|
||||||
|
"""原HandleAction"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"执行其他动作的逻辑尚未实现"
|
||||||
|
) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
|
||||||
|
|
||||||
|
# ====== 响应发送相关方法 ======
|
||||||
|
async def _send_response(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
|
||||||
|
# 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from src.chat.utils.timer_calculator import Timer
|
|||||||
from src.chat.planner_actions.planner import ActionPlanner
|
from src.chat.planner_actions.planner import ActionPlanner
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
from src.chat.heart_flow.hfc_utils_old import CycleDetail
|
||||||
from src.bw_learner.expression_learner_old import expression_learner_manager
|
from src.bw_learner.expression_learner_old import expression_learner_manager
|
||||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||||
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||||
@@ -155,7 +155,7 @@ class HeartFChatting:
|
|||||||
|
|
||||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||||
self._cycle_counter += 1
|
self._cycle_counter += 1
|
||||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
self._current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||||
cycle_timers = {}
|
cycle_timers = {}
|
||||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||||
|
|||||||
31
src/chat/heart_flow/heartFC_utils.py
Normal file
31
src/chat/heart_flow/heartFC_utils.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Dict, TypedDict
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CyclePlanInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CycleActionInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CycleDetail:
|
||||||
|
"""循环信息记录类"""
|
||||||
|
|
||||||
|
cycle_id: int
|
||||||
|
thinking_id: str = ""
|
||||||
|
"""思考ID"""
|
||||||
|
start_time: float = time.time()
|
||||||
|
"""开始时间,单位为秒"""
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
"""结束时间,单位为秒,None表示未结束"""
|
||||||
|
time_records: Dict[str, float] = {}
|
||||||
|
"""计时器记录,key为计时器名称,value为用时,单位为秒"""
|
||||||
|
loop_plan_info: Optional[CyclePlanInfo] = None
|
||||||
|
"""循环计划记录"""
|
||||||
|
loop_action_info: Optional[CycleActionInfo] = None
|
||||||
|
"""循环Action调用记录"""
|
||||||
@@ -31,61 +31,6 @@ class CycleDetail:
|
|||||||
self.end_time: Optional[float] = None
|
self.end_time: Optional[float] = None
|
||||||
self.timers: Dict[str, float] = {}
|
self.timers: Dict[str, float] = {}
|
||||||
|
|
||||||
self.loop_plan_info: CyclePlanInfo = CyclePlanInfo()
|
|
||||||
self.loop_action_info: CycleActionInfo = CycleActionInfo()
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""将循环信息转换为字典格式"""
|
|
||||||
|
|
||||||
def convert_to_serializable(obj, depth=0, seen=None):
|
|
||||||
if seen is None:
|
|
||||||
seen = set()
|
|
||||||
|
|
||||||
# 防止递归过深
|
|
||||||
if depth > 5: # 降低递归深度限制
|
|
||||||
return str(obj)
|
|
||||||
|
|
||||||
# 防止循环引用
|
|
||||||
obj_id = id(obj)
|
|
||||||
if obj_id in seen:
|
|
||||||
return str(obj)
|
|
||||||
seen.add(obj_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(obj, "to_dict"):
|
|
||||||
# 对于有to_dict方法的对象,直接调用其to_dict方法
|
|
||||||
return obj.to_dict()
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
# 对于字典,只保留基本类型和可序列化的值
|
|
||||||
return {
|
|
||||||
k: convert_to_serializable(v, depth + 1, seen)
|
|
||||||
for k, v in obj.items()
|
|
||||||
if isinstance(k, (str, int, float, bool))
|
|
||||||
}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
# 对于列表和元组,只保留可序列化的元素
|
|
||||||
return [
|
|
||||||
convert_to_serializable(item, depth + 1, seen)
|
|
||||||
for item in obj
|
|
||||||
if not isinstance(item, (dict, list, tuple))
|
|
||||||
or isinstance(item, (str, int, float, bool, type(None)))
|
|
||||||
]
|
|
||||||
elif isinstance(obj, (str, int, float, bool, type(None))):
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
return str(obj)
|
|
||||||
finally:
|
|
||||||
seen.remove(obj_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cycle_id": self.cycle_id,
|
|
||||||
"start_time": self.start_time,
|
|
||||||
"end_time": self.end_time,
|
|
||||||
"timers": self.timers,
|
|
||||||
"thinking_id": self.thinking_id,
|
|
||||||
"loop_plan_info": convert_to_serializable(self.loop_plan_info),
|
|
||||||
"loop_action_info": convert_to_serializable(self.loop_action_info),
|
|
||||||
}
|
|
||||||
|
|
||||||
def set_loop_info(self, loop_info: Dict[str, Any]):
|
def set_loop_info(self, loop_info: Dict[str, Any]):
|
||||||
"""设置循环信息"""
|
"""设置循环信息"""
|
||||||
133
src/common/utils/utils_config.py
Normal file
133
src/common/utils/utils_config.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
logger = get_logger("config_utils")
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionConfigUtils:
|
||||||
|
@staticmethod
|
||||||
|
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||||
|
# sourcery skip: use-next
|
||||||
|
"""
|
||||||
|
根据聊天会话ID获取表达配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 聊天会话ID,格式为哈希值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||||
|
"""
|
||||||
|
if not global_config.expression.learning_list:
|
||||||
|
return True, True, True
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
for config_item in global_config.expression.learning_list:
|
||||||
|
if not config_item.platform and not config_item.item_id:
|
||||||
|
continue # 这是全局的
|
||||||
|
stream_id = ExpressionConfigUtils._get_stream_id(
|
||||||
|
config_item.platform,
|
||||||
|
str(config_item.item_id),
|
||||||
|
(config_item.rule_type == "group"),
|
||||||
|
)
|
||||||
|
if stream_id is None:
|
||||||
|
continue
|
||||||
|
if stream_id == session_id:
|
||||||
|
continue
|
||||||
|
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||||
|
for config_item in global_config.expression.learning_list:
|
||||||
|
if not config_item.platform and not config_item.item_id:
|
||||||
|
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||||
|
|
||||||
|
return True, True, True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
|
||||||
|
# sourcery skip: remove-unnecessary-cast
|
||||||
|
"""
|
||||||
|
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台名称
|
||||||
|
id_str: 用户或群组的原始ID字符串
|
||||||
|
is_group: 是否为群聊
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 生成的聊天流ID(哈希值)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
|
||||||
|
if is_group:
|
||||||
|
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
|
||||||
|
else:
|
||||||
|
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成聊天流ID失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatConfigUtils:
|
||||||
|
@staticmethod
|
||||||
|
def get_talk_value(session_id: Optional[str]) -> float:
|
||||||
|
result = global_config.chat.talk_value or 0.0
|
||||||
|
if not global_config.chat.enable_talk_value_rules or not global_config.chat.talk_value_rules:
|
||||||
|
return result
|
||||||
|
local_time = time.localtime()
|
||||||
|
now_min = local_time.tm_hour * 60 + local_time.tm_min
|
||||||
|
|
||||||
|
# 优先匹配会话相关的规则
|
||||||
|
if session_id:
|
||||||
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
|
||||||
|
for rule in global_config.chat.talk_value_rules:
|
||||||
|
if not rule.platform and not rule.item_id:
|
||||||
|
continue # 一起留空表示全局
|
||||||
|
if rule.rule_type == "group":
|
||||||
|
rule_session_id = SessionUtils.calculate_session_id(rule.platform, group_id=str(rule.item_id))
|
||||||
|
else:
|
||||||
|
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
||||||
|
if rule_session_id != session_id:
|
||||||
|
continue # 不匹配的会话ID,跳过
|
||||||
|
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||||
|
if not parsed_range:
|
||||||
|
continue # 无法解析的时间范围,跳过
|
||||||
|
start_min, end_min = parsed_range
|
||||||
|
in_range: bool = False
|
||||||
|
if start_min <= end_min:
|
||||||
|
in_range = start_min <= now_min <= end_min
|
||||||
|
else: # 跨天的时间范围
|
||||||
|
in_range = now_min >= start_min or now_min <= end_min
|
||||||
|
if in_range:
|
||||||
|
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||||
|
|
||||||
|
# 没有匹配到会话相关的规则,继续匹配全局规则
|
||||||
|
for rule in global_config.chat.talk_value_rules:
|
||||||
|
if rule.platform or rule.item_id:
|
||||||
|
continue # 只匹配全局规则
|
||||||
|
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||||
|
if not parsed_range:
|
||||||
|
continue # 无法解析的时间范围,跳过
|
||||||
|
start_min, end_min = parsed_range
|
||||||
|
in_range: bool = False
|
||||||
|
if start_min <= end_min:
|
||||||
|
in_range = start_min <= now_min <= end_min
|
||||||
|
else: # 跨天的时间范围
|
||||||
|
in_range = now_min >= start_min or now_min <= end_min
|
||||||
|
if in_range:
|
||||||
|
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||||
|
return result # 如果没有任何规则生效,返回默认值
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_range(range_str: str) -> Optional[tuple[int, int]]:
|
||||||
|
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||||
|
try:
|
||||||
|
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||||
|
sh, sm = [int(x) for x in start_str.split(":")]
|
||||||
|
eh, em = [int(x) for x in end_str.split(":")]
|
||||||
|
return sh * 60 + sm, eh * 60 + em
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user