feat:合并timing和plan展示,回复频率控制
This commit is contained in:
@@ -123,6 +123,7 @@ async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.
|
||||
session_id="session-1",
|
||||
_chat_history=[],
|
||||
_clear_force_continue_until_reply=lambda: None,
|
||||
_record_reply_sent=lambda: None,
|
||||
run_sub_agent=None,
|
||||
)
|
||||
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
|
||||
@@ -214,15 +215,26 @@ async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytes
|
||||
await emit_planner_finalized(
|
||||
session_id="session-1",
|
||||
cycle_id=3,
|
||||
request_messages=[{"role": "user", "content": "你好"}],
|
||||
selected_history_count=5,
|
||||
tool_count=2,
|
||||
timing_request_messages=[{"role": "user", "content": "先看看要不要继续"}],
|
||||
timing_selected_history_count=3,
|
||||
timing_tool_count=1,
|
||||
timing_action="continue",
|
||||
timing_content="继续",
|
||||
timing_tool_calls=[SimpleNamespace(call_id="timing-call-1", func_name="continue", args={})],
|
||||
timing_tool_results=["- continue [成功]: 继续执行"],
|
||||
timing_prompt_tokens=40,
|
||||
timing_completion_tokens=5,
|
||||
timing_total_tokens=45,
|
||||
timing_duration_ms=11.2,
|
||||
planner_request_messages=[{"role": "user", "content": "你好"}],
|
||||
planner_selected_history_count=5,
|
||||
planner_tool_count=2,
|
||||
planner_content="先查询再回复",
|
||||
planner_tool_calls=[SimpleNamespace(call_id="call-1", func_name="reply", args={"msg_id": "m1"})],
|
||||
prompt_tokens=100,
|
||||
completion_tokens=30,
|
||||
total_tokens=130,
|
||||
duration_ms=88.5,
|
||||
planner_prompt_tokens=100,
|
||||
planner_completion_tokens=30,
|
||||
planner_total_tokens=130,
|
||||
planner_duration_ms=88.5,
|
||||
tools=[
|
||||
{
|
||||
"tool_call_id": "call-1",
|
||||
@@ -240,6 +252,8 @@ async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytes
|
||||
|
||||
assert captured["event"] == "planner.finalized"
|
||||
payload = captured["data"]
|
||||
assert payload["timing_gate"]["result"]["action"] == "continue"
|
||||
assert payload["timing_gate"]["result"]["tool_results"] == ["- continue [成功]: 继续执行"]
|
||||
assert payload["request"]["messages"][0]["content"] == "你好"
|
||||
assert payload["request"]["tool_count"] == 2
|
||||
assert payload["planner"]["tool_calls"][0]["id"] == "call-1"
|
||||
@@ -247,6 +261,51 @@ async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytes
|
||||
assert payload["final_state"]["agent_state"] == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_planner_finalized_supports_timing_only_cycle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _fake_broadcast(event: str, data: dict[str, Any]) -> None:
|
||||
captured["event"] = event
|
||||
captured["data"] = data
|
||||
|
||||
monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast)
|
||||
|
||||
await emit_planner_finalized(
|
||||
session_id="session-2",
|
||||
cycle_id=7,
|
||||
timing_request_messages=[{"role": "user", "content": "先别回"}],
|
||||
timing_selected_history_count=2,
|
||||
timing_tool_count=1,
|
||||
timing_action="no_reply",
|
||||
timing_content="当前不适合继续",
|
||||
timing_tool_calls=[SimpleNamespace(call_id="timing-call-2", func_name="no_reply", args={})],
|
||||
timing_tool_results=["- no_reply [成功]: 暂停当前对话"],
|
||||
timing_prompt_tokens=18,
|
||||
timing_completion_tokens=4,
|
||||
timing_total_tokens=22,
|
||||
timing_duration_ms=6.5,
|
||||
planner_request_messages=None,
|
||||
planner_selected_history_count=None,
|
||||
planner_tool_count=None,
|
||||
planner_content=None,
|
||||
planner_tool_calls=None,
|
||||
planner_prompt_tokens=None,
|
||||
planner_completion_tokens=None,
|
||||
planner_total_tokens=None,
|
||||
planner_duration_ms=None,
|
||||
tools=[],
|
||||
time_records={"timing_gate": 0.02},
|
||||
agent_state="stop",
|
||||
)
|
||||
|
||||
assert captured["event"] == "planner.finalized"
|
||||
payload = captured["data"]
|
||||
assert payload["timing_gate"]["result"]["action"] == "no_reply"
|
||||
assert payload["planner"] is None
|
||||
assert payload["request"] is None
|
||||
|
||||
|
||||
def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without_detail() -> None:
|
||||
engine = object.__new__(MaisakaReasoningEngine)
|
||||
tool_call = SimpleNamespace(call_id="call-2", func_name="query_memory")
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.file_watcher import FileChange
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
|
||||
from .heartFC_utils import CycleDetail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
install(extra_lines=5)
|
||||
|
||||
logger = get_logger("heartFC_chat")
|
||||
|
||||
|
||||
class HeartFChatting:
|
||||
"""
|
||||
管理一个连续的Focus Chat聊天会话
|
||||
用于在特定的聊天会话里面生成回复
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
"""
|
||||
初始化 HeartFChatting 实例
|
||||
|
||||
Args:
|
||||
session_id: 聊天会话ID
|
||||
"""
|
||||
# 基础属性
|
||||
self.session_id = session_id
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self.session_name = session_name
|
||||
|
||||
# 系统运行状态
|
||||
self._running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._cycle_counter: int = 0
|
||||
self._hfc_lock: asyncio.Lock = asyncio.Lock() # 用于保护 _hfc_func 的并发访问
|
||||
# 聊天频率相关
|
||||
self._consecutive_no_reply_count = 0 # 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self._talk_frequency_adjust: float = 1.0 # 发言频率修正值,默认为1.0,可以根据需要调整
|
||||
|
||||
# HFC内消息缓存
|
||||
self.message_cache: List[SessionMessage] = []
|
||||
|
||||
# Asyncio Event 用于控制循环的开始和结束
|
||||
self._cycle_event = asyncio.Event()
|
||||
|
||||
# 表达方式相关内容
|
||||
self._min_messages_for_extraction = 30 # 最少提取消息数
|
||||
self._min_extraction_interval = 60 # 最小提取时间间隔,单位为秒
|
||||
self._last_extraction_time: float = 0.0 # 上次提取的时间戳
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
# 表达学习器
|
||||
self._expression_learner: ExpressionLearner = ExpressionLearner(session_id)
|
||||
# 黑话挖掘器
|
||||
self._jargon_miner: JargonMiner = JargonMiner(session_id, session_name=session_name)
|
||||
|
||||
# TODO: ChatSummarizer 聊天总结器重构
|
||||
|
||||
# ====== 公开方法 ======
|
||||
|
||||
async def start(self):
|
||||
"""启动 HeartFChatting 的主循环"""
|
||||
# 先检查是否已经启动运行
|
||||
if self._running:
|
||||
logger.debug(f"{self.log_prefix} 已经在运行中,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
self._running = True
|
||||
self._cycle_event.clear() # 确保事件初始状态为未设置
|
||||
|
||||
self._loop_task = asyncio.create_task(self.main_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 启动 HeartFChatting 失败: {e}", exc_info=True)
|
||||
self._running = False # 确保状态正确
|
||||
self._cycle_event.set() # 确保事件被设置,避免死锁
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""停止 HeartFChatting 的主循环"""
|
||||
if not self._running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已经停止,无需重复停止")
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._cycle_event.set() # 触发事件,通知循环结束
|
||||
|
||||
if self._loop_task:
|
||||
self._loop_task.cancel() # 取消主循环任务
|
||||
try:
|
||||
await self._loop_task # 等待任务完成
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 主循环已成功取消")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 停止 HeartFChatting 时发生错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self._loop_task = None # 确保任务引用被清理
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 已停止")
|
||||
|
||||
def adjust_talk_frequency(self, new_value: float):
|
||||
"""调整发言频率的调整值
|
||||
|
||||
Args:
|
||||
new_value: 新的修正值,必须为非负数。值越大,修正发言频率越高;值越小,修正发言频率越低。
|
||||
"""
|
||||
self._talk_frequency_adjust = max(0.0, new_value)
|
||||
|
||||
async def register_message(self, message: "SessionMessage"):
|
||||
"""注册一条消息到 HeartFChatting 的缓存中,并检测其是否产生提及,决定是否唤醒聊天
|
||||
|
||||
Args:
|
||||
message: 待注册的消息对象
|
||||
"""
|
||||
self.message_cache.append(message)
|
||||
# 先检查at必回复
|
||||
if global_config.chat.inevitable_at_reply and message.is_at:
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return # 直接返回,避免同一条消息被主循环再次处理
|
||||
# 再检查提及必回复
|
||||
if global_config.chat.mentioned_bot_reply and message.is_mentioned:
|
||||
# 直接获取锁,确保一定一定触发回复逻辑,不受当前是否正在执行主循环的影响
|
||||
async with self._hfc_lock: # 确保与主循环逻辑的互斥访问
|
||||
await self._judge_and_response(message)
|
||||
return
|
||||
|
||||
async def main_loop(self):
|
||||
try:
|
||||
while self._running and not self._cycle_event.is_set():
|
||||
if not self._hfc_lock.locked():
|
||||
async with self._hfc_lock: # 确保主循环逻辑的互斥访问
|
||||
await self._hfc_func()
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 主循环被取消,正在关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误: {e},将于3s后尝试重新启动")
|
||||
await self.stop() # 确保状态正确
|
||||
await asyncio.sleep(3)
|
||||
await self.start() # 尝试重新启动
|
||||
|
||||
async def _config_callback(self, file_change: Optional[FileChange] = None):
|
||||
"""配置文件变更回调函数"""
|
||||
# TODO: 根据配置文件变动重新计算相关参数:
|
||||
"""
|
||||
需要计算的参数:
|
||||
self._enable_expression_use = expr_use # 允许使用表达方式,但不一定启用学习
|
||||
self._enable_expression_learning = expr_learn # 允许学习表达方式
|
||||
self._enable_jargon_learning = jargon_learn # 允许学习黑话
|
||||
"""
|
||||
|
||||
# ====== 心流聊天核心逻辑 ======
|
||||
async def _hfc_func(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""心流聊天的主循环逻辑"""
|
||||
if self._consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self._consecutive_no_reply_count >= 3:
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
threshold = 1
|
||||
|
||||
if len(self.message_cache) < threshold:
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
|
||||
talk_value_threshold = (
|
||||
random.random() * ChatConfigUtils.get_talk_value(self.session_id) * self._talk_frequency_adjust
|
||||
)
|
||||
if mentioned_message and global_config.chat.mentioned_bot_reply:
|
||||
await self._judge_and_response(mentioned_message)
|
||||
elif random.random() < talk_value_threshold:
|
||||
await self._judge_and_response()
|
||||
return True
|
||||
|
||||
async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None):
|
||||
"""判定和生成回复"""
|
||||
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
|
||||
# TODO: 完成反思器之后的逻辑
|
||||
current_cycle_detail = self._start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# TODO: 动作检查逻辑
|
||||
# TODO: Planner逻辑
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
self._end_cycle(current_cycle_detail)
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_func 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
# ====== 学习器触发逻辑 ======
|
||||
async def _trigger_expression_learning(self, messages: List["SessionMessage"]):
|
||||
self._expression_learner.add_messages(messages)
|
||||
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||
return
|
||||
if self._expression_learner.get_cache_size() < self._min_messages_for_extraction:
|
||||
return
|
||||
if not self._enable_expression_learning:
|
||||
return
|
||||
extraction_end_time = time.time()
|
||||
logger.info(
|
||||
f"聊天流 {self.session_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {self._last_extraction_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
self._last_extraction_time = extraction_end_time
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
learnt_style = await self._expression_learner.learn(jargon_miner)
|
||||
if learnt_style:
|
||||
logger.info(f"{self.log_prefix} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 表达学习未获得有效结果")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {e}", exc_info=True)
|
||||
|
||||
# ====== 记录循环执行信息相关逻辑 ======
|
||||
def _start_cycle(self) -> CycleDetail:
|
||||
self._cycle_counter += 1
|
||||
current_cycle_detail = CycleDetail(cycle_id=self._cycle_counter)
|
||||
current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
return current_cycle_detail
|
||||
|
||||
def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True):
|
||||
cycle_detail.end_time = time.time()
|
||||
timer_strings: List[str] = [
|
||||
f"{name}: {duration:.2f}s"
|
||||
for name, duration in cycle_detail.time_records.items()
|
||||
if not only_long_execution or duration >= 0.1
|
||||
]
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第 {cycle_detail.cycle_id} 个心流循环完成"
|
||||
f"耗时: {cycle_detail.end_time - cycle_detail.start_time:.2f}秒\n"
|
||||
f"详细计时: {', '.join(timer_strings) if timer_strings else '无'}"
|
||||
)
|
||||
|
||||
return cycle_detail
|
||||
|
||||
# ====== Action相关逻辑 ======
|
||||
async def _execute_action(self, *args, **kwargs):
|
||||
"""原ExecuteAction"""
|
||||
raise NotImplementedError("执行动作的逻辑尚未实现") # TODO: 实现动作执行的逻辑,替换掉*args, **kwargs*占位符
|
||||
|
||||
async def _execute_other_actions(self, *args, **kwargs):
|
||||
"""原HandleAction"""
|
||||
raise NotImplementedError(
|
||||
"执行其他动作的逻辑尚未实现"
|
||||
) # TODO: 实现其他动作执行的逻辑, 替换掉*args, **kwargs*占位符
|
||||
|
||||
# ====== 响应发送相关方法 ======
|
||||
async def _send_response(self, *args, **kwargs):
|
||||
raise NotImplementedError("发送回复的逻辑尚未实现") # TODO: 实现发送回复的逻辑,替换掉*args, **kwargs*占位符
|
||||
# 传入的消息至少应该是个MessageSequence实例,最好是SessionMessage实例,随后可直接转化为MessageSending实例
|
||||
@@ -26,8 +26,6 @@ from src.services.message_service import (
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
from src.learners.expression_selector import expression_selector
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person
|
||||
from src.core.types import ActionInfo, EventType
|
||||
@@ -295,53 +293,19 @@ class DefaultReplyer:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = "", think_level: int = 1
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
"""构建表达习惯块。"""
|
||||
del chat_history
|
||||
del target
|
||||
del reply_reason
|
||||
del think_level
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1/2
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.session_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
reply_reason=reply_reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时:{expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
# 旧 replyer 的表达方式选择链路已停用,这里不再执行额外的模型筛选。
|
||||
logger.debug("旧 replyer 表达方式选择已停用,跳过 expression habits 构建")
|
||||
return "", []
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str) -> str:
|
||||
del chat_history
|
||||
|
||||
@@ -31,7 +31,6 @@ from .official_configs import (
|
||||
MessageReceiveConfig,
|
||||
PersonalityConfig,
|
||||
PluginRuntimeConfig,
|
||||
RelationshipConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
@@ -56,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.5.1"
|
||||
CONFIG_VERSION: str = "8.5.2"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -414,6 +414,7 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
|
||||
maisaka = _as_dict(data.get("maisaka"))
|
||||
mem = _as_dict(data.get("memory"))
|
||||
debug = _as_dict(data.get("debug"))
|
||||
if maisaka is not None:
|
||||
moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit")
|
||||
if any(key in maisaka for key in moved_memory_keys) and mem is None:
|
||||
@@ -426,11 +427,19 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append(f"maisaka.{moved_key}_moved_to_memory")
|
||||
|
||||
if mem is not None and "show_memory_prompt" in mem and debug is None:
|
||||
debug = {}
|
||||
data["debug"] = debug
|
||||
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
if debug is not None and _move_section_key(mem, debug, "show_memory_prompt"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.show_memory_prompt_moved_to_debug")
|
||||
|
||||
for removed_key in (
|
||||
"agent_timeout_seconds",
|
||||
"max_agent_iterations",
|
||||
@@ -440,6 +449,12 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append(f"memory.{removed_key}_removed")
|
||||
|
||||
relationship = _as_dict(data.get("relationship"))
|
||||
if relationship is not None:
|
||||
data.pop("relationship", None)
|
||||
migrated_any = True
|
||||
reasons.append("relationship_removed")
|
||||
|
||||
exp = _as_dict(data.get("experimental"))
|
||||
if exp is not None:
|
||||
if _migrate_extra_prompt_list(exp, "chat_prompts"):
|
||||
|
||||
@@ -1,587 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self) -> None:
|
||||
"""初始化表达方式选择器。"""
|
||||
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="utils", request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
|
||||
"""从 Hook 载荷恢复表达方式选择结果。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 恢复后的表达方式列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Dict[str, Any]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
expression_id = raw_expression.get("id")
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not isinstance(expression_id, int) or not situation or not style or not source_id:
|
||||
continue
|
||||
normalized_expression = dict(raw_expression)
|
||||
normalized_expression["id"] = expression_id
|
||||
normalized_expression["situation"] = situation
|
||||
normalized_expression["style"] = style
|
||||
normalized_expression["source_id"] = source_id
|
||||
normalized_expressions.append(normalized_expression)
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
|
||||
"""规范化最终选中的表达方式 ID 列表。
|
||||
|
||||
Args:
|
||||
raw_ids: Hook 返回的 ID 列表。
|
||||
expressions: 当前最终表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[int]: 规范化后的 ID 列表。
|
||||
"""
|
||||
|
||||
if isinstance(raw_ids, list):
|
||||
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
|
||||
if normalized_ids:
|
||||
return normalized_ids
|
||||
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id,直接使用 ChatManager 提供的接口"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
return SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
简单模式:只选择 count > 1 的项目,要求至少有10个才进行选择,随机选5个,不进行LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
max_num: 最大选择数量(此参数在此模式下不使用,固定选择5个)
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
|
||||
min_required = 8
|
||||
if len(style_exprs) < min_required:
|
||||
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
|
||||
if not style_exprs:
|
||||
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
|
||||
# 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程)
|
||||
fallback_num = min(3, max_num) if max_num > 0 else 3
|
||||
if fallback_selected := self._random_expressions(chat_id, fallback_num):
|
||||
self.update_expressions_last_active_time(fallback_selected)
|
||||
selected_ids = [expr["id"] for expr in fallback_selected]
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
|
||||
)
|
||||
return fallback_selected, selected_ids
|
||||
return [], []
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
|
||||
f"简单模式降级为随机选择 3 个"
|
||||
)
|
||||
select_count = min(3, len(style_exprs))
|
||||
else:
|
||||
# 高 count 数量达标时,固定选择 5 个
|
||||
select_count = 5
|
||||
import random
|
||||
|
||||
selected_style = random.sample(style_exprs, select_count)
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_style:
|
||||
self.update_expressions_last_active_time(selected_style)
|
||||
|
||||
selected_ids = [expr["id"] for expr in selected_style]
|
||||
logger.debug(
|
||||
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||
)
|
||||
return selected_style, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
return weighted_sample(style_exprs, total_num) if style_exprs else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""选择适合的表达方式。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
before_select_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.before_select",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
|
||||
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
|
||||
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
|
||||
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
|
||||
target_message = str(raw_target_message or "").strip() or None
|
||||
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
|
||||
reply_reason = str(raw_reply_reason or "").strip() or None
|
||||
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
selected_expressions, selected_ids = await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
after_selection_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.after_selection",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
selected_expressions=[dict(item) for item in selected_expressions],
|
||||
selected_expression_ids=list(selected_ids),
|
||||
)
|
||||
if after_selection_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
after_selection_kwargs = after_selection_result.kwargs
|
||||
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
|
||||
if raw_selected_expressions is not None:
|
||||
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
|
||||
selected_ids = self._normalize_selected_expression_ids(
|
||||
after_selection_kwargs.get("selected_expression_ids"),
|
||||
selected_expressions,
|
||||
)
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择
|
||||
if think_level == 0:
|
||||
return self._select_expressions_simple(chat_id, max_num)
|
||||
|
||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 分离 count > 1 和 count <= 1 的表达方式
|
||||
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
|
||||
|
||||
# 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回)
|
||||
min_high_count = 10
|
||||
min_total_count = 10
|
||||
select_high_count = 5
|
||||
select_random_count = 5
|
||||
|
||||
# 检查数量要求
|
||||
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
|
||||
if len(high_count_exprs) < min_high_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
|
||||
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
|
||||
)
|
||||
high_count_valid = False
|
||||
else:
|
||||
high_count_valid = True
|
||||
|
||||
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
|
||||
if len(all_style_exprs) < min_total_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
# 先选取高count的表达方式(如果数量达标)
|
||||
if high_count_valid:
|
||||
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
|
||||
else:
|
||||
selected_high = []
|
||||
|
||||
# 然后从所有表达方式中随机抽样(使用加权抽样)
|
||||
remaining_num = select_random_count
|
||||
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
|
||||
|
||||
# 合并候选池(去重,避免重复)
|
||||
candidate_exprs = selected_high.copy()
|
||||
candidate_ids = {expr["id"] for expr in candidate_exprs}
|
||||
for expr in selected_random:
|
||||
if expr["id"] not in candidate_ids:
|
||||
candidate_exprs.append(expr)
|
||||
candidate_ids.add(expr["id"])
|
||||
|
||||
# 打乱顺序,避免高count的都在前面
|
||||
import random
|
||||
|
||||
random.shuffle(candidate_exprs)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in candidate_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
chat_context = ""
|
||||
else:
|
||||
reply_reason_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt_template = prompt_manager.get_prompt("expression_select")
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("chat_observe_info", chat_context)
|
||||
prompt_template.add_context("all_situations", all_situations_str)
|
||||
prompt_template.add_context("max_num", str(max_num))
|
||||
prompt_template.add_context("target_message", target_message_str)
|
||||
prompt_template.add_context("target_message_extra_block", target_message_extra_block)
|
||||
prompt_template.add_context("reply_reason_block", reply_reason_block)
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
# 4. 调用LLM
|
||||
generation_result = await self.llm_model.generate_response(prompt=prompt)
|
||||
content = generation_result.response
|
||||
# print(prompt)
|
||||
# print(content)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.debug(f"从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
@@ -1,134 +0,0 @@
|
||||
from json_repair import repair_json
|
||||
from typing import List, Tuple
|
||||
|
||||
import re
|
||||
import json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
if in_string and char in ["“", "”"]:
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
raw = match[1].strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
@@ -205,7 +205,7 @@ async def handle_tool(
|
||||
else:
|
||||
for sent_message in sent_messages:
|
||||
tool_ctx.append_sent_message_to_chat_history(sent_message)
|
||||
tool_ctx.runtime._clear_force_continue_until_reply()
|
||||
tool_ctx.runtime._record_reply_sent()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"回复已生成并发送。",
|
||||
|
||||
@@ -163,6 +163,103 @@ def _serialize_tool_results(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]
|
||||
return serialized_tools
|
||||
|
||||
|
||||
def _serialize_request_block(
|
||||
messages: Optional[List[Any]],
|
||||
selected_history_count: Optional[int],
|
||||
tool_count: Optional[int],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""标准化请求区块。"""
|
||||
|
||||
if messages is None and selected_history_count is None and tool_count is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"messages": _serialize_messages(list(messages or [])),
|
||||
"selected_history_count": int(selected_history_count or 0),
|
||||
"tool_count": int(tool_count or 0),
|
||||
}
|
||||
|
||||
|
||||
def _serialize_planner_block(
|
||||
content: Optional[str],
|
||||
tool_calls: Optional[List[Any]],
|
||||
prompt_tokens: Optional[int],
|
||||
completion_tokens: Optional[int],
|
||||
total_tokens: Optional[int],
|
||||
duration_ms: Optional[float],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""标准化 planner 结果区块。"""
|
||||
|
||||
if (
|
||||
content is None
|
||||
and tool_calls is None
|
||||
and prompt_tokens is None
|
||||
and completion_tokens is None
|
||||
and total_tokens is None
|
||||
and duration_ms is None
|
||||
):
|
||||
return None
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": _serialize_tool_calls_from_objects(list(tool_calls or [])),
|
||||
"prompt_tokens": int(prompt_tokens or 0),
|
||||
"completion_tokens": int(completion_tokens or 0),
|
||||
"total_tokens": int(total_tokens or 0),
|
||||
"duration_ms": float(duration_ms or 0.0),
|
||||
}
|
||||
|
||||
|
||||
def _serialize_timing_gate_block(
|
||||
*,
|
||||
request_messages: Optional[List[Any]],
|
||||
selected_history_count: Optional[int],
|
||||
tool_count: Optional[int],
|
||||
action: Optional[str],
|
||||
content: Optional[str],
|
||||
tool_calls: Optional[List[Any]],
|
||||
tool_results: Optional[List[str]],
|
||||
prompt_tokens: Optional[int],
|
||||
completion_tokens: Optional[int],
|
||||
total_tokens: Optional[int],
|
||||
duration_ms: Optional[float],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""标准化 Timing Gate 结果区块。"""
|
||||
|
||||
if (
|
||||
request_messages is None
|
||||
and selected_history_count is None
|
||||
and tool_count is None
|
||||
and action is None
|
||||
and content is None
|
||||
and tool_calls is None
|
||||
and tool_results is None
|
||||
and prompt_tokens is None
|
||||
and completion_tokens is None
|
||||
and total_tokens is None
|
||||
and duration_ms is None
|
||||
):
|
||||
return None
|
||||
|
||||
return {
|
||||
"request": _serialize_request_block(
|
||||
request_messages,
|
||||
selected_history_count,
|
||||
tool_count,
|
||||
),
|
||||
"result": {
|
||||
"action": action,
|
||||
"content": content,
|
||||
"tool_calls": _serialize_tool_calls_from_objects(list(tool_calls or [])),
|
||||
"tool_results": _normalize_payload_value(list(tool_results or [])),
|
||||
"prompt_tokens": int(prompt_tokens or 0),
|
||||
"completion_tokens": int(completion_tokens or 0),
|
||||
"total_tokens": int(total_tokens or 0),
|
||||
"duration_ms": float(duration_ms or 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _broadcast(event: str, data: Dict[str, Any]) -> None:
|
||||
"""通过统一 WebSocket 管理器向监控主题广播事件。"""
|
||||
|
||||
@@ -268,16 +365,27 @@ async def emit_planner_finalized(
|
||||
*,
|
||||
session_id: str,
|
||||
cycle_id: int,
|
||||
request_messages: List[Any],
|
||||
selected_history_count: int,
|
||||
tool_count: int,
|
||||
timing_request_messages: Optional[List[Any]],
|
||||
timing_selected_history_count: Optional[int],
|
||||
timing_tool_count: Optional[int],
|
||||
timing_action: Optional[str],
|
||||
timing_content: Optional[str],
|
||||
timing_tool_calls: Optional[List[Any]],
|
||||
timing_tool_results: Optional[List[str]],
|
||||
timing_prompt_tokens: Optional[int],
|
||||
timing_completion_tokens: Optional[int],
|
||||
timing_total_tokens: Optional[int],
|
||||
timing_duration_ms: Optional[float],
|
||||
planner_request_messages: Optional[List[Any]],
|
||||
planner_selected_history_count: Optional[int],
|
||||
planner_tool_count: Optional[int],
|
||||
planner_content: Optional[str],
|
||||
planner_tool_calls: List[Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
duration_ms: float,
|
||||
tools: List[Dict[str, Any]],
|
||||
planner_tool_calls: Optional[List[Any]],
|
||||
planner_prompt_tokens: Optional[int],
|
||||
planner_completion_tokens: Optional[int],
|
||||
planner_total_tokens: Optional[int],
|
||||
planner_duration_ms: Optional[float],
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
time_records: Dict[str, float],
|
||||
agent_state: str,
|
||||
) -> None:
|
||||
@@ -287,20 +395,33 @@ async def emit_planner_finalized(
|
||||
"session_id": session_id,
|
||||
"cycle_id": cycle_id,
|
||||
"timestamp": time.time(),
|
||||
"request": {
|
||||
"messages": _serialize_messages(request_messages),
|
||||
"selected_history_count": selected_history_count,
|
||||
"tool_count": tool_count,
|
||||
},
|
||||
"planner": {
|
||||
"content": planner_content,
|
||||
"tool_calls": _serialize_tool_calls_from_objects(planner_tool_calls),
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
"tools": _serialize_tool_results(tools),
|
||||
"timing_gate": _serialize_timing_gate_block(
|
||||
request_messages=timing_request_messages,
|
||||
selected_history_count=timing_selected_history_count,
|
||||
tool_count=timing_tool_count,
|
||||
action=timing_action,
|
||||
content=timing_content,
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
prompt_tokens=timing_prompt_tokens,
|
||||
completion_tokens=timing_completion_tokens,
|
||||
total_tokens=timing_total_tokens,
|
||||
duration_ms=timing_duration_ms,
|
||||
),
|
||||
"request": _serialize_request_block(
|
||||
planner_request_messages,
|
||||
planner_selected_history_count,
|
||||
planner_tool_count,
|
||||
),
|
||||
"planner": _serialize_planner_block(
|
||||
planner_content,
|
||||
planner_tool_calls,
|
||||
planner_prompt_tokens,
|
||||
planner_completion_tokens,
|
||||
planner_total_tokens,
|
||||
planner_duration_ms,
|
||||
),
|
||||
"tools": _serialize_tool_results(list(tools or [])),
|
||||
"final_state": {
|
||||
"time_records": _normalize_payload_value(time_records),
|
||||
"agent_state": agent_state,
|
||||
|
||||
@@ -347,6 +347,10 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
planner_started_at = 0.0
|
||||
planner_duration_ms = 0.0
|
||||
timing_duration_ms = 0.0
|
||||
timing_action: Optional[str] = None
|
||||
timing_response: Optional[ChatResponse] = None
|
||||
timing_tool_results: Optional[list[str]] = None
|
||||
response: Optional[ChatResponse] = None
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
try:
|
||||
@@ -458,23 +462,43 @@ class MaisakaReasoningEngine:
|
||||
break
|
||||
finally:
|
||||
completed_cycle = self._end_cycle(cycle_detail)
|
||||
if response is not None:
|
||||
await emit_planner_finalized(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
request_messages=response.request_messages,
|
||||
selected_history_count=response.selected_history_count,
|
||||
tool_count=response.tool_count,
|
||||
planner_content=response.content,
|
||||
planner_tool_calls=response.tool_calls,
|
||||
prompt_tokens=response.prompt_tokens,
|
||||
completion_tokens=response.completion_tokens,
|
||||
total_tokens=response.total_tokens,
|
||||
duration_ms=planner_duration_ms,
|
||||
tools=tool_monitor_results,
|
||||
time_records=dict(completed_cycle.time_records),
|
||||
agent_state=self._runtime._agent_state,
|
||||
)
|
||||
await emit_planner_finalized(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
timing_request_messages=(
|
||||
timing_response.request_messages if timing_response is not None else None
|
||||
),
|
||||
timing_selected_history_count=(
|
||||
timing_response.selected_history_count if timing_response is not None else None
|
||||
),
|
||||
timing_tool_count=timing_response.tool_count if timing_response is not None else None,
|
||||
timing_action=timing_action,
|
||||
timing_content=timing_response.content if timing_response is not None else None,
|
||||
timing_tool_calls=timing_response.tool_calls if timing_response is not None else None,
|
||||
timing_tool_results=timing_tool_results,
|
||||
timing_prompt_tokens=timing_response.prompt_tokens if timing_response is not None else None,
|
||||
timing_completion_tokens=(
|
||||
timing_response.completion_tokens if timing_response is not None else None
|
||||
),
|
||||
timing_total_tokens=timing_response.total_tokens if timing_response is not None else None,
|
||||
timing_duration_ms=timing_duration_ms if timing_response is not None else None,
|
||||
planner_request_messages=response.request_messages if response is not None else None,
|
||||
planner_selected_history_count=(
|
||||
response.selected_history_count if response is not None else None
|
||||
),
|
||||
planner_tool_count=response.tool_count if response is not None else None,
|
||||
planner_content=response.content if response is not None else None,
|
||||
planner_tool_calls=response.tool_calls if response is not None else None,
|
||||
planner_prompt_tokens=response.prompt_tokens if response is not None else None,
|
||||
planner_completion_tokens=(
|
||||
response.completion_tokens if response is not None else None
|
||||
),
|
||||
planner_total_tokens=response.total_tokens if response is not None else None,
|
||||
planner_duration_ms=planner_duration_ms if response is not None else None,
|
||||
tools=tool_monitor_results,
|
||||
time_records=dict(completed_cycle.time_records),
|
||||
agent_state=self._runtime._agent_state,
|
||||
)
|
||||
finally:
|
||||
if self._runtime._agent_state == self._runtime._STATE_RUNNING:
|
||||
self._runtime._agent_state = self._runtime._STATE_STOP
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Maisaka 非 CLI 运行时。"""
|
||||
|
||||
from collections import deque
|
||||
from math import ceil
|
||||
from typing import Any, Literal, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
@@ -17,7 +19,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
@@ -77,9 +79,14 @@ class MaisakaHeartFlowChatting:
|
||||
self._cycle_counter = 0
|
||||
self._internal_loop_task: Optional[asyncio.Task] = None
|
||||
self._message_turn_scheduled = False
|
||||
self._deferred_message_turn_task: Optional[asyncio.Task[None]] = None
|
||||
self._message_debounce_seconds = 1.0
|
||||
self._message_debounce_required = False
|
||||
self._message_received_at_by_id: dict[str, float] = {}
|
||||
self._last_message_received_at = 0.0
|
||||
self._talk_frequency_adjust = 1.0
|
||||
self._reply_latency_measurement_started_at: Optional[float] = None
|
||||
self._recent_reply_latencies: deque[tuple[float, float]] = deque()
|
||||
self._wait_timeout_task: Optional[asyncio.Task[None]] = None
|
||||
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
|
||||
self._max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
@@ -132,6 +139,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._running = False
|
||||
self._message_turn_scheduled = False
|
||||
self._message_debounce_required = False
|
||||
self._cancel_deferred_message_turn_task()
|
||||
self._cancel_wait_timeout_task()
|
||||
while not self._internal_turn_queue.empty():
|
||||
_ = self._internal_turn_queue.get_nowait()
|
||||
@@ -152,16 +160,19 @@ class MaisakaHeartFlowChatting:
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
|
||||
|
||||
def adjust_talk_frequency(self, frequency: float) -> None:
|
||||
"""兼容现有管理器接口的占位方法。"""
|
||||
_ = frequency
|
||||
"""调整当前会话的回复频率倍率。"""
|
||||
self._talk_frequency_adjust = max(0.01, float(frequency))
|
||||
self._schedule_message_turn()
|
||||
|
||||
async def register_message(self, message: SessionMessage) -> None:
|
||||
"""缓存一条新消息并唤醒主循环。"""
|
||||
if self._running:
|
||||
self._ensure_background_tasks_running()
|
||||
self._last_message_received_at = time.time()
|
||||
received_at = time.time()
|
||||
self._last_message_received_at = received_at
|
||||
self._update_message_trigger_state(message)
|
||||
self.message_cache.append(message)
|
||||
self._message_received_at_by_id[message.message_id] = received_at
|
||||
self._source_messages_by_id[message.message_id] = message
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
self._cancel_wait_timeout_task()
|
||||
@@ -197,6 +208,93 @@ class MaisakaHeartFlowChatting:
|
||||
if self._running:
|
||||
self._schedule_message_turn()
|
||||
|
||||
def _get_effective_reply_frequency(self) -> float:
|
||||
"""返回当前会话生效的回复频率。"""
|
||||
talk_value = max(0.01, float(ChatConfigUtils.get_talk_value(self.session_id)))
|
||||
return max(0.01, talk_value * self._talk_frequency_adjust)
|
||||
|
||||
def _get_message_trigger_threshold(self) -> int:
|
||||
"""根据回复频率折算出触发一轮循环所需的消息数。"""
|
||||
effective_frequency = min(1.0, self._get_effective_reply_frequency())
|
||||
return max(1, int(ceil(1.0 / effective_frequency)))
|
||||
|
||||
def _get_pending_message_count(self) -> int:
|
||||
"""统计当前尚未进入内部循环的新消息数量。"""
|
||||
pending_messages = self.message_cache[self._last_processed_index :]
|
||||
if not pending_messages:
|
||||
return 0
|
||||
|
||||
seen_message_ids: set[str] = set()
|
||||
for message in pending_messages:
|
||||
seen_message_ids.add(message.message_id)
|
||||
return len(seen_message_ids)
|
||||
|
||||
def _prune_recent_reply_latencies(self, now: Optional[float] = None) -> None:
|
||||
"""仅保留最近 10 分钟内的回复时长记录。"""
|
||||
current_time = time.time() if now is None else now
|
||||
expire_before = current_time - 600.0
|
||||
while self._recent_reply_latencies and self._recent_reply_latencies[0][0] < expire_before:
|
||||
self._recent_reply_latencies.popleft()
|
||||
|
||||
def _get_recent_average_reply_latency(self) -> Optional[float]:
|
||||
"""获取最近 10 分钟平均消息回复时长。"""
|
||||
self._prune_recent_reply_latencies()
|
||||
if not self._recent_reply_latencies:
|
||||
return None
|
||||
|
||||
total_duration = sum(duration for _, duration in self._recent_reply_latencies)
|
||||
return total_duration / len(self._recent_reply_latencies)
|
||||
|
||||
def _record_reply_sent(self) -> None:
|
||||
"""在成功发送 reply 后记录本轮消息回复时长。"""
|
||||
self._clear_force_continue_until_reply()
|
||||
if self._reply_latency_measurement_started_at is None:
|
||||
return
|
||||
|
||||
reply_duration = max(0.0, time.time() - self._reply_latency_measurement_started_at)
|
||||
self._reply_latency_measurement_started_at = None
|
||||
self._recent_reply_latencies.append((time.time(), reply_duration))
|
||||
self._prune_recent_reply_latencies()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已记录消息回复时长: {reply_duration:.2f} 秒 "
|
||||
f"最近10分钟样本数={len(self._recent_reply_latencies)}"
|
||||
)
|
||||
|
||||
def _should_trigger_message_turn_by_idle_compensation(
|
||||
self,
|
||||
*,
|
||||
pending_count: int,
|
||||
trigger_threshold: int,
|
||||
) -> bool:
|
||||
"""在新消息不足阈值时,按空窗时间折算补齐触发条件。"""
|
||||
average_reply_latency = self._get_recent_average_reply_latency()
|
||||
if average_reply_latency is None or average_reply_latency <= 0:
|
||||
return False
|
||||
|
||||
idle_seconds = max(0.0, time.time() - self._last_message_received_at)
|
||||
equivalent_message_count = pending_count + idle_seconds / average_reply_latency
|
||||
return equivalent_message_count >= trigger_threshold
|
||||
|
||||
def _cancel_deferred_message_turn_task(self) -> None:
|
||||
"""取消等待空窗补偿触发的延迟任务。"""
|
||||
if self._deferred_message_turn_task is None:
|
||||
return
|
||||
self._deferred_message_turn_task.cancel()
|
||||
self._deferred_message_turn_task = None
|
||||
|
||||
async def _schedule_deferred_message_turn(self, delay_seconds: float) -> None:
|
||||
"""在预计满足空窗补偿条件时再次检查是否应触发循环。"""
|
||||
try:
|
||||
if delay_seconds > 0:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
if not self._running:
|
||||
return
|
||||
self._schedule_message_turn()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
self._deferred_message_turn_task = None
|
||||
|
||||
def _update_message_trigger_state(self, message: SessionMessage) -> None:
|
||||
"""补齐消息中的 @/提及 标记,并在命中时启用强制 continue。"""
|
||||
|
||||
@@ -356,8 +454,30 @@ class MaisakaHeartFlowChatting:
|
||||
if not self._has_pending_messages() or self._message_turn_scheduled:
|
||||
return
|
||||
|
||||
self._message_turn_scheduled = True
|
||||
self._internal_turn_queue.put_nowait("message")
|
||||
pending_count = self._get_pending_message_count()
|
||||
if pending_count <= 0:
|
||||
return
|
||||
|
||||
trigger_threshold = self._get_message_trigger_threshold()
|
||||
if pending_count >= trigger_threshold or self._should_trigger_message_turn_by_idle_compensation(
|
||||
pending_count=pending_count,
|
||||
trigger_threshold=trigger_threshold,
|
||||
):
|
||||
self._cancel_deferred_message_turn_task()
|
||||
self._message_turn_scheduled = True
|
||||
self._internal_turn_queue.put_nowait("message")
|
||||
return
|
||||
|
||||
average_reply_latency = self._get_recent_average_reply_latency()
|
||||
if average_reply_latency is None or average_reply_latency <= 0:
|
||||
return
|
||||
|
||||
idle_seconds = max(0.0, time.time() - self._last_message_received_at)
|
||||
delay_seconds = max(0.0, (trigger_threshold - pending_count) * average_reply_latency - idle_seconds)
|
||||
self._cancel_deferred_message_turn_task()
|
||||
self._deferred_message_turn_task = asyncio.create_task(
|
||||
self._schedule_deferred_message_turn(delay_seconds)
|
||||
)
|
||||
|
||||
def _collect_pending_messages(self) -> list[SessionMessage]:
|
||||
"""从消息缓存中收集一批尚未处理的消息。"""
|
||||
@@ -380,6 +500,13 @@ class MaisakaHeartFlowChatting:
|
||||
# f"{self.log_prefix} 已从消息缓存区[{start_index}:{self._last_processed_index}] "
|
||||
# f"收集 {len(unique_messages)} 条新消息"
|
||||
# )
|
||||
if unique_messages and self._reply_latency_measurement_started_at is None:
|
||||
self._reply_latency_measurement_started_at = min(
|
||||
self._message_received_at_by_id.get(message.message_id, self._last_message_received_at)
|
||||
for message in unique_messages
|
||||
)
|
||||
for message in unique_messages:
|
||||
self._message_received_at_by_id.pop(message.message_id, None)
|
||||
return unique_messages
|
||||
|
||||
async def _wait_for_message_quiet_period(self) -> None:
|
||||
|
||||
@@ -30,7 +30,6 @@ from src.config.official_configs import (
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
@@ -97,7 +96,6 @@ async def get_config_section_schema(section_name: str):
|
||||
支持的section_name:
|
||||
- bot: BotConfig
|
||||
- personality: PersonalityConfig
|
||||
- relationship: RelationshipConfig
|
||||
- chat: ChatConfig
|
||||
- message_receive: MessageReceiveConfig
|
||||
- emoji: EmojiConfig
|
||||
@@ -119,7 +117,6 @@ async def get_config_section_schema(section_name: str):
|
||||
section_map = {
|
||||
"bot": BotConfig,
|
||||
"personality": PersonalityConfig,
|
||||
"relationship": RelationshipConfig,
|
||||
"chat": ChatConfig,
|
||||
"message_receive": MessageReceiveConfig,
|
||||
"emoji": EmojiConfig,
|
||||
|
||||
Reference in New Issue
Block a user