@@ -7,4 +7,4 @@ mongodb
|
||||
napcat
|
||||
docs/
|
||||
.github/
|
||||
# test
|
||||
# test
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,6 +2,7 @@ data/
|
||||
data1/
|
||||
mongodb/
|
||||
NapCat.Framework.Windows.Once/
|
||||
NapCat.Framework.Windows.OneKey/
|
||||
log/
|
||||
logs/
|
||||
out/
|
||||
@@ -49,6 +50,7 @@ template/compare/model_config_template.toml
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
MaiBot-Dashboard/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
38
Dockerfile
38
Dockerfile
@@ -1,4 +1,22 @@
|
||||
FROM python:3.13.5-slim-bookworm
|
||||
# 编译 LPMM
|
||||
FROM python:3.13-slim AS lpmm-builder
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
WORKDIR /MaiMBot-LPMM
|
||||
|
||||
# 同级目录下需要有 MaiMBot-LPMM
|
||||
COPY MaiMBot-LPMM /MaiMBot-LPMM
|
||||
|
||||
# 安装编译器和编译依赖
|
||||
RUN apt-get update && apt-get install -y build-essential
|
||||
RUN uv pip install --system --upgrade pip
|
||||
RUN cd /MaiMBot-LPMM && uv pip install --system -r requirements.txt
|
||||
|
||||
# 编译 LPMM
|
||||
RUN cd /MaiMBot-LPMM/lib/quick_algo && python build_lib.py --cleanup --cythonize --install
|
||||
|
||||
# 运行环境
|
||||
FROM python:3.13-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# 工作目录
|
||||
@@ -6,22 +24,12 @@ WORKDIR /MaiMBot
|
||||
|
||||
# 复制依赖列表
|
||||
COPY requirements.txt .
|
||||
# 同级目录下需要有 maim_message MaiMBot-LPMM
|
||||
#COPY maim_message /maim_message
|
||||
COPY MaiMBot-LPMM /MaiMBot-LPMM
|
||||
|
||||
# 编译器
|
||||
RUN apt-get update && apt-get install -y build-essential
|
||||
# 从编译阶段复制 LPMM 编译结果
|
||||
COPY --from=lpmm-builder /usr/local/lib/python3.13/site-packages/ /usr/local/lib/python3.13/site-packages/
|
||||
|
||||
# lpmm编译安装
|
||||
RUN cd /MaiMBot-LPMM && uv pip install --system -r requirements.txt
|
||||
RUN uv pip install --system Cython py-cpuinfo setuptools
|
||||
RUN cd /MaiMBot-LPMM/lib/quick_algo && python build_lib.py --cleanup --cythonize --install
|
||||
|
||||
|
||||
# 安装依赖
|
||||
# 安装运行时依赖
|
||||
RUN uv pip install --system --upgrade pip
|
||||
#RUN uv pip install --system -e /maim_message
|
||||
RUN uv pip install --system -r requirements.txt
|
||||
|
||||
# 复制项目代码
|
||||
@@ -29,4 +37,4 @@ COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT [ "python","bot.py" ]
|
||||
ENTRYPOINT [ "python","bot.py" ]
|
||||
|
||||
2
bot.py
2
bot.py
@@ -30,7 +30,7 @@ else:
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
||||
|
||||
initialize_logging()
|
||||
|
||||
|
||||
@@ -1,5 +1,45 @@
|
||||
# Changelog
|
||||
|
||||
## [0.11.2] - 2025-11-15
|
||||
### 🌟 主要功能更改
|
||||
- "海马体Agent"记忆系统上线,最新最好的记忆系统,默认已接入lpmm
|
||||
- 添加黑话jargon学习系统
|
||||
- 添加群特殊Prompt系统
|
||||
- 优化直接提及时的回复速度
|
||||
|
||||
### 细节功能更改
|
||||
- 添加 WebUI 模块及相关 API 路由和 Token 管理功能
|
||||
- 可通过海马体Agent记录和查询群昵称
|
||||
- 添加聊天记录总结模块
|
||||
- 添加大量新统计指标
|
||||
|
||||
### 功能更改和修复
|
||||
- 移除表达方式学习上限限制
|
||||
- 移除部分未使用代码
|
||||
- 移除问题追踪和旧版记忆
|
||||
- 移除Exp+model表达方式,移除无用代码
|
||||
- 移除问题跟踪和记忆整理
|
||||
- 移除主动发言功能
|
||||
- 优化自我识别和情绪
|
||||
- 优化记忆提取能力
|
||||
- 优化planner,提及时消耗更少,连续no_reply时降低敏感度
|
||||
- 压缩1/3的planner消耗
|
||||
- 优化记忆检索占用
|
||||
- 优化记忆提取和聊天压缩
|
||||
- 优化错别字生成和分段
|
||||
- 优化log和添加changelog
|
||||
- 美化统计界面
|
||||
- 修正记忆提取LLM统计
|
||||
- 修复docker问题
|
||||
- 修复一些潜在问题
|
||||
- 修复bool和boolean问题
|
||||
- 修复超时给到所有信息的Bug
|
||||
- 修复回复超长现可返回原文
|
||||
- 修复私聊记忆
|
||||
- 修复prompt问题
|
||||
- 修复(bot): 恢复戳一戳正常响应
|
||||
- 提供更多细节debug配置
|
||||
|
||||
## [0.11.1] - 2025-11-4
|
||||
### 功能更改和修复
|
||||
- 记忆现在能够被遗忘,并且拥有更好的合并
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from typing import List, Tuple, Type, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseCommand,
|
||||
ComponentInfo,
|
||||
ConfigField
|
||||
)
|
||||
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
|
||||
from src.plugin_system.apis import send_api, frequency_api
|
||||
|
||||
|
||||
class SetTalkFrequencyCommand(BaseCommand):
|
||||
"""设置当前聊天的talk_frequency值"""
|
||||
|
||||
command_name = "set_talk_frequency"
|
||||
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
||||
@@ -19,35 +15,35 @@ class SetTalkFrequencyCommand(BaseCommand):
|
||||
# 获取命令参数 - 使用命名捕获组
|
||||
if not self.matched_groups or "value" not in self.matched_groups:
|
||||
return False, "命令格式错误", False
|
||||
|
||||
|
||||
value_str = self.matched_groups["value"]
|
||||
if not value_str:
|
||||
return False, "无法获取数值参数", False
|
||||
|
||||
|
||||
value = float(value_str)
|
||||
|
||||
|
||||
# 获取聊天流ID
|
||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
||||
return False, "无法获取聊天流信息", False
|
||||
|
||||
|
||||
chat_id = self.message.chat_stream.stream_id
|
||||
|
||||
|
||||
# 设置talk_frequency
|
||||
frequency_api.set_talk_frequency_adjust(chat_id, value)
|
||||
|
||||
|
||||
final_value = frequency_api.get_current_talk_value(chat_id)
|
||||
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
|
||||
base_value = final_value / adjust_value
|
||||
|
||||
|
||||
# 发送反馈消息(不保存到数据库)
|
||||
await send_api.text_to_stream(
|
||||
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
|
||||
chat_id,
|
||||
storage_message=False
|
||||
storage_message=False,
|
||||
)
|
||||
|
||||
|
||||
return True, None, False
|
||||
|
||||
|
||||
except ValueError:
|
||||
error_msg = "数值格式错误,请输入有效的数字"
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
@@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand):
|
||||
|
||||
class ShowFrequencyCommand(BaseCommand):
|
||||
"""显示当前聊天的频率控制状态"""
|
||||
|
||||
command_name = "show_frequency"
|
||||
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
||||
command_pattern = r"^/chat\s+(?:show|s)$"
|
||||
@@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin):
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"frequency": "频率控制配置",
|
||||
"features": "功能开关配置"
|
||||
}
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
@@ -138,13 +131,14 @@ class BetterFrequencyPlugin(BasePlugin):
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
components = []
|
||||
|
||||
|
||||
# 根据配置决定是否注册命令组件
|
||||
if self.config.get("features", {}).get("enable_commands", True):
|
||||
components.extend([
|
||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||
])
|
||||
|
||||
|
||||
components.extend(
|
||||
[
|
||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||
]
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Deep Think插件 (Deep Think Actions)",
|
||||
"version": "1.0.0",
|
||||
"description": "可以深度思考",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.11.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["deep", "think", "action", "built-in"],
|
||||
"categories": ["Deep Think"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "deep_think",
|
||||
"description": "发送深度思考"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
from typing import List, Tuple, Type, Any
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
logger = get_logger("relation_actions")
|
||||
|
||||
|
||||
|
||||
class DeepThinkTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "deep_think"
|
||||
description = "深度思考,对某个知识,概念或逻辑问题进行全面且深入的思考,当面临复杂环境或重要问题时,使用此获得更好的解决方案。"
|
||||
parameters = [
|
||||
("question", ToolParamType.STRING, "需要思考的问题,越具体越好(从上下文中总结)", True, None),
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
question: str = function_args.get("question") # type: ignore
|
||||
|
||||
print(f"question: {question}")
|
||||
|
||||
prompt = f"""
|
||||
请你思考以下问题,以简洁的一段话回答:
|
||||
{question}
|
||||
"""
|
||||
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
||||
|
||||
success, thinking_result, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="deep_think"
|
||||
)
|
||||
|
||||
logger.info(f"{question}: {thinking_result}")
|
||||
|
||||
thinking_result =f"思考结果:{thinking_result}\n**注意** 因为你进行了深度思考,最后的回复内容可以回复的长一些,更加详细一些,不用太简洁。\n"
|
||||
|
||||
return {"content": thinking_result}
|
||||
|
||||
|
||||
@register_plugin
|
||||
class DeepThinkPlugin(BasePlugin):
|
||||
"""关系动作插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- Reply: 回复动作
|
||||
- NoReply: 不回复动作
|
||||
- Emoji: 表情动作
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "deep_think" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((DeepThinkTool.get_tool_info(), DeepThinkTool))
|
||||
|
||||
return components
|
||||
@@ -25,4 +25,6 @@ structlog>=25.4.0
|
||||
toml>=0.10.2
|
||||
tomlkit>=0.13.3
|
||||
urllib3>=2.5.0
|
||||
uvicorn>=0.35.0
|
||||
uvicorn>=0.35.0
|
||||
msgpack
|
||||
zstandard
|
||||
@@ -6,15 +6,16 @@ import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
@@ -28,16 +29,16 @@ def clean_output_text(text: str) -> str:
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
|
||||
# 移除表情包内容:[表情包:...]
|
||||
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
|
||||
|
||||
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
|
||||
|
||||
# 移除回复内容:[回复...],说:... 的完整模式
|
||||
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
|
||||
|
||||
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
|
||||
|
||||
# 清理多余的空格和换行
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
for _chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
@@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
same_user = msg.user_info.user_id == last.user_info.user_id
|
||||
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
@@ -199,38 +200,36 @@ def build_pairs_for_chat(
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n_merged = len(merged_messages)
|
||||
n_original = len(original_messages)
|
||||
|
||||
|
||||
if n_merged == 0 or n_original == 0:
|
||||
return pairs
|
||||
|
||||
# 为每个合并后的消息找到对应的原始消息位置
|
||||
merged_to_original_map = {}
|
||||
original_idx = 0
|
||||
|
||||
|
||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||
# 找到这个合并消息对应的第一个原始消息
|
||||
while (original_idx < n_original and
|
||||
original_messages[original_idx].time < merged_msg.time):
|
||||
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
|
||||
original_idx += 1
|
||||
|
||||
|
||||
# 如果找到了时间匹配的原始消息,建立映射
|
||||
if (original_idx < n_original and
|
||||
original_messages[original_idx].time == merged_msg.time):
|
||||
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
|
||||
merged_to_original_map[merged_idx] = original_idx
|
||||
|
||||
for merged_idx in range(n_merged):
|
||||
merged_msg = merged_messages[merged_idx]
|
||||
|
||||
|
||||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||||
continue
|
||||
|
||||
|
||||
# 找到对应的原始消息位置
|
||||
if merged_idx not in merged_to_original_map:
|
||||
continue
|
||||
|
||||
|
||||
original_idx = merged_to_original_map[merged_idx]
|
||||
|
||||
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, original_idx - window)
|
||||
@@ -266,7 +265,7 @@ def build_pairs(
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
# 对消息进行合并,用于output
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||
@@ -385,5 +384,3 @@ def run_interactive() -> int:
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -6,16 +5,17 @@ import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
@@ -39,19 +39,14 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||
"""获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表"""
|
||||
expressions = Expression.select()
|
||||
data = []
|
||||
|
||||
|
||||
for expr in expressions:
|
||||
# 如果create_date为空,跳过该记录
|
||||
if expr.create_date is None:
|
||||
continue
|
||||
|
||||
data.append((
|
||||
expr.create_date,
|
||||
expr.count,
|
||||
expr.chat_id,
|
||||
expr.type
|
||||
))
|
||||
|
||||
|
||||
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -60,71 +55,71 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
|
||||
# 分离数据
|
||||
create_dates = [item[0] for item in data]
|
||||
counts = [item[1] for item in data]
|
||||
chat_ids = [item[2] for item in data]
|
||||
expression_types = [item[3] for item in data]
|
||||
|
||||
_chat_ids = [item[2] for item in data]
|
||||
_expression_types = [item[3] for item in data]
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
time_span = max(dates) - min(dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
|
||||
# 创建散点图
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
|
||||
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(scatter)
|
||||
cbar.set_label('数据点顺序', fontsize=10)
|
||||
|
||||
cbar.set_label("数据点顺序", fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 数据统计 ===")
|
||||
print("\n=== 数据统计 ===")
|
||||
print(f"总数据点数量: {len(data)}")
|
||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||
print(f"平均使用次数: {np.mean(counts):.2f}")
|
||||
print(f"中位数使用次数: {np.median(counts):.2f}")
|
||||
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n散点图已保存到: {save_path}")
|
||||
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
@@ -134,7 +129,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
|
||||
# 按chat_id分组
|
||||
chat_groups = {}
|
||||
for item in data:
|
||||
@@ -142,75 +137,82 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(item)
|
||||
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
|
||||
|
||||
# 为每个聊天分配不同颜色
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
|
||||
|
||||
|
||||
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
|
||||
create_dates = [item[0] for item in chat_data]
|
||||
counts = [item[1] for item in chat_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
|
||||
chat_name = get_chat_name(chat_id)
|
||||
# 截断过长的聊天名称
|
||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
|
||||
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 分组统计 ===")
|
||||
print("\n=== 分组统计 ===")
|
||||
print(f"总聊天数量: {len(chat_groups)}")
|
||||
for chat_id, chat_data in chat_groups.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
counts = [item[1] for item in chat_data]
|
||||
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n分组散点图已保存到: {save_path}")
|
||||
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
@@ -220,7 +222,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
|
||||
# 按type分组
|
||||
type_groups = {}
|
||||
for item in data:
|
||||
@@ -228,69 +230,76 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||
if expr_type not in type_groups:
|
||||
type_groups[expr_type] = []
|
||||
type_groups[expr_type].append(item)
|
||||
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
|
||||
# 为每个类型分配不同颜色
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
|
||||
|
||||
|
||||
for i, (expr_type, type_data) in enumerate(type_groups.items()):
|
||||
create_dates = [item[0] for item in type_data]
|
||||
counts = [item[1] for item in type_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
||||
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 类型统计 ===")
|
||||
print("\n=== 类型统计 ===")
|
||||
for expr_type, type_data in type_groups.items():
|
||||
counts = [item[1] for item in type_data]
|
||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n类型散点图已保存到: {save_path}")
|
||||
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
@@ -298,35 +307,35 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始分析表达式数据...")
|
||||
|
||||
|
||||
# 获取数据
|
||||
data = get_expression_data()
|
||||
|
||||
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据(create_date不为空的数据)")
|
||||
return
|
||||
|
||||
|
||||
print(f"找到 {len(data)} 条有效数据")
|
||||
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 生成时间戳用于文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
# 1. 创建基础散点图
|
||||
print("\n1. 创建基础散点图...")
|
||||
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
|
||||
|
||||
|
||||
# 2. 创建按聊天分组的散点图
|
||||
print("\n2. 创建按聊天分组的散点图...")
|
||||
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
|
||||
|
||||
|
||||
# 3. 创建按类型分组的散点图
|
||||
print("\n3. 创建按类型分组的散点图...")
|
||||
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
|
||||
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
|
||||
1125
scripts/mmipkg_tool.py
Normal file
1125
scripts/mmipkg_tool.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,6 @@ from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -238,7 +237,6 @@ class BrainChatting:
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
@@ -442,7 +442,7 @@ class BrainPlanner:
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
|
||||
@@ -940,14 +940,12 @@ class EmojiManager:
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗,meme的角度去分析,精简回答"
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.5
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答"
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.5
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import frequency_api
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""{name_block}
|
||||
@@ -28,7 +29,7 @@ def init_prompt():
|
||||
""",
|
||||
"frequency_adjust_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
@@ -40,7 +41,7 @@ class FrequencyControl:
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
|
||||
self.last_frequency_adjust_time: float = 0.0
|
||||
self.frequency_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
|
||||
@@ -53,27 +54,25 @@ class FrequencyControl:
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
|
||||
|
||||
async def trigger_frequency_adjust(self) -> None:
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if time.time() - self.last_frequency_adjust_time < 120 or len(msg_list) <= 5:
|
||||
|
||||
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
|
||||
return
|
||||
else:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
@@ -97,28 +96,29 @@ class FrequencyControl:
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
|
||||
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 0.8))
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 1.2))
|
||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||
self.last_frequency_adjust_time = time.time()
|
||||
else:
|
||||
logger.info(f"频率调整:response不符合要求,取消本次调整")
|
||||
logger.info("频率调整:response不符合要求,取消本次调整")
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
@@ -143,6 +143,7 @@ class FrequencyControlManager:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
# 创建全局实例
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from multiprocessing import context
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
@@ -17,21 +16,18 @@ from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.memory_system.question_maker import QuestionMaker
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.memory_system.curious import check_and_make_question
|
||||
from src.jargon import extract_and_store_jargon
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -105,10 +101,16 @@ class HeartFChatting:
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
|
||||
self.question_probability_multiplier = 1
|
||||
self.questioned = False
|
||||
|
||||
|
||||
# 跟踪连续 no_reply 次数,用于动态调整阈值
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
# 聊天内容概括器
|
||||
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
@@ -124,6 +126,10 @@ class HeartFChatting:
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
|
||||
# 启动聊天内容概括器的后台定期检查循环
|
||||
await self.chat_history_summarizer.start()
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +179,7 @@ class HeartFChatting:
|
||||
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _loopbody(self):
|
||||
async def _loopbody(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
@@ -184,43 +190,20 @@ class HeartFChatting:
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
question_probability = 0
|
||||
if time.time() - self.last_active_time > 7200:
|
||||
question_probability = 0.0003
|
||||
elif time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.0001
|
||||
# 根据连续 no_reply 次数动态调整阈值
|
||||
# 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
|
||||
# 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
|
||||
if self.consecutive_no_reply_count >= 5:
|
||||
threshold = 2
|
||||
elif self.consecutive_no_reply_count >= 3:
|
||||
# 1.5 的含义:50%概率为1,50%概率为2
|
||||
threshold = 2 if random.random() < 0.5 else 1
|
||||
else:
|
||||
question_probability = 0.00003
|
||||
threshold = 1
|
||||
|
||||
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
|
||||
|
||||
# print(f"{self.log_prefix} questioned: {self.questioned},len: {len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id))}")
|
||||
if question_probability > 0 and not self.questioned and len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id)) == 0: #长久没有回复,可以试试主动发言,提问概率随着时间增加
|
||||
# logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,概率: {question_probability}")
|
||||
if random.random() < question_probability: # 30%概率主动发言
|
||||
try:
|
||||
self.questioned = True
|
||||
self.last_active_time = time.time()
|
||||
# print(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
question_maker = QuestionMaker(self.stream_id)
|
||||
question, context,conflict_context = await question_maker.make_question()
|
||||
if question:
|
||||
logger.info(f"{self.log_prefix} 问题: {question}")
|
||||
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
|
||||
await self._lift_question_reply(question,context,thinking_id)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 无问题")
|
||||
# self.end_cycle(cycle_timers, thinking_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 主动提问失败: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
if len(recent_messages_list) >= 1:
|
||||
if len(recent_messages_list) >= threshold:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
# print(message.processed_plain_text)
|
||||
# !处理no_reply_until_call逻辑
|
||||
if self.no_reply_until_call:
|
||||
for message in recent_messages_list:
|
||||
@@ -317,6 +300,91 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _run_planner_without_reply(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""执行planner,但不包含reply动作(用于并行执行场景,提及时使用简化版提示词)"""
|
||||
try:
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
is_mentioned=True, # 标记为提及时,使用简化版提示词
|
||||
)
|
||||
# 过滤掉reply动作(虽然提及时不应该有reply,但为了安全还是过滤一下)
|
||||
return [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} Planner执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
async def _generate_mentioned_reply(
|
||||
self,
|
||||
force_reply_message: "DatabaseMessages",
|
||||
thinking_id: str,
|
||||
cycle_timers: Dict[str, float],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
) -> Dict[str, Any]:
|
||||
"""当被提及时,独立生成回复的任务"""
|
||||
try:
|
||||
self.questioned = False
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
with Timer("提及回复生成", cycle_timers):
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=[], # 独立回复,不依赖planner的动作
|
||||
reply_reason=reason,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point=self.last_read_time,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.warning(f"{self.log_prefix} 提及回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "提及回复生成失败", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=force_reply_message,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=[], # 独立回复,不依赖planner的动作
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 提及回复生成异常: {e}")
|
||||
traceback.print_exc()
|
||||
return {"action_type": "reply", "success": False, "result": f"提及回复生成异常: {e}", "loop_info": None}
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
@@ -324,20 +392,24 @@ class HeartFChatting:
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust())
|
||||
|
||||
asyncio.create_task(
|
||||
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
|
||||
)
|
||||
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
asyncio.create_task(check_and_make_question(self.stream_id, recent_messages_list))
|
||||
|
||||
|
||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
|
||||
asyncio.create_task(extract_and_store_jargon(self.stream_id))
|
||||
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
|
||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
@@ -349,66 +421,94 @@ class HeartFChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
# 如果被提及,让回复生成和planner并行执行
|
||||
if force_reply_message:
|
||||
logger.info(f"{self.log_prefix} 检测到提及,回复生成与planner并行执行")
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
# 并行执行planner和回复生成
|
||||
planner_task = asyncio.create_task(
|
||||
self._run_planner_without_reply(
|
||||
available_actions=available_actions,
|
||||
cycle_timers=cycle_timers,
|
||||
)
|
||||
)
|
||||
|
||||
has_reply = False
|
||||
for action in action_to_use_info:
|
||||
if action.action_type == "reply":
|
||||
has_reply = True
|
||||
break
|
||||
|
||||
if not has_reply and force_reply_message:
|
||||
action_to_use_info.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="有人提到了你,进行回复",
|
||||
action_data={},
|
||||
action_message=force_reply_message,
|
||||
reply_task = asyncio.create_task(
|
||||
self._generate_mentioned_reply(
|
||||
force_reply_message=force_reply_message,
|
||||
thinking_id=thinking_id,
|
||||
cycle_timers=cycle_timers,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
)
|
||||
|
||||
# 等待两个任务完成
|
||||
planner_result, reply_result = await asyncio.gather(planner_task, reply_task, return_exceptions=True)
|
||||
|
||||
# 处理planner结果
|
||||
if isinstance(planner_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} Planner执行异常: {planner_result}")
|
||||
action_to_use_info = []
|
||||
else:
|
||||
action_to_use_info = planner_result
|
||||
|
||||
# 处理回复结果
|
||||
if isinstance(reply_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
|
||||
reply_result = {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"result": "回复生成异常",
|
||||
"loop_info": None,
|
||||
}
|
||||
else:
|
||||
# 正常流程:只执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
reply_result = None
|
||||
|
||||
# 只在提及情况下过滤掉planner返回的reply动作(提及时已有独立回复生成)
|
||||
if force_reply_message:
|
||||
action_to_use_info = [action for action in action_to_use_info if action.action_type != "reply"]
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
# 3. 并行执行所有动作(不包括reply,reply已经独立执行)
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
@@ -419,6 +519,10 @@ class HeartFChatting:
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 如果有独立的回复结果,添加到结果列表中
|
||||
if reply_result:
|
||||
results = list(results) + [reply_result]
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
@@ -456,7 +560,7 @@ class HeartFChatting:
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
reply_text = reply_text_from_reply
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
@@ -469,7 +573,7 @@ class HeartFChatting:
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
_reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
@@ -545,7 +649,6 @@ class HeartFChatting:
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
|
||||
|
||||
return success, action_text
|
||||
|
||||
except Exception as e:
|
||||
@@ -553,78 +656,6 @@ class HeartFChatting:
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
|
||||
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
new_msg = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=1,
|
||||
)
|
||||
|
||||
reply_action_info = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning= "",
|
||||
action_data={},
|
||||
action_message=new_msg[0],
|
||||
available_actions=None,
|
||||
loop_start_time=time.time(),
|
||||
action_reasoning=reason)
|
||||
self.action_planner.add_plan_log(reasoning=f"你对问题\"{question}\"感到好奇,想要和群友讨论", actions=[reply_action_info])
|
||||
|
||||
success, llm_response = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_data={
|
||||
"raw_reply": f"我对这个问题感到好奇:{question}",
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.info("主动提问发言失败")
|
||||
self.action_planner.add_plan_excute_log(result="主动回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "主动回复生成失败", "loop_info": None}
|
||||
|
||||
if success:
|
||||
for reply_seg in llm_response.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await send_api.text_to_stream(
|
||||
text=send_data,
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": llm_response.reply_set.reply_data[0].content},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": [reply_action_info],
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": llm_response.reply_set.reply_data[0].content,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
self.last_active_time = time.time()
|
||||
self.action_planner.add_plan_excute_log(result=f"你提问:{question}")
|
||||
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你提问:{question}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: "ReplySetModel",
|
||||
@@ -686,6 +717,9 @@ class HeartFChatting:
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -697,7 +731,6 @@ class HeartFChatting:
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "no_reply_until_call":
|
||||
@@ -705,6 +738,8 @@ class HeartFChatting:
|
||||
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
|
||||
# 增加连续 no_reply 计数
|
||||
self.consecutive_no_reply_count += 1
|
||||
self.no_reply_until_call = True
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -716,14 +751,23 @@ class HeartFChatting:
|
||||
action_name="no_reply_until_call",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
|
||||
return {
|
||||
"action_type": "no_reply_until_call",
|
||||
"success": True,
|
||||
"result": "保持沉默,直到有人直接叫的名字",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
# 刷新主动发言状态
|
||||
# 重置连续 no_reply 计数
|
||||
self.consecutive_no_reply_count = 0
|
||||
|
||||
reason = action_planner_info.reasoning or "选择回复"
|
||||
reason = action_planner_info.reasoning or ""
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -740,23 +784,20 @@ class HeartFChatting:
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=reason,
|
||||
reply_reason=planner_reasoning,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
@@ -778,12 +819,12 @@ class HeartFChatting:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action = action_planner_info.action_type,
|
||||
action_reasoning = action_planner_info.action_reasoning or "",
|
||||
action_data = action_planner_info.action_data or {},
|
||||
cycle_timers = cycle_timers,
|
||||
thinking_id = thinking_id,
|
||||
action_message= action_planner_info.action_message,
|
||||
action=action_planner_info.action_type,
|
||||
action_reasoning=action_planner_info.action_reasoning or "",
|
||||
action_data=action_planner_info.action_data or {},
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
action_message=action_planner_info.action_message,
|
||||
)
|
||||
|
||||
self.last_active_time = time.time()
|
||||
|
||||
@@ -13,10 +13,11 @@ from src.person_info.person_info import Person
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
pass
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
@@ -83,10 +84,19 @@ class HeartFCMessageReceiver:
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
|
||||
# 如果是群聊,获取群号和群昵称
|
||||
group_id = None
|
||||
group_nick_name = None
|
||||
if chat.group_info:
|
||||
group_id = chat.group_info.group_id # type: ignore
|
||||
group_nick_name = userinfo.user_cardname # type: ignore
|
||||
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_id=message.message_info.user_info.user_id, # type: ignore
|
||||
nickname=userinfo.user_nickname, # type: ignore
|
||||
group_id=group_id,
|
||||
group_nick_name=group_nick_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,6 +30,8 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
|
||||
@@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -171,7 +170,11 @@ class ChatBot:
|
||||
|
||||
# 撤回事件打印;无法获取被撤回者则省略
|
||||
if sub_type == "recall":
|
||||
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None))
|
||||
op_name = (
|
||||
getattr(op, "user_cardname", None)
|
||||
or getattr(op, "user_nickname", None)
|
||||
or str(getattr(op, "user_id", None))
|
||||
)
|
||||
recalled_name = None
|
||||
try:
|
||||
if isinstance(recalled, dict):
|
||||
@@ -189,7 +192,7 @@ class ChatBot:
|
||||
logger.info(f"{op_name} 撤回了消息")
|
||||
else:
|
||||
logger.debug(
|
||||
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) "
|
||||
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) "
|
||||
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
|
||||
)
|
||||
except Exception:
|
||||
@@ -234,7 +237,6 @@ class ChatBot:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
message_data["message_info"]["group_info"]["group_id"]
|
||||
@@ -258,7 +260,7 @@ class ChatBot:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
if await self.handle_notice_message(message):
|
||||
return
|
||||
pass
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
@@ -49,28 +49,52 @@ reply
|
||||
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
|
||||
3.不要回复你自己发送的消息
|
||||
4.不要单独对表情包进行回复
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"想要回复的消息id",
|
||||
"reason":"回复的原因"
|
||||
}}
|
||||
{{"action":"reply", "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
|
||||
no_reply
|
||||
动作描述:
|
||||
保持沉默,不回复直到有新消息
|
||||
控制聊天频率,不要太过频繁的发言
|
||||
{{
|
||||
"action": "no_reply",
|
||||
}}
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
no_reply_until_call
|
||||
{no_reply_until_call_block}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
不要回复你自己发送的消息
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
{plan_style}
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
**示例**
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
```""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{time_block}
|
||||
{name_block}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
**可选的action**
|
||||
no_reply
|
||||
动作描述:
|
||||
保持沉默,直到有人直接叫你的名字
|
||||
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
|
||||
当你频繁选择no_reply时使用,表示话题暂时与你无关
|
||||
{{
|
||||
"action": "no_reply_until_call",
|
||||
}}
|
||||
没有合适的可以使用的动作,不使用action
|
||||
{{"action":"no_reply"}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
@@ -78,31 +102,21 @@ no_reply_until_call
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
不要回复你自己发送的消息
|
||||
先输出你的选择思考理由,再输出你选择的action,理由是一段平文本,不要分点,精简。
|
||||
先输出你的简短的选择思考理由,再输出你选择的action,理由不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
{plan_style}
|
||||
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,如果输出多个json,每个json都要单独用```json包裹,你可以重复使用同一个动作或不同动作:
|
||||
请选择所有符合使用要求的action,动作用json格式输出,用```json包裹,如果输出多个json,每个json都要单独一行放在同一个```json代码块内,你可以重复使用同一个动作或不同动作:
|
||||
**示例**
|
||||
// 理由文本
|
||||
// 理由文本(简短)
|
||||
```json
|
||||
{{
|
||||
"action":"动作名",
|
||||
"target_message_id":"触发动作的消息id",
|
||||
//对应参数
|
||||
}}
|
||||
```
|
||||
```json
|
||||
{{
|
||||
"action":"动作名",
|
||||
"target_message_id":"触发动作的消息id",
|
||||
//对应参数
|
||||
}}
|
||||
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
|
||||
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
|
||||
```""",
|
||||
"planner_prompt",
|
||||
"planner_prompt_mentioned",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
@@ -111,11 +125,7 @@ no_reply_until_call
|
||||
动作描述:{action_description}
|
||||
使用条件{parallel_text}:
|
||||
{action_require}
|
||||
{{
|
||||
"action": "{action_name}",{action_parameters},
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"触发action的原因"
|
||||
}}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)", "reason":"原因"}}
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
@@ -133,7 +143,6 @@ class ActionPlanner:
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
@@ -231,6 +240,7 @@ class ActionPlanner:
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
is_mentioned: bool = False,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
@@ -270,6 +280,11 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 如果是提及时且没有可用动作,直接返回空列表,不调用LLM以节省token
|
||||
if is_mentioned and not filtered_actions:
|
||||
logger.info(f"{self.log_prefix}提及时没有可用动作,跳过plan调用")
|
||||
return []
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
@@ -278,6 +293,7 @@ class ActionPlanner:
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
is_mentioned=is_mentioned,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
@@ -289,7 +305,9 @@ class ActionPlanner:
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
@@ -299,24 +317,79 @@ class ActionPlanner:
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
|
||||
def add_plan_excute_log(self, result: str):
|
||||
self.plan_log.append(("", time.time(), result))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def get_plan_log_str(self) -> str:
|
||||
plan_log_str = ""
|
||||
for reasoning, time, content in self.plan_log:
|
||||
def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str:
|
||||
"""
|
||||
获取计划日志字符串
|
||||
|
||||
Args:
|
||||
max_action_records: 显示多少条最新的action记录,默认2
|
||||
max_execution_records: 显示多少条最新执行结果记录,默认8
|
||||
|
||||
Returns:
|
||||
格式化的日志字符串
|
||||
"""
|
||||
action_records = []
|
||||
execution_records = []
|
||||
|
||||
# 从后往前遍历,收集最新的记录
|
||||
for reasoning, timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
|
||||
plan_log_str += f"{time}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
|
||||
# 这是action记录
|
||||
if len(action_records) < max_action_records:
|
||||
action_records.append((reasoning, timestamp, content, "action"))
|
||||
else:
|
||||
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
|
||||
plan_log_str += f"{time}:{content}\n"
|
||||
|
||||
# 这是执行结果记录
|
||||
if len(execution_records) < max_execution_records:
|
||||
execution_records.append((reasoning, timestamp, content, "execution"))
|
||||
|
||||
# 合并所有记录并按时间戳排序
|
||||
all_records = action_records + execution_records
|
||||
all_records.sort(key=lambda x: x[1]) # 按时间戳排序
|
||||
|
||||
plan_log_str = ""
|
||||
|
||||
# 按时间顺序添加所有记录
|
||||
for reasoning, timestamp, content, record_type in all_records:
|
||||
time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S")
|
||||
if record_type == "action":
|
||||
# plan_log_str += f"{time_str}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
|
||||
plan_log_str += f"{time_str}:{reasoning}\n"
|
||||
else:
|
||||
plan_log_str += f"{time_str}:你执行了action:{content}\n"
|
||||
|
||||
return plan_log_str
|
||||
|
||||
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
|
||||
"""
|
||||
检查是否有连续min_count次以上的no_reply
|
||||
|
||||
Args:
|
||||
min_count: 需要连续的最少次数,默认3
|
||||
|
||||
Returns:
|
||||
如果有连续min_count次以上no_reply返回True,否则返回False
|
||||
"""
|
||||
consecutive_count = 0
|
||||
|
||||
# 从后往前遍历plan_log,检查最新的连续记录
|
||||
for _reasoning, _timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
# 检查所有action是否都是no_reply
|
||||
if all(action.action_type == "no_reply" for action in content):
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= min_count:
|
||||
return True
|
||||
else:
|
||||
# 如果遇到非no_reply的action,重置计数
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
@@ -326,11 +399,11 @@ class ActionPlanner:
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
is_mentioned: bool = False,
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
|
||||
actions_before_now_block=self.get_plan_log_str()
|
||||
actions_before_now_block = self.get_plan_log_str()
|
||||
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
@@ -347,19 +420,47 @@ class ActionPlanner:
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
)
|
||||
# 根据是否是提及时选择不同的模板
|
||||
if is_mentioned:
|
||||
# 提及时使用简化版提示词,不需要reply、no_reply、no_reply_until_call
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt_mentioned")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
)
|
||||
else:
|
||||
# 正常流程使用完整版提示词
|
||||
# 检查是否有连续3次以上no_reply,如果有则添加no_reply_until_call选项
|
||||
no_reply_until_call_block = ""
|
||||
if self._has_consecutive_no_reply(min_count=3):
|
||||
no_reply_until_call_block = """no_reply_until_call
|
||||
动作描述:
|
||||
保持沉默,直到有人直接叫你的名字
|
||||
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
|
||||
当你频繁选择no_reply时使用,表示话题暂时与你无关
|
||||
{{"action":"no_reply_until_call"}}
|
||||
"""
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
no_reply_until_call_block=no_reply_until_call_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
@@ -436,7 +537,7 @@ class ActionPlanner:
|
||||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
|
||||
if not action_info.parallel_action:
|
||||
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
|
||||
else:
|
||||
@@ -463,7 +564,7 @@ class ActionPlanner:
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str,List[ActionPlannerInfo]]:
|
||||
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
@@ -475,7 +576,7 @@ class ActionPlanner:
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
@@ -488,7 +589,7 @@ class ActionPlanner:
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}",[
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
@@ -507,7 +608,11 @@ class ActionPlanner:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning))
|
||||
actions.extend(
|
||||
self._parse_single_action(
|
||||
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
@@ -530,7 +635,7 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning,actions
|
||||
return extracted_reasoning, actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
@@ -552,10 +657,11 @@ class ActionPlanner:
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if matches:
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
@@ -564,19 +670,38 @@ class ActionPlanner:
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
for match in matches:
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果单行解析失败,尝试将整个块作为一个JSON对象或数组
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
@@ -6,8 +6,6 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
@@ -37,10 +35,12 @@ from src.plugin_system.apis import llm_api
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
init_rewrite_prompt()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
|
||||
logger = get_logger("replyer")
|
||||
@@ -56,7 +56,6 @@ class DefaultReplyer:
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
@@ -134,12 +133,12 @@ class DefaultReplyer:
|
||||
try:
|
||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||
# logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
|
||||
logger.info(f"replyer生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning:
|
||||
logger.info(f"replyer生成推理:\n{reasoning_content}")
|
||||
logger.info(f"replyer生成模型: {model_name}")
|
||||
|
||||
|
||||
llm_response.content = content
|
||||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
@@ -227,13 +226,14 @@ class DefaultReplyer:
|
||||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
@@ -244,9 +244,9 @@ class DefaultReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -268,38 +268,13 @@ class DefaultReplyer:
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_memory_block(self) -> str:
|
||||
"""构建记忆块
|
||||
"""
|
||||
# if not global_config.memory.enable_memory:
|
||||
# return ""
|
||||
|
||||
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
|
||||
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
async def build_question_block(self) -> str:
|
||||
"""构建问题块"""
|
||||
# if not global_config.question.enable_question:
|
||||
# return ""
|
||||
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
|
||||
questions_str = ""
|
||||
for question in questions:
|
||||
questions_str += f"- {question.question}\n"
|
||||
if questions_str:
|
||||
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
@@ -327,7 +302,7 @@ class DefaultReplyer:
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
_result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】: {content}\n"
|
||||
|
||||
@@ -367,45 +342,45 @@ class DefaultReplyer:
|
||||
|
||||
def _replace_picids_with_descriptions(self, text: str) -> str:
|
||||
"""将文本中的[picid:xxx]替换为具体的图片描述
|
||||
|
||||
|
||||
Args:
|
||||
text: 包含picid标记的文本
|
||||
|
||||
|
||||
Returns:
|
||||
替换后的文本
|
||||
"""
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
pic_id = match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, text)
|
||||
|
||||
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
|
||||
"""分析target内容类型(基于原始picid格式)
|
||||
|
||||
|
||||
Args:
|
||||
target: 目标消息内容(包含[picid:xxx]格式)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
return False, False, "", ""
|
||||
|
||||
|
||||
# 检查是否只包含picid标记
|
||||
picid_pattern = r"\[picid:[^\]]+\]"
|
||||
picid_matches = re.findall(picid_pattern, target)
|
||||
|
||||
|
||||
# 移除所有picid标记后检查是否还有文字内容
|
||||
text_without_picids = re.sub(picid_pattern, "", target).strip()
|
||||
|
||||
|
||||
has_only_pics = len(picid_matches) > 0 and not text_without_picids
|
||||
has_text = bool(text_without_picids)
|
||||
|
||||
|
||||
# 提取图片部分(转换为[图片:描述]格式)
|
||||
pic_part = ""
|
||||
if picid_matches:
|
||||
@@ -420,7 +395,7 @@ class DefaultReplyer:
|
||||
else:
|
||||
pic_descriptions.append(f"[图片:{description}]")
|
||||
pic_part = "".join(pic_descriptions)
|
||||
|
||||
|
||||
return has_only_pics, has_text, pic_part, text_without_picids
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
@@ -505,7 +480,7 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
return all_dialogue_prompt
|
||||
|
||||
|
||||
def core_background_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
@@ -627,18 +602,97 @@ class DefaultReplyer:
|
||||
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (global_config.personality.states and
|
||||
global_config.personality.state_probability > 0 and
|
||||
random.random() < global_config.personality.state_probability):
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
# 随机选择一个状态替换personality
|
||||
selected_state = random.choice(global_config.personality.states)
|
||||
prompt_personality = selected_state
|
||||
|
||||
|
||||
prompt_personality = f"{prompt_personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
解析聊天prompt配置字符串并生成对应的 chat_id 和 prompt内容
|
||||
|
||||
Args:
|
||||
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
|
||||
|
||||
Returns:
|
||||
tuple: (chat_id, prompt_content),如果解析失败则返回 None
|
||||
"""
|
||||
try:
|
||||
# 使用 split 分割,但限制分割次数为3,因为prompt内容可能包含冒号
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
prompt_content = parts[3]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
|
||||
"""
|
||||
根据聊天流ID获取匹配的额外prompt(仅匹配group类型)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
str: 匹配的额外prompt内容,如果没有匹配则返回空字符串
|
||||
"""
|
||||
if not global_config.experimental.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_str in global_config.experimental.chat_prompts:
|
||||
if not isinstance(chat_prompt_str, str):
|
||||
continue
|
||||
|
||||
# 解析配置字符串,检查类型是否为group
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
stream_type = parts[2]
|
||||
# 只匹配group类型
|
||||
if stream_type != "group":
|
||||
continue
|
||||
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
config_chat_id, prompt_content = result
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到群聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
@@ -667,7 +721,7 @@ class DefaultReplyer:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
_is_group_chat = bool(chat_stream.group_info)
|
||||
platform = chat_stream.platform
|
||||
|
||||
user_id = "用户ID"
|
||||
@@ -683,10 +737,10 @@ class DefaultReplyer:
|
||||
target = reply_message.processed_plain_text
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
@@ -730,12 +784,11 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行五个构建任务
|
||||
# 并行执行七个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(self.build_memory_block(), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
@@ -743,21 +796,24 @@ class DefaultReplyer:
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(self.build_question_block(), "question_block"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
# "memory_block": "回忆",
|
||||
"memory_block": "记忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"question_block": "问题",
|
||||
"memory_retrieval": "记忆检索",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
@@ -781,16 +837,20 @@ class DefaultReplyer:
|
||||
expression_habits_block: str
|
||||
selected_expressions: List[int]
|
||||
# relation_info: str = results_dict["relation_info"]
|
||||
# memory_block: str = results_dict["memory_block"]
|
||||
memory_block: str = results_dict["memory_block"]
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
question_block: str = results_dict["question_block"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
else:
|
||||
@@ -820,14 +880,18 @@ class DefaultReplyer:
|
||||
# 构建分离的对话 prompt
|
||||
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
|
||||
|
||||
# 获取匹配的额外prompt
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
|
||||
|
||||
# 固定使用群聊回复模板
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
memory_block=memory_block,
|
||||
bot_name=global_config.bot.nickname,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
# memory_block=memory_block,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -839,7 +903,9 @@ class DefaultReplyer:
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
question_block=question_block,
|
||||
memory_retrieval=memory_retrieval,
|
||||
chat_prompt=chat_prompt_block,
|
||||
planner_reasoning=planner_reasoning,
|
||||
), selected_expressions
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
@@ -854,10 +920,10 @@ class DefaultReplyer:
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
@@ -899,9 +965,7 @@ class DefaultReplyer:
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
@@ -918,7 +982,9 @@ class DefaultReplyer:
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
@@ -1028,6 +1094,10 @@ class DefaultReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
@@ -1045,6 +1115,10 @@ class DefaultReplyer:
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||
)
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||
end_time = time.time()
|
||||
@@ -1052,7 +1126,7 @@ class DefaultReplyer:
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
return ""
|
||||
found_knowledge_from_lpmm = result.get("content", "")
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
related_info += found_knowledge_from_lpmm
|
||||
|
||||
@@ -6,7 +6,6 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
@@ -37,14 +36,17 @@ from src.plugin_system.apis import llm_api
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
init_rewrite_prompt()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
class PrivateReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -239,13 +241,14 @@ class PrivateReplyer:
|
||||
|
||||
return f"{sender_relation}"
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
@@ -256,9 +259,9 @@ class PrivateReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -276,9 +279,7 @@ class PrivateReplyer:
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = (
|
||||
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
)
|
||||
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
@@ -290,15 +291,6 @@ class PrivateReplyer:
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
|
||||
async def build_memory_block(self) -> str:
|
||||
"""构建记忆块
|
||||
"""
|
||||
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
|
||||
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -365,45 +357,45 @@ class PrivateReplyer:
|
||||
|
||||
def _replace_picids_with_descriptions(self, text: str) -> str:
|
||||
"""将文本中的[picid:xxx]替换为具体的图片描述
|
||||
|
||||
|
||||
Args:
|
||||
text: 包含picid标记的文本
|
||||
|
||||
|
||||
Returns:
|
||||
替换后的文本
|
||||
"""
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
pic_id = match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, text)
|
||||
|
||||
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
|
||||
"""分析target内容类型(基于原始picid格式)
|
||||
|
||||
|
||||
Args:
|
||||
target: 目标消息内容(包含[picid:xxx]格式)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
return False, False, "", ""
|
||||
|
||||
|
||||
# 检查是否只包含picid标记
|
||||
picid_pattern = r"\[picid:[^\]]+\]"
|
||||
picid_matches = re.findall(picid_pattern, target)
|
||||
|
||||
|
||||
# 移除所有picid标记后检查是否还有文字内容
|
||||
text_without_picids = re.sub(picid_pattern, "", target).strip()
|
||||
|
||||
|
||||
has_only_pics = len(picid_matches) > 0 and not text_without_picids
|
||||
has_text = bool(text_without_picids)
|
||||
|
||||
|
||||
# 提取图片部分(转换为[图片:描述]格式)
|
||||
pic_part = ""
|
||||
if picid_matches:
|
||||
@@ -418,7 +410,7 @@ class PrivateReplyer:
|
||||
else:
|
||||
pic_descriptions.append(f"[图片:{description}]")
|
||||
pic_part = "".join(pic_descriptions)
|
||||
|
||||
|
||||
return has_only_pics, has_text, pic_part, text_without_picids
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
@@ -524,18 +516,97 @@ class PrivateReplyer:
|
||||
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (global_config.personality.states and
|
||||
global_config.personality.state_probability > 0 and
|
||||
random.random() < global_config.personality.state_probability):
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
# 随机选择一个状态替换personality
|
||||
selected_state = random.choice(global_config.personality.states)
|
||||
prompt_personality = selected_state
|
||||
|
||||
|
||||
prompt_personality = f"{prompt_personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
解析聊天prompt配置字符串并生成对应的 chat_id 和 prompt内容
|
||||
|
||||
Args:
|
||||
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
|
||||
|
||||
Returns:
|
||||
tuple: (chat_id, prompt_content),如果解析失败则返回 None
|
||||
"""
|
||||
try:
|
||||
# 使用 split 分割,但限制分割次数为3,因为prompt内容可能包含冒号
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
prompt_content = parts[3]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
chat_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
|
||||
"""
|
||||
根据聊天流ID获取匹配的额外prompt(仅匹配private类型)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
str: 匹配的额外prompt内容,如果没有匹配则返回空字符串
|
||||
"""
|
||||
if not global_config.experimental.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_str in global_config.experimental.chat_prompts:
|
||||
if not isinstance(chat_prompt_str, str):
|
||||
continue
|
||||
|
||||
# 解析配置字符串,检查类型是否为private
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
stream_type = parts[2]
|
||||
# 只匹配private类型
|
||||
if stream_type != "private":
|
||||
continue
|
||||
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
config_chat_id, prompt_content = result
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到私聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
@@ -577,13 +648,11 @@ class PrivateReplyer:
|
||||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
|
||||
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
@@ -592,7 +661,7 @@ class PrivateReplyer:
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
|
||||
|
||||
dialogue_prompt = build_readable_messages(
|
||||
message_list_before_now_long,
|
||||
replace_bot_name=True,
|
||||
@@ -635,16 +704,12 @@ class PrivateReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行五个构建任务
|
||||
# 并行执行八个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
|
||||
),
|
||||
self._time_and_run_task(self.build_memory_block(), "memory_block"),
|
||||
# self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
@@ -652,18 +717,24 @@ class PrivateReplyer:
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
"memory_block": "回忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"memory_retrieval": "记忆检索",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
@@ -687,14 +758,20 @@ class PrivateReplyer:
|
||||
expression_habits_block: str
|
||||
selected_expressions: List[int]
|
||||
relation_info: str = results_dict["relation_info"]
|
||||
memory_block: str = results_dict["memory_block"]
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
memory_retrieval: str = results_dict["memory_retrieval"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
|
||||
# 从 chosen_actions 中提取 planner 的整体思考理由
|
||||
planner_reasoning = ""
|
||||
if global_config.chat.include_planner_reasoning and reply_reason:
|
||||
# 如果没有 chosen_actions,使用 reply_reason 作为备选
|
||||
planner_reasoning = f"你的想法是:{reply_reason}"
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
else:
|
||||
@@ -718,6 +795,10 @@ class PrivateReplyer:
|
||||
# 其他情况(空内容等)
|
||||
reply_target_block = f"现在对方说的:{target}。引起了你的注意"
|
||||
|
||||
# 获取匹配的额外prompt
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"private_replyer_self_prompt",
|
||||
@@ -725,7 +806,6 @@ class PrivateReplyer:
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -738,6 +818,8 @@ class PrivateReplyer:
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
memory_retrieval=memory_retrieval,
|
||||
chat_prompt=chat_prompt_block,
|
||||
), selected_expressions
|
||||
else:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
@@ -746,7 +828,6 @@ class PrivateReplyer:
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
@@ -758,6 +839,9 @@ class PrivateReplyer:
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
sender_name=sender,
|
||||
memory_retrieval=memory_retrieval,
|
||||
chat_prompt=chat_prompt_block,
|
||||
planner_reasoning=planner_reasoning,
|
||||
), selected_expressions
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
@@ -772,15 +856,13 @@ class PrivateReplyer:
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
@@ -820,9 +902,7 @@ class PrivateReplyer:
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
@@ -839,7 +919,9 @@ class PrivateReplyer:
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
@@ -930,7 +1012,7 @@ class PrivateReplyer:
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
|
||||
|
||||
content = content.strip()
|
||||
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
@@ -950,6 +1032,10 @@ class PrivateReplyer:
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
@@ -1022,6 +1108,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,58 +1,54 @@
|
||||
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
|
||||
def init_replyer_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里正在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("正在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}{question_block}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。请不要思考太长
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
@@ -61,10 +57,10 @@ def init_replyer_prompt():
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
|
||||
{identity}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
493
src/chat/utils/chat_history_summarizer.py
Normal file
493
src/chat/utils/chat_history_summarizer.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
聊天内容概括器
|
||||
用于累积、打包和压缩聊天记录
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Set
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("chat_history_summarizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBatch:
|
||||
"""消息批次"""
|
||||
|
||||
messages: List[DatabaseMessages]
|
||||
start_time: float
|
||||
end_time: float
|
||||
is_preparing: bool = False # 是否处于准备结束模式
|
||||
|
||||
|
||||
class ChatHistorySummarizer:
|
||||
"""聊天内容概括器"""
|
||||
|
||||
def __init__(self, chat_id: str, check_interval: int = 60):
|
||||
"""
|
||||
初始化聊天内容概括器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
check_interval: 定期检查间隔(秒),默认60秒
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self._chat_display_name = self._get_chat_display_name()
|
||||
self.log_prefix = f"[{self._chat_display_name}]"
|
||||
|
||||
# 记录时间点,用于计算新消息
|
||||
self.last_check_time = time.time()
|
||||
|
||||
# 当前累积的消息批次
|
||||
self.current_batch: Optional[MessageBatch] = None
|
||||
|
||||
# LLM请求器,用于压缩聊天内容
|
||||
self.summarizer_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
|
||||
)
|
||||
|
||||
# 后台循环相关
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self._periodic_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
def _get_chat_display_name(self) -> str:
|
||||
"""获取聊天显示名称"""
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
if chat_name:
|
||||
return chat_name
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
except Exception:
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
if len(self.chat_id) > 20:
|
||||
return f"{self.chat_id[:8]}..."
|
||||
return self.chat_id
|
||||
|
||||
async def process(self, current_time: Optional[float] = None):
|
||||
"""
|
||||
处理聊天内容概括
|
||||
|
||||
Args:
|
||||
current_time: 当前时间戳,如果为None则使用time.time()
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
|
||||
try:
|
||||
# 获取从上次检查时间到当前时间的新消息
|
||||
new_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=self.last_check_time,
|
||||
end_time=current_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
if not new_messages:
|
||||
# 没有新消息,检查是否需要打包
|
||||
if self.current_batch and self.current_batch.messages:
|
||||
await self._check_and_package(current_time)
|
||||
self.last_check_time = current_time
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
|
||||
)
|
||||
|
||||
# 有新消息,更新最后检查时间
|
||||
self.last_check_time = current_time
|
||||
|
||||
# 如果有当前批次,添加新消息
|
||||
if self.current_batch:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
else:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
messages=new_messages,
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
|
||||
|
||||
# 检查是否需要打包
|
||||
await self._check_and_package(current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_and_package(self, current_time: float):
|
||||
"""检查是否需要打包"""
|
||||
if not self.current_batch or not self.current_batch.messages:
|
||||
return
|
||||
|
||||
messages = self.current_batch.messages
|
||||
message_count = len(messages)
|
||||
last_message_time = messages[-1].time if messages else current_time
|
||||
time_since_last_message = current_time - last_message_time
|
||||
|
||||
# 格式化时间差显示
|
||||
if time_since_last_message < 60:
|
||||
time_str = f"{time_since_last_message:.1f}秒"
|
||||
elif time_since_last_message < 3600:
|
||||
time_str = f"{time_since_last_message / 60:.1f}分钟"
|
||||
else:
|
||||
time_str = f"{time_since_last_message / 3600:.1f}小时"
|
||||
|
||||
preparing_status = "是" if self.current_batch.is_preparing else "否"
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}"
|
||||
)
|
||||
|
||||
# 检查打包条件
|
||||
should_package = False
|
||||
|
||||
# 条件1: 消息长度超过120,直接打包
|
||||
if message_count >= 120:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条)")
|
||||
|
||||
# 条件2: 最后一条消息的时间和当前时间差>600秒,直接打包
|
||||
elif time_since_last_message > 600:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟)")
|
||||
|
||||
# 条件3: 消息长度超过100,进入准备结束模式
|
||||
elif message_count > 100:
|
||||
if not self.current_batch.is_preparing:
|
||||
self.current_batch.is_preparing = True
|
||||
logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值(100条),进入准备结束模式")
|
||||
|
||||
# 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒,就打包
|
||||
if time_since_last_message > 10:
|
||||
should_package = True
|
||||
logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒)")
|
||||
|
||||
if should_package:
|
||||
await self._package_and_store()
|
||||
|
||||
async def _package_and_store(self):
|
||||
"""打包并存储聊天记录"""
|
||||
if not self.current_batch or not self.current_batch.messages:
|
||||
return
|
||||
|
||||
messages = self.current_batch.messages
|
||||
start_time = self.current_batch.start_time
|
||||
end_time = self.current_batch.end_time
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
)
|
||||
|
||||
# 检查是否有bot发言
|
||||
# 第一条消息前推600s到最后一条消息的时间内
|
||||
check_start_time = max(start_time - 600, 0)
|
||||
check_end_time = end_time
|
||||
|
||||
# 使用包含边界的时间范围查询
|
||||
bot_messages = message_api.get_messages_by_time_in_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
start_time=check_start_time,
|
||||
end_time=check_end_time,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_mai=False,
|
||||
filter_command=False,
|
||||
)
|
||||
|
||||
# 检查是否有bot的发言
|
||||
has_bot_message = False
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
for msg in bot_messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
break
|
||||
|
||||
if not has_bot_message:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 批次内无Bot发言,丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}"
|
||||
)
|
||||
self.current_batch = None
|
||||
return
|
||||
|
||||
# 有bot发言,进行压缩和存储
|
||||
try:
|
||||
# 构建对话原文
|
||||
original_text = build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
# 获取参与的所有人的昵称
|
||||
participants_set: Set[str] = set()
|
||||
for msg in messages:
|
||||
# 使用 msg.user_platform(扁平化字段)或 msg.user_info.platform
|
||||
platform = (
|
||||
getattr(msg, "user_platform", None)
|
||||
or (msg.user_info.platform if msg.user_info else None)
|
||||
or msg.chat_info.platform
|
||||
)
|
||||
person = Person(platform=platform, user_id=msg.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
if person_name:
|
||||
participants_set.add(person_name)
|
||||
participants = list(participants_set)
|
||||
logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
|
||||
|
||||
# 使用LLM压缩聊天内容
|
||||
success, theme, keywords, summary = await self._compress_with_llm(original_text)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}")
|
||||
# 清空当前批次,避免重复处理
|
||||
self.current_batch = None
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)} 字"
|
||||
)
|
||||
|
||||
# 存储到数据库
|
||||
await self._store_to_database(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
participants=participants,
|
||||
theme=theme,
|
||||
keywords=keywords,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}")
|
||||
|
||||
# 清空当前批次
|
||||
self.current_batch = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# 出错时也清空批次,避免重复处理
|
||||
self.current_batch = None
|
||||
|
||||
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
|
||||
"""
|
||||
使用LLM压缩聊天内容
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
|
||||
"""
|
||||
prompt = f"""请对以下聊天记录进行概括,提取以下信息:
|
||||
|
||||
1. 主题:这段对话的主要内容,一个简短的标题(不超过20字)
|
||||
2. 关键词:这段对话的关键词,用列表形式返回(3-10个关键词)
|
||||
3. 概括:对这段话的平文本概括(50-200字)
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{{
|
||||
"theme": "主题",
|
||||
"keywords": ["关键词1", "关键词2", ...],
|
||||
"summary": "概括内容"
|
||||
}}
|
||||
|
||||
聊天记录:
|
||||
{original_text}
|
||||
|
||||
请直接返回JSON,不要包含其他内容。"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 尝试找到JSON对象的开始和结束位置
|
||||
# 查找第一个 { 和最后一个匹配的 }
|
||||
start_idx = json_str.find("{")
|
||||
if start_idx == -1:
|
||||
raise ValueError("未找到JSON对象开始标记")
|
||||
|
||||
# 从后往前查找最后一个 }
|
||||
end_idx = json_str.rfind("}")
|
||||
if end_idx == -1 or end_idx <= start_idx:
|
||||
raise ValueError("未找到JSON对象结束标记")
|
||||
|
||||
# 提取JSON字符串
|
||||
json_str = json_str[start_idx : end_idx + 1]
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,尝试修复字符串值中的中文引号
|
||||
# 简单方法:将字符串值中的中文引号替换为转义的英文引号
|
||||
# 使用状态机方法:遍历字符串,在字符串值内部替换中文引号
|
||||
fixed_chars = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
i = 0
|
||||
while i < len(json_str):
|
||||
char = json_str[i]
|
||||
if escape_next:
|
||||
fixed_chars.append(char)
|
||||
escape_next = False
|
||||
elif char == "\\":
|
||||
fixed_chars.append(char)
|
||||
escape_next = True
|
||||
elif char == '"' and not escape_next:
|
||||
fixed_chars.append(char)
|
||||
in_string = not in_string
|
||||
elif in_string and (char == '"' or char == '"'):
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
fixed_chars.append('\\"')
|
||||
else:
|
||||
fixed_chars.append(char)
|
||||
i += 1
|
||||
|
||||
json_str = "".join(fixed_chars)
|
||||
# 再次尝试解析
|
||||
result = json.loads(json_str)
|
||||
|
||||
theme = result.get("theme", "未命名对话")
|
||||
keywords = result.get("keywords", [])
|
||||
summary = result.get("summary", "无概括")
|
||||
|
||||
# 确保keywords是列表
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
|
||||
return True, theme, keywords, summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
|
||||
# 返回失败标志和默认值
|
||||
return False, "未命名对话", [], "压缩失败,无法生成概括"
|
||||
|
||||
async def _store_to_database(
|
||||
self,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
original_text: str,
|
||||
participants: List[str],
|
||||
theme: str,
|
||||
keywords: List[str],
|
||||
summary: str,
|
||||
):
|
||||
"""存储到数据库"""
|
||||
try:
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.plugin_system.apis import database_api
|
||||
|
||||
# 准备数据
|
||||
data = {
|
||||
"chat_id": self.chat_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"original_text": original_text,
|
||||
"participants": json.dumps(participants, ensure_ascii=False),
|
||||
"theme": theme,
|
||||
"keywords": json.dumps(keywords, ensure_ascii=False),
|
||||
"summary": summary,
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
# 使用db_save存储(使用start_time和chat_id作为唯一标识)
|
||||
# 由于可能有多条记录,我们使用组合键,但peewee不支持,所以使用start_time作为唯一标识
|
||||
# 但为了避免冲突,我们使用组合键:chat_id + start_time
|
||||
# 由于peewee不支持组合键,我们直接创建新记录(不提供key_field和key_value)
|
||||
saved_record = await database_api.db_save(
|
||||
ChatHistory,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""启动后台定期检查循环"""
|
||||
if self._running:
|
||||
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._periodic_task = asyncio.create_task(self._periodic_check_loop())
|
||||
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒")
|
||||
|
||||
async def stop(self):
|
||||
"""停止后台定期检查循环"""
|
||||
self._running = False
|
||||
if self._periodic_task:
|
||||
self._periodic_task.cancel()
|
||||
try:
|
||||
await self._periodic_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._periodic_task = None
|
||||
logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
|
||||
|
||||
async def _periodic_check_loop(self):
|
||||
"""后台定期检查循环"""
|
||||
try:
|
||||
while self._running:
|
||||
# 执行一次检查
|
||||
await self.process()
|
||||
|
||||
# 等待指定间隔后再次检查
|
||||
await asyncio.sleep(self.check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 后台检查循环被取消")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self._running = False
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
@@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||
output_lines = []
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
for action in actions:
|
||||
action_time = action.time or current_time
|
||||
action_name = action.action_name or "未知动作"
|
||||
@@ -595,7 +594,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||
|
||||
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
||||
output_lines.append(line)
|
||||
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
@@ -674,7 +672,7 @@ def build_readable_messages(
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否替换机器人名称为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式
|
||||
timestamp_mode: 时间戳显示模式,"normal"或"normal_no_YMD"或"relative"
|
||||
read_mark: 已读标记时间戳
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
@@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
||||
362
src/chat/utils/memory_forget_task.py
Normal file
362
src/chat/utils/memory_forget_task.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
记忆遗忘任务
|
||||
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
logger = get_logger("memory_forget_task")
|
||||
|
||||
|
||||
class MemoryForgetTask(AsyncTask):
|
||||
"""记忆遗忘任务,每5分钟执行一次"""
|
||||
|
||||
def __init__(self):
|
||||
# 每5分钟执行一次(300秒)
|
||||
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
async def run(self):
|
||||
"""执行遗忘检查"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||
|
||||
# 执行4个阶段的遗忘检查
|
||||
await self._forget_stage_1(current_time)
|
||||
await self._forget_stage_2(current_time)
|
||||
await self._forget_stage_3(current_time)
|
||||
await self._forget_stage_4(current_time)
|
||||
|
||||
logger.info("[记忆遗忘] 遗忘检查完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_1(self, current_time: float):
|
||||
"""
|
||||
第一次遗忘检查:
|
||||
搜集所有:记忆还未被遗忘检查过(forget_times=0),且已经是30分钟之外的记忆
|
||||
取count最高25%和最低25%,删除,然后标记被遗忘检查次数为1
|
||||
"""
|
||||
try:
|
||||
# 30分钟 = 1800秒
|
||||
time_threshold = current_time - 1800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高25%和最低25%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.25) # 25%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录(处理count相同的情况:随机选择)
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重(避免重复删除),使用id去重
|
||||
seen_ids = set()
|
||||
unique_to_delete = []
|
||||
for record in to_delete:
|
||||
if record.id not in seen_ids:
|
||||
seen_ids.add(record.id)
|
||||
unique_to_delete.append(record)
|
||||
to_delete = unique_to_delete
|
||||
|
||||
# 删除记录并更新forget_times
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为1
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
# 批量更新
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_2(self, current_time: float):
|
||||
"""
|
||||
第二次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为1,且已经是8小时之外的记忆
|
||||
取count最高7%和最低7%,删除,然后标记被遗忘检查次数为2
|
||||
"""
|
||||
try:
|
||||
# 8小时 = 28800秒
|
||||
time_threshold = current_time - 28800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高7%和最低7%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.07) # 7%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为2
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_3(self, current_time: float):
|
||||
"""
|
||||
第三次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为2,且已经是48小时之外的记忆
|
||||
取count最高5%和最低5%,删除,然后标记被遗忘检查次数为3
|
||||
"""
|
||||
try:
|
||||
# 48小时 = 172800秒
|
||||
time_threshold = current_time - 172800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高5%和最低5%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.05) # 5%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为3
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_4(self, current_time: float):
|
||||
"""
|
||||
第四次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为3,且已经是7天之外的记忆
|
||||
取count最高2%和最低2%,删除,然后标记被遗忘检查次数为4
|
||||
"""
|
||||
try:
|
||||
# 7天 = 604800秒
|
||||
time_threshold = current_time - 604800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高2%和最低2%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.02) # 2%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为4
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
||||
|
||||
def _handle_same_count_random(
|
||||
self, candidates: List[ChatHistory], delete_count: int, mode: str
|
||||
) -> List[ChatHistory]:
|
||||
"""
|
||||
处理count相同的情况,随机选择要删除的记录
|
||||
|
||||
Args:
|
||||
candidates: 候选记录列表(已按count排序)
|
||||
delete_count: 要删除的数量
|
||||
mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录
|
||||
|
||||
Returns:
|
||||
要删除的记录列表
|
||||
"""
|
||||
if not candidates or delete_count == 0:
|
||||
return []
|
||||
|
||||
to_delete = []
|
||||
|
||||
if mode == "high":
|
||||
# 从最高count开始选择
|
||||
start_idx = 0
|
||||
while start_idx < len(candidates) and len(to_delete) < delete_count:
|
||||
# 找到所有count相同的记录
|
||||
current_count = candidates[start_idx].count
|
||||
same_count_records = []
|
||||
idx = start_idx
|
||||
while idx < len(candidates) and candidates[idx].count == current_count:
|
||||
same_count_records.append(candidates[idx])
|
||||
idx += 1
|
||||
|
||||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||
needed = delete_count - len(to_delete)
|
||||
if len(same_count_records) <= needed:
|
||||
to_delete.extend(same_count_records)
|
||||
else:
|
||||
# 随机选择需要的数量
|
||||
to_delete.extend(random.sample(same_count_records, needed))
|
||||
|
||||
start_idx = idx
|
||||
|
||||
else: # mode == "low"
|
||||
# 从最低count开始选择
|
||||
start_idx = len(candidates) - 1
|
||||
while start_idx >= 0 and len(to_delete) < delete_count:
|
||||
# 找到所有count相同的记录
|
||||
current_count = candidates[start_idx].count
|
||||
same_count_records = []
|
||||
idx = start_idx
|
||||
while idx >= 0 and candidates[idx].count == current_count:
|
||||
same_count_records.append(candidates[idx])
|
||||
idx -= 1
|
||||
|
||||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||
needed = delete_count - len(to_delete)
|
||||
if len(same_count_records) <= needed:
|
||||
to_delete.extend(same_count_records)
|
||||
else:
|
||||
# 随机选择需要的数量
|
||||
to_delete.extend(random.sample(same_count_records, needed))
|
||||
|
||||
start_idx = idx
|
||||
|
||||
return to_delete
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
@@ -10,6 +11,7 @@ from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
@@ -51,6 +53,7 @@ STD_TIME_COST_BY_MODULE = "std_time_costs_by_module"
|
||||
ONLINE_TIME = "online_time"
|
||||
TOTAL_MSG_CNT = "total_messages"
|
||||
MSG_CNT_BY_CHAT = "messages_by_chat"
|
||||
TOTAL_REPLY_CNT = "total_replies"
|
||||
|
||||
|
||||
class OnlineTimeRecordTask(AsyncTask):
|
||||
@@ -134,6 +137,37 @@ def _format_online_time(online_seconds: int) -> str:
|
||||
return f"{minutes}分钟{seconds}秒"
|
||||
|
||||
|
||||
def _format_large_number(num: float | int, html: bool = False) -> str:
|
||||
"""
|
||||
格式化大数字,使用K后缀节省空间(大于9999时)
|
||||
:param num: 要格式化的数字
|
||||
:param html: 是否用于HTML输出(如果是,K会着色)
|
||||
:return: 格式化后的字符串,如 12K, 1.3K, 120K
|
||||
"""
|
||||
if num >= 10000:
|
||||
# 大于等于10000,使用K后缀
|
||||
value = num / 1000.0
|
||||
if value >= 10:
|
||||
number_part = str(int(value))
|
||||
k_suffix = "K"
|
||||
else:
|
||||
number_part = f"{value:.1f}"
|
||||
k_suffix = "K"
|
||||
|
||||
if html:
|
||||
# HTML输出:K着色为主题色并加粗大写
|
||||
return f"{number_part}<span style='color: #8b5cf6; font-weight: bold;'>K</span>"
|
||||
else:
|
||||
# 控制台输出:纯文本,K大写
|
||||
return f"{number_part}{k_suffix}"
|
||||
else:
|
||||
# 小于10000,直接显示
|
||||
if isinstance(num, float):
|
||||
return f"{num:.1f}" if num != int(num) else str(int(num))
|
||||
else:
|
||||
return str(num)
|
||||
|
||||
|
||||
class StatisticOutputTask(AsyncTask):
|
||||
"""统计输出任务"""
|
||||
|
||||
@@ -165,11 +199,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
self.stat_period: List[Tuple[str, timedelta, str]] = [
|
||||
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
|
||||
("last_7_days", timedelta(days=7), "最近7天"),
|
||||
("last_3_days", timedelta(days=3), "最近3天"),
|
||||
("last_24_hours", timedelta(days=1), "最近24小时"),
|
||||
("last_3_hours", timedelta(hours=3), "最近3小时"),
|
||||
("last_hour", timedelta(hours=1), "最近1小时"),
|
||||
("last_30_days", timedelta(days=30), "近30天"),
|
||||
("last_7_days", timedelta(days=7), "近7天"),
|
||||
("last_3_days", timedelta(days=3), "近3天"),
|
||||
("last_24_hours", timedelta(days=1), "近1天"),
|
||||
("last_3_hours", timedelta(hours=3), "近3小时"),
|
||||
("last_hour", timedelta(hours=1), "近1小时"),
|
||||
("last_15_minutes", timedelta(minutes=15), "近15分钟"),
|
||||
]
|
||||
"""
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
@@ -462,10 +498,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
period_key: {
|
||||
TOTAL_MSG_CNT: 0,
|
||||
MSG_CNT_BY_CHAT: defaultdict(int),
|
||||
TOTAL_REPLY_CNT: 0,
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
@@ -503,11 +547,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
is_bot_reply = False
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
is_bot_reply = True
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
if is_bot_reply:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
return stats
|
||||
|
||||
@@ -541,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
continue
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
||||
self.stat_period = [
|
||||
item for item in self.stat_period if item[0] != "all_time"
|
||||
] # 删除"所有时间"的统计时段
|
||||
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
|
||||
except Exception as e:
|
||||
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
|
||||
@@ -593,12 +646,12 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 更新上次完整统计数据的时间戳
|
||||
# 将所有defaultdict转换为普通dict以避免类型冲突
|
||||
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
|
||||
|
||||
|
||||
# 将 name_mapping 中的元组转换为列表,因为JSON不支持元组
|
||||
json_safe_name_mapping = {}
|
||||
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
|
||||
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
|
||||
|
||||
|
||||
local_storage["last_full_statistics"] = {
|
||||
"name_mapping": json_safe_name_mapping,
|
||||
"stat_data": clean_stat_data,
|
||||
@@ -633,12 +686,45 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
格式化总统计数据
|
||||
"""
|
||||
# 计算总token数(从所有模型的token数中累加)
|
||||
total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0
|
||||
|
||||
# 计算花费/消息数量指标(每100条)
|
||||
cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0
|
||||
|
||||
# 计算花费/时间指标(花费/小时)
|
||||
online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0
|
||||
cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0
|
||||
|
||||
# 计算token/时间指标(token/小时)
|
||||
tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0
|
||||
|
||||
# 计算花费/回复数量指标(每100条)
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0
|
||||
|
||||
# 计算花费/消息数量(排除自己回复)指标(每100条)
|
||||
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
|
||||
cost_per_100_messages_excluding_replies = (
|
||||
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
|
||||
if total_messages_excluding_replies > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
output = [
|
||||
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
|
||||
f"总消息数: {stats[TOTAL_MSG_CNT]}",
|
||||
f"总请求数: {stats[TOTAL_REQ_CNT]}",
|
||||
f"总消息数: {_format_large_number(stats[TOTAL_MSG_CNT])}",
|
||||
f"总回复数: {_format_large_number(total_replies)}",
|
||||
f"总请求数: {_format_large_number(stats[TOTAL_REQ_CNT])}",
|
||||
f"总Token数: {_format_large_number(total_tokens)}",
|
||||
f"总花费: {stats[TOTAL_COST]:.2f}¥",
|
||||
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
|
||||
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
|
||||
if total_messages_excluding_replies > 0
|
||||
else "花费/消息数量(排除回复): N/A",
|
||||
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
|
||||
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
|
||||
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -665,8 +751,22 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
formatted_out_tokens = _format_large_number(out_tokens)
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
data_fmt.format(
|
||||
name,
|
||||
formatted_count,
|
||||
formatted_in_tokens,
|
||||
formatted_out_tokens,
|
||||
formatted_tokens,
|
||||
cost,
|
||||
avg_time_cost,
|
||||
std_time_cost,
|
||||
)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
@@ -682,10 +782,12 @@ class StatisticOutputTask(AsyncTask):
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
output.append(f"{chat_name[:32]:<32} {count:>10}")
|
||||
formatted_count = _format_large_number(count)
|
||||
output.append(f"{chat_name[:32]:<32} {formatted_count:>10}")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"格式化聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
output.append(f"{'未知聊天':<32} {count:>10}")
|
||||
formatted_count = _format_large_number(count)
|
||||
output.append(f"{'未知聊天':<32} {formatted_count:>10}")
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
@@ -735,6 +837,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
for period in self.stat_period
|
||||
]
|
||||
tab_list.append('<button class="tab-link" onclick="showTab(event, \'charts\')">数据图表</button>')
|
||||
tab_list.append('<button class="tab-link" onclick="showTab(event, \'metrics\')">指标趋势</button>')
|
||||
|
||||
def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str:
|
||||
"""
|
||||
@@ -750,10 +853,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr>"
|
||||
f"<td>{model_name}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
|
||||
f"<td>{_format_large_number(count, html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[IN_TOK_BY_MODEL][model_name], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[OUT_TOK_BY_MODEL][model_name], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name], html=True)}</td>"
|
||||
f"<td>{stat_data[COST_BY_MODEL][model_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
@@ -768,10 +871,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr>"
|
||||
f"<td>{req_type}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
|
||||
f"<td>{_format_large_number(count, html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[IN_TOK_BY_TYPE][req_type], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[OUT_TOK_BY_TYPE][req_type], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type], html=True)}</td>"
|
||||
f"<td>{stat_data[COST_BY_TYPE][req_type]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
@@ -786,10 +889,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr>"
|
||||
f"<td>{module_name}</td>"
|
||||
f"<td>{count}</td>"
|
||||
f"<td>{stat_data[IN_TOK_BY_MODULE][module_name]}</td>"
|
||||
f"<td>{stat_data[OUT_TOK_BY_MODULE][module_name]}</td>"
|
||||
f"<td>{stat_data[TOTAL_TOK_BY_MODULE][module_name]}</td>"
|
||||
f"<td>{_format_large_number(count, html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[IN_TOK_BY_MODULE][module_name], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[OUT_TOK_BY_MODULE][module_name], html=True)}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name], html=True)}</td>"
|
||||
f"<td>{stat_data[COST_BY_MODULE][module_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
@@ -805,12 +908,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
chat_rows.append(f"<tr><td>{chat_name}</td><td>{count}</td></tr>")
|
||||
chat_rows.append(f"<tr><td>{chat_name}</td><td>{_format_large_number(count, html=True)}</td></tr>")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
chat_rows.append(f"<tr><td>未知聊天</td><td>{count}</td></tr>")
|
||||
|
||||
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
||||
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
|
||||
|
||||
chat_rows_html = (
|
||||
"\n".join(chat_rows)
|
||||
if chat_rows
|
||||
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
||||
)
|
||||
# 生成HTML
|
||||
return f"""
|
||||
<div id=\"{div_id}\" class=\"tab-content\">
|
||||
@@ -818,48 +925,98 @@ class StatisticOutputTask(AsyncTask):
|
||||
<strong>统计时段: </strong>
|
||||
{start_time.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}
|
||||
</p>
|
||||
<p class=\"info-item\"><strong>总在线时间: </strong>{_format_online_time(stat_data[ONLINE_TIME])}</p>
|
||||
<p class=\"info-item\"><strong>总消息数: </strong>{stat_data[TOTAL_MSG_CNT]}</p>
|
||||
<p class=\"info-item\"><strong>总请求数: </strong>{stat_data[TOTAL_REQ_CNT]}</p>
|
||||
<p class=\"info-item\"><strong>总花费: </strong>{stat_data[TOTAL_COST]:.2f} ¥</p>
|
||||
<div class=\"kpi-cards\">
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总在线时间</div>
|
||||
<div class=\"kpi-value\">{_format_online_time(stat_data[ONLINE_TIME])}</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总消息数</div>
|
||||
<div class=\"kpi-value\">{_format_large_number(stat_data[TOTAL_MSG_CNT], html=True)}</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总回复数</div>
|
||||
<div class=\"kpi-value\">{_format_large_number(stat_data.get(TOTAL_REPLY_CNT, 0), html=True)}</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总请求数</div>
|
||||
<div class=\"kpi-value\">{_format_large_number(stat_data[TOTAL_REQ_CNT], html=True)}</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总Token数</div>
|
||||
<div class=\"kpi-value\">{_format_large_number(sum(stat_data[TOTAL_TOK_BY_MODEL].values()) if stat_data[TOTAL_TOK_BY_MODEL] else 0, html=True)}</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">总花费</div>
|
||||
<div class=\"kpi-value\">{stat_data[TOTAL_COST]:.2f} ¥</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">花费/消息数量</div>
|
||||
<div class=\"kpi-value\">{(stat_data[TOTAL_COST] / stat_data[TOTAL_MSG_CNT] * 100 if stat_data[TOTAL_MSG_CNT] > 0 else 0.0):.4f} ¥/100条</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">花费/消息数量(排除回复)</div>
|
||||
<div class=\"kpi-value\">{(stat_data[TOTAL_COST] / (stat_data[TOTAL_MSG_CNT] - stat_data.get(TOTAL_REPLY_CNT, 0)) * 100 if (stat_data[TOTAL_MSG_CNT] - stat_data.get(TOTAL_REPLY_CNT, 0)) > 0 else 0.0):.4f} ¥/100条</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">花费/回复数量</div>
|
||||
<div class=\"kpi-value\">{(stat_data[TOTAL_COST] / stat_data.get(TOTAL_REPLY_CNT, 0) * 100 if stat_data.get(TOTAL_REPLY_CNT, 0) > 0 else 0.0):.4f} ¥/100条</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">花费/时间</div>
|
||||
<div class=\"kpi-value\">{(stat_data[TOTAL_COST] / (stat_data[ONLINE_TIME] / 3600.0) if stat_data[ONLINE_TIME] > 0 else 0.0):.2f} ¥/小时</div>
|
||||
</div>
|
||||
<div class=\"kpi-card\">
|
||||
<div class=\"kpi-title\">Token/时间</div>
|
||||
<div class=\"kpi-value\">{_format_large_number(sum(stat_data[TOTAL_TOK_BY_MODEL].values()) / (stat_data[ONLINE_TIME] / 3600.0) if stat_data[ONLINE_TIME] > 0 and stat_data[TOTAL_TOK_BY_MODEL] else 0.0, html=True)}/小时</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2>按模型分类统计</h2>
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h2>按模块分类统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{module_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{module_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h2>按请求类型分类统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h2>聊天消息统计</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{chat_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{chat_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h2>数据分布图表</h2>
|
||||
<div style="display: flex; flex-wrap: wrap; gap: 20px; margin-top: 20px;">
|
||||
@@ -1066,6 +1223,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
chart_data = self._generate_chart_data(stat)
|
||||
tab_content_list.append(self._generate_chart_tab(chart_data))
|
||||
|
||||
# 添加指标趋势图表
|
||||
metrics_data = self._generate_metrics_data(now)
|
||||
tab_content_list.append(self._generate_metrics_tab(metrics_data))
|
||||
|
||||
joined_tab_list = "\n".join(tab_list)
|
||||
joined_tab_content = "\n".join(tab_content_list)
|
||||
|
||||
@@ -1083,21 +1244,22 @@ class StatisticOutputTask(AsyncTask):
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f4f7f6;
|
||||
color: #333;
|
||||
background-color: #faf7ff;
|
||||
color: #3a2f57;
|
||||
line-height: 1.6;
|
||||
}
|
||||
.container {
|
||||
max-width: 900px;
|
||||
margin: 20px auto;
|
||||
background-color: #fff;
|
||||
background-color: #ffffff;
|
||||
padding: 25px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 10px 28px rgba(122, 98, 182, 0.12);
|
||||
border: 1px solid #e5dcff;
|
||||
}
|
||||
h1, h2 {
|
||||
color: #2c3e50;
|
||||
border-bottom: 2px solid #3498db;
|
||||
color: #473673;
|
||||
border-bottom: 2px solid #9f8efb;
|
||||
padding-bottom: 10px;
|
||||
margin-top: 0;
|
||||
}
|
||||
@@ -1113,33 +1275,62 @@ class StatisticOutputTask(AsyncTask):
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.info-item {
|
||||
background-color: #ecf0f1;
|
||||
background-color: #f3eeff;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 8px;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
.info-item strong {
|
||||
color: #2980b9;
|
||||
color: #7162bf;
|
||||
}
|
||||
/* 新增:顶部工具条与按钮 */
|
||||
.toolbar { display: flex; align-items: center; justify-content: space-between; gap: 12px; margin-bottom: 8px; }
|
||||
.toolbar .right { display: flex; gap: 8px; align-items: center; }
|
||||
.btn {
|
||||
border: 1px solid #e3daff;
|
||||
background-color: #fbf9ff;
|
||||
color: #4a3c75;
|
||||
padding: 8px 12px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
transition: all .2s ease;
|
||||
}
|
||||
.btn:hover { border-color: #9f8efb; color: #7c6bcf; background-color: #f1ecff; }
|
||||
/* 新增:KPI 卡片 */
|
||||
.kpi-cards { display: grid; grid-template-columns: repeat(5, 1fr); gap: 12px; margin: 12px 0 6px; }
|
||||
.kpi-card {
|
||||
background: linear-gradient(145deg, #ffffff 0%, #f6f2ff 100%);
|
||||
border: 1px solid #e3dbff;
|
||||
border-radius: 10px;
|
||||
padding: 14px 16px;
|
||||
box-shadow: 0 6px 16px rgba(113, 98, 191, 0.1);
|
||||
}
|
||||
.kpi-title { font-size: 12px; color: #8579a6; letter-spacing: .3px; margin-bottom: 6px; }
|
||||
.kpi-value { font-size: 20px; font-weight: 700; letter-spacing: .2px; color: #8b5cf6; }
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-top: 15px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
/* 新增:表格包裹容器,支持横向滚动 */
|
||||
.table-wrap { width: 100%; overflow-x: auto; border-radius: 6px; }
|
||||
th, td {
|
||||
border: 1px solid #ddd;
|
||||
border: 1px solid #e6ddff;
|
||||
padding: 10px;
|
||||
text-align: left;
|
||||
}
|
||||
th {
|
||||
background-color: #3498db;
|
||||
background-color: #9f8efb;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 1;
|
||||
}
|
||||
tr:nth-child(even) {
|
||||
background-color: #f9f9f9;
|
||||
background-color: #f6f1ff;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
@@ -1149,25 +1340,32 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
.tabs {
|
||||
overflow: hidden;
|
||||
background: #ecf0f1;
|
||||
background: #f9f6ff;
|
||||
display: flex;
|
||||
border: 1px solid #e4dcff;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 8px 18px rgba(120, 101, 179, 0.08);
|
||||
}
|
||||
.tabs button {
|
||||
background: inherit; border: none; outline: none;
|
||||
padding: 14px 16px; cursor: pointer;
|
||||
transition: 0.3s; font-size: 16px;
|
||||
padding: 12px 14px; cursor: pointer;
|
||||
transition: 0.2s; font-size: 15px;
|
||||
color: #52467a;
|
||||
}
|
||||
.tabs button:hover {
|
||||
background-color: #d4dbdc;
|
||||
background-color: #efe9ff;
|
||||
}
|
||||
.tabs button.active {
|
||||
background-color: #b3bbbd;
|
||||
background-color: rgba(159, 142, 251, 0.25);
|
||||
color: #6253a9;
|
||||
}
|
||||
.tab-content {
|
||||
display: none;
|
||||
padding: 20px;
|
||||
background-color: #fff;
|
||||
border: 1px solid #ccc;
|
||||
background-color: #fefcff;
|
||||
border: 1px solid #e4dcff;
|
||||
border-top: none;
|
||||
border-radius: 0 0 10px 10px;
|
||||
}
|
||||
.tab-content.active {
|
||||
display: block;
|
||||
@@ -1178,14 +1376,19 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
+ f"""
|
||||
<div class="container">
|
||||
<h1>MaiBot运行统计报告</h1>
|
||||
<p class="info-item"><strong>统计截止时间:</strong> {now.strftime("%Y-%m-%d %H:%M:%S")}</p>
|
||||
<div class="toolbar">
|
||||
<h1 style="margin: 0;">MaiBot运行统计报告</h1>
|
||||
<div class="right">
|
||||
<span class="info-item" style="margin: 0;"><strong>统计截止时间:</strong> {now.strftime("%Y-%m-%d %H:%M:%S")}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="tabs">
|
||||
{joined_tab_list}
|
||||
</div>
|
||||
|
||||
{joined_tab_content}
|
||||
<div class="footer">Made with ❤️ by MaiBot • 本页会定期自动覆盖生成</div>
|
||||
</div>
|
||||
"""
|
||||
+ """
|
||||
@@ -1319,16 +1522,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 生成不同颜色的调色板
|
||||
colors = [
|
||||
"#3498db",
|
||||
"#e74c3c",
|
||||
"#2ecc71",
|
||||
"#f39c12",
|
||||
"#9b59b6",
|
||||
"#1abc9c",
|
||||
"#34495e",
|
||||
"#e67e22",
|
||||
"#95a5a6",
|
||||
"#f1c40f",
|
||||
"#8b5cf6",
|
||||
"#9f8efb",
|
||||
"#b5a6ff",
|
||||
"#c7bbff",
|
||||
"#d9ceff",
|
||||
"#a78bfa",
|
||||
"#9073d8",
|
||||
"#bfaefc",
|
||||
"#cabdfd",
|
||||
"#e6e0ff",
|
||||
]
|
||||
|
||||
# 默认使用24小时数据生成数据集
|
||||
@@ -1510,7 +1713,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
function createChart(chartType, data, timeRange) {{
|
||||
const config = chartConfigs[chartType];
|
||||
const colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c', '#34495e', '#e67e22', '#95a5a6', '#f1c40f'];
|
||||
const colors = ['#8b5cf6', '#9f8efb', '#b5a6ff', '#c7bbff', '#d9ceff', '#a78bfa', '#9073d8', '#bfaefc', '#cabdfd', '#e6e0ff'];
|
||||
|
||||
let datasets = [];
|
||||
|
||||
@@ -1591,6 +1794,320 @@ class StatisticOutputTask(AsyncTask):
|
||||
</div>
|
||||
"""
|
||||
|
||||
def _generate_metrics_data(self, now: datetime) -> dict:
|
||||
"""生成指标趋势数据"""
|
||||
metrics_data = {}
|
||||
|
||||
# 24小时尺度:1小时为单位
|
||||
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
|
||||
|
||||
# 7天尺度:1天为单位
|
||||
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
|
||||
|
||||
# 30天尺度:1天为单位
|
||||
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
|
||||
|
||||
return metrics_data
|
||||
|
||||
def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict:
|
||||
"""收集指定时间范围内每个间隔的指标数据"""
|
||||
start_time = now - timedelta(hours=hours)
|
||||
time_points = []
|
||||
current_time = start_time
|
||||
|
||||
# 生成时间点
|
||||
while current_time <= now:
|
||||
time_points.append(current_time)
|
||||
current_time += timedelta(hours=interval_hours)
|
||||
|
||||
# 初始化数据结构
|
||||
cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量(每100条)
|
||||
cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时)
|
||||
tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时)
|
||||
cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量(每100条)
|
||||
|
||||
# 每个时间点的累计数据
|
||||
total_costs = [0.0] * len(time_points)
|
||||
total_tokens = [0] * len(time_points)
|
||||
total_messages = [0] * len(time_points)
|
||||
total_replies = [0] * len(time_points)
|
||||
total_online_hours = [0.0] * len(time_points)
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
|
||||
interval_seconds = interval_hours * 3600
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
interval_index = int(time_diff // interval_seconds)
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
cost = record.cost or 0.0
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
total_token = prompt_tokens + completion_tokens
|
||||
|
||||
total_costs[interval_index] += cost
|
||||
total_tokens[interval_index] += total_token
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
interval_index = int(time_diff // interval_seconds)
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
total_messages[interval_index] += 1
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
total_replies[interval_index] += 1
|
||||
|
||||
# 查询在线时间记录
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore
|
||||
record_start = record.start_timestamp
|
||||
record_end = record.end_timestamp
|
||||
|
||||
# 找到记录覆盖的所有时间间隔
|
||||
for idx, time_point in enumerate(time_points):
|
||||
interval_start = time_point
|
||||
interval_end = time_point + timedelta(hours=interval_hours)
|
||||
|
||||
# 计算重叠部分
|
||||
overlap_start = max(record_start, interval_start)
|
||||
overlap_end = min(record_end, interval_end)
|
||||
|
||||
if overlap_end > overlap_start:
|
||||
overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0
|
||||
total_online_hours[idx] += overlap_hours
|
||||
|
||||
# 计算指标
|
||||
for idx in range(len(time_points)):
|
||||
# 花费/消息数量(每100条)
|
||||
if total_messages[idx] > 0:
|
||||
cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
|
||||
|
||||
# 花费/时间(每小时)
|
||||
if total_online_hours[idx] > 0:
|
||||
cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
|
||||
|
||||
# Token/时间(每小时)
|
||||
if total_online_hours[idx] > 0:
|
||||
tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
|
||||
|
||||
# 花费/回复数量(每100条)
|
||||
if total_replies[idx] > 0:
|
||||
cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
|
||||
|
||||
# 生成时间标签
|
||||
if interval_hours == 1:
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
else:
|
||||
time_labels = [t.strftime("%m-%d") for t in time_points]
|
||||
|
||||
return {
|
||||
"time_labels": time_labels,
|
||||
"cost_per_100_messages": cost_per_100_messages,
|
||||
"cost_per_hour": cost_per_hour,
|
||||
"tokens_per_hour": tokens_per_hour,
|
||||
"cost_per_100_replies": cost_per_100_replies,
|
||||
}
|
||||
|
||||
def _generate_metrics_tab(self, metrics_data: dict) -> str:
|
||||
"""生成指标趋势图表选项卡HTML内容"""
|
||||
colors = {
|
||||
"cost_per_100_messages": "#8b5cf6",
|
||||
"cost_per_hour": "#9f8efb",
|
||||
"tokens_per_hour": "#c7bbff",
|
||||
"cost_per_100_replies": "#d9ceff",
|
||||
}
|
||||
|
||||
return f"""
|
||||
<div id="metrics" class="tab-content">
|
||||
<h2>指标趋势图表</h2>
|
||||
|
||||
<!-- 时间尺度选择按钮 -->
|
||||
<div style="margin: 20px 0; text-align: center;">
|
||||
<label style="margin-right: 10px; font-weight: bold;">时间尺度:</label>
|
||||
<button class="time-scale-btn" onclick="switchMetricsTimeScale('24h')">24小时</button>
|
||||
<button class="time-scale-btn active" onclick="switchMetricsTimeScale('7d')">7天</button>
|
||||
<button class="time-scale-btn" onclick="switchMetricsTimeScale('30d')">30天</button>
|
||||
</div>
|
||||
|
||||
<div style="margin-top: 20px;">
|
||||
<div style="margin-bottom: 40px;">
|
||||
<canvas id="costPer100MessagesChart" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<div style="margin-bottom: 40px;">
|
||||
<canvas id="costPerHourChart" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<div style="margin-bottom: 40px;">
|
||||
<canvas id="tokensPerHourChart" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<div>
|
||||
<canvas id="costPer100RepliesChart" width="800" height="400"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.time-scale-btn {{
|
||||
background-color: #ecf0f1;
|
||||
border: 1px solid #bdc3c7;
|
||||
color: #2c3e50;
|
||||
padding: 8px 16px;
|
||||
margin: 0 5px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
transition: all 0.3s ease;
|
||||
}}
|
||||
|
||||
.time-scale-btn:hover {{
|
||||
background-color: #d5dbdb;
|
||||
}}
|
||||
|
||||
.time-scale-btn.active {{
|
||||
background-color: #8b5cf6;
|
||||
color: white;
|
||||
border-color: #7c6bcf;
|
||||
}}
|
||||
</style>
|
||||
|
||||
<script>
|
||||
const allMetricsData = {json.dumps(metrics_data)};
|
||||
let currentMetricsCharts = {{}};
|
||||
|
||||
const metricsConfigs = {{
|
||||
costPer100Messages: {{
|
||||
id: 'costPer100MessagesChart',
|
||||
title: '花费/消息数量',
|
||||
yAxisLabel: '花费 (¥/100条)',
|
||||
dataKey: 'cost_per_100_messages',
|
||||
color: '{colors["cost_per_100_messages"]}'
|
||||
}},
|
||||
costPerHour: {{
|
||||
id: 'costPerHourChart',
|
||||
title: '花费/时间',
|
||||
yAxisLabel: '花费 (¥/小时)',
|
||||
dataKey: 'cost_per_hour',
|
||||
color: '{colors["cost_per_hour"]}'
|
||||
}},
|
||||
tokensPerHour: {{
|
||||
id: 'tokensPerHourChart',
|
||||
title: 'Token/时间',
|
||||
yAxisLabel: 'Token (/小时)',
|
||||
dataKey: 'tokens_per_hour',
|
||||
color: '{colors["tokens_per_hour"]}'
|
||||
}},
|
||||
costPer100Replies: {{
|
||||
id: 'costPer100RepliesChart',
|
||||
title: '花费/回复数量',
|
||||
yAxisLabel: '花费 (¥/100条)',
|
||||
dataKey: 'cost_per_100_replies',
|
||||
color: '{colors["cost_per_100_replies"]}'
|
||||
}}
|
||||
}};
|
||||
|
||||
function switchMetricsTimeScale(timeScale) {{
|
||||
// 更新按钮状态
|
||||
document.querySelectorAll('.time-scale-btn').forEach(btn => {{
|
||||
btn.classList.remove('active');
|
||||
}});
|
||||
event.target.classList.add('active');
|
||||
|
||||
// 更新图表数据
|
||||
const data = allMetricsData[timeScale];
|
||||
updateAllMetricsCharts(data, timeScale);
|
||||
}}
|
||||
|
||||
function updateAllMetricsCharts(data, timeScale) {{
|
||||
// 销毁现有图表
|
||||
Object.values(currentMetricsCharts).forEach(chart => {{
|
||||
if (chart) chart.destroy();
|
||||
}});
|
||||
|
||||
currentMetricsCharts = {{}};
|
||||
|
||||
// 重新创建图表
|
||||
createMetricsChart('costPer100Messages', data, timeScale);
|
||||
createMetricsChart('costPerHour', data, timeScale);
|
||||
createMetricsChart('tokensPerHour', data, timeScale);
|
||||
createMetricsChart('costPer100Replies', data, timeScale);
|
||||
}}
|
||||
|
||||
function createMetricsChart(chartType, data, timeScale) {{
|
||||
const config = metricsConfigs[chartType];
|
||||
|
||||
currentMetricsCharts[chartType] = new Chart(document.getElementById(config.id), {{
|
||||
type: 'line',
|
||||
data: {{
|
||||
labels: data.time_labels,
|
||||
datasets: [{{
|
||||
label: config.title,
|
||||
data: data[config.dataKey],
|
||||
borderColor: config.color,
|
||||
backgroundColor: config.color + '20',
|
||||
tension: 0.4,
|
||||
fill: false
|
||||
}}]
|
||||
}},
|
||||
options: {{
|
||||
responsive: true,
|
||||
plugins: {{
|
||||
title: {{
|
||||
display: true,
|
||||
text: timeScale + '内' + config.title + '趋势',
|
||||
font: {{ size: 16 }}
|
||||
}},
|
||||
legend: {{
|
||||
display: false
|
||||
}}
|
||||
}},
|
||||
scales: {{
|
||||
x: {{
|
||||
title: {{
|
||||
display: true,
|
||||
text: '时间'
|
||||
}},
|
||||
ticks: {{
|
||||
maxTicksLimit: 12
|
||||
}}
|
||||
}},
|
||||
y: {{
|
||||
title: {{
|
||||
display: true,
|
||||
text: config.yAxisLabel
|
||||
}},
|
||||
beginAtZero: true
|
||||
}}
|
||||
}},
|
||||
interaction: {{
|
||||
intersect: false,
|
||||
mode: 'index'
|
||||
}}
|
||||
}}
|
||||
}});
|
||||
}}
|
||||
|
||||
// 初始化图表(默认7天)
|
||||
document.addEventListener('DOMContentLoaded', function() {{
|
||||
updateAllMetricsCharts(allMetricsData['7d'], '7d');
|
||||
}});
|
||||
</script>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
class AsyncStatisticOutputTask(AsyncTask):
|
||||
"""完全异步的统计输出任务 - 更高性能版本"""
|
||||
@@ -1682,6 +2199,15 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
|
||||
|
||||
def _generate_metrics_data(self, now: datetime) -> dict:
|
||||
return StatisticOutputTask._generate_metrics_data(self, now) # type: ignore
|
||||
|
||||
def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict:
|
||||
return StatisticOutputTask._collect_metrics_interval_data(self, now, hours, interval_hours) # type: ignore
|
||||
|
||||
def _generate_metrics_tab(self, metrics_data: dict) -> str:
|
||||
return StatisticOutputTask._generate_metrics_tab(self, metrics_data) # type: ignore
|
||||
|
||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
|
||||
|
||||
|
||||
@@ -4,14 +4,11 @@ import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool:
|
||||
|
||||
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
"""解析 platforms 列表,返回平台到账号的映射
|
||||
|
||||
|
||||
Args:
|
||||
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
|
||||
|
||||
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
@@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
@@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
platform = getattr(message.message_info, "platform", "") or ""
|
||||
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
|
||||
@@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
elif current_account:
|
||||
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text):
|
||||
elif re.search(
|
||||
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text
|
||||
):
|
||||
is_mentioned = True
|
||||
|
||||
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
||||
@@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
||||
@@ -221,14 +219,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if char in separators:
|
||||
# 检查分割条件:如果分隔符左右都是英文字母,则不分割
|
||||
# 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
|
||||
can_split = True
|
||||
if 0 < i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# if is_english_letter(prev_char) and is_english_letter(next_char) and char == ' ': # 原计划只对空格应用此规则,现应用于所有分隔符
|
||||
if is_english_letter(prev_char) and is_english_letter(next_char):
|
||||
can_split = False
|
||||
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
||||
if char == " ":
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
can_split = False
|
||||
|
||||
if can_split:
|
||||
# 只有当当前段不为空时才添加
|
||||
@@ -328,6 +329,20 @@ def random_remove_punctuation(text: str) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def _get_random_default_reply() -> str:
|
||||
"""获取随机默认回复"""
|
||||
default_replies = [
|
||||
f"{global_config.bot.nickname}不知道哦",
|
||||
f"{global_config.bot.nickname}不知道",
|
||||
"不知道哦",
|
||||
"不知道",
|
||||
"不晓得",
|
||||
"懒得说",
|
||||
"()",
|
||||
]
|
||||
return random.choice(default_replies)
|
||||
|
||||
|
||||
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
||||
if not global_config.response_post_process.enable_response_post_process:
|
||||
return [text]
|
||||
@@ -356,7 +371,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
# 如果基本上是中文,则进行长度过滤
|
||||
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
|
||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
return [_get_random_default_reply()]
|
||||
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=global_config.chinese_typo.error_rate,
|
||||
@@ -374,15 +389,26 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo.enable and enable_chinese_typo:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
sentences.append(typoed_text)
|
||||
if typo_corrections:
|
||||
sentences.append(typo_corrections)
|
||||
# 50%概率新增正确字/词,50%概率用正确分句替换错别字分句
|
||||
if random.random() < 0.5:
|
||||
sentences.append(typoed_text)
|
||||
sentences.append(typo_corrections)
|
||||
else:
|
||||
# 用正确的分句替换错别字分句
|
||||
sentences.append(sentence)
|
||||
else:
|
||||
sentences.append(typoed_text)
|
||||
else:
|
||||
sentences.append(sentence)
|
||||
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.bot.nickname}不知道哦"]
|
||||
if global_config.response_splitter.enable_overflow_return_all:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),直接返回原文")
|
||||
sentences = [cleaned_text]
|
||||
else:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [_get_random_default_reply()]
|
||||
|
||||
# if extracted_contents:
|
||||
# for content in extracted_contents:
|
||||
@@ -441,7 +467,6 @@ def calculate_typing_time(
|
||||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||
@@ -518,7 +543,6 @@ def get_western_ratio(paragraph):
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
"""将时间戳转换为人类可读的时间格式
|
||||
|
||||
@@ -103,14 +103,16 @@ class ImageManager:
|
||||
invalid_values = ["", "None"]
|
||||
|
||||
# 清理 Images 表
|
||||
deleted_images = Images.delete().where(
|
||||
(Images.description >> None) | (Images.description << invalid_values)
|
||||
).execute()
|
||||
deleted_images = (
|
||||
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
|
||||
)
|
||||
|
||||
# 清理 ImageDescriptions 表
|
||||
deleted_descriptions = ImageDescriptions.delete().where(
|
||||
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
|
||||
).execute()
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete()
|
||||
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
|
||||
.execute()
|
||||
)
|
||||
|
||||
if deleted_images or deleted_descriptions:
|
||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
||||
|
||||
@@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
chat_id: str,
|
||||
chat_info_stream_id: str,
|
||||
chat_info_platform: str,
|
||||
action_reasoning:str
|
||||
action_reasoning: str,
|
||||
):
|
||||
self.action_id = action_id
|
||||
self.time = time
|
||||
@@ -235,4 +235,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.action_reasoning = action_reasoning
|
||||
self.action_reasoning = action_reasoning
|
||||
|
||||
@@ -20,6 +20,8 @@ logger = get_logger("database_model")
|
||||
|
||||
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
|
||||
# 这允许您在一个地方为所有模型指定数据库。
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
# 将下面的 'db' 替换为您实际的数据库实例变量名。
|
||||
@@ -265,6 +267,7 @@ class PersonInfo(BaseModel):
|
||||
platform = TextField() # 平台
|
||||
user_id = TextField(index=True) # 用户ID
|
||||
nickname = TextField(null=True) # 用户昵称
|
||||
group_nick_name = TextField(null=True) # 群昵称列表 (JSON格式,存储 [{"group_id": str, "group_nick_name": str}])
|
||||
memory_points = TextField(null=True) # 个人印象的点
|
||||
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
||||
know_since = FloatField(null=True) # 首次印象总结时间
|
||||
@@ -315,58 +318,91 @@ class Expression(BaseModel):
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
class MemoryChest(BaseModel):
|
||||
|
||||
class Jargon(BaseModel):
|
||||
"""
|
||||
用于存储记忆仓库的模型
|
||||
用于存储俚语的模型
|
||||
"""
|
||||
|
||||
title = TextField() # 标题
|
||||
content = TextField() # 内容
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
locked = BooleanField(default=False) # 是否锁定
|
||||
content = TextField()
|
||||
raw_content = TextField(null=True)
|
||||
type = TextField(null=True)
|
||||
translation = TextField(null=True)
|
||||
meaning = TextField(null=True)
|
||||
chat_id = TextField(index=True)
|
||||
is_global = BooleanField(default=False)
|
||||
count = IntegerField(default=0)
|
||||
is_jargon = BooleanField(null=True) # None表示未判定,True表示是黑话,False表示不是黑话
|
||||
last_inference_count = IntegerField(null=True) # 最后一次判定的count值,用于避免重启后重复判定
|
||||
is_complete = BooleanField(default=False) # 是否已完成所有推断(count>=100后不再推断)
|
||||
inference_with_context = TextField(null=True) # 基于上下文的推断结果(JSON格式)
|
||||
inference_content_only = TextField(null=True) # 仅基于词条的推断结果(JSON格式)
|
||||
|
||||
class Meta:
|
||||
table_name = "memory_chest"
|
||||
table_name = "jargon"
|
||||
|
||||
class MemoryConflict(BaseModel):
|
||||
|
||||
class ChatHistory(BaseModel):
|
||||
"""
|
||||
用于存储记忆整合过程中冲突内容的模型
|
||||
用于存储聊天历史概括的模型
|
||||
"""
|
||||
|
||||
conflict_content = TextField() # 冲突内容
|
||||
answer = TextField(null=True) # 回答内容
|
||||
create_time = FloatField() # 创建时间
|
||||
update_time = FloatField() # 更新时间
|
||||
context = TextField(null=True) # 上下文
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
raise_time = FloatField(null=True) # 触发次数
|
||||
chat_id = TextField(index=True) # 聊天ID
|
||||
start_time = DoubleField() # 起始时间
|
||||
end_time = DoubleField() # 结束时间
|
||||
original_text = TextField() # 对话原文
|
||||
participants = TextField() # 参与的所有人的昵称,JSON格式存储
|
||||
theme = TextField() # 主题:这段对话的主要内容,一个简短的标题
|
||||
keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储
|
||||
summary = TextField() # 概括:对这段话的平文本概括
|
||||
count = IntegerField(default=0) # 被检索次数
|
||||
forget_times = IntegerField(default=0) # 被遗忘检查的次数
|
||||
|
||||
class Meta:
|
||||
table_name = "memory_conflicts"
|
||||
table_name = "chat_history"
|
||||
|
||||
|
||||
class ThinkingBack(BaseModel):
|
||||
"""
|
||||
用于存储记忆检索思考过程的模型
|
||||
"""
|
||||
|
||||
chat_id = TextField(index=True) # 聊天ID
|
||||
question = TextField() # 提出的问题
|
||||
context = TextField(null=True) # 上下文信息
|
||||
found_answer = BooleanField(default=False) # 是否找到答案
|
||||
answer = TextField(null=True) # 答案内容
|
||||
thinking_steps = TextField(null=True) # 思考步骤(JSON格式)
|
||||
create_time = DoubleField() # 创建时间
|
||||
update_time = DoubleField() # 更新时间
|
||||
|
||||
class Meta:
|
||||
table_name = "thinking_back"
|
||||
|
||||
|
||||
MODELS = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
ActionRecords,
|
||||
Jargon,
|
||||
ChatHistory,
|
||||
ThinkingBack,
|
||||
]
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""
|
||||
创建所有在模型中定义的数据库表。
|
||||
"""
|
||||
with db:
|
||||
db.create_tables(
|
||||
[
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict, # 添加记忆冲突表
|
||||
]
|
||||
)
|
||||
db.create_tables(MODELS)
|
||||
|
||||
|
||||
def initialize_database(sync_constraints=False):
|
||||
@@ -379,24 +415,9 @@ def initialize_database(sync_constraints=False):
|
||||
如果为 True,会检查并修复字段的 NULL 约束不一致问题。
|
||||
"""
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
with db: # 管理 table_exists 检查的连接
|
||||
for model in models:
|
||||
for model in MODELS:
|
||||
table_name = model._meta.table_name
|
||||
if not db.table_exists(model):
|
||||
logger.warning(f"表 '{table_name}' 未找到,正在创建...")
|
||||
@@ -476,24 +497,9 @@ def sync_field_constraints():
|
||||
如果发现不一致,会自动修复字段约束。
|
||||
"""
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
with db:
|
||||
for model in models:
|
||||
for model in MODELS:
|
||||
table_name = model._meta.table_name
|
||||
if not db.table_exists(model):
|
||||
logger.warning(f"表 '{table_name}' 不存在,跳过约束检查")
|
||||
@@ -660,26 +666,11 @@ def check_field_constraints():
|
||||
用于在修复前预览需要修复的内容。
|
||||
"""
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
inconsistencies = {}
|
||||
|
||||
try:
|
||||
with db:
|
||||
for model in models:
|
||||
for model in MODELS:
|
||||
table_name = model._meta.table_name
|
||||
if not db.table_exists(model):
|
||||
continue
|
||||
|
||||
@@ -351,6 +351,7 @@ MODULE_COLORS = {
|
||||
# 核心模块
|
||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||
"memory": "\033[38;5;34m", # 天蓝色
|
||||
"memory_retrieval": "\033[38;5;34m", # 天蓝色
|
||||
"config": "\033[93m", # 亮黄色
|
||||
"common": "\033[95m", # 亮紫色
|
||||
"tools": "\033[96m", # 亮青色
|
||||
@@ -372,6 +373,8 @@ MODULE_COLORS = {
|
||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# jargon相关
|
||||
"jargon": "\033[38;5;220m", # 金黄色,突出显示
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
@@ -440,6 +443,7 @@ MODULE_ALIASES = {
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
"memory_retrieval": "回忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
"plugin_manager": "插件",
|
||||
@@ -450,6 +454,7 @@ MODULE_ALIASES = {
|
||||
"planner": "规划器",
|
||||
"config": "配置",
|
||||
"main": "主程序",
|
||||
"chat_history_summarizer": "聊天概括器",
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
@@ -18,8 +18,8 @@ class Server:
|
||||
|
||||
# 配置 CORS
|
||||
origins = [
|
||||
"http://localhost:3000", # 允许的前端源
|
||||
"http://127.0.0.1:3000",
|
||||
"http://localhost:7999", # 允许的前端源
|
||||
"http://127.0.0.1:7999",
|
||||
# 在生产环境中,您应该添加实际的前端域名
|
||||
]
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from src.config.official_configs import (
|
||||
MoodConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
JargonConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -55,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.1-snapshot.1"
|
||||
MMC_VERSION = "0.11.2"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -354,6 +355,7 @@ class Config(ConfigBase):
|
||||
debug: DebugConfig
|
||||
mood: MoodConfig
|
||||
voice: VoiceConfig
|
||||
jargon: JargonConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -27,7 +27,7 @@ class BotConfig(ConfigBase):
|
||||
|
||||
nickname: str
|
||||
"""昵称"""
|
||||
|
||||
|
||||
platforms: list[str] = field(default_factory=lambda: [])
|
||||
"""其他平台列表"""
|
||||
|
||||
@@ -88,12 +88,6 @@ class ChatConfig(ConfigBase):
|
||||
mentioned_bot_reply: bool = True
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
auto_chat_value: float = 1
|
||||
"""自动聊天,越小,麦麦主动聊天的概率越低"""
|
||||
|
||||
enable_auto_chat_value_rules: bool = True
|
||||
"""是否启用动态自动聊天频率规则"""
|
||||
|
||||
at_bot_inevitable_reply: float = 1
|
||||
"""@bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
@@ -119,26 +113,12 @@ class ChatConfig(ConfigBase):
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
自动聊天频率规则列表,支持按聊天流/按日内时段配置。
|
||||
规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
|
||||
["", "09:00-22:59", 1.0], # 全局规则:白天正常
|
||||
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
include_planner_reasoning: bool = False
|
||||
"""是否将planner推理加入replyer,默认关闭(不加入)"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
@@ -245,61 +225,6 @@ class ChatConfig(ConfigBase):
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.talk_value
|
||||
|
||||
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 auto_chat_value,未匹配则回退到基础值。"""
|
||||
if not self.enable_auto_chat_value_rules or not self.auto_chat_value_rules:
|
||||
return self.auto_chat_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
# 1) 先尝试匹配指定 chat 的规则
|
||||
if chat_id:
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", "")
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
# 跳过全局
|
||||
if target == "":
|
||||
continue
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
|
||||
if config_chat_id is None or config_chat_id != chat_id:
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 再匹配全局规则("")
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", None)
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if target != "" or not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.auto_chat_value
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
@@ -311,23 +236,24 @@ class MessageReceiveConfig(ConfigBase):
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
max_memory_number: int = 100
|
||||
"""记忆最大数量"""
|
||||
|
||||
memory_build_frequency: int = 1
|
||||
"""记忆构建频率"""
|
||||
|
||||
max_agent_iterations: int = 5
|
||||
"""Agent最多迭代轮数(最低为1)"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.max_agent_iterations < 1:
|
||||
raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: str = "classic"
|
||||
"""表达方式模式,可选:classic经典模式,exp_model 表达模型模式"""
|
||||
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
@@ -494,13 +420,14 @@ class MoodConfig(ConfigBase):
|
||||
|
||||
enable_mood: bool = True
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
|
||||
mood_update_threshold: float = 1
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
|
||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||
"""情感特征,影响情绪的变化情况"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -626,6 +553,9 @@ class ResponseSplitterConfig(ConfigBase):
|
||||
enable_kaomoji_protection: bool = False
|
||||
"""是否启用颜文字保护"""
|
||||
|
||||
enable_overflow_return_all: bool = False
|
||||
"""是否在超出句子数量限制时合并后一次性返回"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TelemetryConfig(ConfigBase):
|
||||
@@ -641,13 +571,19 @@ class DebugConfig(ConfigBase):
|
||||
|
||||
show_prompt: bool = False
|
||||
"""是否显示prompt"""
|
||||
|
||||
|
||||
show_replyer_prompt: bool = True
|
||||
"""是否显示回复器prompt"""
|
||||
|
||||
|
||||
show_replyer_reasoning: bool = True
|
||||
"""是否显示回复器推理"""
|
||||
|
||||
show_jargon_prompt: bool = False
|
||||
"""是否显示jargon相关提示词"""
|
||||
|
||||
show_planner_prompt: bool = False
|
||||
"""是否显示planner相关提示词"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalConfig(ConfigBase):
|
||||
@@ -656,6 +592,25 @@ class ExperimentalConfig(ConfigBase):
|
||||
enable_friend_chat: bool = False
|
||||
"""是否启用好友聊天"""
|
||||
|
||||
chat_prompts: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
为指定聊天添加额外的prompt配置列表
|
||||
格式: ["platform:id:type:prompt内容", ...]
|
||||
|
||||
示例:
|
||||
[
|
||||
"qq:114514:group:这是一个摄影群,你精通摄影知识",
|
||||
"qq:19198:group:这是一个二次元交流群",
|
||||
"qq:114514:private:这是你与好朋友的私聊"
|
||||
]
|
||||
|
||||
说明:
|
||||
- platform: 平台名称,如 "qq"
|
||||
- id: 群ID或用户ID
|
||||
- type: "group" 或 "private"
|
||||
- prompt内容: 要添加的额外prompt文本
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
@@ -692,6 +647,9 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用LPMM知识库"""
|
||||
|
||||
lpmm_mode: Literal["classic", "agent"] = "classic"
|
||||
"""LPMM知识库模式,可选:classic经典模式,agent 模式,结合最新的记忆一同使用"""
|
||||
|
||||
rag_synonym_search_top_k: int = 10
|
||||
"""RAG同义词搜索的Top K数量"""
|
||||
@@ -725,3 +683,11 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class JargonConfig(ConfigBase):
|
||||
"""Jargon配置类"""
|
||||
|
||||
all_global: bool = False
|
||||
"""是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id"""
|
||||
@@ -3,31 +3,30 @@ import difflib
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r'@<[^>]*>', '', content)
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
||||
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
||||
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
@@ -35,11 +34,11 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
@@ -49,10 +48,10 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
@@ -65,11 +64,11 @@ def format_create_date(timestamp: float) -> str:
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
@@ -15,7 +14,6 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.express.express_utils import filter_message_content, calculate_similarity
|
||||
from json_repair import repair_json
|
||||
|
||||
@@ -158,8 +156,6 @@ class ExpressionLearner:
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
@@ -169,7 +165,7 @@ class ExpressionLearner:
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
@@ -183,10 +179,7 @@ class ExpressionLearner:
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
@@ -195,16 +188,13 @@ class ExpressionLearner:
|
||||
) in learnt_expressions:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
continue
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
@@ -215,40 +205,7 @@ class ExpressionLearner:
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
has_new_expressions = True
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
learner.add_style(style, situation)
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
@@ -334,7 +291,7 @@ class ExpressionLearner:
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
|
||||
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"match_responses 内容: {match_responses}")
|
||||
|
||||
@@ -344,12 +301,12 @@ class ExpressionLearner:
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
@@ -367,9 +324,7 @@ class ExpressionLearner:
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -409,7 +364,6 @@ class ExpressionLearner:
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
@@ -426,17 +380,17 @@ class ExpressionLearner:
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
continue
|
||||
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
continue
|
||||
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
@@ -449,7 +403,6 @@ class ExpressionLearner:
|
||||
|
||||
return filtered_with_up
|
||||
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
@@ -483,21 +436,21 @@ class ExpressionLearner:
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
@@ -12,27 +10,25 @@ from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.express.express_utils import filter_message_content, weighted_sample
|
||||
from src.express.express_utils import weighted_sample
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
expression_evaluation_prompt = """{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
{reply_reason_block}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
1.聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2.话题类型(日常、技术、游戏、情感等)
|
||||
3.情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
@@ -46,6 +42,8 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
@@ -115,90 +113,14 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
target_message: 目标消息内容
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 预测的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 过滤目标消息内容,移除回复、表情包等特殊格式
|
||||
filtered_target_message = filter_message_content(target_message)
|
||||
|
||||
logger.info(f"为{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
|
||||
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
for related_chat_id in related_chat_ids:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, filtered_target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
# 获取预测风格的完整信息
|
||||
learner = style_learner_manager.get_learner(related_chat_id)
|
||||
style_id, situation = learner.get_style_info(best_style)
|
||||
|
||||
if style_id and situation:
|
||||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
|
||||
if expr_query.exists():
|
||||
expr = expr_query.get()
|
||||
predicted_expressions.append({
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": filtered_target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
# 按预测分数排序,取前 total_num 个
|
||||
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
|
||||
selected_expressions = predicted_expressions[:total_num]
|
||||
|
||||
logger.info(f"为{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._random_expressions(chat_id, total_num)
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
@@ -207,9 +129,7 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -228,31 +148,32 @@ class ExpressionSelector:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
根据配置模式选择适合的表达方式
|
||||
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
@@ -261,53 +182,9 @@ class ExpressionSelector:
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 获取配置模式
|
||||
expression_mode = global_config.expression.mode
|
||||
|
||||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, target_message, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
else:
|
||||
logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
target_message: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
exp_model模式:直接使用模型预测,不经过LLM
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
target_message: 目标消息内容
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
|
||||
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"exp_model模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
@@ -315,16 +192,18 @@ class ExpressionSelector:
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
@@ -353,25 +232,38 @@ class ExpressionSelector:
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_str = f",现在你想要对这条消息进行回复:“{target_message}”"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
chat_context = ""
|
||||
else:
|
||||
reply_reason_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
chat_observe_info=chat_context,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
reply_reason_block=reply_reason_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# print(prompt)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
@@ -425,17 +317,13 @@ class ExpressionSelector:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
"表达方式激活: 更新last_active_time in db"
|
||||
)
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import Counter, defaultdict
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
|
||||
class ExpressorModel:
|
||||
"""
|
||||
直接使用朴素贝叶斯精排(可在线学习)
|
||||
支持存储situation字段,不参与计算,仅与style对应
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
gamma: float = 1.0,
|
||||
vocab_size: int = 200000,
|
||||
use_jieba: bool = True):
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||||
"""添加候选文本和对应的situation"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||||
"""批量添加候选文本和对应的situations"""
|
||||
for i, (cid, text) in enumerate(items):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
self.add_candidate(cid, text, situation)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return None, {}
|
||||
|
||||
if not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 对所有候选进行评分
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
|
||||
# 根据k参数限制返回的候选数量
|
||||
if k is not None and k > 0:
|
||||
# 按分数降序排序,取前k个
|
||||
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
limited_scores = dict(sorted_scores[:k])
|
||||
best = sorted_scores[0][0] if sorted_scores else None
|
||||
return best, limited_scores
|
||||
else:
|
||||
# 如果没有指定k,返回所有分数
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: float):
|
||||
self.nb.decay(factor=factor)
|
||||
|
||||
def get_situation(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的situation"""
|
||||
return self._situations.get(cid)
|
||||
|
||||
def get_style(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的style"""
|
||||
return self._candidates.get(cid)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""获取候选的style和situation信息"""
|
||||
return self._candidates.get(cid), self._situations.get(cid)
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""获取所有候选的style和situation信息"""
|
||||
return {cid: (style, self._situations.get(cid))
|
||||
for cid, style in self._candidates.items()}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb": {
|
||||
"cls_counts": dict(self.nb.cls_counts),
|
||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||
"alpha": self.nb.alpha,
|
||||
"beta": self.nb.beta,
|
||||
"gamma": self.nb.gamma,
|
||||
"V": self.nb.V,
|
||||
}
|
||||
}, f)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
# 还原候选文本
|
||||
self._candidates = obj["candidates"]
|
||||
# 还原situations(兼容旧版本)
|
||||
self._situations = obj.get("situations", {})
|
||||
# 还原朴素贝叶斯模型
|
||||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||||
self.nb.alpha = obj["nb"]["alpha"]
|
||||
self.nb.beta = obj["nb"]["beta"]
|
||||
self.nb.gamma = obj["nb"]["gamma"]
|
||||
self.nb.V = obj["nb"]["V"]
|
||||
self.nb._logZ.clear()
|
||||
|
||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||
from collections import defaultdict
|
||||
outer = defaultdict(lambda: defaultdict(float))
|
||||
for k, inner in d.items():
|
||||
outer[k].update(inner)
|
||||
return outer
|
||||
@@ -1,60 +0,0 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: float = None):
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
@@ -1,31 +0,0 @@
|
||||
import re
|
||||
from typing import List, Optional, Set
|
||||
|
||||
try:
|
||||
import jieba
|
||||
_HAS_JIEBA = True
|
||||
except Exception:
|
||||
_HAS_JIEBA = False
|
||||
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
# 匹配纯符号的正则表达式
|
||||
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
|
||||
|
||||
def simple_en_tokenize(text: str) -> List[str]:
|
||||
return _WORD_RE.findall(text.lower())
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
||||
self.stopwords = stopwords or set()
|
||||
self.use_jieba = use_jieba and _HAS_JIEBA
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.use_jieba:
|
||||
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
||||
else:
|
||||
toks = simple_en_tokenize(text)
|
||||
# 过滤掉纯符号和停用词
|
||||
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]
|
||||
@@ -1,628 +0,0 @@
|
||||
"""
|
||||
多聊天室表达风格学习系统
|
||||
支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .expressor_model.model import ExpressorModel
|
||||
|
||||
logger = get_logger("style_learner")
|
||||
|
||||
|
||||
class StyleLearner:
|
||||
"""
|
||||
单个聊天室的表达风格学习器
|
||||
学习从up_content到style的映射关系
|
||||
支持动态管理风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
"alpha": 0.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||
"vocab_size": 200000,
|
||||
"use_jieba": True
|
||||
}
|
||||
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0 # 下一个可用的style_id
|
||||
|
||||
# 学习统计
|
||||
self.learning_stats = {
|
||||
"total_samples": 0,
|
||||
"style_counts": defaultdict(int),
|
||||
"last_update": None,
|
||||
"style_usage_frequency": defaultdict(int) # 风格使用频率
|
||||
}
|
||||
|
||||
def add_style(self, style: str, situation: str = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
situation: 对应的situation文本(可选)
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否已存在
|
||||
if style in self.style_to_id:
|
||||
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
|
||||
return True
|
||||
|
||||
# 检查是否超过最大限制
|
||||
if len(self.style_to_id) >= self.max_styles:
|
||||
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
|
||||
return False
|
||||
|
||||
# 生成新的style_id
|
||||
style_id = f"style_{self.next_style_id}"
|
||||
self.next_style_id += 1
|
||||
|
||||
# 添加到映射
|
||||
self.style_to_id[style] = style_id
|
||||
self.id_to_style[style_id] = style
|
||||
if situation:
|
||||
self.id_to_situation[style_id] = situation
|
||||
|
||||
# 添加到expressor模型
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
|
||||
(f", situation: '{situation}'" if situation else ""))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_style(self, style: str) -> bool:
|
||||
"""
|
||||
删除一个风格
|
||||
|
||||
Args:
|
||||
style: 要删除的风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
try:
|
||||
if style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 从映射中删除
|
||||
del self.style_to_id[style]
|
||||
del self.id_to_style[style_id]
|
||||
if style_id in self.id_to_situation:
|
||||
del self.id_to_situation[style_id]
|
||||
|
||||
# 从expressor模型中删除(通过重新构建)
|
||||
self._rebuild_expressor()
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
|
||||
return False
|
||||
|
||||
def update_style(self, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
更新一个风格
|
||||
|
||||
Args:
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
try:
|
||||
if old_style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
|
||||
return False
|
||||
|
||||
if new_style in self.style_to_id and new_style != old_style:
|
||||
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[old_style]
|
||||
|
||||
# 更新映射
|
||||
del self.style_to_id[old_style]
|
||||
self.style_to_id[new_style] = style_id
|
||||
self.id_to_style[style_id] = new_style
|
||||
|
||||
# 更新expressor模型(保留原有的situation)
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, new_style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
|
||||
return False
|
||||
|
||||
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
|
||||
"""
|
||||
批量添加风格
|
||||
|
||||
Args:
|
||||
styles: 风格文本列表
|
||||
situations: 对应的situation文本列表(可选)
|
||||
|
||||
Returns:
|
||||
int: 成功添加的数量
|
||||
"""
|
||||
success_count = 0
|
||||
for i, style in enumerate(styles):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
if self.add_style(style, situation):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
|
||||
return success_count
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
"""获取所有已注册的风格"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def get_style_count(self) -> int:
|
||||
"""获取当前风格数量"""
|
||||
return len(self.style_to_id)
|
||||
|
||||
def get_situation(self, style: str) -> Optional[str]:
|
||||
"""
|
||||
获取风格对应的situation
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Optional[str]: 对应的situation,如果不存在则返回None
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
return self.id_to_situation.get(style_id)
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取风格的完整信息
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: (style_id, situation)
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None, None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
获取所有风格的完整信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
|
||||
"""
|
||||
result = {}
|
||||
for style, style_id in self.style_to_id.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
result[style] = (style_id, situation)
|
||||
return result
|
||||
|
||||
def _rebuild_expressor(self):
|
||||
"""重新构建expressor模型(删除风格后使用)"""
|
||||
try:
|
||||
# 重新创建expressor
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 重新添加所有风格和situation
|
||||
for style_id, style_text in self.id_to_style.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, style_text, situation)
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
|
||||
|
||||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个up_content到style的映射
|
||||
如果style不存在,会自动添加
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
style: 对应的style文本
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
try:
|
||||
# 如果style不存在,先添加它
|
||||
if style not in self.style_to_id:
|
||||
if not self.add_style(style):
|
||||
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
|
||||
return False
|
||||
|
||||
# 获取style_id
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 使用正反馈学习
|
||||
self.expressor.update_positive(up_content, style_id)
|
||||
|
||||
# 更新统计
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats["style_usage_frequency"][style] += 1
|
||||
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style文本, 所有候选的分数]
|
||||
"""
|
||||
try:
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
return None, {}
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
# 转换所有分数
|
||||
style_scores = {}
|
||||
for sid, score in scores.items():
|
||||
style_text = self.id_to_style.get(sid)
|
||||
if style_text:
|
||||
style_scores[style_text] = score
|
||||
|
||||
return best_style, style_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None, {}
|
||||
|
||||
def decay_learning(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对学习到的知识进行衰减(遗忘)
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,None则使用配置中的gamma
|
||||
"""
|
||||
self.expressor.decay(factor)
|
||||
logger.debug(f"[{self.chat_id}] 执行知识衰减")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取学习统计信息"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"total_samples": self.learning_stats["total_samples"],
|
||||
"style_count": len(self.style_to_id),
|
||||
"max_styles": self.max_styles,
|
||||
"style_counts": dict(self.learning_stats["style_counts"]),
|
||||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||||
"last_update": self.learning_stats["last_update"],
|
||||
"all_styles": list(self.style_to_id.keys())
|
||||
}
|
||||
|
||||
def save(self, base_path: str) -> bool:
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
|
||||
"""
|
||||
try:
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
|
||||
# 保存模型和统计信息
|
||||
save_data = {
|
||||
"model_config": self.model_config,
|
||||
"style_to_id": self.style_to_id,
|
||||
"id_to_style": self.id_to_style,
|
||||
"id_to_situation": self.id_to_situation,
|
||||
"next_style_id": self.next_style_id,
|
||||
"max_styles": self.max_styles,
|
||||
"learning_stats": self.learning_stats
|
||||
}
|
||||
|
||||
# 先保存expressor模型
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
self.expressor.save(expressor_path)
|
||||
|
||||
# 保存其他数据
|
||||
with open(file_path, "wb") as f:
|
||||
pickle.dump(save_data, f)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
|
||||
return False
|
||||
|
||||
def load(self, base_path: str) -> bool:
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
base_path: 基础路径
|
||||
"""
|
||||
try:
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
|
||||
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
|
||||
return False
|
||||
|
||||
# 加载其他数据
|
||||
with open(file_path, "rb") as f:
|
||||
save_data = pickle.load(f)
|
||||
|
||||
# 恢复配置和状态
|
||||
self.model_config = save_data["model_config"]
|
||||
self.style_to_id = save_data["style_to_id"]
|
||||
self.id_to_style = save_data["id_to_style"]
|
||||
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
|
||||
self.next_style_id = save_data["next_style_id"]
|
||||
self.max_styles = save_data.get("max_styles", 2000)
|
||||
self.learning_stats = save_data["learning_stats"]
|
||||
|
||||
# 重新创建expressor并加载
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
self.expressor.load(expressor_path)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class StyleLearnerManager:
|
||||
"""
|
||||
多聊天室表达风格学习管理器
|
||||
为每个chat_id维护独立的StyleLearner实例
|
||||
每个chat_id可以动态管理自己的风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, model_save_path: str = "data/style_models"):
|
||||
self.model_save_path = model_save_path
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
|
||||
# 自动保存配置
|
||||
self.auto_save_interval = 300 # 5分钟
|
||||
self._auto_save_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("StyleLearnerManager 已初始化")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置,None则使用默认配置
|
||||
|
||||
Returns:
|
||||
StyleLearner实例
|
||||
"""
|
||||
if chat_id not in self.learners:
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
|
||||
self.learners[chat_id] = learner
|
||||
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def add_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id添加风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.add_style(style)
|
||||
|
||||
def remove_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id删除风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.remove_style(style)
|
||||
|
||||
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id更新风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.update_style(old_style, new_style)
|
||||
|
||||
def get_chat_styles(self, chat_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定chat_id的所有风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
List[str]: 风格列表
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.get_all_styles()
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
style: 对应的style
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
预测最合适的style
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style, 所有候选分数]
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.predict_style(up_content, top_k)
|
||||
|
||||
def decay_all_learners(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对所有学习器执行衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
for learner in self.learners.values():
|
||||
learner.decay_learning(factor)
|
||||
logger.info("已对所有学习器执行衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
"""获取所有学习器的统计信息"""
|
||||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||||
|
||||
def save_all_models(self) -> bool:
|
||||
"""保存所有模型"""
|
||||
success_count = 0
|
||||
for learner in self.learners.values():
|
||||
if learner.save(self.model_save_path):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
|
||||
return success_count == len(self.learners)
|
||||
|
||||
def load_all_models(self) -> int:
|
||||
"""加载所有已保存的模型"""
|
||||
if not os.path.exists(self.model_save_path):
|
||||
return 0
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self.model_save_path):
|
||||
if filename.endswith("_style_model.pkl"):
|
||||
chat_id = filename.replace("_style_model.pkl", "")
|
||||
learner = StyleLearner(chat_id)
|
||||
if learner.load(self.model_save_path):
|
||||
self.learners[chat_id] = learner
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"已加载 {loaded_count} 个模型")
|
||||
return loaded_count
|
||||
|
||||
async def start_auto_save(self) -> None:
|
||||
"""启动自动保存任务"""
|
||||
if self._auto_save_task is None or self._auto_save_task.done():
|
||||
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
|
||||
logger.info("已启动自动保存任务")
|
||||
|
||||
async def stop_auto_save(self) -> None:
|
||||
"""停止自动保存任务"""
|
||||
if self._auto_save_task and not self._auto_save_task.done():
|
||||
self._auto_save_task.cancel()
|
||||
try:
|
||||
await self._auto_save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("已停止自动保存任务")
|
||||
|
||||
async def _auto_save_loop(self) -> None:
|
||||
"""自动保存循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.auto_save_interval)
|
||||
self.save_all_models()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动保存失败: {e}")
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
style_learner_manager = StyleLearnerManager()
|
||||
5
src/jargon/__init__.py
Normal file
5
src/jargon/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .jargon_miner import extract_and_store_jargon
|
||||
|
||||
__all__ = [
|
||||
"extract_and_store_jargon",
|
||||
]
|
||||
861
src/jargon/jargon_miner.py
Normal file
861
src/jargon/jargon_miner.py
Normal file
@@ -0,0 +1,861 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Any
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_anonymous_messages,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_list,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的SELF是你自己的发言**
|
||||
{chat_str}
|
||||
|
||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语
|
||||
- 请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
- 合并重复项,去重
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
- 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
[
|
||||
{{"content": "词条", "raw_content": "包含该词条的完整对话上下文原文"}},
|
||||
{{"content": "词条2", "raw_content": "包含该词条的完整对话上下文原文"}}
|
||||
]
|
||||
|
||||
现在请输出:
|
||||
"""
|
||||
Prompt(prompt_str, "extract_jargon_prompt")
|
||||
|
||||
|
||||
def _init_inference_prompts() -> None:
|
||||
"""初始化含义推断相关的prompt"""
|
||||
# Prompt 1: 基于raw_content和content推断
|
||||
prompt1_str = """
|
||||
**词条内容**
|
||||
{content}
|
||||
**词条出现的上下文(raw_content)其中的SELF是你自己的发言**
|
||||
{raw_content_list}
|
||||
|
||||
请根据以上词条内容和上下文,推断这个词条的含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
- 如果上下文信息不足,无法推断含义,请设置 no_info 为 true
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
|
||||
"no_info": false
|
||||
}}
|
||||
注意:如果信息不足无法推断,请设置 "no_info": true,此时 meaning 可以为空字符串
|
||||
"""
|
||||
Prompt(prompt1_str, "jargon_inference_with_context_prompt")
|
||||
|
||||
# Prompt 2: 仅基于content推断
|
||||
prompt2_str = """
|
||||
**词条内容**
|
||||
{content}
|
||||
|
||||
请仅根据这个词条本身,推断其含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)"
|
||||
}}
|
||||
"""
|
||||
Prompt(prompt2_str, "jargon_inference_content_only_prompt")
|
||||
|
||||
# Prompt 3: 比较两个推断结果
|
||||
prompt3_str = """
|
||||
**推断结果1(基于上下文)**
|
||||
{inference1}
|
||||
|
||||
**推断结果2(仅基于词条)**
|
||||
{inference2}
|
||||
|
||||
请比较这两个推断结果,判断它们是否相同或类似。
|
||||
- 如果两个推断结果的"含义"相同或类似,说明这个词条不是黑话(含义明确)
|
||||
- 如果两个推断结果有差异,说明这个词条可能是黑话(需要上下文才能理解)
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"is_similar": true/false,
|
||||
"reason": "判断理由"
|
||||
}}
|
||||
"""
|
||||
Prompt(prompt3_str, "jargon_compare_inference_prompt")
|
||||
|
||||
|
||||
_init_prompt()
|
||||
_init_inference_prompts()
|
||||
|
||||
|
||||
async def _enrich_raw_content_if_needed(
|
||||
content: str,
|
||||
raw_content_list: List[str],
|
||||
chat_id: str,
|
||||
messages: List[Any],
|
||||
extraction_start_time: float,
|
||||
extraction_end_time: float,
|
||||
) -> List[str]:
|
||||
"""
|
||||
检查raw_content是否只包含黑话本身,如果是,则获取该消息的前三条消息作为原始内容
|
||||
|
||||
Args:
|
||||
content: 黑话内容
|
||||
raw_content_list: 原始raw_content列表
|
||||
chat_id: 聊天ID
|
||||
messages: 当前时间窗口内的消息列表
|
||||
extraction_start_time: 提取开始时间
|
||||
extraction_end_time: 提取结束时间
|
||||
|
||||
Returns:
|
||||
处理后的raw_content列表
|
||||
"""
|
||||
enriched_list = []
|
||||
|
||||
for raw_content in raw_content_list:
|
||||
# 检查raw_content是否只包含黑话本身(去除空白字符后比较)
|
||||
raw_content_clean = raw_content.strip()
|
||||
content_clean = content.strip()
|
||||
|
||||
# 如果raw_content只包含黑话本身(可能有一些标点或空白),则尝试获取上下文
|
||||
# 去除所有空白字符后比较,确保只包含黑话本身
|
||||
raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
|
||||
if raw_content_normalized == content_normalized:
|
||||
# 在消息列表中查找只包含该黑话的消息(去除空白后比较)
|
||||
target_message = None
|
||||
for msg in messages:
|
||||
msg_content = (msg.processed_plain_text or msg.display_message or "").strip()
|
||||
msg_content_normalized = msg_content.replace(" ", "").replace("\n", "").replace("\t", "")
|
||||
# 检查消息内容是否只包含黑话本身(去除空白后完全匹配)
|
||||
if msg_content_normalized == content_normalized:
|
||||
target_message = msg
|
||||
break
|
||||
|
||||
if target_message and target_message.time:
|
||||
# 获取该消息的前三条消息
|
||||
try:
|
||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=target_message.time,
|
||||
limit=3
|
||||
)
|
||||
|
||||
if previous_messages:
|
||||
# 将前三条消息和当前消息一起格式化
|
||||
context_messages = previous_messages + [target_message]
|
||||
# 按时间排序
|
||||
context_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
# 格式化为可读消息
|
||||
formatted_context, _ = await build_readable_messages_with_list(
|
||||
context_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
)
|
||||
|
||||
if formatted_context.strip():
|
||||
enriched_list.append(formatted_context.strip())
|
||||
logger.warning(f"为黑话 {content} 补充了上下文消息")
|
||||
else:
|
||||
# 如果格式化失败,使用原始raw_content
|
||||
enriched_list.append(raw_content)
|
||||
else:
|
||||
# 没有找到前三条消息,使用原始raw_content
|
||||
enriched_list.append(raw_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取黑话 {content} 的上下文消息失败: {e}")
|
||||
# 出错时使用原始raw_content
|
||||
enriched_list.append(raw_content)
|
||||
else:
|
||||
# 没有找到包含黑话的消息,使用原始raw_content
|
||||
enriched_list.append(raw_content)
|
||||
else:
|
||||
# raw_content包含更多内容,直接使用
|
||||
enriched_list.append(raw_content)
|
||||
|
||||
return enriched_list
|
||||
|
||||
|
||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
"""
|
||||
判断是否需要进行含义推断
|
||||
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
|
||||
并且count必须大于last_inference_count,避免重启后重复判定
|
||||
如果is_complete为True,不再进行推断
|
||||
"""
|
||||
# 如果已完成所有推断,不再推断
|
||||
if jargon_obj.is_complete:
|
||||
return False
|
||||
|
||||
count = jargon_obj.count or 0
|
||||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
|
||||
# 如果count没有超过上次判定值,不需要判定
|
||||
if count <= last_inference:
|
||||
return False
|
||||
|
||||
# 找到第一个大于last_inference的阈值
|
||||
next_threshold = None
|
||||
for threshold in thresholds:
|
||||
if threshold > last_inference:
|
||||
next_threshold = threshold
|
||||
break
|
||||
|
||||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||||
if next_threshold is None:
|
||||
return False
|
||||
|
||||
# 检查count是否达到或超过这个阈值
|
||||
return count >= next_threshold
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.last_learning_time: float = time.time()
|
||||
# 频率控制,可按需调整
|
||||
self.min_messages_for_learning: int = 15
|
||||
self.min_learning_interval: float = 20
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
|
||||
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
|
||||
"""通过ID加载对象并推断"""
|
||||
try:
|
||||
jargon_obj = Jargon.get_by_id(jargon_id)
|
||||
# 再次检查is_complete,因为可能在异步任务执行时已被标记为完成
|
||||
if jargon_obj.is_complete:
|
||||
logger.debug(f"jargon {jargon_obj.content} 已完成所有推断,跳过")
|
||||
return
|
||||
await self.infer_meaning(jargon_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"通过ID推断jargon失败: {e}")
|
||||
|
||||
async def infer_meaning(self, jargon_obj: Jargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""
|
||||
try:
|
||||
content = jargon_obj.content
|
||||
raw_content_str = jargon_obj.raw_content or ""
|
||||
|
||||
# 解析raw_content列表
|
||||
raw_content_list = []
|
||||
if raw_content_str:
|
||||
try:
|
||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||||
|
||||
if not raw_content_list:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
raw_content_list=raw_content_text,
|
||||
)
|
||||
|
||||
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断1结果
|
||||
inference1 = None
|
||||
try:
|
||||
resp1 = response1.strip()
|
||||
if resp1.startswith("{") and resp1.endswith("}"):
|
||||
inference1 = json.loads(resp1)
|
||||
else:
|
||||
repaired = repair_json(resp1)
|
||||
inference1 = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(inference1, dict):
|
||||
logger.warning(f"jargon {content} 推断1结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断1解析失败: {e}")
|
||||
return
|
||||
|
||||
# 检查推断1是否表示信息不足无法推断
|
||||
no_info = inference1.get("no_info", False)
|
||||
meaning1 = inference1.get("meaning", "").strip()
|
||||
if no_info or not meaning1:
|
||||
logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新")
|
||||
# 更新最后一次判定的count值,避免在同一阈值重复尝试
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
jargon_obj.save()
|
||||
return
|
||||
|
||||
|
||||
# 步骤2: 仅基于content推断
|
||||
prompt2 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_content_only_prompt",
|
||||
content=content,
|
||||
)
|
||||
|
||||
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断2结果
|
||||
inference2 = None
|
||||
try:
|
||||
resp2 = response2.strip()
|
||||
if resp2.startswith("{") and resp2.endswith("}"):
|
||||
inference2 = json.loads(resp2)
|
||||
else:
|
||||
repaired = repair_json(resp2)
|
||||
inference2 = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(inference2, dict):
|
||||
logger.warning(f"jargon {content} 推断2结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
else:
|
||||
logger.debug(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.debug(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.debug(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
# 步骤3: 比较两个推断结果
|
||||
prompt3 = await global_prompt_manager.format_prompt(
|
||||
"jargon_compare_inference_prompt",
|
||||
inference1=json.dumps(inference1, ensure_ascii=False),
|
||||
inference2=json.dumps(inference2, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
# 解析比较结果
|
||||
comparison = None
|
||||
try:
|
||||
resp3 = response3.strip()
|
||||
if resp3.startswith("{") and resp3.endswith("}"):
|
||||
comparison = json.loads(resp3)
|
||||
else:
|
||||
repaired = repair_json(resp3)
|
||||
comparison = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(comparison, dict):
|
||||
logger.warning(f"jargon {content} 比较结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 比较解析失败: {e}")
|
||||
return
|
||||
|
||||
# 判断是否为黑话
|
||||
is_similar = comparison.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
if is_jargon:
|
||||
# 是黑话,使用推断1的结果(基于上下文,更准确)
|
||||
jargon_obj.meaning = inference1.get("meaning", "")
|
||||
else:
|
||||
# 不是黑话,也记录含义(使用推断2的结果,因为含义明确)
|
||||
jargon_obj.meaning = inference2.get("meaning", "")
|
||||
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
|
||||
jargon_obj.save()
|
||||
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||||
meaning = jargon_obj.meaning or "无详细说明"
|
||||
is_global = jargon_obj.is_global
|
||||
if is_global:
|
||||
logger.info(f"[通用黑话]{content}的含义是 {meaning}")
|
||||
else:
|
||||
logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}")
|
||||
else:
|
||||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||
logger.info(f"[{self.stream_name}]{content} 不是黑话")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"jargon推断失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
# 冷却时间检查
|
||||
if time.time() - self.last_learning_time < self.min_learning_interval:
|
||||
return False
|
||||
|
||||
# 拉取最近消息数量是否足够
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
|
||||
|
||||
async def run_once(self) -> None:
|
||||
try:
|
||||
if not self.should_trigger():
|
||||
return
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_learning_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 拉取学习窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
limit=20,
|
||||
)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
chat_str: str = await build_anonymous_messages(messages)
|
||||
if not chat_str.strip():
|
||||
return
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"extract_jargon_prompt",
|
||||
chat_str=chat_str,
|
||||
)
|
||||
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
|
||||
# 解析为JSON
|
||||
entries: List[dict] = []
|
||||
try:
|
||||
resp = response.strip()
|
||||
parsed = None
|
||||
if resp.startswith("[") and resp.endswith("]"):
|
||||
parsed = json.loads(resp)
|
||||
else:
|
||||
repaired = repair_json(resp)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return
|
||||
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
content = str(item.get("content", "")).strip()
|
||||
raw_content_value = item.get("raw_content", "")
|
||||
|
||||
# 处理raw_content:可能是字符串或列表
|
||||
raw_content_list = []
|
||||
if isinstance(raw_content_value, list):
|
||||
raw_content_list = [str(rc).strip() for rc in raw_content_value if str(rc).strip()]
|
||||
# 去重
|
||||
raw_content_list = list(dict.fromkeys(raw_content_list))
|
||||
elif isinstance(raw_content_value, str):
|
||||
raw_content_str = raw_content_value.strip()
|
||||
if raw_content_str:
|
||||
raw_content_list = [raw_content_str]
|
||||
|
||||
if content and raw_content_list:
|
||||
entries.append({
|
||||
"content": content,
|
||||
"raw_content": raw_content_list
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 去重并写入DB(按 chat_id + content 去重)
|
||||
# 使用content作为去重键
|
||||
seen = set()
|
||||
uniq_entries = []
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
if content_key not in seen:
|
||||
seen.add(content_key)
|
||||
uniq_entries.append(entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
# 检查并补充raw_content:如果只包含黑话本身,则获取前三条消息作为上下文
|
||||
raw_content_list = await _enrich_raw_content_if_needed(
|
||||
content=content,
|
||||
raw_content_list=raw_content_list,
|
||||
chat_id=self.chat_id,
|
||||
messages=messages,
|
||||
extraction_start_time=extraction_start_time,
|
||||
extraction_end_time=extraction_end_time,
|
||||
)
|
||||
|
||||
try:
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(Jargon.content == content)
|
||||
)
|
||||
else:
|
||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(
|
||||
(Jargon.chat_id == self.chat_id) &
|
||||
(Jargon.content == content)
|
||||
)
|
||||
)
|
||||
|
||||
if query.exists():
|
||||
obj = query.get()
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=self.chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
|
||||
self.last_learning_time = extraction_end_time
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
|
||||
|
||||
class JargonMinerManager:
|
||||
def __init__(self) -> None:
|
||||
self._miners: dict[str, JargonMiner] = {}
|
||||
|
||||
def get_miner(self, chat_id: str) -> JargonMiner:
|
||||
if chat_id not in self._miners:
|
||||
self._miners[chat_id] = JargonMiner(chat_id)
|
||||
return self._miners[chat_id]
|
||||
|
||||
|
||||
miner_manager = JargonMinerManager()
|
||||
|
||||
|
||||
async def extract_and_store_jargon(chat_id: str) -> None:
|
||||
miner = miner_manager.get_miner(chat_id)
|
||||
await miner.run_once()
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
case_sensitive: bool = False,
|
||||
fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索jargon,支持大小写不敏感和模糊搜索
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
chat_id: 可选的聊天ID
|
||||
- 如果开启了all_global:此参数被忽略,查询所有is_global=True的记录
|
||||
- 如果关闭了all_global:如果提供则优先搜索该聊天或global的jargon
|
||||
limit: 返回结果数量限制,默认10
|
||||
case_sensitive: 是否大小写敏感,默认False(不敏感)
|
||||
fuzzy: 是否模糊搜索,默认True(使用LIKE匹配)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 包含content, meaning的字典列表
|
||||
"""
|
||||
if not keyword or not keyword.strip():
|
||||
return []
|
||||
|
||||
keyword = keyword.strip()
|
||||
|
||||
# 构建查询
|
||||
query = Jargon.select(
|
||||
Jargon.content,
|
||||
Jargon.meaning
|
||||
)
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive:
|
||||
# 大小写敏感
|
||||
if fuzzy:
|
||||
# 模糊搜索
|
||||
search_condition = Jargon.content.contains(keyword)
|
||||
else:
|
||||
# 精确匹配
|
||||
search_condition = (Jargon.content == keyword)
|
||||
else:
|
||||
# 大小写不敏感
|
||||
if fuzzy:
|
||||
# 模糊搜索(使用LOWER函数)
|
||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||
else:
|
||||
# 精确匹配(使用LOWER函数)
|
||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
||||
|
||||
query = query.where(search_condition)
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||
if chat_id:
|
||||
query = query.where(
|
||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
||||
)
|
||||
|
||||
# 只返回有meaning的记录
|
||||
query = query.where(
|
||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||
)
|
||||
|
||||
# 按count降序排序,优先返回出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
# 限制结果数量
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询并返回结果
|
||||
results = []
|
||||
for jargon in query:
|
||||
results.append({
|
||||
"content": jargon.content or "",
|
||||
"meaning": jargon.meaning or ""
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None:
|
||||
"""将黑话存入jargon系统
|
||||
|
||||
Args:
|
||||
jargon_keyword: 黑话关键词
|
||||
answer: 答案内容(将概括为raw_content)
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
try:
|
||||
# 概括答案为简短的raw_content
|
||||
summary_prompt = f"""请将以下答案概括为一句简短的话(不超过50字),作为黑话"{jargon_keyword}"的使用示例:
|
||||
|
||||
答案:{answer}
|
||||
|
||||
只输出概括后的内容,不要输出其他内容:"""
|
||||
|
||||
success, summary, _, _ = await llm_api.generate_with_model(
|
||||
summary_prompt,
|
||||
model_config=model_config.model_task_config.utils_small,
|
||||
request_type="memory.summarize_jargon",
|
||||
)
|
||||
|
||||
logger.info(f"概括答案提示: {summary_prompt}")
|
||||
logger.info(f"概括答案: {summary}")
|
||||
|
||||
if not success:
|
||||
logger.warning(f"概括答案失败,使用原始答案: {summary}")
|
||||
summary = answer[:100] # 截取前100字符作为备用
|
||||
|
||||
raw_content = summary.strip()[:200] # 限制长度
|
||||
|
||||
# 检查是否已存在
|
||||
if global_config.jargon.all_global:
|
||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||
else:
|
||||
query = Jargon.select().where(
|
||||
(Jargon.chat_id == chat_id) &
|
||||
(Jargon.content == jargon_keyword)
|
||||
)
|
||||
|
||||
if query.exists():
|
||||
# 更新现有记录
|
||||
obj = query.get()
|
||||
obj.count = (obj.count or 0) + 1
|
||||
|
||||
# 合并raw_content列表
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + [raw_content]))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
if global_config.jargon.all_global:
|
||||
obj.is_global = True
|
||||
|
||||
obj.save()
|
||||
logger.info(f"更新jargon记录: {jargon_keyword}")
|
||||
else:
|
||||
# 创建新记录
|
||||
is_global_new = True if global_config.jargon.all_global else False
|
||||
Jargon.create(
|
||||
content=jargon_keyword,
|
||||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||
chat_id=chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
)
|
||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储jargon失败: {e}")
|
||||
|
||||
|
||||
@@ -143,8 +143,13 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
:param tool_option_param: 工具参数对象
|
||||
:return: 转换后的工具参数字典
|
||||
"""
|
||||
# JSON Schema要求使用"boolean"而不是"bool"
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": tool_option_param.param_type.value,
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
}
|
||||
if tool_option_param.enum_values:
|
||||
@@ -250,7 +255,7 @@ def _build_stream_api_resp(
|
||||
if fr:
|
||||
reason = str(fr)
|
||||
break
|
||||
|
||||
|
||||
if str(reason).endswith("MAX_TOKENS"):
|
||||
has_visible_output = bool(resp.content and resp.content.strip())
|
||||
if has_visible_output:
|
||||
@@ -281,8 +286,8 @@ async def _default_stream_response_handler(
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||
_usage_record = None # 使用情况记录
|
||||
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
|
||||
resp = APIResponse()
|
||||
|
||||
resp = APIResponse()
|
||||
|
||||
def _insure_buffer_closed():
|
||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||
_fc_delta_buffer.close()
|
||||
@@ -298,7 +303,7 @@ async def _default_stream_response_handler(
|
||||
chunk,
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
resp=resp,
|
||||
resp=resp,
|
||||
)
|
||||
|
||||
if chunk.usage_metadata:
|
||||
@@ -314,7 +319,7 @@ async def _default_stream_response_handler(
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
last_resp=last_resp,
|
||||
resp=resp,
|
||||
resp=resp,
|
||||
), _usage_record
|
||||
except Exception:
|
||||
# 确保缓冲区被关闭
|
||||
|
||||
@@ -36,7 +36,7 @@ from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
|
||||
logger = get_logger("OpenAI客户端")
|
||||
logger = get_logger("llm_models")
|
||||
|
||||
|
||||
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
||||
@@ -77,6 +77,23 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
|
||||
"content": content,
|
||||
}
|
||||
|
||||
if message.role == RoleType.Assistant and getattr(message, "tool_calls", None):
|
||||
tool_calls_payload: list[dict[str, Any]] = []
|
||||
for call in message.tool_calls or []:
|
||||
tool_calls_payload.append(
|
||||
{
|
||||
"id": call.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.func_name,
|
||||
"arguments": json.dumps(call.args or {}, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
)
|
||||
ret["tool_calls"] = tool_calls_payload
|
||||
if ret["content"] == []:
|
||||
ret["content"] = ""
|
||||
|
||||
# 添加工具调用ID
|
||||
if message.role == RoleType.Tool:
|
||||
if not message.tool_call_id:
|
||||
@@ -101,8 +118,13 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
|
||||
:param tool_option_param: 工具参数对象
|
||||
:return: 转换后的工具参数字典
|
||||
"""
|
||||
# JSON Schema要求使用"boolean"而不是"bool"
|
||||
param_type_value = tool_option_param.param_type.value
|
||||
if param_type_value == "bool":
|
||||
param_type_value = "boolean"
|
||||
|
||||
return_dict: dict[str, Any] = {
|
||||
"type": tool_option_param.param_type.value,
|
||||
"type": param_type_value,
|
||||
"description": tool_option_param.description,
|
||||
}
|
||||
if tool_option_param.enum_values:
|
||||
@@ -239,7 +261,7 @@ def _build_stream_api_resp(
|
||||
|
||||
# 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出)
|
||||
# 保留 finish_reason 仅用于上层判断
|
||||
|
||||
|
||||
if not resp.content and not resp.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
@@ -293,7 +315,7 @@ async def _default_stream_response_handler(
|
||||
|
||||
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
|
||||
finish_reason = event.choices[0].finish_reason
|
||||
|
||||
|
||||
if hasattr(event, "model") and event.model and not _model_name:
|
||||
_model_name = event.model # 记录模型名
|
||||
|
||||
@@ -341,10 +363,7 @@ async def _default_stream_response_handler(
|
||||
model_dbg = None
|
||||
|
||||
# 统一日志格式
|
||||
logger.info(
|
||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
||||
% (model_dbg or "")
|
||||
)
|
||||
logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (model_dbg or ""))
|
||||
|
||||
return resp, _usage_record
|
||||
except Exception:
|
||||
@@ -387,9 +406,7 @@ def _default_normal_response_parser(
|
||||
raw_snippet = str(resp)[:300]
|
||||
except Exception:
|
||||
raw_snippet = "<unserializable>"
|
||||
logger.debug(
|
||||
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
|
||||
)
|
||||
logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}")
|
||||
except Exception:
|
||||
# 日志采集失败不应影响控制流
|
||||
pass
|
||||
@@ -444,17 +461,14 @@ def _default_normal_response_parser(
|
||||
choice0 = resp.choices[0]
|
||||
reason = getattr(choice0, "finish_reason", None)
|
||||
if reason and reason == "length":
|
||||
print(resp)
|
||||
# print(resp)
|
||||
_model_name = resp.model
|
||||
# 统一日志格式
|
||||
logger.info(
|
||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
||||
% (_model_name or "")
|
||||
)
|
||||
logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (_model_name or ""))
|
||||
return api_response, _usage_record
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||
|
||||
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from .tool_option import ToolCall
|
||||
|
||||
|
||||
# 设计这系列类的目的是为未来可能的扩展做准备
|
||||
@@ -20,6 +23,7 @@ class Message:
|
||||
role: RoleType,
|
||||
content: str | list[tuple[str, str] | str],
|
||||
tool_call_id: str | None = None,
|
||||
tool_calls: Optional[List[ToolCall]] = None,
|
||||
):
|
||||
"""
|
||||
初始化消息对象
|
||||
@@ -28,6 +32,13 @@ class Message:
|
||||
self.role: RoleType = role
|
||||
self.content: str | list[tuple[str, str] | str] = content
|
||||
self.tool_call_id: str | None = tool_call_id
|
||||
self.tool_calls: Optional[List[ToolCall]] = tool_calls
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Role: {self.role}, Content: {self.content}, "
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
@@ -35,6 +46,7 @@ class MessageBuilder:
|
||||
self.__role: RoleType = RoleType.User
|
||||
self.__content: list[tuple[str, str] | str] = []
|
||||
self.__tool_call_id: str | None = None
|
||||
self.__tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||
"""
|
||||
@@ -86,12 +98,27 @@ class MessageBuilder:
|
||||
self.__tool_call_id = tool_call_id
|
||||
return self
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
"""
|
||||
设置助手消息的工具调用列表
|
||||
:param tool_calls: 工具调用列表
|
||||
:return: MessageBuilder对象
|
||||
"""
|
||||
if self.__role != RoleType.Assistant:
|
||||
raise ValueError("仅当角色为Assistant时才能设置工具调用列表")
|
||||
if not tool_calls:
|
||||
raise ValueError("工具调用列表不能为空")
|
||||
self.__tool_calls = tool_calls
|
||||
return self
|
||||
|
||||
def build(self) -> Message:
|
||||
"""
|
||||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0:
|
||||
if len(self.__content) == 0 and not (
|
||||
self.__role == RoleType.Assistant and self.__tool_calls
|
||||
):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
@@ -104,4 +131,5 @@ class MessageBuilder:
|
||||
else self.__content
|
||||
),
|
||||
tool_call_id=self.__tool_call_id,
|
||||
tool_calls=self.__tool_calls,
|
||||
)
|
||||
|
||||
@@ -166,6 +166,57 @@ class LLMRequest:
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def generate_response_with_message_async(
|
||||
self,
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
message_factory (Callable[[BaseClient], List[Message]]): 已构建好的消息工厂
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
||||
raise_when_empty (bool): 当响应为空时是否抛出异常
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
message_factory=message_factory,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
||||
logger.debug(f"LLM生成内容: {response}")
|
||||
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""
|
||||
@@ -277,9 +328,7 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except NetworkConnectionError as e:
|
||||
@@ -289,9 +338,7 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except RespNotOkException as e:
|
||||
|
||||
47
src/main.py
47
src/main.py
@@ -5,6 +5,8 @@ from maim_message import MessageServer
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
|
||||
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
@@ -13,7 +15,6 @@ from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server, Server
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.memory_system.memory_management_task import MemoryManagementTask
|
||||
from rich.traceback import install
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
@@ -35,6 +36,37 @@ class MainSystem:
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
|
||||
# 注册 WebUI API 路由
|
||||
self._register_webui_routes()
|
||||
|
||||
# 设置 WebUI(开发/生产模式)
|
||||
self._setup_webui()
|
||||
|
||||
def _register_webui_routes(self):
|
||||
"""注册 WebUI API 路由"""
|
||||
try:
|
||||
from src.webui.routes import router as webui_router
|
||||
self.server.register_router(webui_router)
|
||||
logger.info("WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.warning(f"注册 WebUI API 路由失败: {e}")
|
||||
|
||||
def _setup_webui(self):
|
||||
"""设置 WebUI(根据环境变量决定模式)"""
|
||||
import os
|
||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||
if not webui_enabled:
|
||||
logger.info("WebUI 已禁用")
|
||||
return
|
||||
|
||||
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
|
||||
|
||||
try:
|
||||
from src.webui.manager import setup_webui
|
||||
setup_webui(mode=webui_mode)
|
||||
except Exception as e:
|
||||
logger.error(f"设置 WebUI 失败: {e}")
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
@@ -65,9 +97,17 @@ class MainSystem:
|
||||
# 添加统计信息输出任务
|
||||
await async_task_manager.add_task(StatisticOutputTask())
|
||||
|
||||
# 添加聊天流统计任务(每5分钟生成一次报告,统计最近30天的数据)
|
||||
# await async_task_manager.add_task(TokenStatisticsTask())
|
||||
|
||||
# 添加遥测心跳任务
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
# 添加记忆遗忘任务
|
||||
from src.chat.utils.memory_forget_task import MemoryForgetTask
|
||||
|
||||
await async_task_manager.add_task(MemoryForgetTask())
|
||||
|
||||
# 启动API服务器
|
||||
# start_api_server()
|
||||
# logger.info("API服务器启动成功")
|
||||
@@ -92,10 +132,6 @@ class MainSystem:
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# 添加记忆管理任务
|
||||
await async_task_manager.add_task(MemoryManagementTask())
|
||||
logger.info("记忆管理任务已启动")
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
@@ -103,7 +139,6 @@ class MainSystem:
|
||||
self.app.register_message_handler(chat_bot.message_process)
|
||||
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
||||
|
||||
|
||||
# 触发 ON_START 事件
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
@@ -1,876 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.message_api import build_readable_messages
|
||||
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
|
||||
from json_repair import repair_json
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
|
||||
from .memory_utils import (
|
||||
find_best_matching_memory,
|
||||
check_title_exists_fuzzy,
|
||||
get_all_titles,
|
||||
find_most_similar_memory_by_chat_id,
|
||||
|
||||
)
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
class MemoryChest:
|
||||
def __init__(self):
|
||||
|
||||
self.LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory_chest",
|
||||
)
|
||||
|
||||
self.LLMRequest_build = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="memory_chest_build",
|
||||
)
|
||||
|
||||
|
||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
||||
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
|
||||
|
||||
def remove_one_memory_by_age_weight(self) -> bool:
|
||||
"""
|
||||
删除一条记忆:按“越老/越新更易被删”的权重随机选择(老=较小id,新=较大id)。
|
||||
|
||||
返回:是否删除成功
|
||||
"""
|
||||
try:
|
||||
memories = list(MemoryChestModel.select())
|
||||
if not memories:
|
||||
return False
|
||||
|
||||
# 排除锁定项
|
||||
candidates = [m for m in memories if not getattr(m, "locked", False)]
|
||||
if not candidates:
|
||||
return False
|
||||
|
||||
# 按 id 排序,使用 id 近似时间顺序(小 -> 老,大 -> 新)
|
||||
candidates.sort(key=lambda m: m.id)
|
||||
n = len(candidates)
|
||||
if n == 1:
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == candidates[0].id).execute()
|
||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{candidates[0].title}")
|
||||
return True
|
||||
|
||||
# 计算U型权重:中间最低,两端最高
|
||||
# r ∈ [0,1] 为位置归一化,w = 0.1 + 0.9 * (abs(r-0.5)*2)**1.5
|
||||
weights = []
|
||||
for idx, _m in enumerate(candidates):
|
||||
r = idx / (n - 1)
|
||||
w = 0.1 + 0.9 * (abs(r - 0.5) * 2) ** 1.5
|
||||
weights.append(w)
|
||||
|
||||
import random as _random
|
||||
selected = _random.choices(candidates, weights=weights, k=1)[0]
|
||||
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == selected.id).execute()
|
||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{selected.title}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 按年龄权重删除记忆时出错: {e}")
|
||||
return False
|
||||
|
||||
def _compute_merge_similarity_threshold(self) -> float:
|
||||
"""
|
||||
根据当前记忆数量占比动态计算合并相似度阈值。
|
||||
|
||||
规则:占比越高,阈值越低。
|
||||
- < 60%: 0.80(更严格,避免早期误合并)
|
||||
- < 80%: 0.70
|
||||
- < 100%: 0.60
|
||||
- < 120%: 0.50
|
||||
- >= 120%: 0.45(最宽松,加速收敛)
|
||||
"""
|
||||
try:
|
||||
current_count = MemoryChestModel.select().count()
|
||||
max_count = max(1, int(global_config.memory.max_memory_number))
|
||||
percentage = current_count / max_count
|
||||
|
||||
if percentage < 0.6:
|
||||
return 0.70
|
||||
elif percentage < 0.8:
|
||||
return 0.60
|
||||
elif percentage < 1.0:
|
||||
return 0.50
|
||||
elif percentage < 1.5:
|
||||
return 0.40
|
||||
elif percentage < 2:
|
||||
return 0.30
|
||||
else:
|
||||
return 0.25
|
||||
except Exception:
|
||||
# 发生异常时使用保守阈值
|
||||
return 0.70
|
||||
|
||||
async def build_running_content(self, chat_id: str = None) -> str:
|
||||
"""
|
||||
构建记忆仓库的运行内容
|
||||
|
||||
Args:
|
||||
message_str: 消息内容
|
||||
chat_id: 聊天ID,用于提取对应的运行内容
|
||||
|
||||
Returns:
|
||||
str: 构建后的运行内容
|
||||
"""
|
||||
# 检查是否需要更新:基于消息数量和最新消息时间差的智能更新机制
|
||||
#
|
||||
# 更新机制说明:
|
||||
# 1. 消息数量 > 100:直接触发更新(高频消息场景)
|
||||
# 2. 消息数量 > 70 且最新消息时间差 > 30秒:触发更新(中高频消息场景)
|
||||
# 3. 消息数量 > 50 且最新消息时间差 > 60秒:触发更新(中频消息场景)
|
||||
# 4. 消息数量 > 30 且最新消息时间差 > 300秒:触发更新(低频消息场景)
|
||||
#
|
||||
# 设计理念:
|
||||
# - 消息越密集,时间阈值越短,确保及时更新记忆
|
||||
# - 消息越稀疏,时间阈值越长,避免频繁无意义的更新
|
||||
# - 通过最新消息时间差判断消息活跃度,而非简单的总时间差
|
||||
# - 平衡更新频率与性能,在保证记忆及时性的同时减少计算开销
|
||||
if chat_id not in self.running_content_list:
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": "",
|
||||
"last_update_time": time.time(),
|
||||
"create_time": time.time()
|
||||
}
|
||||
|
||||
should_update = True
|
||||
if chat_id and chat_id in self.running_content_list:
|
||||
last_update_time = self.running_content_list[chat_id]["last_update_time"]
|
||||
current_time = time.time()
|
||||
# 使用message_api获取消息数量
|
||||
message_list = get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=last_update_time,
|
||||
timestamp_end=current_time,
|
||||
chat_id=chat_id,
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
new_messages_count = len(message_list)
|
||||
|
||||
# 获取最新消息的时间戳
|
||||
latest_message_time = last_update_time
|
||||
if message_list:
|
||||
# 假设消息列表按时间排序,取最后一条消息的时间戳
|
||||
latest_message = message_list[-1]
|
||||
if hasattr(latest_message, 'timestamp'):
|
||||
latest_message_time = latest_message.timestamp
|
||||
elif isinstance(latest_message, dict) and 'timestamp' in latest_message:
|
||||
latest_message_time = latest_message['timestamp']
|
||||
|
||||
# 计算最新消息时间与现在时间的差(秒)
|
||||
latest_message_time_diff = current_time - latest_message_time
|
||||
|
||||
# 智能更新条件判断 - 按优先级从高到低检查
|
||||
should_update = False
|
||||
update_reason = ""
|
||||
|
||||
if global_config.memory.memory_build_frequency > 0:
|
||||
if new_messages_count > 100/global_config.memory.memory_build_frequency:
|
||||
# 条件1:消息数量 > 100,直接触发更新
|
||||
# 适用场景:群聊刷屏、高频讨论等消息密集场景
|
||||
# 无需时间限制,确保重要信息不被遗漏
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 100,直接触发更新"
|
||||
elif new_messages_count > 70/global_config.memory.memory_build_frequency and latest_message_time_diff > 30:
|
||||
# 条件2:消息数量 > 70 且最新消息时间差 > 30秒
|
||||
# 适用场景:中高频讨论,但需要确保消息流已稳定
|
||||
# 30秒的时间差确保不是正在进行的实时对话
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 70 且最新消息时间差 {latest_message_time_diff:.1f}s > 30s"
|
||||
elif new_messages_count > 50/global_config.memory.memory_build_frequency and latest_message_time_diff > 60:
|
||||
# 条件3:消息数量 > 50 且最新消息时间差 > 60秒
|
||||
# 适用场景:中等频率讨论,等待1分钟确保对话告一段落
|
||||
# 平衡及时性与稳定性
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 50 且最新消息时间差 {latest_message_time_diff:.1f}s > 60s"
|
||||
elif new_messages_count > 30/global_config.memory.memory_build_frequency and latest_message_time_diff > 300:
|
||||
# 条件4:消息数量 > 30 且最新消息时间差 > 300秒(5分钟)
|
||||
# 适用场景:低频但有一定信息量的讨论
|
||||
# 5分钟的时间差确保对话完全结束,避免频繁更新
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 30 且最新消息时间差 {latest_message_time_diff:.1f}s > 300s"
|
||||
|
||||
logger.debug(f"chat_id {chat_id} 更新检查: {update_reason if should_update else f'消息数量 {new_messages_count},最新消息时间差 {latest_message_time_diff:.1f}s,不满足更新条件'}")
|
||||
|
||||
|
||||
if should_update:
|
||||
# 如果有chat_id,先提取对应的running_content
|
||||
message_str = build_readable_messages(
|
||||
message_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
remove_emoji_stickers=True,
|
||||
)
|
||||
|
||||
# 随机从格式示例列表中选取若干行用于提示
|
||||
format_candidates = [
|
||||
"[概念] 是 [概念的含义(简短描述,不超过十个字)]",
|
||||
"[概念] 不是 [对概念的负面含义(简短描述,不超过十个字)]",
|
||||
"[概念1] 与 [概念2] 是 [概念1和概念2的关联(简短描述,不超过二十个字)]",
|
||||
"[概念1] 包含 [概念2] 和 [概念3]",
|
||||
"[概念1] 属于 [概念2]",
|
||||
"[概念1] 的例子是 [例子1] 和 [例子2]",
|
||||
"[概念] 的特征是 [特征1]、[特征2]",
|
||||
"[概念1] 导致 [概念2]",
|
||||
"[概念1] 需要 [条件1] 和 [条件2]",
|
||||
"[概念1] 的用途是 [用途1] 和 [用途2]",
|
||||
"[概念1] 与 [概念2] 的区别是 [区别点]",
|
||||
"[概念] 的别名是 [别名]",
|
||||
"[概念1] 包括但不限于 [概念2]、[概念3]",
|
||||
"[概念] 的反义是 [反义概念]",
|
||||
"[概念] 的组成有 [部分1]、[部分2]",
|
||||
"[概念] 出现于 [时间或场景]",
|
||||
"[概念] 的方法有 [方法1]、[方法2]",
|
||||
]
|
||||
|
||||
selected_count = random.randint(3, 6)
|
||||
selected_lines = random.sample(format_candidates, selected_count)
|
||||
format_section = "\n".join(selected_lines) + "\n......(不要包含中括号)"
|
||||
|
||||
prompt = f"""
|
||||
以下是一段你参与的聊天记录,请你在其中总结出记忆:
|
||||
|
||||
<聊天记录>
|
||||
{message_str}
|
||||
</聊天记录>
|
||||
聊天记录中可能包含有效信息,也可能信息密度很低,请你根据聊天记录中的信息,总结出记忆内容
|
||||
--------------------------------
|
||||
对[图片]的处理:
|
||||
1.除非与文本有关,不要将[图片]的内容整合到记忆中
|
||||
2.如果图片与某个概念相关,将图片中的关键内容也整合到记忆中,不要写入图片原文,例如:
|
||||
|
||||
聊天记录(与图片有关):
|
||||
用户说:[图片1:这是一个黄色的龙形状玩偶,被一只手拿着。]
|
||||
用户说:这个玩偶看起来很可爱,是我新买的奶龙
|
||||
总结的记忆内容:
|
||||
黄色的龙形状玩偶 是 奶龙
|
||||
|
||||
聊天记录(概念与图片无关):
|
||||
用户说:[图片1:这是一个台电脑,屏幕上显示了某种游戏。]
|
||||
用户说:使命召唤今天发售了新一代,有没有人玩
|
||||
总结的记忆内容:
|
||||
使命召唤新一代 是 最新发售的游戏
|
||||
|
||||
请主要关注概念和知识或者时效性较强的信息!!,而不是聊天的琐事
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
|
||||
记忆内容的格式,你必须仿照下面的格式,但不一定全部使用:
|
||||
{format_section}
|
||||
|
||||
请仿照上述格式输出,每个知识点一句话。输出成一段平文本
|
||||
现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
|
||||
running_content, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
print(f"prompt: {prompt}\n记忆仓库构建运行内容: {running_content}")
|
||||
|
||||
# 直接保存:每次构建后立即入库,并刷新时间戳窗口
|
||||
if chat_id and running_content:
|
||||
await self._save_to_database_and_clear(chat_id, running_content)
|
||||
|
||||
|
||||
return running_content
|
||||
|
||||
|
||||
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
|
||||
"""
|
||||
根据问题获取答案
|
||||
"""
|
||||
logger.info(f"正在回忆问题答案: {question}")
|
||||
|
||||
title = await self.select_title_by_question(question)
|
||||
|
||||
if not title:
|
||||
return ""
|
||||
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == title:
|
||||
content = memory.content
|
||||
|
||||
if random.random() < 0.5:
|
||||
type = "要求原文能够较为全面的回答问题"
|
||||
else:
|
||||
type = "要求提取简短的内容"
|
||||
|
||||
prompt = f"""
|
||||
目标文段:
|
||||
{content}
|
||||
|
||||
你现在需要从目标文段中找出合适的信息来回答问题:{question}
|
||||
请务必从目标文段中提取相关信息的**原文**并输出,{type}
|
||||
如果没有原文能够回答问题,输出"无有效信息"即可,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
|
||||
answer, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
if "无有效" in answer or "无有效信息" in answer or "无信息" in answer:
|
||||
logger.info(f"没有能够回答{question}的记忆")
|
||||
return ""
|
||||
|
||||
logger.info(f"记忆仓库对问题 “{question}” 获取答案: {answer}")
|
||||
|
||||
# 将问题和答案存到fetched_memory_list
|
||||
if chat_id and answer:
|
||||
self.fetched_memory_list.append((chat_id, (question, answer, time.time())))
|
||||
|
||||
# 清理fetched_memory_list
|
||||
self._cleanup_fetched_memory_list()
|
||||
|
||||
return answer
|
||||
|
||||
def get_chat_memories_as_string(self, chat_id: str) -> str:
|
||||
"""
|
||||
获取某个chat_id的所有记忆,并构建成字符串
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 格式化的记忆字符串,格式:问题:xxx,答案:xxxxx\n问题:xxx,答案:xxxxx\n...
|
||||
"""
|
||||
try:
|
||||
memories = []
|
||||
|
||||
# 从fetched_memory_list中获取该chat_id的所有记忆
|
||||
for cid, (question, answer, timestamp) in self.fetched_memory_list:
|
||||
if cid == chat_id:
|
||||
memories.append(f"问题:{question},答案:{answer}")
|
||||
|
||||
# 按时间戳排序(最新的在后面)
|
||||
memories.sort()
|
||||
|
||||
# 用换行符连接所有记忆
|
||||
result = "\n".join(memories)
|
||||
|
||||
# logger.info(f"chat_id {chat_id} 共有 {len(memories)} 条记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取chat_id {chat_id} 的记忆时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def select_title_by_question(self, question: str) -> str:
|
||||
"""
|
||||
根据消息内容选择最匹配的标题
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
|
||||
Returns:
|
||||
str: 选择的标题
|
||||
"""
|
||||
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
|
||||
titles = get_all_titles(exclude_locked=True)
|
||||
formatted_titles = ""
|
||||
for title in titles:
|
||||
formatted_titles += f"{title}\n"
|
||||
|
||||
prompt = f"""
|
||||
所有主题:
|
||||
{formatted_titles}
|
||||
|
||||
请根据以下问题,选择一个能够回答问题的主题:
|
||||
问题:{question}
|
||||
请你输出主题,不要输出其他内容,完整输出主题名:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 根据 title 获取 titles 里的对应项
|
||||
selected_title = None
|
||||
|
||||
# 使用模糊查找匹配标题
|
||||
best_match = find_best_matching_memory(title, similarity_threshold=0.8)
|
||||
if best_match:
|
||||
selected_title = best_match[0] # 获取匹配的标题
|
||||
logger.info(f"记忆仓库选择标题: {selected_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.7 的标题匹配: {title}")
|
||||
selected_title = None
|
||||
|
||||
return selected_title
|
||||
|
||||
def _cleanup_fetched_memory_list(self):
|
||||
"""
|
||||
清理fetched_memory_list,移除超过10分钟的记忆和超过10条的最旧记忆
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
ten_minutes_ago = current_time - 600 # 10分钟 = 600秒
|
||||
|
||||
# 移除超过10分钟的记忆
|
||||
self.fetched_memory_list = [
|
||||
(chat_id, (question, answer, timestamp))
|
||||
for chat_id, (question, answer, timestamp) in self.fetched_memory_list
|
||||
if timestamp > ten_minutes_ago
|
||||
]
|
||||
|
||||
# 如果记忆条数超过10条,移除最旧的5条
|
||||
if len(self.fetched_memory_list) > 10:
|
||||
# 按时间戳排序,移除最旧的5条
|
||||
self.fetched_memory_list.sort(key=lambda x: x[1][2]) # 按timestamp排序
|
||||
self.fetched_memory_list = self.fetched_memory_list[5:] # 保留最新的5条
|
||||
|
||||
logger.debug(f"fetched_memory_list清理后,当前有 {len(self.fetched_memory_list)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理fetched_memory_list时出错: {e}")
|
||||
|
||||
async def _save_to_database_and_clear(self, chat_id: str, content: str):
|
||||
"""
|
||||
生成标题,保存到数据库,并清空对应chat_id的running_content
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
content: 要保存的内容
|
||||
"""
|
||||
try:
|
||||
# 生成标题
|
||||
title = ""
|
||||
title_prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
{content}
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(title_prompt)
|
||||
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if title:
|
||||
# 保存到数据库
|
||||
MemoryChestModel.create(
|
||||
title=title.strip(),
|
||||
content=content,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logger.info(f"已保存记忆仓库内容,标题: {title.strip()}, chat_id: {chat_id}")
|
||||
|
||||
# 清空内容并刷新时间戳,但保留条目用于增量计算
|
||||
if chat_id in self.running_content_list:
|
||||
current_time = time.time()
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": "",
|
||||
"last_update_time": current_time,
|
||||
"create_time": current_time
|
||||
}
|
||||
logger.info(f"已保存并刷新chat_id {chat_id} 的时间戳,准备下一次增量构建")
|
||||
else:
|
||||
logger.warning(f"生成标题失败,chat_id: {chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存记忆仓库内容时出错: {e}")
|
||||
|
||||
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
选择与给定记忆标题相关的记忆目标(基于文本相似度)
|
||||
|
||||
Args:
|
||||
memory_title: 要匹配的记忆标题
|
||||
chat_id: 聊天ID,用于筛选同chat_id的记忆
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[str]]: (选中的记忆标题列表, 选中的记忆内容列表)
|
||||
"""
|
||||
try:
|
||||
if not chat_id:
|
||||
logger.warning("未提供chat_id,无法进行记忆匹配")
|
||||
return [], []
|
||||
|
||||
# 动态计算相似度阈值(占比越高阈值越低)
|
||||
dynamic_threshold = self._compute_merge_similarity_threshold()
|
||||
|
||||
# 使用相似度匹配查找最相似的记忆(基于动态阈值)
|
||||
similar_memory = find_most_similar_memory_by_chat_id(
|
||||
target_title=memory_title,
|
||||
target_chat_id=chat_id,
|
||||
similarity_threshold=dynamic_threshold
|
||||
)
|
||||
|
||||
if similar_memory:
|
||||
selected_title, selected_content, similarity = similar_memory
|
||||
logger.info(f"为 '{memory_title}' 找到相似记忆: '{selected_title}' (相似度: {similarity:.3f} 阈值: {dynamic_threshold:.2f})")
|
||||
return [selected_title], [selected_content]
|
||||
else:
|
||||
logger.info(f"为 '{memory_title}' 未找到相似度 >= {dynamic_threshold:.2f} 的记忆")
|
||||
return [], []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择合并目标时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
||||
"""
|
||||
根据标题列表查找对应的记忆内容
|
||||
|
||||
Args:
|
||||
titles: 记忆标题列表
|
||||
|
||||
Returns:
|
||||
list[str]: 记忆内容列表
|
||||
"""
|
||||
try:
|
||||
contents = []
|
||||
for title in titles:
|
||||
if not title or not title.strip():
|
||||
continue
|
||||
|
||||
# 使用模糊查找匹配记忆
|
||||
try:
|
||||
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
|
||||
if best_match:
|
||||
# 检查记忆是否被锁定
|
||||
memory_title = best_match[0]
|
||||
memory_content = best_match[1]
|
||||
|
||||
# 查询数据库中的锁定状态
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == memory_title and memory.locked:
|
||||
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
|
||||
continue
|
||||
|
||||
contents.append(memory_content)
|
||||
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
|
||||
except Exception as e:
|
||||
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
# logger.info(f"成功找到 {len(contents)} 条记忆内容")
|
||||
return contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据标题查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
def _parse_merged_parts(self, merged_response: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析合并记忆的part1和part2内容
|
||||
|
||||
Args:
|
||||
merged_response: LLM返回的合并记忆响应
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (part1_content, part2_content)
|
||||
"""
|
||||
try:
|
||||
# 使用正则表达式提取part1和part2内容
|
||||
import re
|
||||
|
||||
# 提取part1内容
|
||||
part1_pattern = r'<part1>(.*?)</part1>'
|
||||
part1_match = re.search(part1_pattern, merged_response, re.DOTALL)
|
||||
part1_content = part1_match.group(1).strip() if part1_match else ""
|
||||
|
||||
# 提取part2内容
|
||||
part2_pattern = r'<part2>(.*?)</part2>'
|
||||
part2_match = re.search(part2_pattern, merged_response, re.DOTALL)
|
||||
part2_content = part2_match.group(1).strip() if part2_match else ""
|
||||
|
||||
# 检查是否包含none或None(不区分大小写)
|
||||
def is_none_content(content: str) -> bool:
|
||||
if not content:
|
||||
return True
|
||||
# 检查是否只包含"none"或"None"(不区分大小写)
|
||||
return re.match(r'^\s*none\s*$', content, re.IGNORECASE) is not None
|
||||
|
||||
# 如果包含none,则设置为空字符串
|
||||
if is_none_content(part1_content):
|
||||
part1_content = ""
|
||||
logger.info("part1内容为none,设置为空")
|
||||
|
||||
if is_none_content(part2_content):
|
||||
part2_content = ""
|
||||
logger.info("part2内容为none,设置为空")
|
||||
|
||||
return part1_content, part2_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并记忆part1/part2时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
def _parse_merge_target_json(self, json_text: str) -> list[str]:
|
||||
"""
|
||||
解析choose_merge_target生成的JSON响应
|
||||
|
||||
Args:
|
||||
json_text: LLM返回的JSON文本
|
||||
|
||||
Returns:
|
||||
list[str]: 解析出的记忆标题列表
|
||||
"""
|
||||
try:
|
||||
# 清理JSON文本,移除可能的额外内容
|
||||
repaired_content = repair_json(json_text)
|
||||
|
||||
# 尝试直接解析JSON
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是列表,提取selected_title字段
|
||||
titles = []
|
||||
for item in parsed_data:
|
||||
if isinstance(item, dict) and "selected_title" in item:
|
||||
value = item.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
titles.append(value)
|
||||
return titles
|
||||
elif isinstance(parsed_data, dict) and "selected_title" in parsed_data:
|
||||
# 如果是单个对象
|
||||
value = parsed_data.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
return [value]
|
||||
else:
|
||||
# 空字符串表示没有相关记忆
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果直接解析失败,尝试提取JSON对象
|
||||
# 查找所有包含selected_title的JSON对象
|
||||
pattern = r'\{[^}]*"selected_title"[^}]*\}'
|
||||
matches = re.findall(pattern, repaired_content)
|
||||
|
||||
titles = []
|
||||
for match in matches:
|
||||
try:
|
||||
obj = json.loads(match)
|
||||
if "selected_title" in obj:
|
||||
value = obj.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
titles.append(value)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if titles:
|
||||
return titles
|
||||
|
||||
logger.warning(f"无法解析JSON响应: {json_text[:200]}...")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并目标JSON时出错: {e}")
|
||||
return []
|
||||
|
||||
async def merge_memory(self,memory_list: list[str], chat_id: str = None) -> tuple[str, str]:
|
||||
"""
|
||||
合并记忆
|
||||
"""
|
||||
try:
|
||||
# 在记忆整合前先清理空chat_id的记忆
|
||||
cleaned_count = self.cleanup_empty_chat_id_memories()
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"记忆整合前清理了 {cleaned_count} 条空chat_id记忆")
|
||||
|
||||
content = ""
|
||||
for memory in memory_list:
|
||||
content += f"{memory}\n"
|
||||
|
||||
prompt = f"""
|
||||
以下是多段记忆内容,请将它们进行整合和修改:
|
||||
{content}
|
||||
--------------------------------
|
||||
请将上面的多段记忆内容,合并成两部分内容,第一部分是可以整合,不冲突的概念和知识,第二部分是相互有冲突的概念和知识
|
||||
请主要关注概念和知识,而不是聊天的琐事
|
||||
重要!!你要关注的概念和知识必须是较为不常见的信息,或者时效性较强的信息!!
|
||||
不要!!关注常见的只是,或者已经过时的信息!!
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.如果有图片,请只关注图片和文本结合的知识和概念性内容
|
||||
4.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
**第一部分**
|
||||
1.如果两个概念在描述同一件事情,且相互之间逻辑不冲突(请你严格判断),且相互之间没有矛盾,请将它们整合成一个概念,并输出到第一部分
|
||||
2.如果某个概念在时间上更新了另一个概念,请用新概念更新就概念来整合,并输出到第一部分
|
||||
3.如果没有可整合的概念,请你输出none
|
||||
**第二部分**
|
||||
1.如果记忆中有无法整合的地方,例如概念不一致,有逻辑上的冲突,请你输出到第二部分
|
||||
2.如果两个概念在描述同一件事情,但相互之间逻辑冲突,请将它们输出到第二部分
|
||||
3.如果没有无法整合的概念,请你输出none
|
||||
|
||||
**输出格式要求**
|
||||
请你按以下格式输出:
|
||||
<part1>
|
||||
第一部分内容,整合后的概念,如果第一部分为none,请输出none
|
||||
</part1>
|
||||
<part2>
|
||||
第二部分内容,无法整合,冲突的概念,如果第二部分为none,请输出none
|
||||
</part2>
|
||||
不要输出其他内容,现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"合并记忆 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"合并记忆 prompt: {prompt}")
|
||||
|
||||
merged_memory, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析part1和part2
|
||||
part1_content, part2_content = self._parse_merged_parts(merged_memory)
|
||||
|
||||
# 处理part2:独立记录冲突内容(无论part1是否为空)
|
||||
if part2_content and part2_content.strip() != "none":
|
||||
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
|
||||
# 记录冲突到数据库
|
||||
await global_conflict_tracker.record_memory_merge_conflict(part2_content,chat_id)
|
||||
|
||||
# 处理part1:生成标题并保存
|
||||
if part1_content and part1_content.strip() != "none":
|
||||
merged_title = await self._generate_title_for_merged_memory(part1_content)
|
||||
|
||||
# 保存part1到数据库
|
||||
MemoryChestModel.create(
|
||||
title=merged_title,
|
||||
content=part1_content,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
logger.info(f"合并记忆part1已保存: {merged_title}")
|
||||
|
||||
return merged_title, part1_content
|
||||
else:
|
||||
logger.warning("合并记忆part1为空,跳过保存")
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _generate_title_for_merged_memory(self, merged_content: str) -> str:
|
||||
"""
|
||||
为合并后的记忆生成标题
|
||||
|
||||
Args:
|
||||
merged_content: 合并后的记忆内容
|
||||
|
||||
Returns:
|
||||
str: 生成的标题
|
||||
"""
|
||||
try:
|
||||
prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
例如:
|
||||
<example>
|
||||
标题:达尔文的自然选择理论
|
||||
内容:达尔文的自然选择是生物进化理论的重要组成部分,它解释了生物进化过程中的自然选择机制。
|
||||
</example>
|
||||
<example>
|
||||
标题:麦麦的禁言插件和支持版本
|
||||
内容:
|
||||
麦麦的禁言插件是一款能够实现禁言的插件
|
||||
麦麦的禁言插件可能不支持0.10.2
|
||||
MutePlugin 是禁言插件的名称
|
||||
</example>
|
||||
|
||||
|
||||
需要对以下内容生成标题:
|
||||
{merged_content}
|
||||
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容,不要浮夸,以白话简洁的风格输出标题
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"生成合并记忆标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"生成合并记忆标题 prompt: {prompt}")
|
||||
|
||||
title_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 清理标题,移除可能的引号或多余字符
|
||||
title = title_response.strip().strip('"').strip("'").strip()
|
||||
|
||||
if title:
|
||||
# 检查是否存在相似标题
|
||||
if check_title_exists_fuzzy(title, similarity_threshold=0.9):
|
||||
logger.warning(f"生成的标题 '{title}' 与现有标题相似,使用时间戳后缀")
|
||||
title = f"{title}_{int(time.time())}"
|
||||
|
||||
logger.info(f"生成合并记忆标题: {title}")
|
||||
return title
|
||||
else:
|
||||
logger.warning("生成合并记忆标题失败,使用默认标题")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成合并记忆标题时出错: {e}")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
def cleanup_empty_chat_id_memories(self) -> int:
|
||||
"""
|
||||
清理chat_id为空的记忆记录
|
||||
|
||||
Returns:
|
||||
int: 被清理的记忆数量
|
||||
"""
|
||||
try:
|
||||
# 查找所有chat_id为空的记忆
|
||||
empty_chat_id_memories = MemoryChestModel.select().where(
|
||||
(MemoryChestModel.chat_id.is_null()) |
|
||||
(MemoryChestModel.chat_id == "") |
|
||||
(MemoryChestModel.chat_id == "None")
|
||||
)
|
||||
|
||||
count = 0
|
||||
for memory in empty_chat_id_memories:
|
||||
logger.info(f"清理空chat_id记忆: 标题='{memory.title}', ID={memory.id}")
|
||||
memory.delete_instance()
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"已清理 {count} 条chat_id为空的记忆记录")
|
||||
else:
|
||||
logger.debug("未发现需要清理的空chat_id记忆记录")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理空chat_id记忆时出错: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
global_memory_chest = MemoryChest()
|
||||
@@ -1,185 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.memory_system.memory_utils import parse_md_json
|
||||
|
||||
logger = get_logger("curious")
|
||||
|
||||
|
||||
class CuriousDetector:
|
||||
"""
|
||||
好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="curious_detector",
|
||||
)
|
||||
|
||||
async def detect_questions(self, recent_messages: List) -> Optional[str]:
|
||||
"""
|
||||
检测最近消息中是否有需要提问的内容
|
||||
|
||||
Args:
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果检测到需要提问的内容,返回问题文本;否则返回None
|
||||
"""
|
||||
try:
|
||||
if not recent_messages or len(recent_messages) < 2:
|
||||
return None
|
||||
|
||||
# 构建聊天内容
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 检查是否已经有问题在跟踪中
|
||||
existing_questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_id)
|
||||
if len(existing_questions) > 0:
|
||||
logger.debug(f"当前已有{len(existing_questions)}个问题在跟踪中,跳过检测")
|
||||
return None
|
||||
|
||||
# 构建检测提示词
|
||||
prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。
|
||||
|
||||
检测条件:
|
||||
1. 聊天中存在逻辑矛盾或冲突的信息
|
||||
2. 有人反对或否定之前提出的信息
|
||||
3. 存在观点不一致的情况
|
||||
4. 有模糊不清或需要澄清的概念
|
||||
5. 有人提出了质疑或反驳
|
||||
|
||||
**重要限制:**
|
||||
- 忽略涉及违法、暴力、色情、政治等敏感话题的内容
|
||||
- 不要对敏感话题提问
|
||||
- 只有在确实存在矛盾或冲突时才提问
|
||||
- 如果聊天内容正常,没有矛盾,请输出:NO
|
||||
|
||||
**聊天记录**
|
||||
{chat_content_block}
|
||||
|
||||
请分析上述聊天记录,如果发现需要提问的内容,请用JSON格式输出:
|
||||
```json
|
||||
{{
|
||||
"question": "具体的问题描述,要完整描述涉及的概念和问题",
|
||||
"reason": "为什么需要提问这个问题的理由"
|
||||
}}
|
||||
```
|
||||
|
||||
如果没有需要提问的内容,请只输出:NO"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"好奇心检测提示词: {prompt}")
|
||||
else:
|
||||
logger.debug("已发送好奇心检测提示词")
|
||||
|
||||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
if not result_text:
|
||||
return None
|
||||
|
||||
result_text = result_text.strip()
|
||||
|
||||
# 检查是否输出NO
|
||||
if result_text.upper() == "NO":
|
||||
logger.debug("未检测到需要提问的内容")
|
||||
return None
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
questions, reasoning = parse_md_json(result_text)
|
||||
if questions and len(questions) > 0:
|
||||
question_data = questions[0]
|
||||
question = question_data.get("question", "")
|
||||
reason = question_data.get("reason", "")
|
||||
|
||||
if question and question.strip():
|
||||
logger.info(f"检测到需要提问的内容: {question}")
|
||||
logger.info(f"提问理由: {reason}")
|
||||
return question
|
||||
except Exception as e:
|
||||
logger.warning(f"解析问题JSON失败: {e}")
|
||||
logger.debug(f"原始响应: {result_text}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"好奇心检测失败: {e}")
|
||||
return None
|
||||
|
||||
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
|
||||
"""
|
||||
将检测到的问题记录到冲突追踪器中
|
||||
|
||||
Args:
|
||||
question: 检测到的问题
|
||||
context: 问题上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
try:
|
||||
if not question or not question.strip():
|
||||
return False
|
||||
|
||||
# 记录问题到冲突追踪器,并开始跟踪
|
||||
await global_conflict_tracker.track_conflict(
|
||||
question=question.strip(),
|
||||
context=context,
|
||||
start_following=False,
|
||||
chat_id=self.chat_id
|
||||
)
|
||||
|
||||
logger.info(f"已记录问题到冲突追踪器: {question}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录问题失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_make_question(chat_id: str, recent_messages: List) -> bool:
|
||||
"""
|
||||
检查聊天记录并生成问题(如果检测到需要提问的内容)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
Returns:
|
||||
bool: 是否检测到并记录了问题
|
||||
"""
|
||||
try:
|
||||
detector = CuriousDetector(chat_id)
|
||||
|
||||
# 检测是否需要提问
|
||||
question = await detector.detect_questions(recent_messages)
|
||||
|
||||
if question:
|
||||
# 记录问题
|
||||
success = await detector.make_question_from_detection(question)
|
||||
if success:
|
||||
logger.info(f"成功检测并记录问题: {question}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查并生成问题失败: {e}")
|
||||
return False
|
||||
@@ -1,182 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
|
||||
class MemoryManagementTask(AsyncTask):
|
||||
"""记忆管理定时任务
|
||||
|
||||
根据Memory_chest中的记忆数量与MAX_MEMORY_NUMBER的比例来决定执行频率:
|
||||
- 小于50%:每600秒执行一次
|
||||
- 大于等于50%:每300秒执行一次
|
||||
|
||||
每次执行时随机选择一个title,执行choose_merge_target和merge_memory,
|
||||
然后删除原始记忆
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
task_name="Memory Management Task",
|
||||
wait_before_start=10, # 启动后等待10秒再开始
|
||||
run_interval=300 # 默认300秒间隔,会根据记忆数量动态调整
|
||||
)
|
||||
self.max_memory_number = global_config.memory.max_memory_number
|
||||
|
||||
async def start_task(self, abort_flag: asyncio.Event):
|
||||
"""重写start_task方法,支持动态调整执行间隔"""
|
||||
if self.wait_before_start > 0:
|
||||
# 等待指定时间后开始任务
|
||||
await asyncio.sleep(self.wait_before_start)
|
||||
|
||||
while not abort_flag.is_set():
|
||||
await self.run()
|
||||
|
||||
# 动态调整执行间隔
|
||||
current_interval = self._calculate_interval()
|
||||
logger.info(f"[记忆管理] 下次执行间隔: {current_interval}秒")
|
||||
|
||||
if current_interval > 0:
|
||||
await asyncio.sleep(current_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _calculate_interval(self) -> int:
|
||||
"""根据当前记忆数量计算执行间隔"""
|
||||
try:
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
|
||||
if percentage < 0.6:
|
||||
# 小于50%,每600秒执行一次
|
||||
return 3600
|
||||
elif percentage < 1:
|
||||
# 大于等于50%,每300秒执行一次
|
||||
return 1800
|
||||
elif percentage < 1.5:
|
||||
# 大于等于100%,每120秒执行一次
|
||||
return 600
|
||||
elif percentage < 1.8:
|
||||
return 120
|
||||
else:
|
||||
return 30
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
|
||||
return 300 # 默认300秒
|
||||
|
||||
def _get_memory_count(self) -> int:
|
||||
"""获取当前记忆数量"""
|
||||
try:
|
||||
count = MemoryChestModel.select().count()
|
||||
logger.debug(f"[记忆管理] 当前记忆数量: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取记忆数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def run(self):
|
||||
"""执行记忆管理任务"""
|
||||
try:
|
||||
|
||||
# 获取当前记忆数量
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
|
||||
# 当占比 > 1.6 时,持续删除直到占比 <= 1.6(越老/越新更易被删)
|
||||
if percentage > 2:
|
||||
logger.info("记忆过多,开始遗忘记忆")
|
||||
while True:
|
||||
if percentage <= 1.8:
|
||||
break
|
||||
removed = global_memory_chest.remove_one_memory_by_age_weight()
|
||||
if not removed:
|
||||
logger.warning("没有可删除的记忆,停止连续删除")
|
||||
break
|
||||
# 重新计算占比
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"遗忘进度: 当前 {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
logger.info("遗忘记忆结束")
|
||||
|
||||
# 如果记忆数量为0,跳过执行
|
||||
if current_count < 10:
|
||||
return
|
||||
|
||||
# 随机选择一个记忆标题和chat_id
|
||||
selected_title, selected_chat_id = self._get_random_memory_title()
|
||||
if not selected_title:
|
||||
logger.warning("无法获取随机记忆标题,跳过执行")
|
||||
return
|
||||
|
||||
# 执行choose_merge_target获取相关记忆(标题与内容)
|
||||
related_titles, related_contents = await global_memory_chest.choose_merge_target(selected_title, selected_chat_id)
|
||||
if not related_titles or not related_contents:
|
||||
logger.info("无合适合并内容,跳过本次合并")
|
||||
return
|
||||
|
||||
logger.info(f"{selected_chat_id} 为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
|
||||
|
||||
# 执行merge_memory合并记忆
|
||||
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents,selected_chat_id)
|
||||
if not merged_title or not merged_content:
|
||||
logger.warning("[记忆管理] 记忆合并失败,跳过删除")
|
||||
return
|
||||
|
||||
logger.info(f"记忆合并成功,新标题: {merged_title}")
|
||||
|
||||
# 删除原始记忆(包括选中的标题和相关的记忆标题)
|
||||
titles_to_delete = [selected_title] + related_titles
|
||||
deleted_count = self._delete_original_memories(titles_to_delete)
|
||||
logger.info(f"已删除 {deleted_count} 条原始记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 执行记忆管理任务时发生错误: {e}", exc_info=True)
|
||||
|
||||
def _get_random_memory_title(self) -> tuple[str, str]:
|
||||
"""随机获取一个记忆标题和对应的chat_id"""
|
||||
try:
|
||||
# 获取所有记忆记录
|
||||
all_memories = MemoryChestModel.select()
|
||||
if not all_memories:
|
||||
return "", ""
|
||||
|
||||
# 随机选择一个记忆
|
||||
selected_memory = random.choice(list(all_memories))
|
||||
return selected_memory.title, selected_memory.chat_id or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取随机记忆标题时发生错误: {e}")
|
||||
return "", ""
|
||||
|
||||
def _delete_original_memories(self, related_titles: List[str]) -> int:
|
||||
"""按标题删除原始记忆"""
|
||||
try:
|
||||
deleted_count = 0
|
||||
# 删除相关记忆(通过标题匹配)
|
||||
for title in related_titles:
|
||||
try:
|
||||
# 通过标题查找并删除对应的记忆
|
||||
memories_to_delete = MemoryChestModel.select().where(MemoryChestModel.title == title)
|
||||
for memory in memories_to_delete:
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == memory.id).execute()
|
||||
deleted_count += 1
|
||||
logger.debug(f"[记忆管理] 删除相关记忆: {memory.title}")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除相关记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除原始记忆时发生错误: {e}")
|
||||
return 0
|
||||
1278
src/memory_system/memory_retrieval.py
Normal file
1278
src/memory_system/memory_retrieval.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,40 +5,15 @@
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
def get_all_titles(exclude_locked: bool = False) -> list[str]:
|
||||
"""
|
||||
获取记忆仓库中的所有标题
|
||||
|
||||
Args:
|
||||
exclude_locked: 是否排除锁定的记忆,默认为 False
|
||||
|
||||
Returns:
|
||||
list: 包含所有标题的列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有记忆记录的标题
|
||||
titles = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title:
|
||||
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
|
||||
if exclude_locked and memory.locked:
|
||||
continue
|
||||
titles.append(memory.title)
|
||||
return titles
|
||||
except Exception as e:
|
||||
print(f"获取记忆标题时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
@@ -134,224 +109,59 @@ def preprocess_text(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def fuzzy_find_memory_by_title(target_title: str, similarity_threshold: float = 0.9) -> List[Tuple[str, str, float]]:
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
根据标题模糊查找记忆
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, float]]: 匹配的记忆列表,每个元素为(title, content, similarity_score)
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
try:
|
||||
# 获取所有记忆
|
||||
all_memories = MemoryChestModel.select()
|
||||
|
||||
matches = []
|
||||
for memory in all_memories:
|
||||
similarity = calculate_similarity(target_title, memory.title)
|
||||
if similarity >= similarity_threshold:
|
||||
matches.append((memory.title, memory.content, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
matches.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# logger.info(f"模糊查找标题 '{target_title}' 找到 {len(matches)} 个匹配项")
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模糊查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def find_best_matching_memory(target_title: str, similarity_threshold: float = 0.9) -> Optional[Tuple[str, str, float]]:
|
||||
"""
|
||||
查找最佳匹配的记忆
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, float]]: 最佳匹配的记忆(title, content, similarity)或None
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
|
||||
if matches:
|
||||
best_match = matches[0] # 已经按相似度排序,第一个是最佳匹配
|
||||
# logger.info(f"找到最佳匹配: '{best_match[0]}' (相似度: {best_match[2]:.3f})")
|
||||
return best_match
|
||||
else:
|
||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找最佳匹配记忆时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_title_exists_fuzzy(target_title: str, similarity_threshold: float = 0.9) -> bool:
|
||||
"""
|
||||
检查标题是否已存在(模糊匹配)
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9(较高阈值避免误判)
|
||||
|
||||
Returns:
|
||||
bool: 是否存在相似标题
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
exists = len(matches) > 0
|
||||
|
||||
if exists:
|
||||
logger.info(f"发现相似标题: '{matches[0][0]}' (相似度: {matches[0][2]:.3f})")
|
||||
else:
|
||||
logger.debug("未发现相似标题")
|
||||
|
||||
return exists
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查标题是否存在时出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_memories_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
根据chat_id进行加权抽样获取记忆列表
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
Args:
|
||||
target_chat_id: 目标聊天ID
|
||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 选中的记忆列表,每个元素为(title, content, chat_id)
|
||||
"""
|
||||
try:
|
||||
# 获取所有记忆
|
||||
all_memories = MemoryChestModel.select()
|
||||
|
||||
# 按chat_id分组
|
||||
same_chat_memories = []
|
||||
other_chat_memories = []
|
||||
|
||||
for memory in all_memories:
|
||||
if memory.title and not memory.locked: # 排除锁定的记忆
|
||||
if memory.chat_id == target_chat_id:
|
||||
same_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
||||
else:
|
||||
other_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
||||
|
||||
# 如果没有同chat_id的记忆,返回空列表
|
||||
if not same_chat_memories:
|
||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
||||
return []
|
||||
|
||||
# 计算抽样数量
|
||||
total_same = len(same_chat_memories)
|
||||
total_other = len(other_chat_memories)
|
||||
|
||||
# 根据权重计算抽样数量
|
||||
if total_other > 0:
|
||||
# 计算其他chat_id记忆的抽样数量(至少1个,最多不超过总数的10%)
|
||||
other_sample_count = max(1, min(total_other, int(total_same * other_chat_weight / same_chat_weight)))
|
||||
else:
|
||||
other_sample_count = 0
|
||||
|
||||
# 随机抽样
|
||||
selected_memories = []
|
||||
|
||||
# 选择同chat_id的记忆(全部选择,因为权重很高)
|
||||
selected_memories.extend(same_chat_memories)
|
||||
|
||||
# 随机选择其他chat_id的记忆
|
||||
if other_sample_count > 0 and total_other > 0:
|
||||
import random
|
||||
other_selected = random.sample(other_chat_memories, min(other_sample_count, total_other))
|
||||
selected_memories.extend(other_selected)
|
||||
|
||||
logger.info(f"加权抽样结果: 同chat_id记忆 {len(same_chat_memories)} 条,其他chat_id记忆 {min(other_sample_count, total_other)} 条")
|
||||
|
||||
return selected_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"按chat_id加权抽样记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[str]:
|
||||
"""
|
||||
根据chat_id进行加权抽样获取记忆标题列表(用于合并选择)
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
Args:
|
||||
target_chat_id: 目标聊天ID
|
||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
||||
|
||||
Returns:
|
||||
List[str]: 选中的记忆标题列表
|
||||
"""
|
||||
try:
|
||||
memories = get_memories_by_chat_id_weighted(target_chat_id, same_chat_weight, other_chat_weight)
|
||||
titles = [memory[0] for memory in memories] # 提取标题
|
||||
return titles
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
|
||||
return []
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
|
||||
def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str, similarity_threshold: float = 0.5) -> Optional[Tuple[str, str, float]]:
|
||||
"""
|
||||
在指定chat_id的记忆中查找最相似的记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
target_chat_id: 目标聊天ID
|
||||
similarity_threshold: 相似度阈值,默认0.7
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, float]]: 最相似的记忆(title, content, similarity)或None
|
||||
"""
|
||||
try:
|
||||
# 获取指定chat_id的所有记忆
|
||||
same_chat_memories = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title and not memory.locked and memory.chat_id == target_chat_id:
|
||||
same_chat_memories.append((memory.title, memory.content))
|
||||
|
||||
if not same_chat_memories:
|
||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
||||
return None
|
||||
|
||||
# 计算相似度并找到最佳匹配
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for title, content in same_chat_memories:
|
||||
# 跳过目标标题本身
|
||||
if title.strip() == target_title.strip():
|
||||
continue
|
||||
|
||||
similarity = calculate_similarity(target_title, title)
|
||||
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = (title, content, similarity)
|
||||
|
||||
# 检查是否超过阈值
|
||||
if best_match and best_similarity >= similarity_threshold:
|
||||
logger.info(f"找到最相似记忆: '{best_match[0]}' (相似度: {best_similarity:.3f})")
|
||||
return best_match
|
||||
else:
|
||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆,最高相似度: {best_similarity:.3f}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找最相似记忆时出错: {e}")
|
||||
return None
|
||||
@@ -1,98 +0,0 @@
|
||||
import time
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.common.database.database_model import MemoryConflict
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
class QuestionMaker:
|
||||
def __init__(self, chat_id: str, context: str = "") -> None:
|
||||
"""问题生成器。
|
||||
|
||||
- chat_id: 会话 ID,用于筛选该会话下的冲突记录。
|
||||
- context: 额外上下文,可用于后续扩展。
|
||||
|
||||
用法示例:
|
||||
>>> qm = QuestionMaker(chat_id="some_chat")
|
||||
>>> question, chat_ctx, conflict_ctx = await qm.make_question()
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.context = context
|
||||
|
||||
def get_context(self, timestamp: float = time.time()) -> str:
|
||||
"""获取指定时间点之前的对话上下文字符串。"""
|
||||
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=timestamp,
|
||||
limit=30,
|
||||
)
|
||||
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_30_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
)
|
||||
return all_dialogue_prompt_str
|
||||
|
||||
|
||||
async def get_all_conflicts(self) -> List[MemoryConflict]:
|
||||
"""获取当前会话下的所有记忆冲突记录。"""
|
||||
conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
|
||||
return conflicts
|
||||
|
||||
async def get_un_answered_conflict(self) -> List[MemoryConflict]:
|
||||
"""获取未回答的记忆冲突记录(answer 为空)。"""
|
||||
conflicts = await self.get_all_conflicts()
|
||||
return [conflict for conflict in conflicts if not conflict.answer]
|
||||
|
||||
async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]:
|
||||
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
||||
|
||||
选择规则:
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.01)。
|
||||
- 若不存在,返回 None。
|
||||
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
||||
"""
|
||||
conflicts = await self.get_un_answered_conflict()
|
||||
if not conflicts:
|
||||
return None
|
||||
|
||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||
if conflicts_with_zero:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
||||
weights = []
|
||||
for conflict in conflicts:
|
||||
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
||||
weight = 1.0 if current_raise_time == 0 else 0.01
|
||||
weights.append(weight)
|
||||
|
||||
# 按权重随机选择
|
||||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||
|
||||
# 选中后,自增 raise_time 并保存
|
||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||
chosen_conflict.save()
|
||||
|
||||
return chosen_conflict
|
||||
else:
|
||||
# 如果没有 raise_time == 0 的冲突,返回 None
|
||||
return None
|
||||
|
||||
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""生成一条用于询问用户的冲突问题与上下文。
|
||||
|
||||
返回三元组 (question, chat_context, conflict_context):
|
||||
- question: 冲突文本;若本次未选中任何冲突则为 None。
|
||||
- chat_context: 该冲突创建时间点前的会话上下文字符串;若无则为 None。
|
||||
- conflict_context: 冲突在 DB 中存储的上下文;若无则为 None。
|
||||
"""
|
||||
conflict = await self.get_random_unanswered_conflict()
|
||||
if not conflict:
|
||||
return None, None, None
|
||||
question = conflict.conflict_content
|
||||
conflict_context = conflict.context
|
||||
create_time = conflict.create_time
|
||||
chat_context = self.get_context(create_time)
|
||||
|
||||
return question, chat_context, conflict_context
|
||||
@@ -1,479 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryConflict
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from typing import List
|
||||
from src.memory_system.memory_utils import parse_md_json
|
||||
|
||||
logger = get_logger("conflict_tracker")
|
||||
|
||||
class QuestionTracker:
|
||||
"""
|
||||
用于跟踪一个问题在后续聊天中的解答情况
|
||||
"""
|
||||
|
||||
def __init__(self, question: str, chat_id: str, context: str = "") -> None:
|
||||
self.question = question
|
||||
self.chat_id = chat_id
|
||||
now = time.time()
|
||||
self.context = context
|
||||
self.start_time = now
|
||||
self.last_read_time = now
|
||||
self.last_judge_time = now # 上次判定的时间
|
||||
self.judge_debounce_interval = 10.0 # 判定防抖间隔:10秒
|
||||
self.consecutive_end_count = 0 # 连续END计数
|
||||
self.active = True
|
||||
# 将 LLM 实例作为类属性,使用 utils 模型
|
||||
self.llm_request = LLMRequest(model_set=model_config.model_task_config.utils, request_type="conflict.judge")
|
||||
|
||||
def stop(self) -> None:
|
||||
self.active = False
|
||||
|
||||
def should_judge_now(self) -> bool:
|
||||
"""
|
||||
检查是否应该进行判定(防抖检查)
|
||||
|
||||
Returns:
|
||||
bool: 是否可以判定
|
||||
"""
|
||||
now = time.time()
|
||||
# 检查是否已经过了10秒的防抖间隔
|
||||
return (now - self.last_judge_time) >= self.judge_debounce_interval
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""比较两个追踪器是否相等(基于问题内容和聊天ID)"""
|
||||
if not isinstance(other, QuestionTracker):
|
||||
return False
|
||||
return self.question == other.question and self.chat_id == other.chat_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""为对象提供哈希值,支持集合操作"""
|
||||
return hash((self.question, self.chat_id))
|
||||
|
||||
async def judge_answer(self, conversation_text: str,chat_len: int) -> tuple[bool, str, str]:
|
||||
"""
|
||||
使用模型判定问题是否已得到解答。
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, str]: (是否结束跟踪, 结束原因或答案, 判定类型)
|
||||
- True: 结束跟踪(已解答、话题转向等)
|
||||
- False: 继续跟踪
|
||||
判定类型: "ANSWERED", "END", "CONTINUE"
|
||||
"""
|
||||
|
||||
end_prompt = ""
|
||||
if chat_len > 20:
|
||||
end_prompt = "\n- 如果最新20条聊天记录内容与问题无关,话题已转向其他方向,请只输出:END"
|
||||
|
||||
prompt = f"""你是一个严谨的判定器。下面给出聊天记录以及一个问题。
|
||||
任务:判断在这段聊天中,该问题是否已经得到明确解答。
|
||||
**你必须严格按照聊天记录的内容,不要添加额外的信息**
|
||||
|
||||
输出规则:
|
||||
- 如果聊天记录内容的信息已解答问题,请只输出:YES: <简短答案>{end_prompt}
|
||||
- 如果问题尚未解答但聊天仍在相关话题上,请只输出:NO
|
||||
|
||||
**问题**
|
||||
{self.question}
|
||||
|
||||
|
||||
**聊天记录**
|
||||
{conversation_text}
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"判定提示词: {prompt}")
|
||||
else:
|
||||
logger.debug("已发送判定提示词")
|
||||
|
||||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.5)
|
||||
|
||||
logger.info(f"判定结果: {prompt}\n{result_text}")
|
||||
|
||||
# 更新上次判定时间
|
||||
self.last_judge_time = time.time()
|
||||
|
||||
if not result_text:
|
||||
return False, "", "CONTINUE"
|
||||
|
||||
text = result_text.strip()
|
||||
if text.upper().startswith("YES:"):
|
||||
answer = text[4:].strip()
|
||||
return True, answer, "ANSWERED"
|
||||
if text.upper().startswith("YES"):
|
||||
# 兼容仅输出 YES 或 YES <answer>
|
||||
answer = text[3:].strip().lstrip(":").strip()
|
||||
return True, answer, "ANSWERED"
|
||||
if text.upper().startswith("END"):
|
||||
# 聊天内容与问题无关,放弃该问题思考
|
||||
return True, "话题已转向其他方向,放弃该问题思考", "END"
|
||||
return False, "", "CONTINUE"
|
||||
|
||||
class ConflictTracker:
|
||||
"""
|
||||
记忆整合冲突追踪器
|
||||
|
||||
用于记录和存储记忆整合过程中的冲突内容
|
||||
"""
|
||||
def __init__(self):
|
||||
self.question_tracker_list:List[QuestionTracker] = []
|
||||
|
||||
self.LLMRequest_tracker = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="conflict_tracker",
|
||||
)
|
||||
|
||||
def get_questions_by_chat_id(self, chat_id: str) -> List[QuestionTracker]:
|
||||
return [tracker for tracker in self.question_tracker_list if tracker.chat_id == chat_id]
|
||||
|
||||
async def track_conflict(self, question: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
|
||||
"""
|
||||
跟踪冲突内容
|
||||
"""
|
||||
tracker = QuestionTracker(question.strip(), chat_id, context)
|
||||
self.question_tracker_list.append(tracker)
|
||||
asyncio.create_task(self._follow_and_record(tracker, question.strip()))
|
||||
return True
|
||||
|
||||
async def record_conflict(self, conflict_content: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
|
||||
"""
|
||||
记录冲突内容
|
||||
|
||||
Args:k
|
||||
conflict_content: 冲突内容
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
try:
|
||||
if not conflict_content or conflict_content.strip() == "":
|
||||
return False
|
||||
|
||||
# 若需要跟随后续消息以判断是否得到解答,则进入跟踪流程
|
||||
if start_following and chat_id:
|
||||
tracker = QuestionTracker(conflict_content.strip(), chat_id, context)
|
||||
self.question_tracker_list.append(tracker)
|
||||
# 后台启动跟踪任务,避免阻塞
|
||||
asyncio.create_task(self._follow_and_record(tracker, conflict_content.strip()))
|
||||
return True
|
||||
|
||||
# 默认:直接记录,不进行跟踪
|
||||
MemoryConflict.create(
|
||||
conflict_content=conflict_content,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录冲突内容时出错: {e}")
|
||||
return False
|
||||
|
||||
async def _follow_and_record(self, tracker: QuestionTracker, original_question: str) -> None:
|
||||
"""
|
||||
后台任务:跟踪问题是否被解答,并写入数据库。
|
||||
"""
|
||||
try:
|
||||
max_duration = 10 * 60 # 30 分钟
|
||||
max_messages = 50 # 最多 100 条消息
|
||||
poll_interval = 2.0 # 秒
|
||||
logger.info(f"开始跟踪问题: {original_question}")
|
||||
while tracker.active:
|
||||
now_ts = time.time()
|
||||
# 终止条件:时长达到上限
|
||||
if now_ts - tracker.start_time >= max_duration:
|
||||
logger.info("问题跟踪达到10分钟上限,判定为未解答")
|
||||
break
|
||||
|
||||
# 统计最近一段是否有新消息(不过滤机器人,过滤命令)
|
||||
recent_msgs = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=tracker.chat_id,
|
||||
timestamp_start=tracker.last_read_time,
|
||||
timestamp_end=now_ts,
|
||||
limit=30,
|
||||
limit_mode="latest",
|
||||
filter_bot=False,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
if len(recent_msgs) > 0:
|
||||
tracker.last_read_time = now_ts
|
||||
|
||||
# 统计从开始到现在的总消息数(用于触发100条上限)
|
||||
all_msgs = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=tracker.chat_id,
|
||||
timestamp_start=tracker.start_time,
|
||||
timestamp_end=now_ts,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_bot=False,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
# 检查是否应该进行判定(防抖检查)
|
||||
if not tracker.should_judge_now():
|
||||
logger.debug(f"判定防抖中,跳过本次判定: {tracker.question}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
continue
|
||||
|
||||
# 构建可读聊天文本
|
||||
chat_text = build_readable_messages(
|
||||
all_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=False,
|
||||
remove_emoji_stickers=True,
|
||||
)
|
||||
chat_len = len(all_msgs)
|
||||
# 让小模型判断是否有答案
|
||||
answered, answer_text, judge_type = await tracker.judge_answer(chat_text,chat_len)
|
||||
|
||||
if judge_type == "ANSWERED":
|
||||
# 问题已解答,直接结束跟踪
|
||||
logger.info("问题已得到解答,结束跟踪并写入答案")
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=tracker.question,
|
||||
create_time=tracker.start_time,
|
||||
update_time=time.time(),
|
||||
answer=answer_text or "",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
return
|
||||
elif judge_type == "END":
|
||||
# 话题转向,增加END计数
|
||||
tracker.consecutive_end_count += 1
|
||||
logger.info(f"话题已转向,连续END次数: {tracker.consecutive_end_count}")
|
||||
|
||||
if tracker.consecutive_end_count >= 2:
|
||||
# 连续两次END,结束跟踪
|
||||
logger.info("连续两次END,结束跟踪")
|
||||
break
|
||||
else:
|
||||
# 第一次END,重置计数器并继续跟踪
|
||||
logger.info("第一次END,继续跟踪")
|
||||
continue
|
||||
elif judge_type == "CONTINUE":
|
||||
# 继续跟踪,重置END计数器
|
||||
tracker.consecutive_end_count = 0
|
||||
continue
|
||||
|
||||
if len(all_msgs) >= max_messages:
|
||||
logger.info("问题跟踪达到100条消息上限,判定为未解答")
|
||||
logger.info(f"追踪结束:{tracker.question}")
|
||||
break
|
||||
|
||||
# 无新消息时稍作等待
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# 未获取到答案,检查是否需要删除记录
|
||||
# 查找现有的冲突记录
|
||||
existing_conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == original_question,
|
||||
MemoryConflict.chat_id == tracker.chat_id
|
||||
)
|
||||
|
||||
if existing_conflict:
|
||||
# 检查raise_time是否大于3且没有答案
|
||||
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
|
||||
if current_raise_time > 0 and not existing_conflict.answer:
|
||||
# 删除该条目
|
||||
await self.delete_conflict(original_question, tracker.chat_id)
|
||||
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
|
||||
else:
|
||||
# 更新记录但不删除
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=existing_conflict.create_time,
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
else:
|
||||
# 如果没有现有记录,创建新记录
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
|
||||
logger.info(f"问题跟踪结束:{original_question}")
|
||||
except Exception as e:
|
||||
logger.error(f"后台问题跟踪任务异常: {e}")
|
||||
finally:
|
||||
# 无论任务成功还是失败,都要从追踪列表中移除
|
||||
tracker.stop()
|
||||
self.remove_tracker(tracker)
|
||||
|
||||
def remove_tracker(self, tracker: QuestionTracker) -> None:
|
||||
"""
|
||||
从追踪列表中移除指定的追踪器
|
||||
|
||||
Args:
|
||||
tracker: 要移除的追踪器对象
|
||||
"""
|
||||
try:
|
||||
if tracker in self.question_tracker_list:
|
||||
self.question_tracker_list.remove(tracker)
|
||||
logger.info(f"已从追踪列表中移除追踪器: {tracker.question}")
|
||||
else:
|
||||
logger.warning(f"尝试移除不存在的追踪器: {tracker.question}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除追踪器时出错: {e}")
|
||||
|
||||
async def add_or_update_conflict(
|
||||
self,
|
||||
conflict_content: str,
|
||||
create_time: float,
|
||||
update_time: float,
|
||||
answer: str = "",
|
||||
context: str = "",
|
||||
chat_id: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
根据conflict_content匹配数据库内容,如果找到相同的就更新update_time和answer,
|
||||
如果没有相同的,就新建一条保存全部内容
|
||||
"""
|
||||
try:
|
||||
# 尝试根据conflict_content查找现有记录
|
||||
existing_conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == conflict_content,
|
||||
MemoryConflict.chat_id == chat_id
|
||||
)
|
||||
|
||||
if existing_conflict:
|
||||
# 如果找到相同的conflict_content,更新update_time和answer
|
||||
existing_conflict.update_time = update_time
|
||||
existing_conflict.answer = answer
|
||||
existing_conflict.save()
|
||||
return True
|
||||
else:
|
||||
# 如果没有找到相同的,创建新记录
|
||||
MemoryConflict.create(
|
||||
conflict_content=conflict_content,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
answer=answer,
|
||||
context=context,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
# 记录错误并返回False
|
||||
logger.error(f"添加或更新冲突记录时出错: {e}")
|
||||
return False
|
||||
|
||||
async def record_memory_merge_conflict(self, part2_content: str, chat_id: str = None) -> bool:
|
||||
"""
|
||||
记录记忆整合过程中的冲突内容(part2)
|
||||
|
||||
Args:
|
||||
part2_content: 冲突内容(part2)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
if not part2_content or part2_content.strip() == "":
|
||||
return False
|
||||
|
||||
prompt = f"""以下是一段有冲突的信息,请你根据这些信息总结出几个具体的提问:
|
||||
冲突信息:
|
||||
{part2_content}
|
||||
|
||||
要求:
|
||||
1.提问必须具体,明确
|
||||
2.提问最好涉及指向明确的事物,而不是代称
|
||||
3.如果缺少上下文,不要强行提问,可以忽略
|
||||
4.请忽略涉及违法,暴力,色情,政治等敏感话题的内容
|
||||
|
||||
请用json格式输出,不要输出其他内容,仅输出提问理由和具体提的提问:
|
||||
**示例**
|
||||
// 理由文本
|
||||
```json
|
||||
{{
|
||||
"question":"提问",
|
||||
}}
|
||||
```
|
||||
```json
|
||||
{{
|
||||
"question":"提问"
|
||||
}}
|
||||
```
|
||||
...提问数量在1-3个之间,不要重复,现在请输出:"""
|
||||
|
||||
question_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_tracker.generate_response_async(prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
questions, reasoning_content = parse_md_json(question_response)
|
||||
|
||||
print(prompt)
|
||||
print(question_response)
|
||||
|
||||
for question in questions:
|
||||
await self.record_conflict(
|
||||
conflict_content=question["question"],
|
||||
context=reasoning_content,
|
||||
start_following=False,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_conflict_count(self) -> int:
|
||||
"""
|
||||
获取冲突记录数量
|
||||
|
||||
Returns:
|
||||
int: 记录数量
|
||||
"""
|
||||
try:
|
||||
return MemoryConflict.select().count()
|
||||
except Exception as e:
|
||||
logger.error(f"获取冲突记录数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
|
||||
"""
|
||||
删除指定的冲突记录
|
||||
|
||||
Args:
|
||||
conflict_content: 冲突内容
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
try:
|
||||
conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == conflict_content,
|
||||
MemoryConflict.chat_id == chat_id
|
||||
)
|
||||
|
||||
if conflict:
|
||||
conflict.delete_instance()
|
||||
logger.info(f"已删除冲突记录: {conflict_content}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除冲突记录时出错: {e}")
|
||||
return False
|
||||
|
||||
# 全局冲突追踪器实例
|
||||
global_conflict_tracker = ConflictTracker()
|
||||
161
src/memory_system/retrieval_tools/README.md
Normal file
161
src/memory_system/retrieval_tools/README.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# 记忆检索工具模块
|
||||
|
||||
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
retrieval_tools/
|
||||
├── __init__.py # 模块导出
|
||||
├── tool_registry.py # 工具注册系统
|
||||
├── tool_utils.py # 工具函数库(共用函数)
|
||||
├── query_jargon.py # 查询jargon工具
|
||||
├── query_chat_history.py # 查询聊天历史工具
|
||||
├── query_lpmm_knowledge.py # 查询LPMM知识库工具
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 模块说明
|
||||
|
||||
### `tool_registry.py`
|
||||
包含工具注册系统的核心类:
|
||||
- `MemoryRetrievalTool`: 工具基类
|
||||
- `MemoryRetrievalToolRegistry`: 工具注册器
|
||||
- `register_memory_retrieval_tool()`: 便捷注册函数
|
||||
- `get_tool_registry()`: 获取注册器实例
|
||||
|
||||
### `tool_utils.py`
|
||||
包含所有工具共用的工具函数:
|
||||
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
|
||||
- `parse_time_range()`: 解析时间范围字符串
|
||||
|
||||
### 工具文件
|
||||
每个工具都有独立的文件:
|
||||
- `query_jargon.py`: 根据关键词在jargon库中查询
|
||||
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索)
|
||||
|
||||
## 如何添加新工具
|
||||
|
||||
1. 创建新的工具文件,例如 `query_new_tool.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
新工具 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
|
||||
"""新工具的实现
|
||||
|
||||
Args:
|
||||
param1: 参数1
|
||||
param2: 参数2
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 实现逻辑
|
||||
return "结果"
|
||||
except Exception as e:
|
||||
logger.error(f"新工具执行失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_new_tool",
|
||||
description="新工具的描述",
|
||||
parameters=[
|
||||
{
|
||||
"name": "param1",
|
||||
"type": "string",
|
||||
"description": "参数1的描述",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "param2",
|
||||
"type": "string",
|
||||
"description": "参数2的描述",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_new_tool
|
||||
)
|
||||
```
|
||||
|
||||
2. 在 `__init__.py` 中导入并注册新工具:
|
||||
|
||||
```python
|
||||
from .query_new_tool import register_tool as register_query_new_tool
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_new_tool() # 添加新工具
|
||||
```
|
||||
|
||||
3. 工具会自动:
|
||||
- 出现在 ReAct Agent 的 prompt 中
|
||||
- 在动作类型列表中可用
|
||||
- 被 ReAct Agent 自动调用
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
|
||||
|
||||
# 初始化所有工具
|
||||
init_all_tools()
|
||||
|
||||
# 获取工具注册器
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 获取特定工具
|
||||
tool = registry.get_tool("query_chat_history")
|
||||
|
||||
# 执行工具(查询时间点事件)
|
||||
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
|
||||
|
||||
# 或者查询关键词
|
||||
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
|
||||
|
||||
# 或者查询时间范围
|
||||
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
|
||||
```
|
||||
|
||||
## 现有工具说明
|
||||
|
||||
### query_jargon
|
||||
根据关键词在jargon库中查询黑话/俚语/缩写的含义
|
||||
- 参数:`keyword` (必填) - 关键词
|
||||
|
||||
### query_chat_history
|
||||
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息
|
||||
- 参数:
|
||||
- `keyword` (可选) - 关键词,用于搜索消息内容
|
||||
- `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一)
|
||||
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一)
|
||||
|
||||
### query_lpmm_knowledge
|
||||
从LPMM知识库中检索与关键词相关的知识内容
|
||||
- 参数:
|
||||
- `query` (必填) - 查询的关键词或问题描述
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有工具函数必须是异步函数(`async def`)
|
||||
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
|
||||
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
|
||||
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
|
||||
- 共用函数放在 `tool_utils.py` 中,避免代码重复
|
||||
|
||||
36
src/memory_system/retrieval_tools/__init__.py
Normal file
36
src/memory_system/retrieval_tools/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
记忆检索工具模块
|
||||
提供统一的工具注册和管理系统
|
||||
"""
|
||||
|
||||
from .tool_registry import (
|
||||
MemoryRetrievalTool,
|
||||
MemoryRetrievalToolRegistry,
|
||||
register_memory_retrieval_tool,
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_jargon import register_tool as register_query_jargon
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from src.config.config import global_config
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryRetrievalTool",
|
||||
"MemoryRetrievalToolRegistry",
|
||||
"register_memory_retrieval_tool",
|
||||
"get_tool_registry",
|
||||
"init_all_tools",
|
||||
]
|
||||
218
src/memory_system/retrieval_tools/query_chat_history.py
Normal file
218
src/memory_system/retrieval_tools/query_chat_history.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
根据时间或关键词在chat_history中查询 - 工具实现
|
||||
从ChatHistory表的聊天记录概述库中查询
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from ..memory_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_range: Optional[str] = None,
|
||||
fuzzy: bool = True
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
time_range: 时间范围或时间点,格式:
|
||||
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
|
||||
fuzzy: 是否使用模糊匹配模式(默认True)
|
||||
- True: 模糊匹配,只要包含任意一个关键词即匹配(OR关系)
|
||||
- False: 全匹配,必须包含所有关键词才匹配(AND关系)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 时间过滤条件
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
if " - " in time_range:
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
query = query.where(time_filter)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
filtered_records = []
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 根据匹配模式检查关键词
|
||||
matched = False
|
||||
if fuzzy:
|
||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||
matched = True
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list))
|
||||
if not kw_matched:
|
||||
matched = False
|
||||
break
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
|
||||
if time_range:
|
||||
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
|
||||
else:
|
||||
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
|
||||
|
||||
records = filtered_records
|
||||
|
||||
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
|
||||
if not records:
|
||||
if time_range:
|
||||
return "未找到指定时间范围内的聊天记录概述"
|
||||
else:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
# 对即将返回的记录增加使用计数
|
||||
records_to_use = records[:3]
|
||||
for record in records_to_use:
|
||||
try:
|
||||
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
result_parts = []
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
elif record.original_text:
|
||||
text_preview = record.original_text[:200]
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_chat_history",
|
||||
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
)
|
||||
76
src/memory_system/retrieval_tools/query_jargon.py
Normal file
76
src/memory_system/retrieval_tools/query_jargon.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
根据关键词在jargon库中查询 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_jargon(keyword: str, chat_id: str) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
Args:
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(keyword).strip()
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
# 先尝试精确匹配
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not results:
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if results:
|
||||
# 如果是模糊匹配,显示找到的实际jargon内容
|
||||
if is_fuzzy_match:
|
||||
# 处理多个结果
|
||||
output_parts = [f"未精确匹配到'{content}'"]
|
||||
for result in results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
output = ",".join(output_parts)
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,模糊搜索): {content},找到{len(results)}条结果")
|
||||
else:
|
||||
# 精确匹配,可能有多条(相同content但不同chat_id的情况)
|
||||
output_parts = []
|
||||
for result in results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{content}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0]
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果")
|
||||
return output
|
||||
|
||||
# 未命中
|
||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询jargon失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
|
||||
execute_func=query_jargon,
|
||||
)
|
||||
65
src/memory_system/retrieval_tools/query_lpmm_knowledge.py
Normal file
65
src/memory_system/retrieval_tools/query_lpmm_knowledge.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
通过LPMM知识库查询信息 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import get_qa_manager
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_lpmm_knowledge(query: str) -> str:
|
||||
"""在LPMM知识库中查询相关信息
|
||||
|
||||
Args:
|
||||
query: 查询关键词
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(query).strip()
|
||||
if not content:
|
||||
return "查询关键词为空"
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用")
|
||||
return "LPMM知识库未启用"
|
||||
|
||||
qa_manager = get_qa_manager()
|
||||
if qa_manager is None:
|
||||
logger.debug("LPMM知识库未初始化,跳过查询")
|
||||
return "LPMM知识库未初始化"
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(content)
|
||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}"
|
||||
|
||||
return f"在LPMM知识库中未找到与“{content}”相关的信息"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM知识库查询失败: {e}")
|
||||
return f"LPMM知识库查询失败:{str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册LPMM知识库查询工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="lpmm_search_knowledge",
|
||||
description="从LPMM知识库中搜索相关信息,适用于需要知识支持的场景。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的关键词或问题",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
|
||||
|
||||
287
src/memory_system/retrieval_tools/query_person_info.py
Normal file
287
src/memory_system/retrieval_tools/query_person_info.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
根据person_name查询用户信息 - 工具实现
|
||||
支持模糊查询,可以查询某个用户的所有信息
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
"""格式化群昵称信息
|
||||
|
||||
Args:
|
||||
group_nick_name_field: 群昵称字段(可能是字符串JSON或None)
|
||||
|
||||
Returns:
|
||||
str: 格式化后的群昵称信息字符串
|
||||
"""
|
||||
if not group_nick_name_field:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# 解析JSON格式的群昵称列表
|
||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
|
||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||
return ""
|
||||
|
||||
# 格式化群昵称列表
|
||||
group_nick_list = []
|
||||
for item in group_nick_names_data:
|
||||
if isinstance(item, dict):
|
||||
group_id = item.get("group_id", "未知群号")
|
||||
group_nick_name = item.get("group_nick_name", "未知群昵称")
|
||||
group_nick_list.append(f" - 群号 {group_id}:{group_nick_name}")
|
||||
elif isinstance(item, str):
|
||||
# 兼容旧格式(如果存在)
|
||||
group_nick_list.append(f" - {item}")
|
||||
|
||||
if group_nick_list:
|
||||
return "群昵称:\n" + "\n".join(group_nick_list)
|
||||
return ""
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"解析群昵称信息失败: {e}")
|
||||
# 如果解析失败,尝试显示原始内容(截断)
|
||||
if isinstance(group_nick_name_field, str):
|
||||
preview = group_nick_name_field[:200]
|
||||
if len(group_nick_name_field) > 200:
|
||||
preview += "..."
|
||||
return f"群昵称(原始数据):{preview}"
|
||||
return ""
|
||||
|
||||
|
||||
async def query_person_info(person_name: str) -> str:
|
||||
"""根据person_name查询用户信息,使用模糊查询
|
||||
|
||||
Args:
|
||||
person_name: 用户名称(person_name字段)
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含用户的所有信息
|
||||
"""
|
||||
try:
|
||||
person_name = str(person_name).strip()
|
||||
if not person_name:
|
||||
return "用户名称为空"
|
||||
|
||||
# 构建查询条件(使用模糊查询)
|
||||
query = PersonInfo.select().where(
|
||||
PersonInfo.person_name.contains(person_name)
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.limit(20)) # 最多返回20条记录
|
||||
|
||||
if not records:
|
||||
return f"未找到模糊匹配'{person_name}'的用户信息"
|
||||
|
||||
# 区分精确匹配和模糊匹配的结果
|
||||
exact_matches = []
|
||||
fuzzy_matches = []
|
||||
|
||||
for record in records:
|
||||
# 检查是否是精确匹配
|
||||
if record.person_name and record.person_name.strip() == person_name:
|
||||
exact_matches.append(record)
|
||||
else:
|
||||
fuzzy_matches.append(record)
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
|
||||
# 先处理精确匹配的结果
|
||||
for record in exact_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【精确匹配】") # 标注为精确匹配
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
if record.nickname:
|
||||
result_parts.append(f"昵称:{record.nickname}")
|
||||
if record.person_id:
|
||||
result_parts.append(f"用户ID:{record.person_id}")
|
||||
if record.platform:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"首次认识时间:{know_since_str}")
|
||||
if record.last_know:
|
||||
last_know_str = datetime.fromtimestamp(record.last_know).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
for memory_point in memory_points_data:
|
||||
if memory_point and isinstance(memory_point, str):
|
||||
parts = memory_point.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
category = parts[0].strip()
|
||||
content = parts[1].strip()
|
||||
weight = parts[2].strip()
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"解析用户 {record.person_id} 的memory_points失败: {e}")
|
||||
# 如果解析失败,直接显示原始内容(截断)
|
||||
memory_preview = str(record.memory_points)[:200]
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
# 再处理模糊匹配的结果
|
||||
for record in fuzzy_matches:
|
||||
result_parts = []
|
||||
result_parts.append("【模糊匹配】") # 标注为模糊匹配
|
||||
|
||||
# 基本信息
|
||||
if record.person_name:
|
||||
result_parts.append(f"用户名称:{record.person_name}")
|
||||
if record.nickname:
|
||||
result_parts.append(f"昵称:{record.nickname}")
|
||||
if record.person_id:
|
||||
result_parts.append(f"用户ID:{record.person_id}")
|
||||
if record.platform:
|
||||
result_parts.append(f"平台:{record.platform}")
|
||||
if record.user_id:
|
||||
result_parts.append(f"平台用户ID:{record.user_id}")
|
||||
|
||||
# 群昵称信息
|
||||
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
|
||||
if group_nick_name_str:
|
||||
result_parts.append(group_nick_name_str)
|
||||
|
||||
# 名称设定原因
|
||||
if record.name_reason:
|
||||
result_parts.append(f"名称设定原因:{record.name_reason}")
|
||||
|
||||
# 认识状态
|
||||
result_parts.append(f"是否已认识:{'是' if record.is_known else '否'}")
|
||||
|
||||
# 时间信息
|
||||
if record.know_since:
|
||||
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"首次认识时间:{know_since_str}")
|
||||
if record.last_know:
|
||||
last_know_str = datetime.fromtimestamp(record.last_know).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"最后认识时间:{last_know_str}")
|
||||
if record.know_times:
|
||||
result_parts.append(f"认识次数:{int(record.know_times)}")
|
||||
|
||||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
for memory_point in memory_points_data:
|
||||
if memory_point and isinstance(memory_point, str):
|
||||
parts = memory_point.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
category = parts[0].strip()
|
||||
content = parts[1].strip()
|
||||
weight = parts[2].strip()
|
||||
memory_list.append(f" - [{category}] {content} (权重: {weight})")
|
||||
else:
|
||||
memory_list.append(f" - {memory_point}")
|
||||
|
||||
if memory_list:
|
||||
result_parts.append("记忆点:\n" + "\n".join(memory_list))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"解析用户 {record.person_id} 的memory_points失败: {e}")
|
||||
# 如果解析失败,直接显示原始内容(截断)
|
||||
memory_preview = str(record.memory_points)[:200]
|
||||
if len(str(record.memory_points)) > 200:
|
||||
memory_preview += "..."
|
||||
result_parts.append(f"记忆点(原始数据):{memory_preview}")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
# 组合所有结果
|
||||
if not results:
|
||||
return f"未找到匹配'{person_name}'的用户信息"
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
# 添加统计信息
|
||||
total_count = len(records)
|
||||
exact_count = len(exact_matches)
|
||||
fuzzy_count = len(fuzzy_matches)
|
||||
|
||||
# 显示精确匹配和模糊匹配的统计
|
||||
if exact_count > 0 or fuzzy_count > 0:
|
||||
stats_parts = []
|
||||
if exact_count > 0:
|
||||
stats_parts.append(f"精确匹配:{exact_count} 条")
|
||||
if fuzzy_count > 0:
|
||||
stats_parts.append(f"模糊匹配:{fuzzy_count} 条")
|
||||
stats_text = ",".join(stats_parts)
|
||||
response_text = f"找到 {total_count} 条匹配的用户信息({stats_text}):\n\n{response_text}"
|
||||
elif total_count > 1:
|
||||
response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}"
|
||||
else:
|
||||
response_text = f"找到用户信息:\n\n{response_text}"
|
||||
|
||||
# 如果结果数量达到限制,添加提示
|
||||
if total_count >= 20:
|
||||
response_text += "\n\n(已显示前20条结果,可能还有更多匹配记录)"
|
||||
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询用户信息失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_person_info",
|
||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||
parameters=[
|
||||
{
|
||||
"name": "person_name",
|
||||
"type": "string",
|
||||
"description": "用户名称,用于查询用户信息",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_person_info
|
||||
)
|
||||
|
||||
160
src/memory_system/retrieval_tools/tool_registry.py
Normal file
160
src/memory_system/retrieval_tools/tool_registry.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
工具注册系统
|
||||
提供统一的工具注册和管理接口
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
):
|
||||
"""
|
||||
初始化工具
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
|
||||
execute_func: 执行函数,必须是异步函数
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_description(self) -> str:
|
||||
"""获取工具的文本描述,用于prompt"""
|
||||
param_descriptions = []
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
required = param.get("required", True)
|
||||
required_str = "必填" if required else "可选"
|
||||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM function calling
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 工具定义字典,格式与BaseTool一致
|
||||
格式: {"name": str, "description": str, "parameters": List[Tuple]}
|
||||
"""
|
||||
# 转换参数格式为元组列表,格式与BaseTool一致
|
||||
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
|
||||
param_tuples = []
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type_str = param.get("type", "string").lower()
|
||||
param_desc = param.get("description", "")
|
||||
is_required = param.get("required", False)
|
||||
enum_values = param.get("enum", None)
|
||||
|
||||
# 转换类型字符串到ToolParamType
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"int": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
# 构建参数元组
|
||||
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
|
||||
param_tuples.append(param_tuple)
|
||||
|
||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||
tool_def = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": param_tuples
|
||||
}
|
||||
|
||||
return tool_def
|
||||
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
def __init__(self):
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
"""注册工具"""
|
||||
if tool.name in self.tools:
|
||||
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
|
||||
return
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||
"""获取工具"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||
"""获取所有工具"""
|
||||
return self.tools.copy()
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""获取所有工具的描述,用于prompt"""
|
||||
descriptions = []
|
||||
for i, tool in enumerate(self.tools.values(), 1):
|
||||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def get_action_types_list(self) -> str:
|
||||
"""获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)"""
|
||||
action_types = [tool.name for tool in self.tools.values()]
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的定义列表,用于LLM function calling
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||||
"""
|
||||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||||
|
||||
|
||||
# 全局工具注册器实例
|
||||
_tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表
|
||||
execute_func: 执行函数
|
||||
"""
|
||||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||
_tool_registry.register_tool(tool)
|
||||
|
||||
|
||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||
"""获取工具注册器实例"""
|
||||
return _tool_registry
|
||||
@@ -1,10 +1,7 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
@@ -25,7 +22,7 @@ def init_prompt():
|
||||
你先前的情绪状态是:{mood_state}
|
||||
你的情绪特点是:{emotion_style}
|
||||
|
||||
现在,请你根据先前的情绪状态和现在的聊天内容,总结推断你现在的情绪状态
|
||||
现在,请你根据先前的情绪状态和现在的聊天内容,总结推断你现在的情绪状态,用简短的词句来描述情绪状态
|
||||
请只输出新的情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"get_mood_prompt",
|
||||
@@ -39,7 +36,7 @@ def init_prompt():
|
||||
{identity_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话或几个词来描述你现在的情绪状态
|
||||
你的情绪特点是:{emotion_style}
|
||||
请只输出新的情绪状态,不要输出其他内容:
|
||||
""",
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.common.database.database import db
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -160,7 +161,9 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
|
||||
class Person:
|
||||
@classmethod
|
||||
def register_person(cls, platform: str, user_id: str, nickname: str):
|
||||
def register_person(
|
||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
注册新用户的类方法
|
||||
必须输入 platform、user_id 和 nickname 参数
|
||||
@@ -169,6 +172,8 @@ class Person:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID
|
||||
nickname: 用户昵称
|
||||
group_id: 群号(可选,仅在群聊时提供)
|
||||
group_nick_name: 群昵称(可选,仅在群聊时提供)
|
||||
|
||||
Returns:
|
||||
Person: 新注册的Person实例
|
||||
@@ -182,7 +187,11 @@ class Person:
|
||||
|
||||
if is_person_known(person_id=person_id):
|
||||
logger.debug(f"用户 {nickname} 已存在")
|
||||
return Person(person_id=person_id)
|
||||
person = Person(person_id=person_id)
|
||||
# 如果是群聊,更新群昵称
|
||||
if group_id and group_nick_name:
|
||||
person.add_group_nick_name(group_id, group_nick_name)
|
||||
return person
|
||||
|
||||
# 创建Person实例
|
||||
person = cls.__new__(cls)
|
||||
@@ -201,6 +210,11 @@ class Person:
|
||||
person.know_since = time.time()
|
||||
person.last_know = time.time()
|
||||
person.memory_points = []
|
||||
person.group_nick_name = [] # 初始化群昵称列表
|
||||
|
||||
# 如果是群聊,添加群昵称
|
||||
if group_id and group_nick_name:
|
||||
person.add_group_nick_name(group_id, group_nick_name)
|
||||
|
||||
# 同步到数据库
|
||||
person.sync_to_database()
|
||||
@@ -217,6 +231,7 @@ class Person:
|
||||
self.platform = platform
|
||||
self.nickname = global_config.bot.nickname
|
||||
self.person_name = global_config.bot.nickname
|
||||
self.group_nick_name: list[dict[str, str]] = []
|
||||
return
|
||||
|
||||
self.user_id = ""
|
||||
@@ -255,6 +270,7 @@ class Person:
|
||||
self.know_since = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
@@ -342,6 +358,31 @@ class Person:
|
||||
return memory_list
|
||||
return random.sample(memory_list, num)
|
||||
|
||||
def add_group_nick_name(self, group_id: str, group_nick_name: str):
|
||||
"""
|
||||
添加或更新群昵称
|
||||
|
||||
Args:
|
||||
group_id: 群号
|
||||
group_nick_name: 群昵称
|
||||
"""
|
||||
if not group_id or not group_nick_name:
|
||||
return
|
||||
|
||||
# 检查是否已存在该群号的记录
|
||||
for item in self.group_nick_name:
|
||||
if item.get("group_id") == group_id:
|
||||
# 更新现有记录
|
||||
item["group_nick_name"] = group_nick_name
|
||||
self.sync_to_database()
|
||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||
return
|
||||
|
||||
# 添加新记录
|
||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
||||
self.sync_to_database()
|
||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||
|
||||
def load_from_database(self):
|
||||
"""从数据库加载个人信息数据"""
|
||||
try:
|
||||
@@ -372,6 +413,21 @@ class Person:
|
||||
else:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
if record.group_nick_name:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_nick_name)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
|
||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||
else:
|
||||
self.sync_to_database()
|
||||
@@ -403,6 +459,9 @@ class Person:
|
||||
)
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False),
|
||||
"group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False),
|
||||
}
|
||||
|
||||
# 检查记录是否存在
|
||||
@@ -664,3 +723,74 @@ class PersonInfoManager:
|
||||
|
||||
|
||||
person_info_manager = PersonInfoManager()
|
||||
|
||||
|
||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||
"""将人物信息存入person_info的memory_points
|
||||
|
||||
Args:
|
||||
person_name: 人物名称
|
||||
memory_content: 记忆内容
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
try:
|
||||
# 从chat_id获取chat_stream
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
platform = chat_stream.platform
|
||||
|
||||
# 尝试从person_name查找person_id
|
||||
# 首先尝试通过person_name查找
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
|
||||
if not person_id:
|
||||
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
||||
if chat_stream.user_info:
|
||||
user_id = chat_stream.user_info.user_id
|
||||
person_id = get_person_id(platform, user_id)
|
||||
else:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
# 创建或获取Person对象
|
||||
person = Person(person_id=person_id)
|
||||
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||
return
|
||||
|
||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
||||
category = "其他" # 默认分类,可以根据需要调整
|
||||
|
||||
# 记忆点格式:category:content:weight
|
||||
weight = "1.0" # 默认权重
|
||||
memory_point = f"{category}:{memory_content}:{weight}"
|
||||
|
||||
# 添加到memory_points
|
||||
if not person.memory_points:
|
||||
person.memory_points = []
|
||||
|
||||
# 检查是否已存在相似的记忆点(避免重复)
|
||||
is_duplicate = False
|
||||
for existing_point in person.memory_points:
|
||||
if existing_point and isinstance(existing_point, str):
|
||||
parts = existing_point.split(":", 2)
|
||||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
person.memory_points.append(memory_point)
|
||||
person.sync_to_database()
|
||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
||||
else:
|
||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储人物记忆失败: {e}")
|
||||
|
||||
@@ -53,6 +53,7 @@ from .apis import (
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
auto_talk_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
@@ -83,6 +84,7 @@ __all__ = [
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"auto_talk_api",
|
||||
"register_plugin",
|
||||
"get_logger",
|
||||
# 基础类
|
||||
|
||||
@@ -20,6 +20,7 @@ from src.plugin_system.apis import (
|
||||
tool_api,
|
||||
frequency_api,
|
||||
mood_api,
|
||||
auto_talk_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
@@ -42,4 +43,5 @@ __all__ = [
|
||||
"tool_api",
|
||||
"frequency_api",
|
||||
"mood_api",
|
||||
"auto_talk_api",
|
||||
]
|
||||
|
||||
56
src/plugin_system/apis/auto_talk_api.py
Normal file
56
src/plugin_system/apis/auto_talk_api.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("auto_talk_api")
|
||||
|
||||
|
||||
def set_question_probability_multiplier(chat_id: str, multiplier: float) -> bool:
|
||||
"""
|
||||
设置指定 chat_id 的主动发言概率乘数。
|
||||
|
||||
返回:
|
||||
bool: 设置是否成功。仅当目标聊天为群聊(HeartFChatting)且存在时为 True。
|
||||
"""
|
||||
try:
|
||||
if not isinstance(chat_id, str):
|
||||
raise TypeError("chat_id 必须是 str")
|
||||
if not isinstance(multiplier, (int, float)):
|
||||
raise TypeError("multiplier 必须是数值类型")
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.chat.heart_flow.heartflow import heartflow as _heartflow
|
||||
|
||||
chat = _heartflow.heartflow_chat_list.get(chat_id)
|
||||
if chat is None:
|
||||
logger.warning(f"未找到 chat_id={chat_id} 的心流实例,无法设置乘数")
|
||||
return False
|
||||
|
||||
# 仅对拥有该属性的群聊心流生效(鸭子类型,避免导入类)
|
||||
if not hasattr(chat, "question_probability_multiplier"):
|
||||
logger.warning(f"chat_id={chat_id} 实例不支持主动发言乘数设置")
|
||||
return False
|
||||
|
||||
# 约束:不允许负值
|
||||
value = float(multiplier)
|
||||
if value < 0:
|
||||
value = 0.0
|
||||
|
||||
chat.question_probability_multiplier = value
|
||||
logger.info(f"[auto_talk_api] chat_id={chat_id} 主动发言乘数已设为 {value}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置主动发言乘数失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_question_probability_multiplier(chat_id: str) -> float:
|
||||
"""获取指定 chat_id 的主动发言概率乘数,未找到则返回 0。"""
|
||||
try:
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.chat.heart_flow.heartflow import heartflow as _heartflow
|
||||
|
||||
chat = _heartflow.heartflow_chat_list.get(chat_id)
|
||||
if chat is None:
|
||||
return 0.0
|
||||
return float(getattr(chat, "question_probability_multiplier", 0.0))
|
||||
except Exception:
|
||||
return 0.0
|
||||
@@ -6,7 +6,9 @@ logger = get_logger("frequency_api")
|
||||
|
||||
|
||||
def get_current_talk_value(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
||||
return frequency_control_manager.get_or_create_frequency_control(
|
||||
chat_id
|
||||
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
||||
|
||||
|
||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||
|
||||
@@ -7,9 +7,11 @@
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from typing import Tuple, Dict, List, Any, Optional, Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
@@ -120,3 +122,44 @@ async def generate_with_model_with_tools(
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
|
||||
|
||||
async def generate_with_model_with_tools_by_message_factory(
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
|
||||
|
||||
Args:
|
||||
message_factory: 消息工厂函数
|
||||
model_config: 模型配置
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容(消息工厂)")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
|
||||
@@ -109,7 +109,7 @@ def get_messages_by_time_in_chat(
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command
|
||||
filter_command=filter_command,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ logger = get_logger("tool_api")
|
||||
|
||||
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
chat_stream: 聊天流对象,用于传递聊天上下文信息
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果未找到则返回None
|
||||
"""
|
||||
|
||||
@@ -77,7 +77,7 @@ class BaseAction(ABC):
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.activation_type = getattr(self.__class__, "activation_type")
|
||||
self.activation_type = self.__class__.activation_type
|
||||
"""激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
@@ -108,21 +108,16 @@ class BaseAction(ABC):
|
||||
self.is_group = False
|
||||
self.target_id = None
|
||||
|
||||
|
||||
self.group_id = (
|
||||
str(self.action_message.chat_info.group_info.group_id)
|
||||
if self.action_message.chat_info.group_info
|
||||
else None
|
||||
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
|
||||
)
|
||||
self.group_name = (
|
||||
self.action_message.chat_info.group_info.group_name
|
||||
if self.action_message.chat_info.group_info
|
||||
else None
|
||||
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
|
||||
)
|
||||
|
||||
self.user_id = str(self.action_message.user_info.user_id)
|
||||
self.user_nickname = self.action_message.user_info.user_nickname
|
||||
|
||||
|
||||
if self.group_id:
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
@@ -132,7 +127,6 @@ class BaseAction(ABC):
|
||||
self.target_id = self.user_id
|
||||
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
@@ -448,7 +442,6 @@ class BaseAction(ABC):
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
|
||||
# 检查新消息
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
@@ -497,7 +490,7 @@ class BaseAction(ABC):
|
||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
# 获取focus_activation_type和normal_activation_type
|
||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
||||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
||||
|
||||
@@ -34,17 +34,17 @@ class BaseTool(ABC):
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
|
||||
"""初始化工具基类
|
||||
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
chat_stream: 聊天流对象,用于获取聊天上下文信息
|
||||
"""
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(与BaseAction保持一致)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# 获取聊天流对象
|
||||
self.chat_stream = chat_stream
|
||||
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None
|
||||
@@ -57,7 +57,7 @@ class BaseTool(ABC):
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
if not cls.name or not cls.description or cls.parameters is None:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
@@ -65,7 +65,7 @@ class BaseTool(ABC):
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
if not cls.name or not cls.description or cls.parameters is None:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return ToolInfo(
|
||||
|
||||
@@ -346,9 +346,7 @@ class EventsManager:
|
||||
|
||||
if not isinstance(result, tuple) or len(result) != 5:
|
||||
if isinstance(result, tuple):
|
||||
annotated = ", ".join(
|
||||
f"{name}={val!r}" for name, val in zip(expected_fields, result)
|
||||
)
|
||||
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
|
||||
actual_desc = f"{len(result)} 个元素 ({annotated})"
|
||||
else:
|
||||
actual_desc = f"非 tuple 类型: {type(result)}"
|
||||
@@ -380,7 +378,6 @@ class EventsManager:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||
return True, None # 发生异常时默认不中断其他处理
|
||||
|
||||
|
||||
def _task_done_callback(
|
||||
self,
|
||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
||||
|
||||
@@ -93,6 +93,14 @@ class ToolExecutor:
|
||||
# 获取可用工具
|
||||
tools = self._get_tool_definitions()
|
||||
|
||||
# 如果没有可用工具,直接返回空内容
|
||||
if not tools:
|
||||
logger.info(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||
if return_details:
|
||||
return [], [], ""
|
||||
else:
|
||||
return [], [], ""
|
||||
|
||||
# print(f"tools: {tools}")
|
||||
|
||||
# 获取当前时间
|
||||
@@ -116,6 +124,7 @@ class ToolExecutor:
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
@@ -180,9 +189,8 @@ class ToolExecutor:
|
||||
tool_info["content"] = str(content)
|
||||
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
|
||||
content_check = tool_info["content"]
|
||||
if (
|
||||
(isinstance(content_check, str) and not content_check.strip())
|
||||
or (isinstance(content_check, (list, tuple)) and len(content_check) == 0)
|
||||
if (isinstance(content_check, str) and not content_check.strip()) or (
|
||||
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
||||
):
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
|
||||
continue
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Memory Build组件",
|
||||
"version": "1.0.0",
|
||||
"description": "可以构建和管理记忆",
|
||||
"author": {
|
||||
"name": "Mai",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.4"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["memory", "build", "built-in"],
|
||||
"categories": ["Memory"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "build_memory",
|
||||
"name": "build_memory",
|
||||
"description": "构建记忆"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.apis.message_api import get_messages_by_time_in_chat, build_readable_messages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
def parse_time_range(time_range: str) -> tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
格式: "YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError("时间范围格式错误,应使用 ' - ' 分隔开始和结束时间")
|
||||
|
||||
start_str, end_str = time_range.split(" - ", 1)
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str.strip())
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str.strip())
|
||||
|
||||
if start_timestamp > end_timestamp:
|
||||
raise ValueError("开始时间不能晚于结束时间")
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
class GetMemoryTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "get_memory"
|
||||
description = "在记忆中搜索,获取某个问题的答案,可以指定搜索的时间范围或时间点"
|
||||
parameters = [
|
||||
("question", ToolParamType.STRING, "需要获取答案的问题", True, None),
|
||||
("time_point", ToolParamType.STRING, "需要获取记忆的时间点,格式为YYYY-MM-DD HH:MM:SS", False, None),
|
||||
("time_range", ToolParamType.STRING, "需要获取记忆的时间范围,格式为YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS", False, None)
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行记忆搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
question: str = function_args.get("question") # type: ignore
|
||||
time_point: str = function_args.get("time_point") # type: ignore
|
||||
time_range: str = function_args.get("time_range") # type: ignore
|
||||
|
||||
# 检查是否指定了时间参数
|
||||
has_time_params = bool(time_point or time_range)
|
||||
|
||||
if has_time_params and not self.chat_id:
|
||||
return {"content": f"问题:{question},无法获取聊天记录:缺少chat_id"}
|
||||
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
|
||||
# 原任务:从记忆仓库获取答案
|
||||
memory_task = asyncio.create_task(
|
||||
global_memory_chest.get_answer_by_question(question=question)
|
||||
)
|
||||
tasks.append(("memory", memory_task))
|
||||
|
||||
# 新任务:从聊天记录获取答案(如果指定了时间参数)
|
||||
chat_task = None
|
||||
if has_time_params:
|
||||
chat_task = asyncio.create_task(
|
||||
self._get_answer_from_chat_history(question, time_point, time_range)
|
||||
)
|
||||
tasks.append(("chat", chat_task))
|
||||
|
||||
# 等待所有任务完成
|
||||
results = {}
|
||||
for task_name, task in tasks:
|
||||
try:
|
||||
results[task_name] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task_name} 执行失败: {e}")
|
||||
results[task_name] = None
|
||||
|
||||
# 处理结果
|
||||
memory_answer = results.get("memory")
|
||||
chat_answer = results.get("chat")
|
||||
|
||||
# 构建返回内容
|
||||
content_parts = []
|
||||
|
||||
if memory_answer:
|
||||
content_parts.append(f"对问题'{question}',你回忆的信息是:{memory_answer}")
|
||||
|
||||
if chat_answer:
|
||||
content_parts.append(f"对问题'{question}',基于聊天记录的回答:{chat_answer}")
|
||||
elif has_time_params:
|
||||
if time_point:
|
||||
content_parts.append(f"在 {time_point} 的时间点,你没有参与聊天")
|
||||
elif time_range:
|
||||
content_parts.append(f"在 {time_range} 的时间范围内,你没有参与聊天")
|
||||
|
||||
if content_parts:
|
||||
retrieval_content = f"问题:{question}" + "\n".join(content_parts)
|
||||
return {"content": retrieval_content}
|
||||
else:
|
||||
return {"content": ""}
|
||||
|
||||
|
||||
async def _get_answer_from_chat_history(self, question: str, time_point: str = None, time_range: str = None) -> str:
|
||||
"""从聊天记录中获取问题的答案"""
|
||||
try:
|
||||
# 确定时间范围
|
||||
print(f"time_point: {time_point}, time_range: {time_range}")
|
||||
|
||||
# 检查time_range的两个时间值是否相同,如果相同则按照time_point处理
|
||||
if time_range and not time_point:
|
||||
try:
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
if start_timestamp == end_timestamp:
|
||||
# 两个时间值相同,按照time_point处理
|
||||
time_point = time_range.split(" - ")[0].strip()
|
||||
time_range = None
|
||||
print(f"time_range两个值相同,按照time_point处理: {time_point}")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析time_range失败: {e}")
|
||||
|
||||
if time_point:
|
||||
# 时间点:搜索前后25条记录
|
||||
target_timestamp = parse_datetime_to_timestamp(time_point)
|
||||
# 获取前后各25条记录,总共50条
|
||||
messages_before = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=0,
|
||||
end_time=target_timestamp,
|
||||
limit=25,
|
||||
limit_mode="latest"
|
||||
)
|
||||
messages_after = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=target_timestamp,
|
||||
end_time=float('inf'),
|
||||
limit=25,
|
||||
limit_mode="earliest"
|
||||
)
|
||||
messages = messages_before + messages_after
|
||||
elif time_range:
|
||||
# 时间范围:搜索范围内最多50条记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
messages = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_timestamp,
|
||||
end_time=end_timestamp,
|
||||
limit=50,
|
||||
limit_mode="latest"
|
||||
)
|
||||
else:
|
||||
return "未指定时间参数"
|
||||
|
||||
if not messages:
|
||||
return "没有找到相关聊天记录"
|
||||
|
||||
# 将消息转换为可读格式
|
||||
chat_content = build_readable_messages(messages, timestamp_mode="relative")
|
||||
|
||||
if not chat_content.strip():
|
||||
return "聊天记录为空"
|
||||
|
||||
# 使用LLM分析聊天内容并回答问题
|
||||
try:
|
||||
llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="chat_history_analysis"
|
||||
)
|
||||
|
||||
analysis_prompt = f"""请根据以下聊天记录内容,回答用户的问题。请输出一段平文本,不要有特殊格式。
|
||||
聊天记录:
|
||||
{chat_content}
|
||||
|
||||
用户问题:{question}
|
||||
|
||||
请仔细分析聊天记录,提取与问题相关的信息,并给出准确的答案。如果聊天记录中没有相关信息,无法回答问题,输出"无有效信息"即可,不要输出其他内容。
|
||||
|
||||
答案:"""
|
||||
|
||||
print(f"analysis_prompt: {analysis_prompt}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
|
||||
prompt=analysis_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
if "无有效信息" in response:
|
||||
return ""
|
||||
|
||||
return response
|
||||
|
||||
except Exception as llm_error:
|
||||
logger.error(f"LLM分析聊天记录失败: {llm_error}")
|
||||
# 如果LLM分析失败,返回聊天内容的摘要
|
||||
if len(chat_content) > 300:
|
||||
chat_content = chat_content[:300] + "..."
|
||||
return chat_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从聊天记录获取答案失败: {e}")
|
||||
return ""
|
||||
@@ -1,53 +0,0 @@
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.memory.build_memory import GetMemoryTool
|
||||
|
||||
logger = get_logger("memory_build")
|
||||
|
||||
|
||||
@register_plugin
|
||||
class MemoryBuildPlugin(BasePlugin):
|
||||
"""记忆构建插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- GetMemory: 获取记忆
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "memory_build" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.1.1", description="配置文件版本"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((GetMemoryTool.get_tool_info(), GetMemoryTool))
|
||||
|
||||
return components
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Relation插件 (Relation Actions)",
|
||||
"version": "1.0.0",
|
||||
"description": "可以构建和管理关系",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["relation", "action", "built-in"],
|
||||
"categories": ["Relation"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "relation",
|
||||
"description": "发送关系"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
from typing import List, Tuple, Type, Any
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||
|
||||
logger = get_logger("relation_actions")
|
||||
|
||||
|
||||
class GetPersonInfoTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "get_person_info"
|
||||
description = "获取某个人的信息,包括印象,特征点,与用户的关系等等"
|
||||
parameters = [
|
||||
("person_name", ToolParamType.STRING, "需要获取信息的人的名称", True, None),
|
||||
("info_type", ToolParamType.STRING, "需要获取信息的类型", True, None),
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
person_name: str = function_args.get("person_name") # type: ignore
|
||||
info_type: str = function_args.get("info_type") # type: ignore
|
||||
|
||||
person = Person(person_name=person_name)
|
||||
if not person:
|
||||
return {"content": f"用户 {person_name} 不存在"}
|
||||
if not person.is_known:
|
||||
return {"content": f"不认识用户 {person_name}"}
|
||||
|
||||
relation_str = await person.build_relationship(info_type=info_type)
|
||||
|
||||
return {"content": relation_str}
|
||||
|
||||
|
||||
@register_plugin
|
||||
class RelationActionsPlugin(BasePlugin):
|
||||
"""关系动作插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- Reply: 回复动作
|
||||
- NoReply: 不回复动作
|
||||
- Emoji: 表情动作
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "relation_actions" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.0.2", description="配置文件版本"),
|
||||
},
|
||||
"components": {
|
||||
"relation_max_memory_num": ConfigField(type=int, default=10, description="关系记忆最大数量"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
# components.append((BuildRelationAction.get_action_info(), BuildRelationAction))
|
||||
# components.append((GetPersonInfoTool.get_tool_info(), GetPersonInfoTool))
|
||||
|
||||
return components
|
||||
@@ -1,230 +0,0 @@
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from typing import Tuple
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system import BaseAction, ActionActivationType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
以下是一些记忆条目的分类:
|
||||
----------------------
|
||||
{category_list}
|
||||
----------------------
|
||||
每一个分类条目类型代表了你对用户:"{person_name}"的印象的一个类别
|
||||
|
||||
现在,你有一条对 {person_name} 的新记忆内容:
|
||||
{memory_point}
|
||||
|
||||
请判断该记忆内容是否属于上述分类,请给出分类的名称。
|
||||
如果不属于上述分类,请输出一个合适的分类名称,对新记忆内容进行概括。要求分类名具有概括性。
|
||||
注意分类数一般不超过5个
|
||||
请严格用json格式输出,不要输出任何其他内容:
|
||||
{{
|
||||
"category": "分类名称"
|
||||
}} """,
|
||||
"relation_category",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
以下是有关{category}的现有记忆:
|
||||
----------------------
|
||||
{memory_list}
|
||||
----------------------
|
||||
|
||||
现在,你有一条对 {person_name} 的新记忆内容:
|
||||
{memory_point}
|
||||
|
||||
请判断该新记忆内容是否已经存在于现有记忆中,你可以对现有进行进行以下修改:
|
||||
注意,一般来说记忆内容不超过5个,且记忆文本不应太长
|
||||
|
||||
1.新增:当记忆内容不存在于现有记忆,且不存在矛盾,请用json格式输出:
|
||||
{{
|
||||
"new_memory": "需要新增的记忆内容"
|
||||
}}
|
||||
2.加深印象:如果这个新记忆已经存在于现有记忆中,在内容上与现有记忆类似,请用json格式输出:
|
||||
{{
|
||||
"memory_id": 1, #请输出你认为需要加深印象的,与新记忆内容类似的,已经存在的记忆的序号
|
||||
"integrate_memory": "加深后的记忆内容,合并内容类似的新记忆和旧记忆"
|
||||
}}
|
||||
3.整合:如果这个新记忆与现有记忆产生矛盾,请你结合其他记忆进行整合,用json格式输出:
|
||||
{{
|
||||
"memory_id": 1, #请输出你认为需要整合的,与新记忆存在矛盾的,已经存在的记忆的序号
|
||||
"integrate_memory": "整合后的记忆内容,合并内容矛盾的新记忆和旧记忆"
|
||||
}}
|
||||
|
||||
现在,请你根据情况选出合适的修改方式,并输出json,不要输出其他内容:
|
||||
""",
|
||||
"relation_category_update",
|
||||
)
|
||||
|
||||
|
||||
class BuildRelationAction(BaseAction):
|
||||
"""关系动作 - 构建关系"""
|
||||
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
parallel_action = True
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "build_relation"
|
||||
action_description = "了解对于某人的记忆,并添加到你对对方的印象中"
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"了解对于某人的记忆,并添加到你对对方的印象中",
|
||||
"对方与有明确提到有关其自身的事件",
|
||||
"对方有提到其个人信息,包括喜好,身份,等等",
|
||||
"对方希望你记住对方的信息",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行关系动作"""
|
||||
|
||||
try:
|
||||
# 1. 获取构建关系的原因
|
||||
impression = self.action_data.get("impression", "")
|
||||
logger.info(f"{self.log_prefix} 添加关系印象原因: {self.reasoning}")
|
||||
person_name = self.action_data.get("person_name", "")
|
||||
# 2. 获取目标用户信息
|
||||
person = Person(person_name=person_name)
|
||||
if not person.is_known:
|
||||
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
|
||||
return False, f"用户 {person_name} 不存在,跳过添加记忆"
|
||||
|
||||
person.last_know = time.time()
|
||||
person.know_times += 1
|
||||
person.sync_to_database()
|
||||
|
||||
category_list = person.get_all_category()
|
||||
if not category_list:
|
||||
category_list_str = "无分类"
|
||||
else:
|
||||
category_list_str = "\n".join(category_list)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_category",
|
||||
category_list=category_list_str,
|
||||
memory_point=impression,
|
||||
person_name=person.person_name,
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
|
||||
# 5. 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("utils_small") # 使用字典访问方式
|
||||
if not chat_model_config:
|
||||
logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils_small'模型配置"
|
||||
|
||||
success, category, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="relation.category"
|
||||
)
|
||||
|
||||
category_data = json.loads(repair_json(category))
|
||||
category = category_data.get("category", "")
|
||||
if not category:
|
||||
logger.warning(f"{self.log_prefix} LLM未给出分类,跳过添加记忆")
|
||||
return False, "LLM未给出分类,跳过添加记忆"
|
||||
|
||||
# 第二部分:更新记忆
|
||||
|
||||
memory_list = person.get_memory_list_by_category(category)
|
||||
if not memory_list:
|
||||
logger.info(f"{self.log_prefix} {person.person_name} 的 {category} 的记忆为空,进行创建")
|
||||
person.memory_points.append(f"{category}:{impression}:1.0")
|
||||
person.sync_to_database()
|
||||
|
||||
return True, f"未找到分类为{category}的记忆点,进行添加"
|
||||
|
||||
memory_list_str = ""
|
||||
memory_list_id = {}
|
||||
for id, memory in enumerate(memory_list, start=1):
|
||||
memory_content = get_memory_content_from_memory(memory)
|
||||
memory_list_str += f"{id}. {memory_content}\n"
|
||||
memory_list_id[id] = memory
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_category_update",
|
||||
category=category,
|
||||
memory_list=memory_list_str,
|
||||
memory_point=impression,
|
||||
person_name=person.person_name,
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
|
||||
chat_model_config = models.get("utils")
|
||||
success, update_memory, _, _ = await llm_api.generate_with_model(
|
||||
prompt,
|
||||
model_config=chat_model_config, # type: ignore
|
||||
request_type="relation.category.update", # type: ignore
|
||||
)
|
||||
|
||||
update_memory_data = json.loads(repair_json(update_memory))
|
||||
new_memory = update_memory_data.get("new_memory", "")
|
||||
memory_id = update_memory_data.get("memory_id", "")
|
||||
integrate_memory = update_memory_data.get("integrate_memory", "")
|
||||
|
||||
if new_memory:
|
||||
# 新记忆
|
||||
person.memory_points.append(f"{category}:{new_memory}:1.0")
|
||||
person.sync_to_database()
|
||||
|
||||
logger.info(f"{self.log_prefix} 为{person.person_name}新增记忆点: {new_memory}")
|
||||
|
||||
return True, f"为{person.person_name}新增记忆点: {new_memory}"
|
||||
elif memory_id and integrate_memory:
|
||||
# 现存或冲突记忆
|
||||
memory = memory_list_id[memory_id]
|
||||
memory_content = get_memory_content_from_memory(memory)
|
||||
del_count = person.del_memory(category, memory_content)
|
||||
|
||||
if del_count > 0:
|
||||
# logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
|
||||
|
||||
memory_weight = get_weight_from_memory(memory)
|
||||
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
|
||||
person.sync_to_database()
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
||||
)
|
||||
|
||||
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
|
||||
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
|
||||
|
||||
return True, "关系动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 关系构建动作执行失败: {e}", exc_info=True)
|
||||
return False, f"关系动作执行失败: {str(e)}"
|
||||
|
||||
|
||||
# 还缺一个关系的太多遗忘和对应的提取
|
||||
init_prompt()
|
||||
1
src/webui/__init__.py
Normal file
1
src/webui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""WebUI 模块"""
|
||||
93
src/webui/manager.py
Normal file
93
src/webui/manager.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
def setup_webui(mode: str = "production") -> bool:
|
||||
"""
|
||||
设置 WebUI
|
||||
|
||||
Args:
|
||||
mode: 运行模式,"development" 或 "production"
|
||||
|
||||
Returns:
|
||||
bool: 是否成功设置
|
||||
"""
|
||||
# 初始化 Token 管理器(确保 token 文件存在)
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
|
||||
if mode == "development":
|
||||
return setup_dev_mode()
|
||||
else:
|
||||
return setup_production_mode()
|
||||
|
||||
|
||||
def setup_dev_mode() -> bool:
|
||||
"""设置开发模式 - 仅启用 CORS,前端自行启动"""
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
return True
|
||||
|
||||
|
||||
def setup_production_mode() -> bool:
|
||||
"""设置生产模式 - 挂载静态文件"""
|
||||
try:
|
||||
from src.common.server import get_global_server
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
server = get_global_server()
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return False
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return False
|
||||
|
||||
# 挂载静态资源
|
||||
if (static_path / "assets").exists():
|
||||
server.app.mount(
|
||||
"/assets",
|
||||
StaticFiles(directory=str(static_path / "assets")),
|
||||
name="assets"
|
||||
)
|
||||
|
||||
# 处理 SPA 路由
|
||||
@server.app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用"""
|
||||
# API 路由不处理
|
||||
if full_path.startswith("api/"):
|
||||
return None
|
||||
|
||||
# 检查文件是否存在
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
# 返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html")
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port = os.getenv("PORT", "8000")
|
||||
logger.info("✅ WebUI 生产模式已挂载")
|
||||
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"挂载 WebUI 静态文件失败: {e}")
|
||||
return False
|
||||
154
src/webui/routes.py
Normal file
154
src/webui/routes.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""WebUI API 路由"""
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "service": "MaiBot WebUI"}
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(request: TokenVerifyRequest):
|
||||
"""
|
||||
验证访问令牌
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(
|
||||
valid=True,
|
||||
message="Token 验证成功"
|
||||
)
|
||||
else:
|
||||
return TokenVerifyResponse(
|
||||
valid=False,
|
||||
message="Token 无效或已过期"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
return TokenUpdateResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 更新失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 更新失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
return TokenRegenerateResponse(
|
||||
success=True,
|
||||
token=new_token,
|
||||
message="Token 已重新生成"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 重新生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||
|
||||
244
src/webui/token_manager.py
Normal file
244
src/webui/token_manager.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
WebUI Token 管理模块
|
||||
负责生成、保存、验证和更新访问令牌
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Token 管理器"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
if config_path is None:
|
||||
# 获取项目根目录 (src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
def _ensure_config(self):
|
||||
"""确保配置文件存在且包含有效的 token"""
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
|
||||
self._create_new_token()
|
||||
else:
|
||||
# 验证配置文件格式
|
||||
try:
|
||||
config = self._load_config()
|
||||
if not config.get("access_token"):
|
||||
logger.warning("WebUI 配置文件中缺少 access_token,正在重新生成")
|
||||
self._create_new_token()
|
||||
else:
|
||||
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||
self._create_new_token()
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: dict):
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"WebUI 配置已保存到: {self.config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WebUI 配置失败: {e}")
|
||||
raise
|
||||
|
||||
def _create_new_token(self) -> str:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp()
|
||||
}
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""获取当前有效的 token"""
|
||||
config = self._load_config()
|
||||
return config.get("access_token", "")
|
||||
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
# 验证新 token 格式
|
||||
is_valid, error_msg = self._validate_custom_token(new_token)
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
return False, f"更新失败: {str(e)}"
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
return self._create_new_token()
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
|
||||
# 全局单例
|
||||
_token_manager_instance: Optional[TokenManager] = None
|
||||
|
||||
|
||||
def get_token_manager() -> TokenManager:
|
||||
"""获取 TokenManager 单例"""
|
||||
global _token_manager_instance
|
||||
if _token_manager_instance is None:
|
||||
_token_manager_instance = TokenManager()
|
||||
return _token_manager_instance
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "6.19.2"
|
||||
version = "6.21.4"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -58,10 +58,6 @@ states = [
|
||||
state_probability = 0.3
|
||||
|
||||
[expression]
|
||||
# 表达方式模式
|
||||
mode = "classic"
|
||||
# 可选:classic经典模式,exp_model 表达模型模式,这个模式需要一定时间学习才会有比较好的效果
|
||||
|
||||
# 表达学习配置
|
||||
learning_list = [ # 表达学习配置列表,支持按聊天流配置
|
||||
["", "enable", "enable", "1.0"], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
@@ -89,11 +85,9 @@ expression_groups = [
|
||||
talk_value = 1 #聊天频率,越小越沉默,范围0-1
|
||||
mentioned_bot_reply = true # 是否启用提及必回复
|
||||
max_context_size = 30 # 上下文长度
|
||||
auto_chat_value = 1 # 自动聊天,越小,麦麦主动聊天的概率越低
|
||||
planner_smooth = 5 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-8,0为关闭,必须大于等于0
|
||||
planner_smooth = 2 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0
|
||||
|
||||
enable_talk_value_rules = true # 是否启用动态发言频率规则
|
||||
enable_auto_chat_value_rules = false # 是否启用动态自动聊天频率规则
|
||||
|
||||
# 动态发言频率规则:按时段/按chat_id调整 talk_value(优先匹配具体chat,再匹配全局)
|
||||
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
@@ -107,23 +101,13 @@ talk_value_rules = [
|
||||
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.3 },
|
||||
]
|
||||
|
||||
# 动态自动聊天频率规则:按时段/按chat_id调整 auto_chat_value(优先匹配具体chat,再匹配全局)
|
||||
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
# 说明:
|
||||
# - target 为空字符串表示全局;type 为 group/private,例如:"qq:1919810:group" 或 "qq:114514:private";
|
||||
# - 支持跨夜区间,例如 "23:00-02:00";数值范围建议 0-1。
|
||||
auto_chat_value_rules = [
|
||||
{ target = "", time = "00:00-08:59", value = 0.3 },
|
||||
{ target = "", time = "09:00-22:59", value = 1.0 },
|
||||
{ target = "qq:1919810:group", time = "20:00-23:59", value = 0.8 },
|
||||
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.5 },
|
||||
]
|
||||
include_planner_reasoning = false # 是否将planner推理加入replyer,默认关闭(不加入)
|
||||
|
||||
[memory]
|
||||
max_memory_number = 100 # 记忆最大数量
|
||||
max_memory_size = 2048 # 记忆最大大小
|
||||
memory_build_frequency = 1 # 记忆构建频率
|
||||
max_agent_iterations = 5 # 记忆思考深度(最低为1(不深入思考))
|
||||
|
||||
[jargon]
|
||||
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
|
||||
|
||||
[tool]
|
||||
enable_tool = true # 是否启用工具
|
||||
@@ -161,6 +145,8 @@ ban_msgs_regex = [
|
||||
|
||||
[lpmm_knowledge] # lpmm知识库配置
|
||||
enable = false # 是否启用lpmm知识库
|
||||
lpmm_mode = "agent"
|
||||
# 可选:classic经典模式,agent 模式,结合最新的记忆一同使用
|
||||
rag_synonym_search_top_k = 10 # 同义词搜索TopK
|
||||
rag_synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
|
||||
info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5
|
||||
@@ -206,6 +192,7 @@ enable = true # 是否启用回复分割器
|
||||
max_length = 512 # 回复允许的最大长度
|
||||
max_sentence_num = 8 # 回复允许的最大句子数
|
||||
enable_kaomoji_protection = false # 是否启用颜文字保护
|
||||
enable_overflow_return_all = false # 是否在句子数量超出回复允许的最大句子数时一次性返回全部内容
|
||||
|
||||
[log]
|
||||
date_style = "m-d H:i:s" # 日期格式
|
||||
@@ -223,6 +210,7 @@ library_log_levels = { aiohttp = "WARNING"} # 设置特定库的日志级别
|
||||
show_prompt = false # 是否显示prompt
|
||||
show_replyer_prompt = false # 是否显示回复器prompt
|
||||
show_replyer_reasoning = false # 是否显示回复器推理
|
||||
show_jargon_prompt = false # 是否显示jargon相关提示词
|
||||
|
||||
[maim_message]
|
||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||
@@ -239,9 +227,18 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效
|
||||
enable = true
|
||||
|
||||
[experimental] #实验性功能
|
||||
none = false # 暂无
|
||||
# 为指定聊天添加额外的prompt配置
|
||||
# 格式: ["platform:id:type:prompt内容", ...]
|
||||
# 示例:
|
||||
# chat_prompts = [
|
||||
# "qq:114514:group:这是一个摄影群,你精通摄影知识",
|
||||
# "qq:19198:group:这是一个二次元交流群",
|
||||
# "qq:114514:private:这是你与好朋友的私聊"
|
||||
# ]
|
||||
chat_prompts = []
|
||||
|
||||
|
||||
#此系统暂时移除,无效配置
|
||||
[relationship]
|
||||
enable_relationship = true # 是否启用关系系统
|
||||
enable_relationship = true # 是否启用关系系统
|
||||
|
||||
|
||||
@@ -103,23 +103,24 @@ price_in = 0
|
||||
price_out = 0
|
||||
|
||||
|
||||
|
||||
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,麦麦的情绪变化等,是麦麦必须的模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048 # 最大输出token数
|
||||
|
||||
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
model_list = ["qwen3-30b"]
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 2048
|
||||
|
||||
[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型
|
||||
model_list = ["qwen3-30b"]
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
|
||||
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
|
||||
model_list = ["siliconflow-deepseek-v3.2-think","siliconflow-deepseek-r1","siliconflow-deepseek-v3.2"]
|
||||
model_list = ["siliconflow-deepseek-v3.2-think","siliconflow-glm-4.6-think","siliconflow-glm-4.6"]
|
||||
temperature = 0.3 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048
|
||||
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
HOST=127.0.0.1
|
||||
PORT=8000
|
||||
PORT=8000
|
||||
|
||||
# WebUI 配置
|
||||
# WEBUI_ENABLED=true
|
||||
# WEBUI_MODE=development # 开发模式(需手动启动前端: cd webui && npm run dev,端口 7999)
|
||||
# WEBUI_MODE=production # 生产模式(需先构建前端: cd webui && npm run build)
|
||||
@@ -1,391 +0,0 @@
|
||||
"""
|
||||
StyleLearner 数据库测试脚本
|
||||
使用数据库中的expression数据测试style_learner功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Tuple
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.common.database.database_model import Expression, db
|
||||
from src.express.style_learner import StyleLearnerManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("style_learner_test")
|
||||
|
||||
|
||||
class StyleLearnerDatabaseTest:
|
||||
"""使用数据库数据测试StyleLearner"""
|
||||
|
||||
def __init__(self, random_state: int = 42):
|
||||
self.random_state = random_state
|
||||
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
|
||||
|
||||
# 测试结果
|
||||
self.test_results = {
|
||||
"total_samples": 0,
|
||||
"train_samples": 0,
|
||||
"test_samples": 0,
|
||||
"unique_styles": 0,
|
||||
"unique_chat_ids": 0,
|
||||
"accuracy": 0.0,
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"predictions": [],
|
||||
"ground_truth": [],
|
||||
"model_save_success": False,
|
||||
"model_save_path": self.manager.model_save_path
|
||||
}
|
||||
|
||||
def load_data_from_database(self) -> List[Dict]:
|
||||
"""
|
||||
从数据库加载expression数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含up_content, style, chat_id的数据列表
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
# 查询所有expression数据
|
||||
expressions = Expression.select().where(
|
||||
(Expression.up_content.is_null(False)) &
|
||||
(Expression.style.is_null(False)) &
|
||||
(Expression.chat_id.is_null(False)) &
|
||||
(Expression.type == "style")
|
||||
)
|
||||
|
||||
data = []
|
||||
for expr in expressions:
|
||||
if expr.up_content and expr.style and expr.chat_id:
|
||||
data.append({
|
||||
"up_content": expr.up_content,
|
||||
"style": expr.style,
|
||||
"chat_id": expr.chat_id,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"context": expr.context,
|
||||
"situation": expr.situation
|
||||
})
|
||||
|
||||
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载数据失败: {e}")
|
||||
return []
|
||||
|
||||
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 预处理后的数据
|
||||
"""
|
||||
# 过滤掉空值或过短的数据
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
up_content = item["up_content"].strip()
|
||||
style = item["style"].strip()
|
||||
|
||||
if len(up_content) >= 2 and len(style) >= 2:
|
||||
filtered_data.append({
|
||||
"up_content": up_content,
|
||||
"style": style,
|
||||
"chat_id": item["chat_id"],
|
||||
"last_active_time": item["last_active_time"],
|
||||
"context": item["context"],
|
||||
"situation": item["situation"]
|
||||
})
|
||||
|
||||
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
|
||||
return filtered_data
|
||||
|
||||
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
分割训练集和测试集
|
||||
训练集使用所有数据,测试集从训练集中随机选择5%
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
|
||||
"""
|
||||
# 训练集使用所有数据
|
||||
train_data = data.copy()
|
||||
|
||||
# 测试集从训练集中随机选择5%
|
||||
test_size = 0.05 # 5%
|
||||
test_data = train_test_split(
|
||||
train_data, test_size=test_size, random_state=self.random_state
|
||||
)[1] # 只取测试集部分
|
||||
|
||||
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条")
|
||||
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
|
||||
return train_data, test_data
|
||||
|
||||
def train_model(self, train_data: List[Dict]) -> None:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
"""
|
||||
logger.info("开始训练模型...")
|
||||
|
||||
# 统计信息
|
||||
chat_ids = set()
|
||||
styles = set()
|
||||
|
||||
for item in train_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
style = item["style"]
|
||||
|
||||
chat_ids.add(chat_id)
|
||||
styles.add(style)
|
||||
|
||||
# 学习映射关系
|
||||
success = self.manager.learn_mapping(chat_id, up_content, style)
|
||||
if not success:
|
||||
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
|
||||
|
||||
self.test_results["train_samples"] = len(train_data)
|
||||
self.test_results["unique_styles"] = len(styles)
|
||||
self.test_results["unique_chat_ids"] = len(chat_ids)
|
||||
|
||||
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
|
||||
|
||||
# 保存训练好的模型
|
||||
logger.info("开始保存训练好的模型...")
|
||||
save_success = self.manager.save_all_models()
|
||||
self.test_results["model_save_success"] = save_success
|
||||
|
||||
if save_success:
|
||||
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
|
||||
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
|
||||
else:
|
||||
logger.warning("部分模型保存失败")
|
||||
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
|
||||
|
||||
def test_model(self, test_data: List[Dict]) -> None:
|
||||
"""
|
||||
测试模型
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
"""
|
||||
logger.info("开始测试模型...")
|
||||
|
||||
predictions = []
|
||||
ground_truth = []
|
||||
correct_predictions = 0
|
||||
|
||||
for item in test_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
true_style = item["style"]
|
||||
|
||||
# 预测风格
|
||||
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
|
||||
|
||||
predictions.append(predicted_style)
|
||||
ground_truth.append(true_style)
|
||||
|
||||
# 检查预测是否正确
|
||||
if predicted_style == true_style:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录详细预测结果
|
||||
self.test_results["predictions"].append({
|
||||
"chat_id": chat_id,
|
||||
"up_content": up_content,
|
||||
"true_style": true_style,
|
||||
"predicted_style": predicted_style,
|
||||
"scores": scores
|
||||
})
|
||||
|
||||
# 计算准确率
|
||||
accuracy = correct_predictions / len(test_data) if test_data else 0
|
||||
|
||||
# 计算其他指标(需要处理None值)
|
||||
valid_predictions = [p for p in predictions if p is not None]
|
||||
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
|
||||
|
||||
if valid_predictions:
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
|
||||
)
|
||||
else:
|
||||
precision = recall = f1 = 0.0
|
||||
|
||||
self.test_results["test_samples"] = len(test_data)
|
||||
self.test_results["accuracy"] = accuracy
|
||||
self.test_results["precision"] = precision
|
||||
self.test_results["recall"] = recall
|
||||
self.test_results["f1_score"] = f1
|
||||
|
||||
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
|
||||
|
||||
def analyze_results(self) -> None:
|
||||
"""分析测试结果"""
|
||||
logger.info("=== 测试结果分析 ===")
|
||||
|
||||
print("\n📊 数据统计:")
|
||||
print(f" 总样本数: {self.test_results['total_samples']}")
|
||||
print(f" 训练样本数: {self.test_results['train_samples']}")
|
||||
print(f" 测试样本数: {self.test_results['test_samples']}")
|
||||
print(f" 唯一风格数: {self.test_results['unique_styles']}")
|
||||
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
|
||||
|
||||
print("\n🎯 模型性能:")
|
||||
print(f" 准确率: {self.test_results['accuracy']:.4f}")
|
||||
print(f" 精确率: {self.test_results['precision']:.4f}")
|
||||
print(f" 召回率: {self.test_results['recall']:.4f}")
|
||||
print(f" F1分数: {self.test_results['f1_score']:.4f}")
|
||||
|
||||
print("\n💾 模型保存:")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
print(f" 保存状态: {save_status}")
|
||||
print(f" 保存路径: {self.test_results['model_save_path']}")
|
||||
|
||||
# 分析各聊天室的性能
|
||||
chat_performance = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
chat_id = pred["chat_id"]
|
||||
if chat_id not in chat_performance:
|
||||
chat_performance[chat_id] = {"correct": 0, "total": 0}
|
||||
|
||||
chat_performance[chat_id]["total"] += 1
|
||||
if pred["predicted_style"] == pred["true_style"]:
|
||||
chat_performance[chat_id]["correct"] += 1
|
||||
|
||||
print("\n📈 各聊天室性能:")
|
||||
for chat_id, perf in chat_performance.items():
|
||||
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
|
||||
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
|
||||
|
||||
# 分析风格分布
|
||||
style_counts = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
style = pred["true_style"]
|
||||
style_counts[style] = style_counts.get(style, 0) + 1
|
||||
|
||||
print("\n🎨 风格分布 (前10个):")
|
||||
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
for style, count in sorted_styles[:10]:
|
||||
print(f" {style}: {count} 次")
|
||||
|
||||
def show_sample_predictions(self, num_samples: int = 10) -> None:
|
||||
"""显示样本预测结果"""
|
||||
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
|
||||
|
||||
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
print(f"\n {i+1}. {status}")
|
||||
print(f" 聊天室: {pred['chat_id']}")
|
||||
print(f" 输入内容: {pred['up_content']}")
|
||||
print(f" 真实风格: {pred['true_style']}")
|
||||
print(f" 预测风格: {pred['predicted_style']}")
|
||||
if pred["scores"]:
|
||||
top_scores = dict(list(pred["scores"].items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
|
||||
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
|
||||
"""保存测试结果到文件"""
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("StyleLearner 数据库测试结果\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
f.write("数据统计:\n")
|
||||
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
|
||||
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
|
||||
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
|
||||
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
|
||||
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
|
||||
|
||||
f.write("模型性能:\n")
|
||||
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
|
||||
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
|
||||
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
|
||||
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
|
||||
|
||||
f.write("模型保存:\n")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
f.write(f" 保存状态: {save_status}\n")
|
||||
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
|
||||
|
||||
f.write("详细预测结果:\n")
|
||||
for i, pred in enumerate(self.test_results["predictions"]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
|
||||
|
||||
logger.info(f"测试结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存测试结果失败: {e}")
|
||||
|
||||
def run_test(self) -> None:
|
||||
"""运行完整测试"""
|
||||
logger.info("开始StyleLearner数据库测试...")
|
||||
|
||||
# 1. 加载数据
|
||||
raw_data = self.load_data_from_database()
|
||||
if not raw_data:
|
||||
logger.error("没有加载到数据,测试终止")
|
||||
return
|
||||
|
||||
# 2. 数据预处理
|
||||
processed_data = self.preprocess_data(raw_data)
|
||||
if not processed_data:
|
||||
logger.error("预处理后没有数据,测试终止")
|
||||
return
|
||||
|
||||
self.test_results["total_samples"] = len(processed_data)
|
||||
|
||||
# 3. 分割数据
|
||||
train_data, test_data = self.split_data(processed_data)
|
||||
|
||||
# 4. 训练模型
|
||||
self.train_model(train_data)
|
||||
|
||||
# 5. 测试模型
|
||||
self.test_model(test_data)
|
||||
|
||||
# 6. 分析结果
|
||||
self.analyze_results()
|
||||
|
||||
# 7. 显示样本预测
|
||||
self.show_sample_predictions(10)
|
||||
|
||||
# 8. 保存结果
|
||||
self.save_results()
|
||||
|
||||
logger.info("StyleLearner数据库测试完成!")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("StyleLearner 数据库测试脚本")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建测试实例
|
||||
test = StyleLearnerDatabaseTest(random_state=42)
|
||||
|
||||
# 运行测试
|
||||
test.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
view_pkl.py
76
view_pkl.py
@@ -1,76 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
查看 .pkl 文件内容的工具脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
def view_pkl_file(file_path):
|
||||
"""查看 pkl 文件内容"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print(f"📊 数据类型: {type(data)}")
|
||||
print("=" * 50)
|
||||
|
||||
if isinstance(data, dict):
|
||||
print("🔑 字典键:")
|
||||
for key in data.keys():
|
||||
print(f" - {key}: {type(data[key])}")
|
||||
print()
|
||||
|
||||
print("📋 详细内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
elif isinstance(data, list):
|
||||
print(f"📝 列表长度: {len(data)}")
|
||||
if data:
|
||||
print(f"📊 第一个元素类型: {type(data[0])}")
|
||||
print("📋 前几个元素:")
|
||||
for i, item in enumerate(data[:3]):
|
||||
print(f" [{i}]: {item}")
|
||||
|
||||
else:
|
||||
print("📋 内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
|
||||
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
|
||||
print("\n" + "="*50)
|
||||
print("🔍 详细词汇统计 (token_counts):")
|
||||
token_counts = data['nb']['token_counts']
|
||||
for style_id, tokens in token_counts.items():
|
||||
print(f"\n📝 {style_id}:")
|
||||
if tokens:
|
||||
# 按词频排序显示前10个词
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
for word, count in sorted_tokens[:10]:
|
||||
print(f" '{word}': {count}")
|
||||
if len(sorted_tokens) > 10:
|
||||
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_pkl.py <pkl文件路径>")
|
||||
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_pkl_file(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
专门查看 expressor.pkl 文件中 token_counts 的脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
|
||||
def view_token_counts(file_path):
|
||||
"""查看 expressor.pkl 文件中的词汇统计"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print("=" * 60)
|
||||
|
||||
if 'nb' not in data or 'token_counts' not in data['nb']:
|
||||
print("❌ 这不是一个 expressor 模型文件")
|
||||
return
|
||||
|
||||
token_counts = data['nb']['token_counts']
|
||||
candidates = data.get('candidates', {})
|
||||
|
||||
print(f"🎯 找到 {len(token_counts)} 个风格")
|
||||
print("=" * 60)
|
||||
|
||||
for style_id, tokens in token_counts.items():
|
||||
style_text = candidates.get(style_id, "未知风格")
|
||||
print(f"\n📝 {style_id}: {style_text}")
|
||||
print(f"📊 词汇数量: {len(tokens)}")
|
||||
|
||||
if tokens:
|
||||
# 按词频排序
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
print("🔤 词汇统计 (按频率排序):")
|
||||
for i, (word, count) in enumerate(sorted_tokens):
|
||||
print(f" {i+1:2d}. '{word}': {count}")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
print("-" * 40)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_tokens.py <expressor.pkl文件路径>")
|
||||
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_token_counts(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user