Merge branch 'Mai-with-u:dev' into dev
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -323,7 +323,7 @@ run_pet.bat
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/deep_think
|
||||
!/plugins/MaiFrequencyControl
|
||||
!/plugins/ChatFrequency/
|
||||
!/plugins/__init__.py
|
||||
|
||||
config.toml
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.10.3** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.11.0** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
|
||||
@@ -1,21 +1,40 @@
|
||||
# Changelog
|
||||
|
||||
## [0.11.0] - 2025-9-22
|
||||
## [0.11.1] - 2025-11-4
|
||||
### 功能更改和修复
|
||||
- 记忆现在能够被遗忘,并且拥有更好的合并
|
||||
- 修复部分llm请求问题
|
||||
- 优化记忆提取
|
||||
- 提供replyer的细节debug配置
|
||||
|
||||
|
||||
## [0.11.0] - 2025-10-27
|
||||
### 🌟 主要功能更改
|
||||
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
|
||||
- 麦麦好奇功能,麦麦会自主提出问题
|
||||
- 添加deepthink插件(默认关闭),让麦麦可以深度思考一些问题
|
||||
- 重构记忆系统,新的记忆系统更可靠,双通道查询,可以查询文本记忆和过去聊天记录
|
||||
- 主动发言功能,麦麦会自主提出问题(可精细调控频率)
|
||||
- 支持多重人格设定,可以随机切换成不同状态
|
||||
- 新增表达方式学习新模式,更少的占用
|
||||
- 添加表情包管理插件
|
||||
- 现可更好的支持多平台
|
||||
- 添加deepthink插件(默认关闭),让麦麦可以深度思考一些问题
|
||||
- 现已内置BetterFrequency插件
|
||||
|
||||
|
||||
### 细节功能更改
|
||||
- 修复配置文件转义问题
|
||||
- 情绪系统现在可以由配置文件控制开关
|
||||
- 修复平行动作控制失效的问题
|
||||
- 添加planner防抖,防止短时间快速消耗token
|
||||
- 优化planner历史状态记录
|
||||
- 修复吞字问题
|
||||
- 修复意外换行问题
|
||||
- 移除VLM的token限制
|
||||
- 为tool工具添加chat_id字段
|
||||
- 更新依赖表
|
||||
- 修复负载均衡
|
||||
- 优化了对gemini和不同模型的支持
|
||||
- 现统计模型名而不是模型标识符
|
||||
- 修改默认推荐模型为ds v3.2
|
||||
- 优化了对gemini和不同模型的支持,优化了对gemini搜索的支持
|
||||
|
||||
## [0.10.3] - 2025-9-22
|
||||
### 🌟 主要功能更改
|
||||
|
||||
50
plugins/ChatFrequency/_manifest.json
Normal file
50
plugins/ChatFrequency/_manifest.json
Normal file
@@ -0,0 +1,50 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "发言频率控制插件|BetterFrequency Plugin",
|
||||
"version": "2.0.0",
|
||||
"description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供命令",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.3"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"keywords": ["frequency", "control", "talk_frequency", "plugin", "shortcut"],
|
||||
"categories": ["Chat", "Frequency", "Control"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "frequency",
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "set_talk_frequency",
|
||||
"description": "设置当前聊天的talk_frequency调整值",
|
||||
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "show_frequency",
|
||||
"description": "显示当前聊天的频率控制状态",
|
||||
"pattern": "/chat show 或 /chat s"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"设置talk_frequency调整值",
|
||||
"调整当前聊天的发言频率",
|
||||
"显示当前频率控制状态",
|
||||
"实时频率控制调整",
|
||||
"命令执行反馈(不保存消息)",
|
||||
"支持完整命令和简化命令",
|
||||
"快速操作支持"
|
||||
]
|
||||
}
|
||||
}
|
||||
150
plugins/ChatFrequency/plugin.py
Normal file
150
plugins/ChatFrequency/plugin.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import List, Tuple, Type, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseCommand,
|
||||
ComponentInfo,
|
||||
ConfigField
|
||||
)
|
||||
from src.plugin_system.apis import send_api, frequency_api
|
||||
|
||||
class SetTalkFrequencyCommand(BaseCommand):
|
||||
"""设置当前聊天的talk_frequency值"""
|
||||
command_name = "set_talk_frequency"
|
||||
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
try:
|
||||
# 获取命令参数 - 使用命名捕获组
|
||||
if not self.matched_groups or "value" not in self.matched_groups:
|
||||
return False, "命令格式错误", False
|
||||
|
||||
value_str = self.matched_groups["value"]
|
||||
if not value_str:
|
||||
return False, "无法获取数值参数", False
|
||||
|
||||
value = float(value_str)
|
||||
|
||||
# 获取聊天流ID
|
||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
||||
return False, "无法获取聊天流信息", False
|
||||
|
||||
chat_id = self.message.chat_stream.stream_id
|
||||
|
||||
# 设置talk_frequency
|
||||
frequency_api.set_talk_frequency_adjust(chat_id, value)
|
||||
|
||||
final_value = frequency_api.get_current_talk_value(chat_id)
|
||||
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
|
||||
base_value = final_value / adjust_value
|
||||
|
||||
# 发送反馈消息(不保存到数据库)
|
||||
await send_api.text_to_stream(
|
||||
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
|
||||
chat_id,
|
||||
storage_message=False
|
||||
)
|
||||
|
||||
return True, None, False
|
||||
|
||||
except ValueError:
|
||||
error_msg = "数值格式错误,请输入有效的数字"
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
except Exception as e:
|
||||
error_msg = f"设置talk_frequency失败: {str(e)}"
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
|
||||
|
||||
class ShowFrequencyCommand(BaseCommand):
|
||||
"""显示当前聊天的频率控制状态"""
|
||||
command_name = "show_frequency"
|
||||
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
||||
command_pattern = r"^/chat\s+(?:show|s)$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
try:
|
||||
# 获取聊天流ID
|
||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
||||
return False, "无法获取聊天流信息", False
|
||||
|
||||
chat_id = self.message.chat_stream.stream_id
|
||||
|
||||
# 获取当前频率控制状态
|
||||
current_talk_frequency = frequency_api.get_current_talk_value(chat_id)
|
||||
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
|
||||
base_value = current_talk_frequency / talk_frequency_adjust
|
||||
|
||||
# 构建显示消息
|
||||
status_msg = f"""当前聊天频率控制状态
|
||||
Talk Value (发言频率):
|
||||
|
||||
• 基础值: {base_value:.2f}
|
||||
• 发言频率调整: {talk_frequency_adjust:.2f}
|
||||
• 当前值: {current_talk_frequency:.2f}
|
||||
|
||||
使用命令:
|
||||
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
|
||||
• /chat show 或 /chat s - 显示当前状态"""
|
||||
|
||||
# 发送状态消息(不保存到数据库)
|
||||
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
|
||||
|
||||
return True, None, False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取频率控制状态失败: {str(e)}"
|
||||
# 使用内置的send_text方法发送错误消息
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@register_plugin
|
||||
class BetterFrequencyPlugin(BasePlugin):
|
||||
"""BetterFrequency插件 - 控制聊天频率的插件"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "better_frequency_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[str] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"frequency": "频率控制配置",
|
||||
"features": "功能开关配置"
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.2", description="插件版本"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
},
|
||||
"frequency": {
|
||||
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
|
||||
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
|
||||
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
components = []
|
||||
|
||||
# 根据配置决定是否注册命令组件
|
||||
if self.config.get("features", {}).get("enable_commands", True):
|
||||
components.extend([
|
||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||
])
|
||||
|
||||
|
||||
return components
|
||||
@@ -177,6 +177,7 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
logger.info('已处理"%s"', doc_item.get("passage", ""))
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
|
||||
@@ -16,11 +16,12 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -236,7 +237,8 @@ class BrainChatting:
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
@@ -414,7 +416,9 @@ class BrainChatting:
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.run()
|
||||
# BaseAction 定义了异步方法 execute() 作为统一执行入口
|
||||
# 这里调用 execute() 以兼容所有 Action 实现
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
|
||||
@@ -249,6 +249,8 @@ class BrainPlanner:
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 提及/被@ 的处理由心流或统一判定模块驱动;Planner 不再做硬编码强制回复
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
|
||||
@@ -379,7 +379,7 @@ class EmojiManager:
|
||||
|
||||
self._scan_task = None
|
||||
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
|
||||
self.llm_emotion_judge = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="emoji"
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
@@ -940,16 +940,16 @@ class EmojiManager:
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
|
||||
prompt, image_base64, "jpg", temperature=0.5
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||
prompt, image_base64, image_format, temperature=0.5
|
||||
)
|
||||
|
||||
# 审核表情包
|
||||
@@ -970,13 +970,14 @@ class EmojiManager:
|
||||
|
||||
# 第二步:LLM情感分析 - 基于详细描述生成情感标签列表
|
||||
emotion_prompt = f"""
|
||||
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||
这是一个基于这个表情包的描述:'{description}'
|
||||
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||
这是一个聊天场景中的表情包描述:'{description}'
|
||||
|
||||
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||
"""
|
||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
|
||||
emotion_prompt, temperature=0.7, max_tokens=600
|
||||
emotion_prompt, temperature=0.7, max_tokens=256
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用累积权重的方法进行加权抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
weights_copy = weights.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 选择一个元素
|
||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
weights_copy.pop(chosen_idx)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
return selected_style
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
|
||||
# logger.info(f"模型名称: {model_name}")
|
||||
# logger.info(f"LLM返回结果: {content}")
|
||||
# if reasoning_content:
|
||||
# logger.info(f"LLM推理: {reasoning_content}")
|
||||
# else:
|
||||
# logger.info(f"LLM推理: 无")
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
@@ -1,5 +1,37 @@
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import frequency_api
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""{name_block}
|
||||
{time_block}
|
||||
你现在正在聊天,请根据下面的聊天记录判断是否有用户觉得你的发言过于频繁或者发言过少
|
||||
{message_str}
|
||||
|
||||
如果用户觉得你的发言过于频繁,请输出"过于频繁",否则输出"正常"
|
||||
如果用户觉得你的发言过少,请输出"过少",否则输出"正常"
|
||||
**你只能输出以下三个词之一,不要输出任何其他文字、解释或标点:**
|
||||
- 正常
|
||||
- 过于频繁
|
||||
- 过少
|
||||
""",
|
||||
"frequency_adjust_prompt",
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
@@ -8,6 +40,11 @@ class FrequencyControl:
|
||||
self.chat_id = chat_id
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
|
||||
self.last_frequency_adjust_time: float = 0.0
|
||||
self.frequency_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
|
||||
)
|
||||
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
@@ -16,7 +53,72 @@ class FrequencyControl:
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
|
||||
async def trigger_frequency_adjust(self) -> None:
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if time.time() - self.last_frequency_adjust_time < 120 or len(msg_list) <= 5:
|
||||
return
|
||||
else:
|
||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_frequency_adjust_time,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
message_str = build_readable_messages(
|
||||
new_msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"frequency_adjust_prompt",
|
||||
name_block=name_block,
|
||||
time_block=time_block,
|
||||
message_str=message_str,
|
||||
)
|
||||
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
|
||||
prompt,
|
||||
)
|
||||
|
||||
# logger.info(f"频率调整 prompt: {prompt}")
|
||||
# logger.info(f"频率调整 response: {response}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"频率调整 prompt: {prompt}")
|
||||
logger.info(f"频率调整 response: {response}")
|
||||
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
|
||||
|
||||
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
|
||||
|
||||
# LLM依然输出过多内容时取消本次调整。合法最多4个字,但有的模型可能会输出一些markdown换行符等,需要长度宽限
|
||||
if len(response) < 20:
|
||||
if "过于频繁" in response:
|
||||
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 0.8))
|
||||
elif "过少" in response:
|
||||
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
|
||||
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 1.2))
|
||||
self.last_frequency_adjust_time = time.time()
|
||||
else:
|
||||
logger.info(f"频率调整:response不符合要求,取消本次调整")
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
@@ -41,6 +143,7 @@ class FrequencyControlManager:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
init_prompt()
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
|
||||
@@ -18,10 +18,11 @@ from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.memory_system.question_maker import QuestionMaker
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.memory_system.curious import check_and_make_question
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
@@ -184,14 +185,12 @@ class HeartFChatting:
|
||||
)
|
||||
|
||||
question_probability = 0
|
||||
if time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.01
|
||||
elif time.time() - self.last_active_time > 1200:
|
||||
question_probability = 0.005
|
||||
elif time.time() - self.last_active_time > 600:
|
||||
question_probability = 0.001
|
||||
else:
|
||||
if time.time() - self.last_active_time > 7200:
|
||||
question_probability = 0.0003
|
||||
elif time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.0001
|
||||
else:
|
||||
question_probability = 0.00003
|
||||
|
||||
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
|
||||
|
||||
@@ -210,7 +209,7 @@ class HeartFChatting:
|
||||
if question:
|
||||
logger.info(f"{self.log_prefix} 问题: {question}")
|
||||
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
|
||||
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
|
||||
await self._lift_question_reply(question,context,thinking_id)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 无问题")
|
||||
# self.end_cycle(cycle_timers, thinking_id)
|
||||
@@ -331,9 +330,12 @@ class HeartFChatting:
|
||||
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust())
|
||||
|
||||
await global_memory_chest.build_running_content(chat_id=self.stream_id)
|
||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
asyncio.create_task(check_and_make_question(self.stream_id, recent_messages_list))
|
||||
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
@@ -551,8 +553,8 @@ class HeartFChatting:
|
||||
traceback.print_exc()
|
||||
return False, ""
|
||||
|
||||
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
|
||||
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
|
||||
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
new_msg = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -17,31 +17,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
|
||||
"""
|
||||
if message.is_picid or message.is_emoji:
|
||||
return 0.0, []
|
||||
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# interested_rate = 0.0
|
||||
keywords = []
|
||||
|
||||
message.interest_value = 1
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
return 1, keywords
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
@@ -67,12 +42,16 @@ class HeartFCMessageReceiver:
|
||||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
_, keywords = await _calculate_interest(message)
|
||||
# 2. 计算at信息
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
|
||||
@@ -149,8 +149,51 @@ class ChatBot:
|
||||
async def handle_notice_message(self, message: MessageRecv):
|
||||
if message.message_info.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.info("notice消息")
|
||||
print(message)
|
||||
logger.debug("notice消息")
|
||||
try:
|
||||
seg = message.message_segment
|
||||
mi = message.message_info
|
||||
sub_type = None
|
||||
scene = None
|
||||
msg_id = None
|
||||
recalled_id = None
|
||||
|
||||
if getattr(seg, "type", None) == "notify" and isinstance(getattr(seg, "data", None), dict):
|
||||
sub_type = seg.data.get("sub_type")
|
||||
scene = seg.data.get("scene")
|
||||
msg_id = seg.data.get("message_id")
|
||||
recalled = seg.data.get("recalled_user_info") or {}
|
||||
if isinstance(recalled, dict):
|
||||
recalled_id = recalled.get("user_id")
|
||||
|
||||
op = mi.user_info
|
||||
gid = mi.group_info.group_id if mi.group_info else None
|
||||
|
||||
# 撤回事件打印;无法获取被撤回者则省略
|
||||
if sub_type == "recall":
|
||||
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None))
|
||||
recalled_name = None
|
||||
try:
|
||||
if isinstance(recalled, dict):
|
||||
recalled_name = (
|
||||
recalled.get("user_cardname")
|
||||
or recalled.get("user_nickname")
|
||||
or str(recalled.get("user_id"))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if recalled_name and str(recalled_id) != str(getattr(op, "user_id", None)):
|
||||
logger.info(f"{op_name} 撤回了 {recalled_name} 的消息")
|
||||
else:
|
||||
logger.info(f"{op_name} 撤回了消息")
|
||||
else:
|
||||
logger.debug(
|
||||
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) "
|
||||
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
|
||||
)
|
||||
except Exception:
|
||||
logger.info("[notice] (简略) 收到一条通知事件")
|
||||
|
||||
return True
|
||||
|
||||
@@ -215,12 +258,13 @@ class ChatBot:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
if await self.handle_notice_message(message):
|
||||
# return
|
||||
pass
|
||||
return
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text,
|
||||
|
||||
@@ -130,6 +130,16 @@ class MessageRecv(Message):
|
||||
self.key_words = []
|
||||
self.key_words_lite = []
|
||||
|
||||
# 兼容适配器通过 additional_config 传入的 @ 标记
|
||||
try:
|
||||
msg_info_dict = message_dict.get("message_info", {})
|
||||
add_cfg = msg_info_dict.get("additional_config") or {}
|
||||
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
|
||||
# 标记为被提及,提高后续回复优先级
|
||||
self.is_mentioned = True # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
@@ -69,6 +67,7 @@ no_reply_until_call
|
||||
动作描述:
|
||||
保持沉默,直到有人直接叫你的名字
|
||||
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
|
||||
当你频繁选择no_reply时使用,表示话题暂时与你无关
|
||||
{{
|
||||
"action": "no_reply_until_call",
|
||||
}}
|
||||
@@ -418,7 +417,6 @@ class ActionPlanner:
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
# sourcery skip: use-join
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
@@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
@@ -133,7 +133,13 @@ class DefaultReplyer:
|
||||
|
||||
try:
|
||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
# logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
logger.info(f"replyer生成内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning:
|
||||
logger.info(f"replyer生成推理:\n{reasoning_content}")
|
||||
logger.info(f"replyer生成模型: {model_name}")
|
||||
|
||||
llm_response.content = content
|
||||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
@@ -238,8 +244,8 @@ class DefaultReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
@@ -480,6 +486,31 @@ class DefaultReplyer:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt = build_readable_messages(
|
||||
latest_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
return all_dialogue_prompt
|
||||
|
||||
def core_background_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
@@ -529,7 +560,7 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
@@ -594,7 +625,18 @@ class DefaultReplyer:
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = f"{global_config.personality.personality};"
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (global_config.personality.states and
|
||||
global_config.personality.state_probability > 0 and
|
||||
random.random() < global_config.personality.state_probability):
|
||||
# 随机选择一个状态替换personality
|
||||
selected_state = random.choice(global_config.personality.states)
|
||||
prompt_personality = selected_state
|
||||
|
||||
prompt_personality = f"{prompt_personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
@@ -731,7 +773,7 @@ class DefaultReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 8:
|
||||
if duration > 12:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
@@ -776,9 +818,7 @@ class DefaultReplyer:
|
||||
reply_target_block = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
|
||||
message_list_before_now_long, user_id, sender
|
||||
)
|
||||
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
@@ -793,9 +833,8 @@ class DefaultReplyer:
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
sender_name=sender,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
dialogue_prompt=dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
@@ -960,9 +999,9 @@ class DefaultReplyer:
|
||||
async def llm_generate_content(self, prompt: str):
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 直接使用已初始化的模型实例
|
||||
logger.info(f"\n{prompt}\n")
|
||||
# logger.info(f"\n{prompt}\n")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
@@ -971,6 +1010,9 @@ class DefaultReplyer:
|
||||
prompt
|
||||
)
|
||||
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
@@ -256,8 +256,8 @@ class PrivateReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
@@ -522,7 +522,18 @@ class PrivateReplyer:
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = f"{global_config.personality.personality};"
|
||||
# 获取基础personality
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (global_config.personality.states and
|
||||
global_config.personality.state_probability > 0 and
|
||||
random.random() < global_config.personality.state_probability):
|
||||
# 随机选择一个状态替换personality
|
||||
selected_state = random.choice(global_config.personality.states)
|
||||
prompt_personality = selected_state
|
||||
|
||||
prompt_personality = f"{prompt_personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
@@ -668,7 +679,7 @@ class PrivateReplyer:
|
||||
continue
|
||||
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 8:
|
||||
if duration > 12:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
|
||||
|
||||
@@ -911,7 +922,7 @@ class PrivateReplyer:
|
||||
# 直接使用已初始化的模型实例
|
||||
logger.info(f"\n{prompt}\n")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
@@ -919,8 +930,12 @@ class PrivateReplyer:
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
|
||||
content = content.strip()
|
||||
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
if global_config.debug.show_replyer_reasoning:
|
||||
logger.info(f"使用 {model_name} 生成回复推理:\n{reasoning_content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
|
||||
@@ -17,8 +17,7 @@ def init_replyer_prompt():
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
{time_block}
|
||||
{background_dialogue_prompt}
|
||||
{core_dialogue_prompt}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{identity}
|
||||
@@ -26,7 +25,7 @@ def init_replyer_prompt():
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。请不要思考太长
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ def init_rewrite_prompt():
|
||||
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括冒号和引号,表情包,emoji,at或 @等 ),只输出一条回复就好。
|
||||
不要输出多余内容(包括冒号和引号,表情包,emoji,at或 @等 ),只输出一条回复就好。不要思考的太长。
|
||||
改写后的回复:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
|
||||
@@ -43,9 +43,12 @@ def replace_user_references(
|
||||
if name_resolver is None:
|
||||
|
||||
def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
# 检查是否是机器人自己(支持多平台)
|
||||
if replace_bot_name:
|
||||
if platform == "qq" and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_id # type: ignore
|
||||
|
||||
@@ -92,6 +95,8 @@ def replace_user_references(
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
# Telegram 文本 @username 的显示映射交由适配器或平台层处理;此处不做硬编码替换
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@@ -432,7 +437,10 @@ def _build_readable_messages_internal(
|
||||
person_name = (
|
||||
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
|
||||
)
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
if replace_bot_name and (
|
||||
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
|
||||
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
|
||||
):
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
@@ -866,7 +874,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
||||
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
|
||||
|
||||
if user_id == global_config.bot.qq_account:
|
||||
if (platform == "qq" and user_id == global_config.bot.qq_account) or (
|
||||
platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
|
||||
):
|
||||
# print("SELF11111111111111")
|
||||
return "SELF"
|
||||
try:
|
||||
|
||||
@@ -334,7 +334,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -492,10 +492,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
try:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"更新 name_mapping 时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
@@ -514,15 +519,32 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
last_all_time_stat = None
|
||||
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
try:
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
||||
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
|
||||
# 修复 name_mapping 数据类型不匹配问题
|
||||
# JSON 中存储为列表,但代码期望为元组
|
||||
raw_name_mapping = last_stat["name_mapping"]
|
||||
self.name_mapping = {}
|
||||
for chat_id, value in raw_name_mapping.items():
|
||||
if isinstance(value, list) and len(value) == 2:
|
||||
# 将列表转换为元组
|
||||
self.name_mapping[chat_id] = (value[0], value[1])
|
||||
elif isinstance(value, tuple) and len(value) == 2:
|
||||
# 已经是元组,直接使用
|
||||
self.name_mapping[chat_id] = value
|
||||
else:
|
||||
# 数据格式不正确,跳过或使用默认值
|
||||
logger.warning(f"name_mapping 中 chat_id {chat_id} 的数据格式不正确: {value}")
|
||||
continue
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
||||
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
|
||||
except Exception as e:
|
||||
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
|
||||
|
||||
stat_start_timestamp = [(period[0], now - period[1]) for period in self.stat_period]
|
||||
|
||||
@@ -571,8 +593,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 更新上次完整统计数据的时间戳
|
||||
# 将所有defaultdict转换为普通dict以避免类型冲突
|
||||
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
|
||||
|
||||
# 将 name_mapping 中的元组转换为列表,因为JSON不支持元组
|
||||
json_safe_name_mapping = {}
|
||||
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
|
||||
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
|
||||
|
||||
local_storage["last_full_statistics"] = {
|
||||
"name_mapping": self.name_mapping,
|
||||
"name_mapping": json_safe_name_mapping,
|
||||
"stat_data": clean_stat_data,
|
||||
"timestamp": now.timestamp(),
|
||||
}
|
||||
@@ -651,10 +679,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
if stats[TOTAL_MSG_CNT] <= 0:
|
||||
return ""
|
||||
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
|
||||
output.extend(
|
||||
f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}"
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items())
|
||||
)
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
output.append(f"{chat_name[:32]:<32} {count:>10}")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"格式化聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
output.append(f"{'未知聊天':<32} {count:>10}")
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
@@ -770,14 +801,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
chat_rows = "\n".join(
|
||||
[
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
]
|
||||
if stat_data[MSG_CNT_BY_CHAT]
|
||||
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
chat_rows = []
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
chat_rows.append(f"<tr><td>{chat_name}</td><td>{count}</td></tr>")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
chat_rows.append(f"<tr><td>未知聊天</td><td>{count}</td></tr>")
|
||||
|
||||
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
||||
# 生成HTML
|
||||
return f"""
|
||||
<div id=\"{div_id}\" class=\"tab-content\">
|
||||
@@ -824,7 +857,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{chat_rows}
|
||||
{chat_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -975,7 +1008,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}}
|
||||
|
||||
// 聊天消息分布饼图
|
||||
const chatLabels = {[self.name_mapping[chat_id][0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
|
||||
const chatLabels = {[self.name_mapping.get(chat_id, ("未知聊天", 0))[0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
|
||||
if (chatLabels.length > 0) {{
|
||||
const chatData = {{
|
||||
labels: chatLabels,
|
||||
@@ -1233,7 +1266,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
@@ -30,76 +30,146 @@ def is_english_letter(char: str) -> bool:
|
||||
return "a" <= char.lower() <= "z"
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
|
||||
except Exception:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
result = f"[{time_str}] {name}: {content}\n"
|
||||
logger.debug(f"result: {result}")
|
||||
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
"""解析 platforms 列表,返回平台到账号的映射
|
||||
|
||||
Args:
|
||||
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
|
||||
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
result = {}
|
||||
for platform_entry in platforms:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
result[platform_name.strip()] = account.strip()
|
||||
return result
|
||||
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
if platform == "qq":
|
||||
return qq_account
|
||||
elif platform == "telegram":
|
||||
# 优先使用 tg,其次使用 telegram
|
||||
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
else:
|
||||
# 其他平台直接使用平台名作为键
|
||||
return platform_accounts.get(platform, "")
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.bot.nickname] + list(global_config.bot.alias_names)
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
platform = getattr(message.message_info, "platform", "") or ""
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
|
||||
nickname = str(global_config.bot.nickname or "")
|
||||
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
|
||||
keywords = [nickname] + alias_names
|
||||
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
# 这部分怎么处理啊啊啊啊
|
||||
# 我觉得可以给消息加一个 reply_probability_boost字段
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
# 1) 直接的 additional_config 标记
|
||||
add_cfg = getattr(message.message_info, "additional_config", None) or {}
|
||||
if isinstance(add_cfg, dict):
|
||||
if add_cfg.get("at_bot") or add_cfg.get("is_mentioned"):
|
||||
is_mentioned = True
|
||||
# 当提供数值型 is_mentioned 时,当作概率提升
|
||||
try:
|
||||
if add_cfg.get("is_mentioned") not in (None, ""):
|
||||
reply_probability = float(add_cfg.get("is_mentioned")) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 已经在上游设置过的 message.is_mentioned
|
||||
if getattr(message, "is_mentioned", False):
|
||||
is_mentioned = True
|
||||
|
||||
# 3) 扫描分段:是否包含 mention_bot(适配器插入)
|
||||
def _has_mention_bot(seg) -> bool:
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||
is_mentioned = True
|
||||
return is_mentioned, is_at, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
)
|
||||
if seg is None:
|
||||
return False
|
||||
if getattr(seg, "type", None) == "mention_bot":
|
||||
return True
|
||||
if getattr(seg, "type", None) == "seglist":
|
||||
for s in getattr(seg, "data", []) or []:
|
||||
if _has_mention_bot(s):
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in message.processed_plain_text:
|
||||
is_mentioned = True
|
||||
|
||||
# 判断是否被@
|
||||
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
|
||||
if _has_mention_bot(getattr(message, "message_segment", None)):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
if is_at and global_config.chat.at_bot_inevitable_reply:
|
||||
# 4) 统一的 @ 检测逻辑
|
||||
if current_account and not is_at and not is_mentioned:
|
||||
if platform == "qq":
|
||||
# QQ 格式: @<name:qq_id>
|
||||
if re.search(rf"@<(.+?):{re.escape(current_account)}>", text):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 其他平台格式: @username 或 @account
|
||||
if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
# 5) 统一的回复检测逻辑
|
||||
if not is_mentioned:
|
||||
# 通用回复格式:包含 "(你)" 或 "(你)"
|
||||
if re.search(r"\[回复 .*?\(你\):", text) or re.search(r"\[回复 .*?(你):", text):
|
||||
is_mentioned = True
|
||||
# ID 形式的回复检测
|
||||
elif current_account:
|
||||
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
|
||||
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
||||
if not is_mentioned and keywords:
|
||||
msg_content = text
|
||||
# 去除各种 @ 与 回复标记,避免误判
|
||||
msg_content = re.sub(r"@(.+?)((\d+))", "", msg_content)
|
||||
msg_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", msg_content)
|
||||
msg_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id|你)\):(.+?)\],说:", "", msg_content)
|
||||
msg_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>:(.+?)\],说:", "", msg_content)
|
||||
for kw in keywords:
|
||||
if kw and kw in msg_content:
|
||||
is_mentioned = True
|
||||
break
|
||||
|
||||
# 7) 概率设置
|
||||
if is_at and getattr(global_config.chat, "at_bot_inevitable_reply", 1):
|
||||
reply_probability = 1.0
|
||||
logger.debug("被@,回复概率设置为100%")
|
||||
else:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text
|
||||
) or re.match(
|
||||
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:",
|
||||
message.processed_plain_text,
|
||||
):
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 判断内容中是否被提及
|
||||
message_content = re.sub(r"@(.+?)((\d+))", "", message.processed_plain_text)
|
||||
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
|
||||
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\):(.+?)\],说:", "", message_content)
|
||||
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>:(.+?)\],说:", "", message_content)
|
||||
for keyword in keywords:
|
||||
if keyword in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.chat.mentioned_bot_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
elif is_mentioned and getattr(global_config.chat, "mentioned_bot_reply", 1):
|
||||
reply_probability = max(reply_probability, 1.0)
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
|
||||
return is_mentioned, is_at, reply_probability
|
||||
|
||||
|
||||
@@ -115,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
sort_order = [("time", -1)]
|
||||
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
who_chat_in_group = []
|
||||
for db_msg in recent_messages:
|
||||
# user_info = UserInfo.from_dict(
|
||||
# {
|
||||
# "platform": msg_db_data["user_platform"],
|
||||
# "user_id": msg_db_data["user_id"],
|
||||
# "user_nickname": msg_db_data["user_nickname"],
|
||||
# "user_cardname": msg_db_data.get("user_cardname", ""),
|
||||
# }
|
||||
# )
|
||||
# if (
|
||||
# (user_info.platform, user_info.user_id) != sender
|
||||
# and user_info.user_id != global_config.bot.qq_account
|
||||
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
# and len(who_chat_in_group) < 5
|
||||
# ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
if (
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
|
||||
and db_msg.user_info.user_id != global_config.bot.qq_account
|
||||
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append(
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
)
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
@@ -410,42 +441,6 @@ def calculate_typing_time(
|
||||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
words = jieba.lcut(text)
|
||||
return Counter(words)
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
"""使用简单的余弦相似度计算文本相似度"""
|
||||
# 将输入文本转换为词频向量
|
||||
text_vector = text_to_vector(text)
|
||||
|
||||
# 计算每个主题的相似度
|
||||
similarities = []
|
||||
for topic in topics:
|
||||
topic_vector = text_to_vector(topic)
|
||||
# 获取所有唯一词
|
||||
all_words = set(text_vector.keys()) | set(topic_vector.keys())
|
||||
# 构建向量
|
||||
v1 = [text_vector.get(word, 0) for word in all_words]
|
||||
v2 = [topic_vector.get(word, 0) for word in all_words]
|
||||
# 计算相似度
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
similarities.append((topic, similarity))
|
||||
|
||||
# 按相似度降序排序并返回前k个
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
@@ -523,47 +518,6 @@ def get_western_ratio(paragraph):
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
start_time (float): 起始时间戳 (不包含)
|
||||
end_time (float): 结束时间戳 (包含)
|
||||
stream_id (str): 聊天流ID
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (消息数量, 文本总长度)
|
||||
"""
|
||||
count = 0
|
||||
total_length = 0
|
||||
|
||||
# 参数校验 (可选但推荐)
|
||||
if start_time >= end_time:
|
||||
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
|
||||
return 0, 0
|
||||
if not stream_id:
|
||||
logger.error("stream_id 不能为空")
|
||||
return 0, 0
|
||||
|
||||
# 使用message_repository中的count_messages和find_messages函数
|
||||
|
||||
# 构建查询条件
|
||||
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 先获取消息数量
|
||||
count = count_messages(filter_query)
|
||||
|
||||
# 获取消息内容计算总长度
|
||||
messages = find_messages(message_filter=filter_query)
|
||||
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
|
||||
|
||||
return count, total_length
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
return 0, 0
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
@@ -698,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data
|
||||
return result
|
||||
|
||||
|
||||
# def assign_message_ids_flexible(
|
||||
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
|
||||
# ) -> list:
|
||||
# """
|
||||
# 为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
# Args:
|
||||
# messages: 消息列表
|
||||
# prefix: ID前缀,默认为"msg"
|
||||
# id_length: ID的总长度(不包括前缀),默认为6
|
||||
# use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
# Returns:
|
||||
# 包含 {'id': str, 'message': any} 格式的字典列表
|
||||
# """
|
||||
# result = []
|
||||
# used_ids = set()
|
||||
|
||||
# for i, message in enumerate(messages):
|
||||
# # 生成唯一的ID
|
||||
# while True:
|
||||
# if use_timestamp:
|
||||
# # 使用时间戳的后几位 + 随机字符
|
||||
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
# remaining_length = id_length - 3
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
# else:
|
||||
# # 使用索引 + 随机字符
|
||||
# index_str = str(i + 1)
|
||||
# remaining_length = max(1, id_length - len(index_str))
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
# if message_id not in used_ids:
|
||||
# used_ids.add(message_id)
|
||||
# break
|
||||
|
||||
# result.append({"id": message_id, "message": message})
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
|
||||
def parse_keywords_string(keywords_input) -> list[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
|
||||
@@ -44,6 +44,11 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
|
||||
try:
|
||||
self._cleanup_invalid_descriptions()
|
||||
except Exception as e:
|
||||
logger.warning(f"数据库清理失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
@@ -92,6 +97,26 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_invalid_descriptions():
|
||||
"""清理数据库中 description 为空或为 'None' 的记录"""
|
||||
invalid_values = ["", "None"]
|
||||
|
||||
# 清理 Images 表
|
||||
deleted_images = Images.delete().where(
|
||||
(Images.description >> None) | (Images.description << invalid_values)
|
||||
).execute()
|
||||
|
||||
# 清理 ImageDescriptions 表
|
||||
deleted_descriptions = ImageDescriptions.delete().where(
|
||||
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
|
||||
).execute()
|
||||
|
||||
if deleted_images or deleted_descriptions:
|
||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
||||
else:
|
||||
logger.info("[清理完成] 未发现无效描述记录")
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
@@ -273,7 +298,7 @@ class ImageManager:
|
||||
prompt = global_config.personality.visual_style
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
prompt, image_base64, image_format, temperature=0.4
|
||||
)
|
||||
|
||||
if description is None:
|
||||
@@ -570,7 +595,7 @@ class ImageManager:
|
||||
|
||||
# 获取VLM描述
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
prompt, image_base64, image_format, temperature=0.4
|
||||
)
|
||||
|
||||
if description is None:
|
||||
|
||||
@@ -303,15 +303,13 @@ class Expression(BaseModel):
|
||||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
context_words = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
|
||||
class Meta:
|
||||
|
||||
@@ -55,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.0-snapshot.3"
|
||||
MMC_VERSION = "0.11.1-snapshot.1"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
||||
@@ -27,6 +27,9 @@ class BotConfig(ConfigBase):
|
||||
|
||||
nickname: str
|
||||
"""昵称"""
|
||||
|
||||
platforms: list[str] = field(default_factory=lambda: [])
|
||||
"""其他平台列表"""
|
||||
|
||||
alias_names: list[str] = field(default_factory=lambda: [])
|
||||
"""别名列表"""
|
||||
@@ -54,6 +57,12 @@ class PersonalityConfig(ConfigBase):
|
||||
private_plan_style: str = ""
|
||||
"""私聊说话规则,行为风格"""
|
||||
|
||||
states: list[str] = field(default_factory=lambda: [])
|
||||
"""状态列表,用于随机替换personality"""
|
||||
|
||||
state_probability: float = 0.0
|
||||
"""状态概率,每次构建人格时替换personality的概率"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationshipConfig(ConfigBase):
|
||||
@@ -82,6 +91,9 @@ class ChatConfig(ConfigBase):
|
||||
auto_chat_value: float = 1
|
||||
"""自动聊天,越小,麦麦主动聊天的概率越低"""
|
||||
|
||||
enable_auto_chat_value_rules: bool = True
|
||||
"""是否启用动态自动聊天频率规则"""
|
||||
|
||||
at_bot_inevitable_reply: float = 1
|
||||
"""@bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
@@ -91,6 +103,9 @@ class ChatConfig(ConfigBase):
|
||||
talk_value: float = 1
|
||||
"""思考频率"""
|
||||
|
||||
enable_talk_value_rules: bool = True
|
||||
"""是否启用动态发言频率规则"""
|
||||
|
||||
talk_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
思考频率规则列表,支持按聊天流/按日内时段配置。
|
||||
@@ -177,7 +192,7 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
def get_talk_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
|
||||
if not self.talk_value_rules:
|
||||
if not self.enable_talk_value_rules or not self.talk_value_rules:
|
||||
return self.talk_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
@@ -232,7 +247,7 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 auto_chat_value,未匹配则回退到基础值。"""
|
||||
if not self.auto_chat_value_rules:
|
||||
if not self.enable_auto_chat_value_rules or not self.auto_chat_value_rules:
|
||||
return self.auto_chat_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
@@ -310,8 +325,8 @@ class MemoryConfig(ConfigBase):
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: Literal["llm", "context", "full-context"] = "context"
|
||||
"""表达方式模式,可选:llm模式,context上下文模式,full-context 完整上下文嵌入模式"""
|
||||
mode: str = "classic"
|
||||
"""表达方式模式,可选:classic经典模式,exp_model 表达模型模式"""
|
||||
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
@@ -626,6 +641,12 @@ class DebugConfig(ConfigBase):
|
||||
|
||||
show_prompt: bool = False
|
||||
"""是否显示prompt"""
|
||||
|
||||
show_replyer_prompt: bool = True
|
||||
"""是否显示回复器prompt"""
|
||||
|
||||
show_replyer_reasoning: bool = True
|
||||
"""是否显示回复器推理"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
93
src/express/express_utils.py
Normal file
93
src/express/express_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import re
|
||||
import difflib
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r'@<[^>]*>', '', content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用随机抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
|
||||
return selected
|
||||
@@ -1,10 +1,8 @@
|
||||
import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
@@ -17,26 +15,16 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.express.express_utils import filter_message_content, calculate_similarity
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 15 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
@@ -105,8 +93,8 @@ class ExpressionLearner:
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 150 / self.learning_intensity
|
||||
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
@@ -139,7 +127,7 @@ class ExpressionLearner:
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self) -> bool:
|
||||
async def trigger_learning_for_chat(self):
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
@@ -150,11 +138,10 @@ class ExpressionLearner:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return False
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
|
||||
@@ -163,161 +150,105 @@ class ExpressionLearner:
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
return
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
当时间差为0天时,衰减值为0(最近活跃的不衰减)
|
||||
当时间差为7天时,衰减值为0.002(中等衰减)
|
||||
当时间差为30天或更长时,衰减值为0.01(高衰减)
|
||||
使用二次函数进行曲线插值
|
||||
"""
|
||||
if time_diff_days <= 0:
|
||||
return 0.0 # 刚激活的表达式不衰减
|
||||
|
||||
if time_diff_days >= DECAY_DAYS:
|
||||
return 0.01 # 长时间未活跃的表达式大幅衰减
|
||||
|
||||
# 使用二次函数插值:在0-30天之间从0衰减到0.01
|
||||
# 使用简单的二次函数:y = a * x^2
|
||||
# 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
|
||||
a = 0.01 / (DECAY_DAYS**2)
|
||||
decay = a * (time_diff_days**2)
|
||||
|
||||
return min(0.01, decay)
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
res = await self.learn_expression(num)
|
||||
learnt_expressions = await self.learn_expression(num)
|
||||
|
||||
if res is None:
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
learnt_expressions = res
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
_chat_id,
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_context_words,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for (
|
||||
chat_id,
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
context_words,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append(
|
||||
{
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"context_words": context_words,
|
||||
}
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.context_words = new_expr["context_words"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
context_words=new_expr["context_words"],
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
.order_by(Expression.count.asc())
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
continue
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
has_new_expressions = True
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
learner.add_style(style, situation)
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
@@ -339,8 +270,8 @@ class ExpressionLearner:
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
print(f"match_expression_context_prompt: {prompt}")
|
||||
print(f"random_msg_match_str: {response}")
|
||||
# print(f"match_expression_context_prompt: {prompt}")
|
||||
# print(f"{response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
@@ -393,24 +324,44 @@ class ExpressionLearner:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
# 确保 match_responses 是一个列表
|
||||
if not isinstance(match_responses, list):
|
||||
if isinstance(match_responses, dict):
|
||||
match_responses = [match_responses]
|
||||
else:
|
||||
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
|
||||
return []
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"match_responses 内容: {match_responses}")
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 检查 match_response 的类型
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response["context"]
|
||||
context = match_response.get("context", "")
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError) as e:
|
||||
except (ValueError, KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
@@ -418,15 +369,12 @@ class ExpressionLearner:
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
"""
|
||||
type_str = "语言风格"
|
||||
prompt = "learn_style_prompt"
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
@@ -439,14 +387,14 @@ class ExpressionLearner:
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
_chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
"learn_style_prompt",
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
@@ -456,49 +404,51 @@ class ExpressionLearner:
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
|
||||
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
|
||||
matched_expressions
|
||||
)
|
||||
|
||||
split_matched_expressions_w_emb = []
|
||||
|
||||
for situation, style, context, context_words in split_matched_expressions:
|
||||
split_matched_expressions_w_emb.append(
|
||||
(self.chat_id, situation, style, context, context_words)
|
||||
)
|
||||
|
||||
return split_matched_expressions_w_emb
|
||||
|
||||
def split_expression_context(
|
||||
self, matched_expressions: List[Tuple[str, str, str]]
|
||||
) -> List[Tuple[str, str, str, List[str]]]:
|
||||
"""
|
||||
对matched_expressions中的context部分进行jieba分词
|
||||
|
||||
Args:
|
||||
matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context)
|
||||
|
||||
Returns:
|
||||
添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words)
|
||||
"""
|
||||
result = []
|
||||
# 为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
# 使用jieba进行分词
|
||||
context_words = list(jieba.cut(context))
|
||||
result.append((situation, style, context, context_words))
|
||||
# 在 bare_lines 中找到第一处相似度达到85%的行
|
||||
pos = None
|
||||
for i, (_, c) in enumerate(bare_lines):
|
||||
similarity = calculate_similarity(c, context)
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
continue
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
continue
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
# 上一句为空,跳过该表达
|
||||
continue
|
||||
filtered_with_up.append((situation, style, context, up_content))
|
||||
|
||||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
return filtered_with_up
|
||||
|
||||
return result
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
@@ -530,6 +480,26 @@ class ExpressionLearner:
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
446
src/express/expression_selector.py
Normal file
446
src/express/expression_selector.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.express.express_utils import filter_message_content, weighted_sample
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
target_message: 目标消息内容
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 预测的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 过滤目标消息内容,移除回复、表情包等特殊格式
|
||||
filtered_target_message = filter_message_content(target_message)
|
||||
|
||||
logger.info(f"为{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
|
||||
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
for related_chat_id in related_chat_ids:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, filtered_target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
# 获取预测风格的完整信息
|
||||
learner = style_learner_manager.get_learner(related_chat_id)
|
||||
style_id, situation = learner.get_style_info(best_style)
|
||||
|
||||
if style_id and situation:
|
||||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
|
||||
if expr_query.exists():
|
||||
expr = expr_query.get()
|
||||
predicted_expressions.append({
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": filtered_target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
# 按预测分数排序,取前 total_num 个
|
||||
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
|
||||
selected_expressions = predicted_expressions[:total_num]
|
||||
|
||||
logger.info(f"为{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._random_expressions(chat_id, total_num)
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
if style_exprs:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
根据配置模式选择适合的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 获取配置模式
|
||||
expression_mode = global_config.expression.mode
|
||||
|
||||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, target_message, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
else:
|
||||
logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
target_message: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
exp_model模式:直接使用模型预测,不经过LLM
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
target_message: 目标消息内容
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
|
||||
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"exp_model模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
"表达方式激活: 更新last_active_time in db"
|
||||
)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
141
src/express/expressor_model/model.py
Normal file
141
src/express/expressor_model/model.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import Counter, defaultdict
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
|
||||
class ExpressorModel:
|
||||
"""
|
||||
直接使用朴素贝叶斯精排(可在线学习)
|
||||
支持存储situation字段,不参与计算,仅与style对应
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
gamma: float = 1.0,
|
||||
vocab_size: int = 200000,
|
||||
use_jieba: bool = True):
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||||
"""添加候选文本和对应的situation"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||||
"""批量添加候选文本和对应的situations"""
|
||||
for i, (cid, text) in enumerate(items):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
self.add_candidate(cid, text, situation)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return None, {}
|
||||
|
||||
if not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 对所有候选进行评分
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
|
||||
# 根据k参数限制返回的候选数量
|
||||
if k is not None and k > 0:
|
||||
# 按分数降序排序,取前k个
|
||||
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
limited_scores = dict(sorted_scores[:k])
|
||||
best = sorted_scores[0][0] if sorted_scores else None
|
||||
return best, limited_scores
|
||||
else:
|
||||
# 如果没有指定k,返回所有分数
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: float):
|
||||
self.nb.decay(factor=factor)
|
||||
|
||||
def get_situation(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的situation"""
|
||||
return self._situations.get(cid)
|
||||
|
||||
def get_style(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的style"""
|
||||
return self._candidates.get(cid)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""获取候选的style和situation信息"""
|
||||
return self._candidates.get(cid), self._situations.get(cid)
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""获取所有候选的style和situation信息"""
|
||||
return {cid: (style, self._situations.get(cid))
|
||||
for cid, style in self._candidates.items()}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb": {
|
||||
"cls_counts": dict(self.nb.cls_counts),
|
||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||
"alpha": self.nb.alpha,
|
||||
"beta": self.nb.beta,
|
||||
"gamma": self.nb.gamma,
|
||||
"V": self.nb.V,
|
||||
}
|
||||
}, f)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
# 还原候选文本
|
||||
self._candidates = obj["candidates"]
|
||||
# 还原situations(兼容旧版本)
|
||||
self._situations = obj.get("situations", {})
|
||||
# 还原朴素贝叶斯模型
|
||||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||||
self.nb.alpha = obj["nb"]["alpha"]
|
||||
self.nb.beta = obj["nb"]["beta"]
|
||||
self.nb.gamma = obj["nb"]["gamma"]
|
||||
self.nb.V = obj["nb"]["V"]
|
||||
self.nb._logZ.clear()
|
||||
|
||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||
from collections import defaultdict
|
||||
outer = defaultdict(lambda: defaultdict(float))
|
||||
for k, inner in d.items():
|
||||
outer[k].update(inner)
|
||||
return outer
|
||||
60
src/express/expressor_model/online_nb.py
Normal file
60
src/express/expressor_model/online_nb.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: float = None):
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
31
src/express/expressor_model/tokenizer.py
Normal file
31
src/express/expressor_model/tokenizer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import re
|
||||
from typing import List, Optional, Set
|
||||
|
||||
try:
|
||||
import jieba
|
||||
_HAS_JIEBA = True
|
||||
except Exception:
|
||||
_HAS_JIEBA = False
|
||||
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
# 匹配纯符号的正则表达式
|
||||
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
|
||||
|
||||
def simple_en_tokenize(text: str) -> List[str]:
|
||||
return _WORD_RE.findall(text.lower())
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
||||
self.stopwords = stopwords or set()
|
||||
self.use_jieba = use_jieba and _HAS_JIEBA
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.use_jieba:
|
||||
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
||||
else:
|
||||
toks = simple_en_tokenize(text)
|
||||
# 过滤掉纯符号和停用词
|
||||
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]
|
||||
628
src/express/style_learner.py
Normal file
628
src/express/style_learner.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
多聊天室表达风格学习系统
|
||||
支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .expressor_model.model import ExpressorModel
|
||||
|
||||
logger = get_logger("style_learner")
|
||||
|
||||
|
||||
class StyleLearner:
|
||||
"""
|
||||
单个聊天室的表达风格学习器
|
||||
学习从up_content到style的映射关系
|
||||
支持动态管理风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
"alpha": 0.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||
"vocab_size": 200000,
|
||||
"use_jieba": True
|
||||
}
|
||||
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0 # 下一个可用的style_id
|
||||
|
||||
# 学习统计
|
||||
self.learning_stats = {
|
||||
"total_samples": 0,
|
||||
"style_counts": defaultdict(int),
|
||||
"last_update": None,
|
||||
"style_usage_frequency": defaultdict(int) # 风格使用频率
|
||||
}
|
||||
|
||||
def add_style(self, style: str, situation: str = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
situation: 对应的situation文本(可选)
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否已存在
|
||||
if style in self.style_to_id:
|
||||
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
|
||||
return True
|
||||
|
||||
# 检查是否超过最大限制
|
||||
if len(self.style_to_id) >= self.max_styles:
|
||||
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
|
||||
return False
|
||||
|
||||
# 生成新的style_id
|
||||
style_id = f"style_{self.next_style_id}"
|
||||
self.next_style_id += 1
|
||||
|
||||
# 添加到映射
|
||||
self.style_to_id[style] = style_id
|
||||
self.id_to_style[style_id] = style
|
||||
if situation:
|
||||
self.id_to_situation[style_id] = situation
|
||||
|
||||
# 添加到expressor模型
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
|
||||
(f", situation: '{situation}'" if situation else ""))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_style(self, style: str) -> bool:
|
||||
"""
|
||||
删除一个风格
|
||||
|
||||
Args:
|
||||
style: 要删除的风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
try:
|
||||
if style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 从映射中删除
|
||||
del self.style_to_id[style]
|
||||
del self.id_to_style[style_id]
|
||||
if style_id in self.id_to_situation:
|
||||
del self.id_to_situation[style_id]
|
||||
|
||||
# 从expressor模型中删除(通过重新构建)
|
||||
self._rebuild_expressor()
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
|
||||
return False
|
||||
|
||||
def update_style(self, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
更新一个风格
|
||||
|
||||
Args:
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
try:
|
||||
if old_style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
|
||||
return False
|
||||
|
||||
if new_style in self.style_to_id and new_style != old_style:
|
||||
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[old_style]
|
||||
|
||||
# 更新映射
|
||||
del self.style_to_id[old_style]
|
||||
self.style_to_id[new_style] = style_id
|
||||
self.id_to_style[style_id] = new_style
|
||||
|
||||
# 更新expressor模型(保留原有的situation)
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, new_style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
|
||||
return False
|
||||
|
||||
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
|
||||
"""
|
||||
批量添加风格
|
||||
|
||||
Args:
|
||||
styles: 风格文本列表
|
||||
situations: 对应的situation文本列表(可选)
|
||||
|
||||
Returns:
|
||||
int: 成功添加的数量
|
||||
"""
|
||||
success_count = 0
|
||||
for i, style in enumerate(styles):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
if self.add_style(style, situation):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
|
||||
return success_count
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
"""获取所有已注册的风格"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def get_style_count(self) -> int:
|
||||
"""获取当前风格数量"""
|
||||
return len(self.style_to_id)
|
||||
|
||||
def get_situation(self, style: str) -> Optional[str]:
|
||||
"""
|
||||
获取风格对应的situation
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Optional[str]: 对应的situation,如果不存在则返回None
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
return self.id_to_situation.get(style_id)
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取风格的完整信息
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: (style_id, situation)
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None, None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
获取所有风格的完整信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
|
||||
"""
|
||||
result = {}
|
||||
for style, style_id in self.style_to_id.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
result[style] = (style_id, situation)
|
||||
return result
|
||||
|
||||
def _rebuild_expressor(self):
|
||||
"""重新构建expressor模型(删除风格后使用)"""
|
||||
try:
|
||||
# 重新创建expressor
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 重新添加所有风格和situation
|
||||
for style_id, style_text in self.id_to_style.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, style_text, situation)
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
|
||||
|
||||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个up_content到style的映射
|
||||
如果style不存在,会自动添加
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
style: 对应的style文本
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
try:
|
||||
# 如果style不存在,先添加它
|
||||
if style not in self.style_to_id:
|
||||
if not self.add_style(style):
|
||||
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
|
||||
return False
|
||||
|
||||
# 获取style_id
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 使用正反馈学习
|
||||
self.expressor.update_positive(up_content, style_id)
|
||||
|
||||
# 更新统计
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats["style_usage_frequency"][style] += 1
|
||||
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style文本, 所有候选的分数]
|
||||
"""
|
||||
try:
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
return None, {}
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
# 转换所有分数
|
||||
style_scores = {}
|
||||
for sid, score in scores.items():
|
||||
style_text = self.id_to_style.get(sid)
|
||||
if style_text:
|
||||
style_scores[style_text] = score
|
||||
|
||||
return best_style, style_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None, {}
|
||||
|
||||
def decay_learning(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对学习到的知识进行衰减(遗忘)
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,None则使用配置中的gamma
|
||||
"""
|
||||
self.expressor.decay(factor)
|
||||
logger.debug(f"[{self.chat_id}] 执行知识衰减")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取学习统计信息"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"total_samples": self.learning_stats["total_samples"],
|
||||
"style_count": len(self.style_to_id),
|
||||
"max_styles": self.max_styles,
|
||||
"style_counts": dict(self.learning_stats["style_counts"]),
|
||||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||||
"last_update": self.learning_stats["last_update"],
|
||||
"all_styles": list(self.style_to_id.keys())
|
||||
}
|
||||
|
||||
def save(self, base_path: str) -> bool:
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
|
||||
"""
|
||||
try:
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
|
||||
# 保存模型和统计信息
|
||||
save_data = {
|
||||
"model_config": self.model_config,
|
||||
"style_to_id": self.style_to_id,
|
||||
"id_to_style": self.id_to_style,
|
||||
"id_to_situation": self.id_to_situation,
|
||||
"next_style_id": self.next_style_id,
|
||||
"max_styles": self.max_styles,
|
||||
"learning_stats": self.learning_stats
|
||||
}
|
||||
|
||||
# 先保存expressor模型
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
self.expressor.save(expressor_path)
|
||||
|
||||
# 保存其他数据
|
||||
with open(file_path, "wb") as f:
|
||||
pickle.dump(save_data, f)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
|
||||
return False
|
||||
|
||||
def load(self, base_path: str) -> bool:
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
base_path: 基础路径
|
||||
"""
|
||||
try:
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
|
||||
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
|
||||
return False
|
||||
|
||||
# 加载其他数据
|
||||
with open(file_path, "rb") as f:
|
||||
save_data = pickle.load(f)
|
||||
|
||||
# 恢复配置和状态
|
||||
self.model_config = save_data["model_config"]
|
||||
self.style_to_id = save_data["style_to_id"]
|
||||
self.id_to_style = save_data["id_to_style"]
|
||||
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
|
||||
self.next_style_id = save_data["next_style_id"]
|
||||
self.max_styles = save_data.get("max_styles", 2000)
|
||||
self.learning_stats = save_data["learning_stats"]
|
||||
|
||||
# 重新创建expressor并加载
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
self.expressor.load(expressor_path)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class StyleLearnerManager:
|
||||
"""
|
||||
多聊天室表达风格学习管理器
|
||||
为每个chat_id维护独立的StyleLearner实例
|
||||
每个chat_id可以动态管理自己的风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, model_save_path: str = "data/style_models"):
|
||||
self.model_save_path = model_save_path
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
|
||||
# 自动保存配置
|
||||
self.auto_save_interval = 300 # 5分钟
|
||||
self._auto_save_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("StyleLearnerManager 已初始化")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置,None则使用默认配置
|
||||
|
||||
Returns:
|
||||
StyleLearner实例
|
||||
"""
|
||||
if chat_id not in self.learners:
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
|
||||
self.learners[chat_id] = learner
|
||||
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def add_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id添加风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.add_style(style)
|
||||
|
||||
def remove_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id删除风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.remove_style(style)
|
||||
|
||||
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id更新风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.update_style(old_style, new_style)
|
||||
|
||||
def get_chat_styles(self, chat_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定chat_id的所有风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
List[str]: 风格列表
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.get_all_styles()
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
style: 对应的style
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
预测最合适的style
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style, 所有候选分数]
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.predict_style(up_content, top_k)
|
||||
|
||||
def decay_all_learners(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对所有学习器执行衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
for learner in self.learners.values():
|
||||
learner.decay_learning(factor)
|
||||
logger.info("已对所有学习器执行衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
"""获取所有学习器的统计信息"""
|
||||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||||
|
||||
def save_all_models(self) -> bool:
|
||||
"""保存所有模型"""
|
||||
success_count = 0
|
||||
for learner in self.learners.values():
|
||||
if learner.save(self.model_save_path):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
|
||||
return success_count == len(self.learners)
|
||||
|
||||
def load_all_models(self) -> int:
|
||||
"""加载所有已保存的模型"""
|
||||
if not os.path.exists(self.model_save_path):
|
||||
return 0
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self.model_save_path):
|
||||
if filename.endswith("_style_model.pkl"):
|
||||
chat_id = filename.replace("_style_model.pkl", "")
|
||||
learner = StyleLearner(chat_id)
|
||||
if learner.load(self.model_save_path):
|
||||
self.learners[chat_id] = learner
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"已加载 {loaded_count} 个模型")
|
||||
return loaded_count
|
||||
|
||||
async def start_auto_save(self) -> None:
|
||||
"""启动自动保存任务"""
|
||||
if self._auto_save_task is None or self._auto_save_task.done():
|
||||
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
|
||||
logger.info("已启动自动保存任务")
|
||||
|
||||
async def stop_auto_save(self) -> None:
|
||||
"""停止自动保存任务"""
|
||||
if self._auto_save_task and not self._auto_save_task.done():
|
||||
self._auto_save_task.cancel()
|
||||
try:
|
||||
await self._auto_save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("已停止自动保存任务")
|
||||
|
||||
async def _auto_save_loop(self) -> None:
|
||||
"""自动保存循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.auto_save_interval)
|
||||
self.save_all_models()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动保存失败: {e}")
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
style_learner_manager = StyleLearnerManager()
|
||||
@@ -13,6 +13,7 @@ from google.genai.types import (
|
||||
ContentUnion,
|
||||
ThinkingConfig,
|
||||
Tool,
|
||||
GoogleSearch,
|
||||
GenerateContentConfig,
|
||||
EmbedContentResponse,
|
||||
EmbedContentConfig,
|
||||
@@ -176,19 +177,21 @@ def _process_delta(
|
||||
delta: GenerateContentResponse,
|
||||
fc_delta_buffer: io.StringIO,
|
||||
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
|
||||
resp: APIResponse | None = None,
|
||||
):
|
||||
if not hasattr(delta, "candidates") or not delta.candidates:
|
||||
raise RespParseException(delta, "响应解析失败,缺失candidates字段")
|
||||
|
||||
if delta.text:
|
||||
fc_delta_buffer.write(delta.text)
|
||||
|
||||
# 处理 thought(Gemini 的特殊字段)
|
||||
for c in getattr(delta, "candidates", []):
|
||||
if c.content and getattr(c.content, "parts", None):
|
||||
for p in c.content.parts:
|
||||
if getattr(p, "thought", False) and getattr(p, "text", None):
|
||||
# 把 thought 写入 buffer,避免 resp.content 永远为空
|
||||
# 保存到 reasoning_content
|
||||
if resp is not None:
|
||||
resp.reasoning_content = (resp.reasoning_content or "") + p.text
|
||||
elif getattr(p, "text", None):
|
||||
# 正常输出写入 buffer
|
||||
fc_delta_buffer.write(p.text)
|
||||
|
||||
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
||||
@@ -213,9 +216,11 @@ def _build_stream_api_resp(
|
||||
_fc_delta_buffer: io.StringIO,
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]],
|
||||
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
|
||||
resp: APIResponse | None = None,
|
||||
) -> APIResponse:
|
||||
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
||||
resp = APIResponse()
|
||||
if resp is None:
|
||||
resp = APIResponse()
|
||||
|
||||
if _fc_delta_buffer.tell() > 0:
|
||||
# 如果正式内容缓冲区不为空,则将其写入APIResponse对象
|
||||
@@ -240,11 +245,15 @@ def _build_stream_api_resp(
|
||||
# 检查是否因为 max_tokens 截断
|
||||
reason = None
|
||||
if last_resp and getattr(last_resp, "candidates", None):
|
||||
c0 = last_resp.candidates[0]
|
||||
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||
|
||||
for c in last_resp.candidates:
|
||||
fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None)
|
||||
if fr:
|
||||
reason = str(fr)
|
||||
break
|
||||
|
||||
if str(reason).endswith("MAX_TOKENS"):
|
||||
if resp.content and resp.content.strip():
|
||||
has_visible_output = bool(resp.content and resp.content.strip())
|
||||
if has_visible_output:
|
||||
logger.warning(
|
||||
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||
@@ -253,7 +262,8 @@ def _build_stream_api_resp(
|
||||
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
|
||||
|
||||
if not resp.content and not resp.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
if not getattr(resp, "reasoning_content", None):
|
||||
raise EmptyResponseException()
|
||||
|
||||
return resp
|
||||
|
||||
@@ -271,7 +281,8 @@ async def _default_stream_response_handler(
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||
_usage_record = None # 使用情况记录
|
||||
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
|
||||
|
||||
resp = APIResponse()
|
||||
|
||||
def _insure_buffer_closed():
|
||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||
_fc_delta_buffer.close()
|
||||
@@ -287,6 +298,7 @@ async def _default_stream_response_handler(
|
||||
chunk,
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
resp=resp,
|
||||
)
|
||||
|
||||
if chunk.usage_metadata:
|
||||
@@ -302,6 +314,7 @@ async def _default_stream_response_handler(
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
last_resp=last_resp,
|
||||
resp=resp,
|
||||
), _usage_record
|
||||
except Exception:
|
||||
# 确保缓冲区被关闭
|
||||
@@ -526,6 +539,15 @@ class GeminiClient(BaseClient):
|
||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||
# 解析并裁剪 thinking_budget
|
||||
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||
# 检测是否为带 -search 的模型
|
||||
enable_google_search = False
|
||||
model_identifier = model_info.model_identifier
|
||||
if model_identifier.endswith("-search"):
|
||||
enable_google_search = True
|
||||
# 去掉后缀并更新模型ID
|
||||
model_identifier = model_identifier.removesuffix("-search")
|
||||
model_info.model_identifier = model_identifier
|
||||
logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}")
|
||||
|
||||
# 将response_format转换为Gemini API所需的格式
|
||||
generation_config_dict = {
|
||||
@@ -548,6 +570,17 @@ class GeminiClient(BaseClient):
|
||||
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
|
||||
generation_config_dict["response_mime_type"] = "application/json"
|
||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||
# 自动启用 GoogleSearch grounding_tool
|
||||
if enable_google_search:
|
||||
grounding_tool = Tool(google_search=GoogleSearch())
|
||||
if "tools" in generation_config_dict:
|
||||
existing = generation_config_dict["tools"]
|
||||
if isinstance(existing, list):
|
||||
existing.append(grounding_tool)
|
||||
else:
|
||||
generation_config_dict["tools"] = [existing, grounding_tool]
|
||||
else:
|
||||
generation_config_dict["tools"] = [grounding_tool]
|
||||
|
||||
generation_config = GenerateContentConfig(**generation_config_dict)
|
||||
|
||||
|
||||
@@ -199,6 +199,7 @@ def _build_stream_api_resp(
|
||||
_fc_delta_buffer: io.StringIO,
|
||||
_rc_delta_buffer: io.StringIO,
|
||||
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
|
||||
finish_reason: str | None = None,
|
||||
) -> APIResponse:
|
||||
resp = APIResponse()
|
||||
|
||||
@@ -236,6 +237,9 @@ def _build_stream_api_resp(
|
||||
|
||||
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||
|
||||
# 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出)
|
||||
# 保留 finish_reason 仅用于上层判断
|
||||
|
||||
if not resp.content and not resp.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
@@ -258,6 +262,8 @@ async def _default_stream_response_handler(
|
||||
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||
_usage_record = None # 使用情况记录
|
||||
finish_reason: str | None = None # 记录最后的 finish_reason
|
||||
_model_name: str | None = None # 记录模型名
|
||||
|
||||
def _insure_buffer_closed():
|
||||
# 确保缓冲区被关闭
|
||||
@@ -285,6 +291,12 @@ async def _default_stream_response_handler(
|
||||
continue # 跳过本帧,避免访问 choices[0]
|
||||
delta = event.choices[0].delta # 获取当前块的delta内容
|
||||
|
||||
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
|
||||
finish_reason = event.choices[0].finish_reason
|
||||
|
||||
if hasattr(event, "model") and event.model and not _model_name:
|
||||
_model_name = event.model # 记录模型名
|
||||
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
|
||||
# 标记:有独立的推理内容块
|
||||
_has_rc_attr_flag = True
|
||||
@@ -307,11 +319,34 @@ async def _default_stream_response_handler(
|
||||
)
|
||||
|
||||
try:
|
||||
return _build_stream_api_resp(
|
||||
resp = _build_stream_api_resp(
|
||||
_fc_delta_buffer,
|
||||
_rc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
), _usage_record
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
# 统一在这里输出 max_tokens 截断的警告,并从 resp 中读取
|
||||
if finish_reason == "length":
|
||||
# 把模型名塞到 resp.raw_data,后续严格“从 resp 提取”
|
||||
try:
|
||||
if _model_name:
|
||||
resp.raw_data = {"model": _model_name}
|
||||
except Exception:
|
||||
pass
|
||||
model_dbg = None
|
||||
try:
|
||||
if isinstance(resp.raw_data, dict):
|
||||
model_dbg = resp.raw_data.get("model")
|
||||
except Exception:
|
||||
model_dbg = None
|
||||
|
||||
# 统一日志格式
|
||||
logger.info(
|
||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
||||
% (model_dbg or "")
|
||||
)
|
||||
|
||||
return resp, _usage_record
|
||||
except Exception:
|
||||
# 确保缓冲区被关闭
|
||||
_insure_buffer_closed()
|
||||
@@ -335,9 +370,32 @@ def _default_normal_response_parser(
|
||||
"""
|
||||
api_response = APIResponse()
|
||||
|
||||
if not hasattr(resp, "choices") or len(resp.choices) == 0:
|
||||
raise EmptyResponseException("响应解析失败,缺失choices字段或choices列表为空")
|
||||
message_part = resp.choices[0].message
|
||||
# 兼容部分 OpenAI 兼容服务在空回复时返回 choices=None 的情况
|
||||
choices = getattr(resp, "choices", None)
|
||||
if not choices:
|
||||
try:
|
||||
model_dbg = getattr(resp, "model", None)
|
||||
id_dbg = getattr(resp, "id", None)
|
||||
usage_dbg = None
|
||||
if hasattr(resp, "usage") and resp.usage:
|
||||
usage_dbg = {
|
||||
"prompt": getattr(resp.usage, "prompt_tokens", None),
|
||||
"completion": getattr(resp.usage, "completion_tokens", None),
|
||||
"total": getattr(resp.usage, "total_tokens", None),
|
||||
}
|
||||
try:
|
||||
raw_snippet = str(resp)[:300]
|
||||
except Exception:
|
||||
raw_snippet = "<unserializable>"
|
||||
logger.debug(
|
||||
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
|
||||
)
|
||||
except Exception:
|
||||
# 日志采集失败不应影响控制流
|
||||
pass
|
||||
# 统一抛出可重试的 EmptyResponseException,触发上层重试逻辑
|
||||
raise EmptyResponseException("响应解析失败,choices 为空或缺失")
|
||||
message_part = choices[0].message
|
||||
|
||||
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
|
||||
# 有有效的推理字段
|
||||
@@ -381,6 +439,22 @@ def _default_normal_response_parser(
|
||||
# 将原始响应存储在原始数据中
|
||||
api_response.raw_data = resp
|
||||
|
||||
# 检查 max_tokens 截断
|
||||
try:
|
||||
choice0 = resp.choices[0]
|
||||
reason = getattr(choice0, "finish_reason", None)
|
||||
if reason and reason == "length":
|
||||
print(resp)
|
||||
_model_name = resp.model
|
||||
# 统一日志格式
|
||||
logger.info(
|
||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
||||
% (_model_name or "")
|
||||
)
|
||||
return api_response, _usage_record
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
@@ -489,7 +563,7 @@ class OpenaiClient(BaseClient):
|
||||
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||
|
||||
# logger.
|
||||
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
|
||||
# logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
|
||||
|
||||
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
|
||||
|
||||
|
||||
@@ -29,12 +29,19 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
:return: 转换后的图片数据
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_data)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
|
||||
# 静态图像,转换为JPEG格式
|
||||
# 仅在非动图时进行格式转换
|
||||
if (
|
||||
not getattr(image, "is_animated", False)
|
||||
and image.format
|
||||
and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"])
|
||||
):
|
||||
reformated_image_data = io.BytesIO()
|
||||
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||
img_to_save = image
|
||||
if img_to_save.mode in ("RGBA", "LA", "P"):
|
||||
img_to_save = img_to_save.convert("RGB")
|
||||
img_to_save.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||
image_data = reformated_image_data.getvalue()
|
||||
|
||||
return image_data
|
||||
@@ -50,20 +57,22 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
:return: 缩放后的图片数据
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_data)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 原始尺寸
|
||||
original_size = (image.width, image.height)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
||||
# 计算新的尺寸,防止为0
|
||||
new_w = max(1, int(original_size[0] * scale))
|
||||
new_h = max(1, int(original_size[1] * scale))
|
||||
new_size = (new_w, new_h)
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
if getattr(image, "is_animated", False):
|
||||
# 动态图片,处理所有帧
|
||||
frames = []
|
||||
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
|
||||
new_size = (max(1, new_size[0] // 2), max(1, new_size[1] // 2)) # 动图,缩放尺寸再打折
|
||||
for frame_idx in range(getattr(image, "n_frames", 1)):
|
||||
image.seek(frame_idx)
|
||||
new_frame = image.copy()
|
||||
@@ -83,6 +92,8 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
else:
|
||||
# 静态图片,直接缩放保存
|
||||
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
if resized_image.mode in ("RGBA", "LA", "P"):
|
||||
resized_image = resized_image.convert("RGB")
|
||||
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||
|
||||
return output_buffer.getvalue(), original_size, new_size
|
||||
|
||||
@@ -270,13 +270,28 @@ class LLMRequest:
|
||||
audio_base64=audio_base64,
|
||||
extra_params=model_info.extra_params,
|
||||
)
|
||||
except (EmptyResponseException, NetworkConnectionError) as e:
|
||||
except EmptyResponseException as e:
|
||||
# 空回复:通常为临时问题,单独记录并重试
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在用尽对临时错误的重试次数后仍然失败。")
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到可重试错误: {str(e)}。剩余重试次数: {retry_remain}")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except NetworkConnectionError as e:
|
||||
# 网络错误:单独记录并重试
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except RespNotOkException as e:
|
||||
@@ -369,8 +384,8 @@ class LLMRequest:
|
||||
failed_models_this_request.add(model_info.name)
|
||||
|
||||
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
|
||||
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
|
||||
raise last_exception from e
|
||||
logger.warning("收到客户端错误 (400),跳过当前模型并继续尝试其他模型。")
|
||||
continue
|
||||
|
||||
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
|
||||
if last_exception:
|
||||
|
||||
@@ -18,7 +18,7 @@ from .memory_utils import (
|
||||
find_best_matching_memory,
|
||||
check_title_exists_fuzzy,
|
||||
get_all_titles,
|
||||
get_memory_titles_by_chat_id_weighted,
|
||||
find_most_similar_memory_by_chat_id,
|
||||
|
||||
)
|
||||
|
||||
@@ -41,6 +41,80 @@ class MemoryChest:
|
||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
||||
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
|
||||
|
||||
def remove_one_memory_by_age_weight(self) -> bool:
|
||||
"""
|
||||
删除一条记忆:按“越老/越新更易被删”的权重随机选择(老=较小id,新=较大id)。
|
||||
|
||||
返回:是否删除成功
|
||||
"""
|
||||
try:
|
||||
memories = list(MemoryChestModel.select())
|
||||
if not memories:
|
||||
return False
|
||||
|
||||
# 排除锁定项
|
||||
candidates = [m for m in memories if not getattr(m, "locked", False)]
|
||||
if not candidates:
|
||||
return False
|
||||
|
||||
# 按 id 排序,使用 id 近似时间顺序(小 -> 老,大 -> 新)
|
||||
candidates.sort(key=lambda m: m.id)
|
||||
n = len(candidates)
|
||||
if n == 1:
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == candidates[0].id).execute()
|
||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{candidates[0].title}")
|
||||
return True
|
||||
|
||||
# 计算U型权重:中间最低,两端最高
|
||||
# r ∈ [0,1] 为位置归一化,w = 0.1 + 0.9 * (abs(r-0.5)*2)**1.5
|
||||
weights = []
|
||||
for idx, _m in enumerate(candidates):
|
||||
r = idx / (n - 1)
|
||||
w = 0.1 + 0.9 * (abs(r - 0.5) * 2) ** 1.5
|
||||
weights.append(w)
|
||||
|
||||
import random as _random
|
||||
selected = _random.choices(candidates, weights=weights, k=1)[0]
|
||||
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == selected.id).execute()
|
||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{selected.title}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 按年龄权重删除记忆时出错: {e}")
|
||||
return False
|
||||
|
||||
def _compute_merge_similarity_threshold(self) -> float:
|
||||
"""
|
||||
根据当前记忆数量占比动态计算合并相似度阈值。
|
||||
|
||||
规则:占比越高,阈值越低。
|
||||
- < 60%: 0.80(更严格,避免早期误合并)
|
||||
- < 80%: 0.70
|
||||
- < 100%: 0.60
|
||||
- < 120%: 0.50
|
||||
- >= 120%: 0.45(最宽松,加速收敛)
|
||||
"""
|
||||
try:
|
||||
current_count = MemoryChestModel.select().count()
|
||||
max_count = max(1, int(global_config.memory.max_memory_number))
|
||||
percentage = current_count / max_count
|
||||
|
||||
if percentage < 0.6:
|
||||
return 0.70
|
||||
elif percentage < 0.8:
|
||||
return 0.60
|
||||
elif percentage < 1.0:
|
||||
return 0.50
|
||||
elif percentage < 1.5:
|
||||
return 0.40
|
||||
elif percentage < 2:
|
||||
return 0.30
|
||||
else:
|
||||
return 0.25
|
||||
except Exception:
|
||||
# 发生异常时使用保守阈值
|
||||
return 0.70
|
||||
|
||||
async def build_running_content(self, chat_id: str = None) -> str:
|
||||
"""
|
||||
构建记忆仓库的运行内容
|
||||
@@ -430,75 +504,43 @@ class MemoryChest:
|
||||
except Exception as e:
|
||||
logger.error(f"保存记忆仓库内容时出错: {e}")
|
||||
|
||||
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> list[str]:
|
||||
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
选择与给定记忆标题相关的记忆目标
|
||||
选择与给定记忆标题相关的记忆目标(基于文本相似度)
|
||||
|
||||
Args:
|
||||
memory_title: 要匹配的记忆标题
|
||||
chat_id: 聊天ID,用于加权抽样
|
||||
chat_id: 聊天ID,用于筛选同chat_id的记忆
|
||||
|
||||
Returns:
|
||||
list[str]: 选中的记忆内容列表
|
||||
tuple[list[str], list[str]]: (选中的记忆标题列表, 选中的记忆内容列表)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了chat_id,使用加权抽样
|
||||
all_titles = get_memory_titles_by_chat_id_weighted(chat_id)
|
||||
# 剔除掉输入的 memory_title 本身
|
||||
all_titles = [title for title in all_titles if title and title.strip() != (memory_title or "").strip()]
|
||||
if not chat_id:
|
||||
logger.warning("未提供chat_id,无法进行记忆匹配")
|
||||
return [], []
|
||||
|
||||
content = ""
|
||||
display_index = 1
|
||||
for title in all_titles:
|
||||
content += f"{display_index}. {title}\n"
|
||||
display_index += 1
|
||||
# 动态计算相似度阈值(占比越高阈值越低)
|
||||
dynamic_threshold = self._compute_merge_similarity_threshold()
|
||||
|
||||
# 使用相似度匹配查找最相似的记忆(基于动态阈值)
|
||||
similar_memory = find_most_similar_memory_by_chat_id(
|
||||
target_title=memory_title,
|
||||
target_chat_id=chat_id,
|
||||
similarity_threshold=dynamic_threshold
|
||||
)
|
||||
|
||||
prompt = f"""
|
||||
所有记忆列表
|
||||
{content}
|
||||
|
||||
请根据以上记忆列表,选择一个与"{memory_title}"相关的记忆,用json输出:
|
||||
如果没有相关记忆,输出:
|
||||
{{
|
||||
"selected_title": ""
|
||||
}}
|
||||
可以选择多个相关的记忆,但最多不超过5个
|
||||
例如:
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}}
|
||||
...
|
||||
注意:请返回原始标题本身作为 selected_title,不要包含前面的序号或多余字符。
|
||||
请输出JSON格式,不要输出其他内容:
|
||||
"""
|
||||
|
||||
# logger.info(f"选择合并目标 prompt: {prompt}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"选择合并目标 prompt: {prompt}")
|
||||
if similar_memory:
|
||||
selected_title, selected_content, similarity = similar_memory
|
||||
logger.info(f"为 '{memory_title}' 找到相似记忆: '{selected_title}' (相似度: {similarity:.3f} 阈值: {dynamic_threshold:.2f})")
|
||||
return [selected_title], [selected_content]
|
||||
else:
|
||||
logger.debug(f"选择合并目标 prompt: {prompt}")
|
||||
|
||||
merge_target_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
selected_titles = self._parse_merge_target_json(merge_target_response)
|
||||
|
||||
# 根据标题查找对应的内容
|
||||
selected_contents = self._get_memories_by_titles(selected_titles)
|
||||
|
||||
logger.info(f"选择合并目标结果: {len(selected_contents)} 条记忆:{selected_titles}")
|
||||
return selected_titles,selected_contents
|
||||
logger.info(f"为 '{memory_title}' 未找到相似度 >= {dynamic_threshold:.2f} 的记忆")
|
||||
return [], []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择合并目标时出错: {e}")
|
||||
return []
|
||||
return [], []
|
||||
|
||||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
||||
"""
|
||||
@@ -659,6 +701,11 @@ class MemoryChest:
|
||||
合并记忆
|
||||
"""
|
||||
try:
|
||||
# 在记忆整合前先清理空chat_id的记忆
|
||||
cleaned_count = self.cleanup_empty_chat_id_memories()
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"记忆整合前清理了 {cleaned_count} 条空chat_id记忆")
|
||||
|
||||
content = ""
|
||||
for memory in memory_list:
|
||||
content += f"{memory}\n"
|
||||
@@ -793,5 +840,37 @@ MutePlugin 是禁言插件的名称
|
||||
logger.error(f"生成合并记忆标题时出错: {e}")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
def cleanup_empty_chat_id_memories(self) -> int:
|
||||
"""
|
||||
清理chat_id为空的记忆记录
|
||||
|
||||
Returns:
|
||||
int: 被清理的记忆数量
|
||||
"""
|
||||
try:
|
||||
# 查找所有chat_id为空的记忆
|
||||
empty_chat_id_memories = MemoryChestModel.select().where(
|
||||
(MemoryChestModel.chat_id.is_null()) |
|
||||
(MemoryChestModel.chat_id == "") |
|
||||
(MemoryChestModel.chat_id == "None")
|
||||
)
|
||||
|
||||
count = 0
|
||||
for memory in empty_chat_id_memories:
|
||||
logger.info(f"清理空chat_id记忆: 标题='{memory.title}', ID={memory.id}")
|
||||
memory.delete_instance()
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"已清理 {count} 条chat_id为空的记忆记录")
|
||||
else:
|
||||
logger.debug("未发现需要清理的空chat_id记忆记录")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理空chat_id记忆时出错: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
global_memory_chest = MemoryChest()
|
||||
185
src/memory_system/curious.py
Normal file
185
src/memory_system/curious.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.memory_system.memory_utils import parse_md_json
|
||||
|
||||
logger = get_logger("curious")
|
||||
|
||||
|
||||
class CuriousDetector:
|
||||
"""
|
||||
好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="curious_detector",
|
||||
)
|
||||
|
||||
async def detect_questions(self, recent_messages: List) -> Optional[str]:
|
||||
"""
|
||||
检测最近消息中是否有需要提问的内容
|
||||
|
||||
Args:
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果检测到需要提问的内容,返回问题文本;否则返回None
|
||||
"""
|
||||
try:
|
||||
if not recent_messages or len(recent_messages) < 2:
|
||||
return None
|
||||
|
||||
# 构建聊天内容
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 检查是否已经有问题在跟踪中
|
||||
existing_questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_id)
|
||||
if len(existing_questions) > 0:
|
||||
logger.debug(f"当前已有{len(existing_questions)}个问题在跟踪中,跳过检测")
|
||||
return None
|
||||
|
||||
# 构建检测提示词
|
||||
prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。
|
||||
|
||||
检测条件:
|
||||
1. 聊天中存在逻辑矛盾或冲突的信息
|
||||
2. 有人反对或否定之前提出的信息
|
||||
3. 存在观点不一致的情况
|
||||
4. 有模糊不清或需要澄清的概念
|
||||
5. 有人提出了质疑或反驳
|
||||
|
||||
**重要限制:**
|
||||
- 忽略涉及违法、暴力、色情、政治等敏感话题的内容
|
||||
- 不要对敏感话题提问
|
||||
- 只有在确实存在矛盾或冲突时才提问
|
||||
- 如果聊天内容正常,没有矛盾,请输出:NO
|
||||
|
||||
**聊天记录**
|
||||
{chat_content_block}
|
||||
|
||||
请分析上述聊天记录,如果发现需要提问的内容,请用JSON格式输出:
|
||||
```json
|
||||
{{
|
||||
"question": "具体的问题描述,要完整描述涉及的概念和问题",
|
||||
"reason": "为什么需要提问这个问题的理由"
|
||||
}}
|
||||
```
|
||||
|
||||
如果没有需要提问的内容,请只输出:NO"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"好奇心检测提示词: {prompt}")
|
||||
else:
|
||||
logger.debug("已发送好奇心检测提示词")
|
||||
|
||||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
if not result_text:
|
||||
return None
|
||||
|
||||
result_text = result_text.strip()
|
||||
|
||||
# 检查是否输出NO
|
||||
if result_text.upper() == "NO":
|
||||
logger.debug("未检测到需要提问的内容")
|
||||
return None
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
questions, reasoning = parse_md_json(result_text)
|
||||
if questions and len(questions) > 0:
|
||||
question_data = questions[0]
|
||||
question = question_data.get("question", "")
|
||||
reason = question_data.get("reason", "")
|
||||
|
||||
if question and question.strip():
|
||||
logger.info(f"检测到需要提问的内容: {question}")
|
||||
logger.info(f"提问理由: {reason}")
|
||||
return question
|
||||
except Exception as e:
|
||||
logger.warning(f"解析问题JSON失败: {e}")
|
||||
logger.debug(f"原始响应: {result_text}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"好奇心检测失败: {e}")
|
||||
return None
|
||||
|
||||
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
|
||||
"""
|
||||
将检测到的问题记录到冲突追踪器中
|
||||
|
||||
Args:
|
||||
question: 检测到的问题
|
||||
context: 问题上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
try:
|
||||
if not question or not question.strip():
|
||||
return False
|
||||
|
||||
# 记录问题到冲突追踪器,并开始跟踪
|
||||
await global_conflict_tracker.track_conflict(
|
||||
question=question.strip(),
|
||||
context=context,
|
||||
start_following=False,
|
||||
chat_id=self.chat_id
|
||||
)
|
||||
|
||||
logger.info(f"已记录问题到冲突追踪器: {question}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录问题失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_make_question(chat_id: str, recent_messages: List) -> bool:
|
||||
"""
|
||||
检查聊天记录并生成问题(如果检测到需要提问的内容)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
Returns:
|
||||
bool: 是否检测到并记录了问题
|
||||
"""
|
||||
try:
|
||||
detector = CuriousDetector(chat_id)
|
||||
|
||||
# 检测是否需要提问
|
||||
question = await detector.detect_questions(recent_messages)
|
||||
|
||||
if question:
|
||||
# 记录问题
|
||||
success = await detector.make_question_from_detection(question)
|
||||
if success:
|
||||
logger.info(f"成功检测并记录问题: {question}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查并生成问题失败: {e}")
|
||||
return False
|
||||
@@ -8,7 +8,6 @@ from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.config.config import global_config
|
||||
from src.memory_system.memory_utils import get_all_titles
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
@@ -56,19 +55,19 @@ class MemoryManagementTask(AsyncTask):
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
|
||||
if percentage < 0.5:
|
||||
if percentage < 0.6:
|
||||
# 小于50%,每600秒执行一次
|
||||
return 3600
|
||||
elif percentage < 0.7:
|
||||
elif percentage < 1:
|
||||
# 大于等于50%,每300秒执行一次
|
||||
return 1800
|
||||
elif percentage < 0.9:
|
||||
# 大于等于70%,每120秒执行一次
|
||||
return 300
|
||||
elif percentage < 1.2:
|
||||
return 30
|
||||
elif percentage < 1.5:
|
||||
# 大于等于100%,每120秒执行一次
|
||||
return 600
|
||||
elif percentage < 1.8:
|
||||
return 120
|
||||
else:
|
||||
return 10
|
||||
return 30
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
|
||||
@@ -93,6 +92,22 @@ class MemoryManagementTask(AsyncTask):
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
|
||||
# 当占比 > 1.6 时,持续删除直到占比 <= 1.6(越老/越新更易被删)
|
||||
if percentage > 2:
|
||||
logger.info("记忆过多,开始遗忘记忆")
|
||||
while True:
|
||||
if percentage <= 1.8:
|
||||
break
|
||||
removed = global_memory_chest.remove_one_memory_by_age_weight()
|
||||
if not removed:
|
||||
logger.warning("没有可删除的记忆,停止连续删除")
|
||||
break
|
||||
# 重新计算占比
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"遗忘进度: 当前 {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
logger.info("遗忘记忆结束")
|
||||
|
||||
# 如果记忆数量为0,跳过执行
|
||||
if current_count < 10:
|
||||
return
|
||||
@@ -109,7 +124,7 @@ class MemoryManagementTask(AsyncTask):
|
||||
logger.info("无合适合并内容,跳过本次合并")
|
||||
return
|
||||
|
||||
logger.info(f"为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
|
||||
logger.info(f"{selected_chat_id} 为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
|
||||
|
||||
# 执行merge_memory合并记忆
|
||||
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents,selected_chat_id)
|
||||
|
||||
@@ -303,4 +303,55 @@ def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str, similarity_threshold: float = 0.5) -> Optional[Tuple[str, str, float]]:
|
||||
"""
|
||||
在指定chat_id的记忆中查找最相似的记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
target_chat_id: 目标聊天ID
|
||||
similarity_threshold: 相似度阈值,默认0.7
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, float]]: 最相似的记忆(title, content, similarity)或None
|
||||
"""
|
||||
try:
|
||||
# 获取指定chat_id的所有记忆
|
||||
same_chat_memories = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title and not memory.locked and memory.chat_id == target_chat_id:
|
||||
same_chat_memories.append((memory.title, memory.content))
|
||||
|
||||
if not same_chat_memories:
|
||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
||||
return None
|
||||
|
||||
# 计算相似度并找到最佳匹配
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for title, content in same_chat_memories:
|
||||
# 跳过目标标题本身
|
||||
if title.strip() == target_title.strip():
|
||||
continue
|
||||
|
||||
similarity = calculate_similarity(target_title, title)
|
||||
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = (title, content, similarity)
|
||||
|
||||
# 检查是否超过阈值
|
||||
if best_match and best_similarity >= similarity_threshold:
|
||||
logger.info(f"找到最相似记忆: '{best_match[0]}' (相似度: {best_similarity:.3f})")
|
||||
return best_match
|
||||
else:
|
||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆,最高相似度: {best_similarity:.3f}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找最相似记忆时出错: {e}")
|
||||
return None
|
||||
@@ -50,41 +50,34 @@ class QuestionMaker:
|
||||
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
||||
|
||||
选择规则:
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.05)。
|
||||
- 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.01)。
|
||||
- 若不存在,返回 None。
|
||||
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
||||
"""
|
||||
conflicts = await self.get_un_answered_conflict()
|
||||
if not conflicts:
|
||||
return None
|
||||
|
||||
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
|
||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||
if not conflicts_with_zero:
|
||||
if random.random() >= 0.05:
|
||||
return None
|
||||
# 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重)
|
||||
chosen_conflict = random.choice(conflicts)
|
||||
else:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.05
|
||||
if conflicts_with_zero:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
||||
weights = []
|
||||
for conflict in conflicts:
|
||||
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
||||
weight = 1.0 if current_raise_time == 0 else 0.05
|
||||
weight = 1.0 if current_raise_time == 0 else 0.01
|
||||
weights.append(weight)
|
||||
|
||||
# 按权重随机选择
|
||||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||
|
||||
# 选中后,自增 raise_time 并保存
|
||||
try:
|
||||
# 选中后,自增 raise_time 并保存
|
||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||
chosen_conflict.save()
|
||||
except Exception:
|
||||
# 静默失败不影响流程
|
||||
pass
|
||||
|
||||
return chosen_conflict
|
||||
return chosen_conflict
|
||||
else:
|
||||
# 如果没有 raise_time == 0 的冲突,返回 None
|
||||
return None
|
||||
|
||||
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""生成一条用于询问用户的冲突问题与上下文。
|
||||
|
||||
@@ -278,15 +278,41 @@ class ConflictTracker:
|
||||
# 无新消息时稍作等待
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# 未获取到答案,仅存储问题
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
# 未获取到答案,检查是否需要删除记录
|
||||
# 查找现有的冲突记录
|
||||
existing_conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == original_question,
|
||||
MemoryConflict.chat_id == tracker.chat_id
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
|
||||
if existing_conflict:
|
||||
# 检查raise_time是否大于3且没有答案
|
||||
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
|
||||
if current_raise_time > 0 and not existing_conflict.answer:
|
||||
# 删除该条目
|
||||
await self.delete_conflict(original_question, tracker.chat_id)
|
||||
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
|
||||
else:
|
||||
# 更新记录但不删除
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=existing_conflict.create_time,
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
else:
|
||||
# 如果没有现有记录,创建新记录
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
|
||||
logger.info(f"问题跟踪结束:{original_question}")
|
||||
except Exception as e:
|
||||
logger.error(f"后台问题跟踪任务异常: {e}")
|
||||
@@ -374,6 +400,7 @@ class ConflictTracker:
|
||||
1.提问必须具体,明确
|
||||
2.提问最好涉及指向明确的事物,而不是代称
|
||||
3.如果缺少上下文,不要强行提问,可以忽略
|
||||
4.请忽略涉及违法,暴力,色情,政治等敏感话题的内容
|
||||
|
||||
请用json格式输出,不要输出其他内容,仅输出提问理由和具体提的提问:
|
||||
**示例**
|
||||
@@ -420,5 +447,33 @@ class ConflictTracker:
|
||||
logger.error(f"获取冲突记录数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
|
||||
"""
|
||||
删除指定的冲突记录
|
||||
|
||||
Args:
|
||||
conflict_content: 冲突内容
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
try:
|
||||
conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == conflict_content,
|
||||
MemoryConflict.chat_id == chat_id
|
||||
)
|
||||
|
||||
if conflict:
|
||||
conflict.delete_instance()
|
||||
logger.info(f"已删除冲突记录: {conflict_content}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除冲突记录时出错: {e}")
|
||||
return False
|
||||
|
||||
# 全局冲突追踪器实例
|
||||
global_conflict_tracker = ConflictTracker()
|
||||
@@ -1,11 +1,12 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("frequency_api")
|
||||
|
||||
|
||||
def get_current_talk_frequency(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
|
||||
def get_current_talk_value(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
||||
|
||||
|
||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, TYPE_CHECKING
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例"""
|
||||
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
chat_stream: 聊天流对象,用于传递聊天上下文信息
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果未找到则返回None
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
# 获取插件配置
|
||||
@@ -19,7 +30,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
return tool_class(plugin_config, chat_stream) if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions():
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
@@ -29,8 +32,23 @@ class BaseTool(ABC):
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
|
||||
"""初始化工具基类
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
chat_stream: 聊天流对象,用于获取聊天上下文信息
|
||||
"""
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(与BaseAction保持一致)
|
||||
# =============================================================================
|
||||
|
||||
# 获取聊天流对象
|
||||
self.chat_stream = chat_stream
|
||||
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None
|
||||
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
|
||||
@@ -223,7 +223,7 @@ class ToolExecutor:
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "MaiCurious插件 (MaiCurious Actions)",
|
||||
"version": "1.0.0",
|
||||
"description": "可以好奇",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.11.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["curious", "action", "built-in"],
|
||||
"categories": ["Deep Think"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "maicurious",
|
||||
"description": "发送好奇"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
from typing import List, Tuple, Type, Any
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.component_types import ActionActivationType
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.plugin_system.apis import frequency_api
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
|
||||
logger = get_logger("question_actions")
|
||||
|
||||
|
||||
|
||||
class CuriousAction(BaseAction):
|
||||
"""频率调节动作 - 调整聊天发言频率"""
|
||||
|
||||
activation_type = ActionActivationType.ALWAYS
|
||||
parallel_action = True
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "make_question"
|
||||
|
||||
action_description = "提出一个问题,当有人反驳你的观点,或其他人之间有观点冲突时使用"
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"question": "对存在疑问的信息提出一个问题,描述全面,完整的描述涉及的概念和问题",
|
||||
}
|
||||
|
||||
action_require = [
|
||||
f"当聊天记录中的概念存在逻辑上的矛盾时使用",
|
||||
f"当有人反对或否定你提出的信息时使用",
|
||||
f"或当你对现有的概念或事物存在疑问时使用",
|
||||
f"有人认为你的观点是错误的,请选择question动作",
|
||||
f"有人与你观点不一致,请选择question动作",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行频率调节动作"""
|
||||
try:
|
||||
if len(global_conflict_tracker.question_tracker_list) > 1:
|
||||
return False, "当前已有问题,请先解答完再提问,不要再使用make_question动作"
|
||||
|
||||
question = self.action_data.get("question", "")
|
||||
|
||||
# 存储问题到冲突追踪器
|
||||
if question:
|
||||
await global_conflict_tracker.record_conflict(conflict_content=question, start_following=True,chat_id=self.chat_id)
|
||||
logger.info(f"已存储问题到冲突追踪器: {question}")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"你产生了一个问题:{question}",
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"问题{question}已记录,不要重复提问该问题"
|
||||
except Exception as e:
|
||||
error_msg = f"问题生成失败: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}", exc_info=True)
|
||||
await self.send_text("问题生成失败")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
@register_plugin
|
||||
class CuriousPlugin(BasePlugin):
|
||||
"""关系动作插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- Reply: 回复动作
|
||||
- NoReply: 不回复动作
|
||||
- Emoji: 表情动作
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "maicurious" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="3.0.0", description="配置文件版本"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((CuriousAction.get_action_info(), CuriousAction))
|
||||
|
||||
return components
|
||||
@@ -1,29 +1,77 @@
|
||||
from typing import Tuple
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
from src.plugin_system import BaseAction, ActionActivationType
|
||||
from src.chat.utils.utils import cut_key_words
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.apis.message_api import get_messages_by_time_in_chat, build_readable_messages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
def parse_time_range(time_range: str) -> tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
格式: "YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError("时间范围格式错误,应使用 ' - ' 分隔开始和结束时间")
|
||||
|
||||
start_str, end_str = time_range.split(" - ", 1)
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str.strip())
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str.strip())
|
||||
|
||||
if start_timestamp > end_timestamp:
|
||||
raise ValueError("开始时间不能晚于结束时间")
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
class GetMemoryTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "get_memory"
|
||||
description = "在记忆中搜索,获取某个问题的答案"
|
||||
description = "在记忆中搜索,获取某个问题的答案,可以指定搜索的时间范围或时间点"
|
||||
parameters = [
|
||||
("question", ToolParamType.STRING, "需要获取答案的问题", True, None)
|
||||
("question", ToolParamType.STRING, "需要获取答案的问题", True, None),
|
||||
("time_point", ToolParamType.STRING, "需要获取记忆的时间点,格式为YYYY-MM-DD HH:MM:SS", False, None),
|
||||
("time_range", ToolParamType.STRING, "需要获取记忆的时间范围,格式为YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS", False, None)
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
"""执行记忆搜索
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
@@ -32,59 +80,166 @@ class GetMemoryTool(BaseTool):
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
question: str = function_args.get("question") # type: ignore
|
||||
time_point: str = function_args.get("time_point") # type: ignore
|
||||
time_range: str = function_args.get("time_range") # type: ignore
|
||||
|
||||
answer = await global_memory_chest.get_answer_by_question(question=question)
|
||||
if not answer:
|
||||
return {"content": f"问题:{question},没有找到相关记忆"}
|
||||
# 检查是否指定了时间参数
|
||||
has_time_params = bool(time_point or time_range)
|
||||
|
||||
return {"content": f"问题:{question},答案:{answer}"}
|
||||
|
||||
class GetMemoryAction(BaseAction):
|
||||
"""关系动作 - 获取记忆"""
|
||||
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
parallel_action = True
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "get_memory"
|
||||
action_description = (
|
||||
"在记忆中搜寻某个问题的答案"
|
||||
)
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"question": "需要搜寻或回答的问题",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"在记忆中搜寻某个问题的答案",
|
||||
"有你不了解的概念",
|
||||
"有人提问关于过去的事情"
|
||||
"你需要根据记忆回答某个问题",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行关系动作"""
|
||||
if has_time_params and not self.chat_id:
|
||||
return {"content": f"问题:{question},无法获取聊天记录:缺少chat_id"}
|
||||
|
||||
question = self.action_data.get("question", "")
|
||||
answer = await global_memory_chest.get_answer_by_question(self.chat_id, question)
|
||||
if not answer:
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"你回忆了有关问题:{question}的记忆,但是没有找到相关记忆",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
return False, f"问题:{question},没有找到相关记忆"
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"你回忆了有关问题:{question}的记忆,答案是:{answer}",
|
||||
action_done=True,
|
||||
# 原任务:从记忆仓库获取答案
|
||||
memory_task = asyncio.create_task(
|
||||
global_memory_chest.get_answer_by_question(question=question)
|
||||
)
|
||||
tasks.append(("memory", memory_task))
|
||||
|
||||
return True, f"成功获取记忆: {answer}"
|
||||
# 新任务:从聊天记录获取答案(如果指定了时间参数)
|
||||
chat_task = None
|
||||
if has_time_params:
|
||||
chat_task = asyncio.create_task(
|
||||
self._get_answer_from_chat_history(question, time_point, time_range)
|
||||
)
|
||||
tasks.append(("chat", chat_task))
|
||||
|
||||
# 等待所有任务完成
|
||||
results = {}
|
||||
for task_name, task in tasks:
|
||||
try:
|
||||
results[task_name] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task_name} 执行失败: {e}")
|
||||
results[task_name] = None
|
||||
|
||||
# 处理结果
|
||||
memory_answer = results.get("memory")
|
||||
chat_answer = results.get("chat")
|
||||
|
||||
# 构建返回内容
|
||||
content_parts = []
|
||||
|
||||
if memory_answer:
|
||||
content_parts.append(f"对问题'{question}',你回忆的信息是:{memory_answer}")
|
||||
|
||||
if chat_answer:
|
||||
content_parts.append(f"对问题'{question}',基于聊天记录的回答:{chat_answer}")
|
||||
elif has_time_params:
|
||||
if time_point:
|
||||
content_parts.append(f"在 {time_point} 的时间点,你没有参与聊天")
|
||||
elif time_range:
|
||||
content_parts.append(f"在 {time_range} 的时间范围内,你没有参与聊天")
|
||||
|
||||
if content_parts:
|
||||
retrieval_content = f"问题:{question}" + "\n".join(content_parts)
|
||||
return {"content": retrieval_content}
|
||||
else:
|
||||
return {"content": ""}
|
||||
|
||||
|
||||
async def _get_answer_from_chat_history(self, question: str, time_point: str = None, time_range: str = None) -> str:
|
||||
"""从聊天记录中获取问题的答案"""
|
||||
try:
|
||||
# 确定时间范围
|
||||
print(f"time_point: {time_point}, time_range: {time_range}")
|
||||
|
||||
# 检查time_range的两个时间值是否相同,如果相同则按照time_point处理
|
||||
if time_range and not time_point:
|
||||
try:
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
if start_timestamp == end_timestamp:
|
||||
# 两个时间值相同,按照time_point处理
|
||||
time_point = time_range.split(" - ")[0].strip()
|
||||
time_range = None
|
||||
print(f"time_range两个值相同,按照time_point处理: {time_point}")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析time_range失败: {e}")
|
||||
|
||||
if time_point:
|
||||
# 时间点:搜索前后25条记录
|
||||
target_timestamp = parse_datetime_to_timestamp(time_point)
|
||||
# 获取前后各25条记录,总共50条
|
||||
messages_before = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=0,
|
||||
end_time=target_timestamp,
|
||||
limit=25,
|
||||
limit_mode="latest"
|
||||
)
|
||||
messages_after = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=target_timestamp,
|
||||
end_time=float('inf'),
|
||||
limit=25,
|
||||
limit_mode="earliest"
|
||||
)
|
||||
messages = messages_before + messages_after
|
||||
elif time_range:
|
||||
# 时间范围:搜索范围内最多50条记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
messages = get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_timestamp,
|
||||
end_time=end_timestamp,
|
||||
limit=50,
|
||||
limit_mode="latest"
|
||||
)
|
||||
else:
|
||||
return "未指定时间参数"
|
||||
|
||||
if not messages:
|
||||
return "没有找到相关聊天记录"
|
||||
|
||||
# 将消息转换为可读格式
|
||||
chat_content = build_readable_messages(messages, timestamp_mode="relative")
|
||||
|
||||
if not chat_content.strip():
|
||||
return "聊天记录为空"
|
||||
|
||||
# 使用LLM分析聊天内容并回答问题
|
||||
try:
|
||||
llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="chat_history_analysis"
|
||||
)
|
||||
|
||||
analysis_prompt = f"""请根据以下聊天记录内容,回答用户的问题。请输出一段平文本,不要有特殊格式。
|
||||
聊天记录:
|
||||
{chat_content}
|
||||
|
||||
用户问题:{question}
|
||||
|
||||
请仔细分析聊天记录,提取与问题相关的信息,并给出准确的答案。如果聊天记录中没有相关信息,无法回答问题,输出"无有效信息"即可,不要输出其他内容。
|
||||
|
||||
答案:"""
|
||||
|
||||
print(f"analysis_prompt: {analysis_prompt}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
|
||||
prompt=analysis_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
if "无有效信息" in response:
|
||||
return ""
|
||||
|
||||
return response
|
||||
|
||||
except Exception as llm_error:
|
||||
logger.error(f"LLM分析聊天记录失败: {llm_error}")
|
||||
# 如果LLM分析失败,返回聊天内容的摘要
|
||||
if len(chat_content) > 300:
|
||||
chat_content = chat_content[:300] + "..."
|
||||
return chat_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从聊天记录获取答案失败: {e}")
|
||||
return ""
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.plugin_system.base.config_types import ConfigField
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.memory.build_memory import GetMemoryAction, GetMemoryTool
|
||||
from src.plugins.built_in.memory.build_memory import GetMemoryTool
|
||||
|
||||
logger = get_logger("memory_build")
|
||||
|
||||
@@ -48,7 +48,6 @@ class MemoryBuildPlugin(BasePlugin):
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
# components.append((GetMemoryAction.get_action_info(), GetMemoryAction))
|
||||
components.append((GetMemoryTool.get_tool_info(), GetMemoryTool))
|
||||
|
||||
return components
|
||||
|
||||
@@ -78,7 +78,7 @@ class RelationActionsPlugin(BasePlugin):
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
|
||||
"config_version": ConfigField(type=str, default="1.0.2", description="配置文件版本"),
|
||||
},
|
||||
"components": {
|
||||
"relation_max_memory_num": ConfigField(type=int, default=10, description="关系记忆最大数量"),
|
||||
@@ -90,7 +90,7 @@ class RelationActionsPlugin(BasePlugin):
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((BuildRelationAction.get_action_info(), BuildRelationAction))
|
||||
components.append((GetPersonInfoTool.get_tool_info(), GetPersonInfoTool))
|
||||
# components.append((BuildRelationAction.get_action_info(), BuildRelationAction))
|
||||
# components.append((GetPersonInfoTool.get_tool_info(), GetPersonInfoTool))
|
||||
|
||||
return components
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "6.18.3"
|
||||
version = "6.19.2"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -14,6 +14,9 @@ version = "6.18.3"
|
||||
[bot]
|
||||
platform = "qq"
|
||||
qq_account = "1145141919810" # 麦麦的QQ账号
|
||||
|
||||
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
|
||||
|
||||
nickname = "麦麦" # 麦麦的昵称
|
||||
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||
|
||||
@@ -44,10 +47,20 @@ private_plan_style = """
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
3.某句话如果已经被回复过,不要重复回复"""
|
||||
|
||||
# 状态,可以理解为人格多样性,会随机替换人格
|
||||
states = [
|
||||
"是一个女大学生,喜欢上网聊天,会刷小红书。" ,
|
||||
"是一个大二心理学生,会刷贴吧和中国知网。" ,
|
||||
"是一个赛博网友,最近很想吐槽人。"
|
||||
]
|
||||
|
||||
# 替换概率,每次构建人格时替换personality的概率(0.0-1.0)
|
||||
state_probability = 0.3
|
||||
|
||||
[expression]
|
||||
# 表达方式模式(此选项暂未使用)
|
||||
mode = "context"
|
||||
# 可选:llm模式,context上下文模式
|
||||
# 表达方式模式
|
||||
mode = "classic"
|
||||
# 可选:classic经典模式,exp_model 表达模型模式,这个模式需要一定时间学习才会有比较好的效果
|
||||
|
||||
# 表达学习配置
|
||||
learning_list = [ # 表达学习配置列表,支持按聊天流配置
|
||||
@@ -79,6 +92,9 @@ max_context_size = 30 # 上下文长度
|
||||
auto_chat_value = 1 # 自动聊天,越小,麦麦主动聊天的概率越低
|
||||
planner_smooth = 5 #规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-8,0为关闭,必须大于等于0
|
||||
|
||||
enable_talk_value_rules = true # 是否启用动态发言频率规则
|
||||
enable_auto_chat_value_rules = false # 是否启用动态自动聊天频率规则
|
||||
|
||||
# 动态发言频率规则:按时段/按chat_id调整 talk_value(优先匹配具体chat,再匹配全局)
|
||||
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
# 说明:
|
||||
@@ -205,6 +221,8 @@ library_log_levels = { aiohttp = "WARNING"} # 设置特定库的日志级别
|
||||
|
||||
[debug]
|
||||
show_prompt = false # 是否显示prompt
|
||||
show_replyer_prompt = false # 是否显示回复器prompt
|
||||
show_replyer_reasoning = false # 是否显示回复器推理
|
||||
|
||||
[maim_message]
|
||||
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "1.7.4"
|
||||
version = "1.7.7"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
@@ -35,9 +35,9 @@ name = "SiliconFlow"
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
api_key = "your-siliconflow-api-key"
|
||||
client_type = "openai"
|
||||
max_retry = 2
|
||||
max_retry = 3
|
||||
timeout = 120
|
||||
retry_interval = 10
|
||||
retry_interval = 5
|
||||
|
||||
|
||||
[[models]] # 模型(可以配置多个)
|
||||
@@ -49,21 +49,31 @@ price_out = 8.0 # 输出价格(用于API调用统计,单
|
||||
#force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false)
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3"
|
||||
name = "siliconflow-deepseek-v3"
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
name = "siliconflow-deepseek-v3.2"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 8.0
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-8B"
|
||||
name = "qwen3-8b"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 0
|
||||
price_out = 0
|
||||
price_out = 3.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = false # 不启用思考
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
name = "siliconflow-deepseek-v3.2-think"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 3.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = true # 不启用思考
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-R1"
|
||||
name = "siliconflow-deepseek-r1"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 4.0
|
||||
price_out = 16.0
|
||||
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
name = "qwen3-30b"
|
||||
@@ -72,8 +82,8 @@ price_in = 0.7
|
||||
price_out = 2.8
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
|
||||
name = "qwen2.5-vl-72b"
|
||||
model_identifier = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
name = "qwen3-vl-30"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 4.13
|
||||
price_out = 4.13
|
||||
@@ -94,12 +104,12 @@ price_out = 0
|
||||
|
||||
|
||||
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,麦麦的情绪变化等,是麦麦必须的模型
|
||||
model_list = ["siliconflow-deepseek-v3"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048 # 最大输出token数
|
||||
|
||||
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
model_list = ["qwen3-8b","qwen3-30b"]
|
||||
model_list = ["qwen3-30b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 2048
|
||||
|
||||
@@ -109,18 +119,18 @@ temperature = 0.7
|
||||
max_tokens = 800
|
||||
|
||||
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
model_list = ["siliconflow-deepseek-v3.2-think","siliconflow-deepseek-r1","siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.3 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 800
|
||||
max_tokens = 2048
|
||||
|
||||
[model_task_config.planner] #决策:负责决定麦麦该什么时候回复的模型
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
|
||||
[model_task_config.vlm] # 图像识别模型
|
||||
model_list = ["qwen2.5-vl-72b"]
|
||||
max_tokens = 800
|
||||
model_list = ["qwen3-vl-30"]
|
||||
max_tokens = 256
|
||||
|
||||
[model_task_config.voice] # 语音识别模型
|
||||
model_list = ["sensevoice-small"]
|
||||
@@ -132,16 +142,16 @@ model_list = ["bge-m3"]
|
||||
#------------LPMM知识库模型------------
|
||||
|
||||
[model_task_config.lpmm_entity_extract] # 实体提取模型
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
|
||||
[model_task_config.lpmm_rdf_build] # RDF构建模型
|
||||
model_list = ["siliconflow-deepseek-v3"]
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
|
||||
[model_task_config.lpmm_qa] # 问答模型
|
||||
model_list = ["qwen3-30b"]
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
|
||||
391
test_style_learner_db.py
Normal file
391
test_style_learner_db.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
StyleLearner 数据库测试脚本
|
||||
使用数据库中的expression数据测试style_learner功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Tuple
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.common.database.database_model import Expression, db
|
||||
from src.express.style_learner import StyleLearnerManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("style_learner_test")
|
||||
|
||||
|
||||
class StyleLearnerDatabaseTest:
|
||||
"""使用数据库数据测试StyleLearner"""
|
||||
|
||||
def __init__(self, random_state: int = 42):
|
||||
self.random_state = random_state
|
||||
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
|
||||
|
||||
# 测试结果
|
||||
self.test_results = {
|
||||
"total_samples": 0,
|
||||
"train_samples": 0,
|
||||
"test_samples": 0,
|
||||
"unique_styles": 0,
|
||||
"unique_chat_ids": 0,
|
||||
"accuracy": 0.0,
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"predictions": [],
|
||||
"ground_truth": [],
|
||||
"model_save_success": False,
|
||||
"model_save_path": self.manager.model_save_path
|
||||
}
|
||||
|
||||
def load_data_from_database(self) -> List[Dict]:
|
||||
"""
|
||||
从数据库加载expression数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含up_content, style, chat_id的数据列表
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
# 查询所有expression数据
|
||||
expressions = Expression.select().where(
|
||||
(Expression.up_content.is_null(False)) &
|
||||
(Expression.style.is_null(False)) &
|
||||
(Expression.chat_id.is_null(False)) &
|
||||
(Expression.type == "style")
|
||||
)
|
||||
|
||||
data = []
|
||||
for expr in expressions:
|
||||
if expr.up_content and expr.style and expr.chat_id:
|
||||
data.append({
|
||||
"up_content": expr.up_content,
|
||||
"style": expr.style,
|
||||
"chat_id": expr.chat_id,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"context": expr.context,
|
||||
"situation": expr.situation
|
||||
})
|
||||
|
||||
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载数据失败: {e}")
|
||||
return []
|
||||
|
||||
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 预处理后的数据
|
||||
"""
|
||||
# 过滤掉空值或过短的数据
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
up_content = item["up_content"].strip()
|
||||
style = item["style"].strip()
|
||||
|
||||
if len(up_content) >= 2 and len(style) >= 2:
|
||||
filtered_data.append({
|
||||
"up_content": up_content,
|
||||
"style": style,
|
||||
"chat_id": item["chat_id"],
|
||||
"last_active_time": item["last_active_time"],
|
||||
"context": item["context"],
|
||||
"situation": item["situation"]
|
||||
})
|
||||
|
||||
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
|
||||
return filtered_data
|
||||
|
||||
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
分割训练集和测试集
|
||||
训练集使用所有数据,测试集从训练集中随机选择5%
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
|
||||
"""
|
||||
# 训练集使用所有数据
|
||||
train_data = data.copy()
|
||||
|
||||
# 测试集从训练集中随机选择5%
|
||||
test_size = 0.05 # 5%
|
||||
test_data = train_test_split(
|
||||
train_data, test_size=test_size, random_state=self.random_state
|
||||
)[1] # 只取测试集部分
|
||||
|
||||
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条")
|
||||
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
|
||||
return train_data, test_data
|
||||
|
||||
def train_model(self, train_data: List[Dict]) -> None:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
"""
|
||||
logger.info("开始训练模型...")
|
||||
|
||||
# 统计信息
|
||||
chat_ids = set()
|
||||
styles = set()
|
||||
|
||||
for item in train_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
style = item["style"]
|
||||
|
||||
chat_ids.add(chat_id)
|
||||
styles.add(style)
|
||||
|
||||
# 学习映射关系
|
||||
success = self.manager.learn_mapping(chat_id, up_content, style)
|
||||
if not success:
|
||||
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
|
||||
|
||||
self.test_results["train_samples"] = len(train_data)
|
||||
self.test_results["unique_styles"] = len(styles)
|
||||
self.test_results["unique_chat_ids"] = len(chat_ids)
|
||||
|
||||
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
|
||||
|
||||
# 保存训练好的模型
|
||||
logger.info("开始保存训练好的模型...")
|
||||
save_success = self.manager.save_all_models()
|
||||
self.test_results["model_save_success"] = save_success
|
||||
|
||||
if save_success:
|
||||
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
|
||||
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
|
||||
else:
|
||||
logger.warning("部分模型保存失败")
|
||||
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
|
||||
|
||||
def test_model(self, test_data: List[Dict]) -> None:
|
||||
"""
|
||||
测试模型
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
"""
|
||||
logger.info("开始测试模型...")
|
||||
|
||||
predictions = []
|
||||
ground_truth = []
|
||||
correct_predictions = 0
|
||||
|
||||
for item in test_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
true_style = item["style"]
|
||||
|
||||
# 预测风格
|
||||
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
|
||||
|
||||
predictions.append(predicted_style)
|
||||
ground_truth.append(true_style)
|
||||
|
||||
# 检查预测是否正确
|
||||
if predicted_style == true_style:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录详细预测结果
|
||||
self.test_results["predictions"].append({
|
||||
"chat_id": chat_id,
|
||||
"up_content": up_content,
|
||||
"true_style": true_style,
|
||||
"predicted_style": predicted_style,
|
||||
"scores": scores
|
||||
})
|
||||
|
||||
# 计算准确率
|
||||
accuracy = correct_predictions / len(test_data) if test_data else 0
|
||||
|
||||
# 计算其他指标(需要处理None值)
|
||||
valid_predictions = [p for p in predictions if p is not None]
|
||||
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
|
||||
|
||||
if valid_predictions:
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
|
||||
)
|
||||
else:
|
||||
precision = recall = f1 = 0.0
|
||||
|
||||
self.test_results["test_samples"] = len(test_data)
|
||||
self.test_results["accuracy"] = accuracy
|
||||
self.test_results["precision"] = precision
|
||||
self.test_results["recall"] = recall
|
||||
self.test_results["f1_score"] = f1
|
||||
|
||||
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
|
||||
|
||||
def analyze_results(self) -> None:
|
||||
"""分析测试结果"""
|
||||
logger.info("=== 测试结果分析 ===")
|
||||
|
||||
print("\n📊 数据统计:")
|
||||
print(f" 总样本数: {self.test_results['total_samples']}")
|
||||
print(f" 训练样本数: {self.test_results['train_samples']}")
|
||||
print(f" 测试样本数: {self.test_results['test_samples']}")
|
||||
print(f" 唯一风格数: {self.test_results['unique_styles']}")
|
||||
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
|
||||
|
||||
print("\n🎯 模型性能:")
|
||||
print(f" 准确率: {self.test_results['accuracy']:.4f}")
|
||||
print(f" 精确率: {self.test_results['precision']:.4f}")
|
||||
print(f" 召回率: {self.test_results['recall']:.4f}")
|
||||
print(f" F1分数: {self.test_results['f1_score']:.4f}")
|
||||
|
||||
print("\n💾 模型保存:")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
print(f" 保存状态: {save_status}")
|
||||
print(f" 保存路径: {self.test_results['model_save_path']}")
|
||||
|
||||
# 分析各聊天室的性能
|
||||
chat_performance = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
chat_id = pred["chat_id"]
|
||||
if chat_id not in chat_performance:
|
||||
chat_performance[chat_id] = {"correct": 0, "total": 0}
|
||||
|
||||
chat_performance[chat_id]["total"] += 1
|
||||
if pred["predicted_style"] == pred["true_style"]:
|
||||
chat_performance[chat_id]["correct"] += 1
|
||||
|
||||
print("\n📈 各聊天室性能:")
|
||||
for chat_id, perf in chat_performance.items():
|
||||
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
|
||||
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
|
||||
|
||||
# 分析风格分布
|
||||
style_counts = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
style = pred["true_style"]
|
||||
style_counts[style] = style_counts.get(style, 0) + 1
|
||||
|
||||
print("\n🎨 风格分布 (前10个):")
|
||||
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
for style, count in sorted_styles[:10]:
|
||||
print(f" {style}: {count} 次")
|
||||
|
||||
def show_sample_predictions(self, num_samples: int = 10) -> None:
|
||||
"""显示样本预测结果"""
|
||||
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
|
||||
|
||||
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
print(f"\n {i+1}. {status}")
|
||||
print(f" 聊天室: {pred['chat_id']}")
|
||||
print(f" 输入内容: {pred['up_content']}")
|
||||
print(f" 真实风格: {pred['true_style']}")
|
||||
print(f" 预测风格: {pred['predicted_style']}")
|
||||
if pred["scores"]:
|
||||
top_scores = dict(list(pred["scores"].items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
|
||||
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
|
||||
"""保存测试结果到文件"""
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("StyleLearner 数据库测试结果\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
f.write("数据统计:\n")
|
||||
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
|
||||
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
|
||||
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
|
||||
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
|
||||
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
|
||||
|
||||
f.write("模型性能:\n")
|
||||
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
|
||||
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
|
||||
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
|
||||
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
|
||||
|
||||
f.write("模型保存:\n")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
f.write(f" 保存状态: {save_status}\n")
|
||||
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
|
||||
|
||||
f.write("详细预测结果:\n")
|
||||
for i, pred in enumerate(self.test_results["predictions"]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
|
||||
|
||||
logger.info(f"测试结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存测试结果失败: {e}")
|
||||
|
||||
def run_test(self) -> None:
|
||||
"""运行完整测试"""
|
||||
logger.info("开始StyleLearner数据库测试...")
|
||||
|
||||
# 1. 加载数据
|
||||
raw_data = self.load_data_from_database()
|
||||
if not raw_data:
|
||||
logger.error("没有加载到数据,测试终止")
|
||||
return
|
||||
|
||||
# 2. 数据预处理
|
||||
processed_data = self.preprocess_data(raw_data)
|
||||
if not processed_data:
|
||||
logger.error("预处理后没有数据,测试终止")
|
||||
return
|
||||
|
||||
self.test_results["total_samples"] = len(processed_data)
|
||||
|
||||
# 3. 分割数据
|
||||
train_data, test_data = self.split_data(processed_data)
|
||||
|
||||
# 4. 训练模型
|
||||
self.train_model(train_data)
|
||||
|
||||
# 5. 测试模型
|
||||
self.test_model(test_data)
|
||||
|
||||
# 6. 分析结果
|
||||
self.analyze_results()
|
||||
|
||||
# 7. 显示样本预测
|
||||
self.show_sample_predictions(10)
|
||||
|
||||
# 8. 保存结果
|
||||
self.save_results()
|
||||
|
||||
logger.info("StyleLearner数据库测试完成!")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("StyleLearner 数据库测试脚本")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建测试实例
|
||||
test = StyleLearnerDatabaseTest(random_state=42)
|
||||
|
||||
# 运行测试
|
||||
test.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
view_pkl.py
Normal file
76
view_pkl.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
查看 .pkl 文件内容的工具脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
def view_pkl_file(file_path):
|
||||
"""查看 pkl 文件内容"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print(f"📊 数据类型: {type(data)}")
|
||||
print("=" * 50)
|
||||
|
||||
if isinstance(data, dict):
|
||||
print("🔑 字典键:")
|
||||
for key in data.keys():
|
||||
print(f" - {key}: {type(data[key])}")
|
||||
print()
|
||||
|
||||
print("📋 详细内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
elif isinstance(data, list):
|
||||
print(f"📝 列表长度: {len(data)}")
|
||||
if data:
|
||||
print(f"📊 第一个元素类型: {type(data[0])}")
|
||||
print("📋 前几个元素:")
|
||||
for i, item in enumerate(data[:3]):
|
||||
print(f" [{i}]: {item}")
|
||||
|
||||
else:
|
||||
print("📋 内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
|
||||
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
|
||||
print("\n" + "="*50)
|
||||
print("🔍 详细词汇统计 (token_counts):")
|
||||
token_counts = data['nb']['token_counts']
|
||||
for style_id, tokens in token_counts.items():
|
||||
print(f"\n📝 {style_id}:")
|
||||
if tokens:
|
||||
# 按词频排序显示前10个词
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
for word, count in sorted_tokens[:10]:
|
||||
print(f" '{word}': {count}")
|
||||
if len(sorted_tokens) > 10:
|
||||
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_pkl.py <pkl文件路径>")
|
||||
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_pkl_file(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
view_tokens.py
Normal file
63
view_tokens.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
专门查看 expressor.pkl 文件中 token_counts 的脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
|
||||
def view_token_counts(file_path):
|
||||
"""查看 expressor.pkl 文件中的词汇统计"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print("=" * 60)
|
||||
|
||||
if 'nb' not in data or 'token_counts' not in data['nb']:
|
||||
print("❌ 这不是一个 expressor 模型文件")
|
||||
return
|
||||
|
||||
token_counts = data['nb']['token_counts']
|
||||
candidates = data.get('candidates', {})
|
||||
|
||||
print(f"🎯 找到 {len(token_counts)} 个风格")
|
||||
print("=" * 60)
|
||||
|
||||
for style_id, tokens in token_counts.items():
|
||||
style_text = candidates.get(style_id, "未知风格")
|
||||
print(f"\n📝 {style_id}: {style_text}")
|
||||
print(f"📊 词汇数量: {len(tokens)}")
|
||||
|
||||
if tokens:
|
||||
# 按词频排序
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
print("🔤 词汇统计 (按频率排序):")
|
||||
for i, (word, count) in enumerate(sorted_tokens):
|
||||
print(f" {i+1:2d}. '{word}': {count}")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
print("-" * 40)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_tokens.py <expressor.pkl文件路径>")
|
||||
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_token_counts(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user