Merge branch 'Mai-with-u:dev' into dev

This commit is contained in:
infinitycat
2025-11-04 21:08:50 +08:00
committed by GitHub
58 changed files with 3808 additions and 1252 deletions

2
.gitignore vendored
View File

@@ -323,7 +323,7 @@ run_pet.bat
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin
!/plugins/deep_think
!/plugins/MaiFrequencyControl
!/plugins/ChatFrequency/
!/plugins/__init__.py
config.toml

View File

@@ -44,7 +44,7 @@
## 🔥 更新和安装
**最新版本: v0.10.3** ([更新日志](changelogs/changelog.md))
**最新版本: v0.11.0** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器

View File

@@ -1,21 +1,40 @@
# Changelog
## [0.11.0] - 2025-9-22
## [0.11.1] - 2025-11-4
### 功能更改和修复
- 记忆现在能够被遗忘,并且拥有更好的合并
- 修复部分llm请求问题
- 优化记忆提取
- 提供replyer的细节debug配置
## [0.11.0] - 2025-10-27
### 🌟 主要功能更改
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
- 麦麦好奇功能,麦麦会自主提出问题
- 添加deepthink插件默认关闭让麦麦可以深度思考一些问题
- 重构记忆系统,新的记忆系统更可靠,双通道查询,可以查询文本记忆和过去聊天记录
- 主动发言功能,麦麦会自主提出问题(可精细调控频率)
- 支持多重人格设定,可以随机切换成不同状态
- 新增表达方式学习新模式,更少的占用
- 添加表情包管理插件
- 现可更好的支持多平台
- 添加deepthink插件默认关闭让麦麦可以深度思考一些问题
- 现已内置BetterFrequency插件
### 细节功能更改
- 修复配置文件转义问题
- 情绪系统现在可以由配置文件控制开关
- 修复平行动作控制失效的问题
- 添加planner防抖防止短时间快速消耗token
- 优化planner历史状态记录
- 修复吞字问题
- 修复意外换行问题
- 移除VLM的token限制
- 为tool工具添加chat_id字段
- 更新依赖表
- 修复负载均衡
- 优化了对gemini和不同模型的支持
- 现统计模型名而不是模型标识符
- 修改默认推荐模型为ds v3.2
- 优化了对gemini和不同模型的支持优化了对gemini搜索的支持
## [0.10.3] - 2025-9-22
### 🌟 主要功能更改

View File

@@ -0,0 +1,50 @@
{
"manifest_version": 1,
"name": "发言频率控制插件|BetterFrequency Plugin",
"version": "2.0.0",
"description": "控制聊天频率支持设置focus_value和talk_frequency调整值提供命令",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.3"
},
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
"keywords": ["frequency", "control", "talk_frequency", "plugin", "shortcut"],
"categories": ["Chat", "Frequency", "Control"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "frequency",
"components": [
{
"type": "command",
"name": "set_talk_frequency",
"description": "设置当前聊天的talk_frequency调整值",
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
},
{
"type": "command",
"name": "show_frequency",
"description": "显示当前聊天的频率控制状态",
"pattern": "/chat show 或 /chat s"
}
],
"features": [
"设置talk_frequency调整值",
"调整当前聊天的发言频率",
"显示当前频率控制状态",
"实时频率控制调整",
"命令执行反馈(不保存消息)",
"支持完整命令和简化命令",
"快速操作支持"
]
}
}

View File

@@ -0,0 +1,150 @@
from typing import List, Tuple, Type, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField
)
from src.plugin_system.apis import send_api, frequency_api
class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值"""
command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups:
return False, "命令格式错误", False
value_str = self.matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
value = float(value_str)
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 设置talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value)
final_value = frequency_api.get_current_talk_value(chat_id)
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = final_value / adjust_value
# 发送反馈消息(不保存到数据库)
await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id,
storage_message=False
)
return True, None, False
except ValueError:
error_msg = "数值格式错误,请输入有效的数字"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
except Exception as e:
error_msg = f"设置talk_frequency失败: {str(e)}"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态"""
command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 获取当前频率控制状态
current_talk_frequency = frequency_api.get_current_talk_value(chat_id)
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = current_talk_frequency / talk_frequency_adjust
# 构建显示消息
status_msg = f"""当前聊天频率控制状态
Talk Value (发言频率):
• 基础值: {base_value:.2f}
• 发言频率调整: {talk_frequency_adjust:.2f}
• 当前值: {current_talk_frequency:.2f}
使用命令:
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
• /chat show 或 /chat s - 显示当前状态"""
# 发送状态消息(不保存到数据库)
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
return True, None, False
except Exception as e:
error_msg = f"获取频率控制状态失败: {str(e)}"
# 使用内置的send_text方法发送错误消息
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
# ===== 插件注册 =====
@register_plugin
class BetterFrequencyPlugin(BasePlugin):
"""BetterFrequency插件 - 控制聊天频率的插件"""
# 插件基本信息
plugin_name: str = "better_frequency_plugin"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"frequency": "频率控制配置",
"features": "功能开关配置"
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.2", description="插件版本"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
},
"frequency": {
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
components = []
# 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True):
components.extend([
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
])
return components

View File

@@ -177,6 +177,7 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
elif doc_item:
with open_ie_doc_lock:
open_ie_doc.append(doc_item)
logger.info('已处理"%s"', doc_item.get("passage", ""))
progress.update(task, advance=1)
except KeyboardInterrupt:
logger.info("\n接收到中断信号,正在优雅地关闭程序...")

View File

@@ -16,11 +16,12 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.memory_system.Memory_chest import global_memory_chest
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
@@ -236,7 +237,8 @@ class BrainChatting:
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@@ -414,7 +416,9 @@ class BrainChatting:
return False, "", ""
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.run()
# BaseAction 定义了异步方法 execute() 作为统一执行入口
# 这里调用 execute() 以兼容所有 Action 实现
result = await action_handler.execute()
success, action_text = result
command = ""

View File

@@ -249,6 +249,8 @@ class BrainPlanner:
# 获取必要信息
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
# 提及/被@ 的处理由心流或统一判定模块驱动Planner 不再做硬编码强制回复
# 应用激活类型过滤
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)

View File

@@ -379,7 +379,7 @@ class EmojiManager:
self._scan_task = None
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
self.llm_emotion_judge = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
@@ -940,16 +940,16 @@ class EmojiManager:
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
prompt, image_base64, "jpg", temperature=0.5
)
else:
prompt = (
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析"
)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
prompt, image_base64, image_format, temperature=0.5
)
# 审核表情包
@@ -970,13 +970,14 @@ class EmojiManager:
# 第二步LLM情感分析 - 基于详细描述生成情感标签列表
emotion_prompt = f"""
请你识别这个表情包的含义和适用场景给我简短的描述每个描述不要超过15个字
这是一个基于这个表情包的描述:'{description}'
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
这是一个聊天场景中的表情包描述:'{description}'
请你识别这个表情包的含义和适用场景给我简短的描述每个描述不要超过15个字
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
"""
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
emotion_prompt, temperature=0.7, max_tokens=600
emotion_prompt, temperature=0.7, max_tokens=256
)
# 处理情感列表

View File

@@ -1,316 +0,0 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, total_num)
else:
selected_style = []
return selected_style
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -1,5 +1,37 @@
from datetime import datetime
import time
from typing import Dict
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.plugin_system.apis import frequency_api
def init_prompt():
Prompt(
"""{name_block}
{time_block}
你现在正在聊天,请根据下面的聊天记录判断是否有用户觉得你的发言过于频繁或者发言过少
{message_str}
如果用户觉得你的发言过于频繁,请输出"过于频繁",否则输出"正常"
如果用户觉得你的发言过少,请输出"过少",否则输出"正常"
**你只能输出以下三个词之一,不要输出任何其他文字、解释或标点:**
- 正常
- 过于频繁
- 过少
""",
"frequency_adjust_prompt",
)
logger = get_logger("frequency_control")
class FrequencyControl:
"""简化的频率控制类仅管理不同chat_id的频率值"""
@@ -8,6 +40,11 @@ class FrequencyControl:
self.chat_id = chat_id
# 发言频率调整值
self.talk_frequency_adjust: float = 1.0
self.last_frequency_adjust_time: float = 0.0
self.frequency_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
)
def get_talk_frequency_adjust(self) -> float:
"""获取发言频率调整值"""
@@ -16,7 +53,72 @@ class FrequencyControl:
def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value))
async def trigger_frequency_adjust(self) -> None:
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=self.last_frequency_adjust_time,
timestamp_end=time.time(),
)
if time.time() - self.last_frequency_adjust_time < 120 or len(msg_list) <= 5:
return
else:
new_msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=self.last_frequency_adjust_time,
timestamp_end=time.time(),
limit=5,
limit_mode="latest",
)
message_str = build_readable_messages(
new_msg_list,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=False,
)
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
bot_name = global_config.bot.nickname
bot_nickname = (
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
)
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
prompt = await global_prompt_manager.format_prompt(
"frequency_adjust_prompt",
name_block=name_block,
time_block=time_block,
message_str=message_str,
)
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
prompt,
)
# logger.info(f"频率调整 prompt: {prompt}")
# logger.info(f"频率调整 response: {response}")
if global_config.debug.show_prompt:
logger.info(f"频率调整 prompt: {prompt}")
logger.info(f"频率调整 response: {response}")
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
# LLM依然输出过多内容时取消本次调整。合法最多4个字但有的模型可能会输出一些markdown换行符等需要长度宽限
if len(response) < 20:
if "过于频繁" in response:
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 0.8))
elif "过少" in response:
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
self.talk_frequency_adjust = max(0.1, min(3.0, self.talk_frequency_adjust * 1.2))
self.last_frequency_adjust_time = time.time()
else:
logger.info(f"频率调整response不符合要求取消本次调整")
class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例"""
@@ -41,6 +143,7 @@ class FrequencyControlManager:
"""获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys())
init_prompt()
# 创建全局实例
frequency_control_manager = FrequencyControlManager()

View File

@@ -18,10 +18,11 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.memory_system.question_maker import QuestionMaker
from src.memory_system.questions import global_conflict_tracker
from src.memory_system.curious import check_and_make_question
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
@@ -184,14 +185,12 @@ class HeartFChatting:
)
question_probability = 0
if time.time() - self.last_active_time > 3600:
question_probability = 0.01
elif time.time() - self.last_active_time > 1200:
question_probability = 0.005
elif time.time() - self.last_active_time > 600:
question_probability = 0.001
else:
if time.time() - self.last_active_time > 7200:
question_probability = 0.0003
elif time.time() - self.last_active_time > 3600:
question_probability = 0.0001
else:
question_probability = 0.00003
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
@@ -210,7 +209,7 @@ class HeartFChatting:
if question:
logger.info(f"{self.log_prefix} 问题: {question}")
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
await self._lift_question_reply(question,context,thinking_id)
else:
logger.info(f"{self.log_prefix} 无问题")
# self.end_cycle(cycle_timers, thinking_id)
@@ -331,9 +330,12 @@ class HeartFChatting:
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust())
await global_memory_chest.build_running_content(chat_id=self.stream_id)
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
asyncio.create_task(check_and_make_question(self.stream_id, recent_messages_list))
cycle_timers, thinking_id = self.start_cycle()
@@ -551,8 +553,8 @@ class HeartFChatting:
traceback.print_exc()
return False, ""
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
new_msg = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),

View File

@@ -1,7 +1,7 @@
import re
import traceback
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
@@ -17,31 +17,6 @@ if TYPE_CHECKING:
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
"""
if message.is_picid or message.is_emoji:
return 0.0, []
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# interested_rate = 0.0
keywords = []
message.interest_value = 1
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
return 1, keywords
class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
@@ -67,12 +42,16 @@ class HeartFCMessageReceiver:
userinfo = message.message_info.user_info
chat = message.chat_stream
# 2. 兴趣度计算与更新
_, keywords = await _calculate_interest(message)
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
await self.storage.store_message(message, chat)
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"

View File

@@ -149,8 +149,51 @@ class ChatBot:
async def handle_notice_message(self, message: MessageRecv):
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("notice消息")
print(message)
logger.debug("notice消息")
try:
seg = message.message_segment
mi = message.message_info
sub_type = None
scene = None
msg_id = None
recalled_id = None
if getattr(seg, "type", None) == "notify" and isinstance(getattr(seg, "data", None), dict):
sub_type = seg.data.get("sub_type")
scene = seg.data.get("scene")
msg_id = seg.data.get("message_id")
recalled = seg.data.get("recalled_user_info") or {}
if isinstance(recalled, dict):
recalled_id = recalled.get("user_id")
op = mi.user_info
gid = mi.group_info.group_id if mi.group_info else None
# 撤回事件打印;无法获取被撤回者则省略
if sub_type == "recall":
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None))
recalled_name = None
try:
if isinstance(recalled, dict):
recalled_name = (
recalled.get("user_cardname")
or recalled.get("user_nickname")
or str(recalled.get("user_id"))
)
except Exception:
pass
if recalled_name and str(recalled_id) != str(getattr(op, "user_id", None)):
logger.info(f"{op_name} 撤回了 {recalled_name} 的消息")
else:
logger.info(f"{op_name} 撤回了消息")
else:
logger.debug(
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) "
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
)
except Exception:
logger.info("[notice] (简略) 收到一条通知事件")
return True
@@ -215,12 +258,13 @@ class ChatBot:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if await self.handle_notice_message(message):
# return
pass
return
# 处理消息内容,生成纯文本
await message.process()
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
# 过滤检查
if _check_ban_words(
message.processed_plain_text,

View File

@@ -130,6 +130,16 @@ class MessageRecv(Message):
self.key_words = []
self.key_words_lite = []
# 兼容适配器通过 additional_config 传入的 @ 标记
try:
msg_info_dict = message_dict.get("message_info", {})
add_cfg = msg_info_dict.get("additional_config") or {}
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
# 标记为被提及,提高后续回复优先级
self.is_mentioned = True # type: ignore
except Exception:
pass
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream

View File

@@ -14,8 +14,6 @@ from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
get_actions_by_timestamp_with_chat,
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
@@ -69,6 +67,7 @@ no_reply_until_call
动作描述:
保持沉默,直到有人直接叫你的名字
当前话题不感兴趣时使用,或有人不喜欢你的发言时使用
当你频繁选择no_reply时使用表示话题暂时与你无关
{{
"action": "no_reply_until_call",
}}
@@ -418,7 +417,6 @@ class ActionPlanner:
return filtered_actions
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
# sourcery skip: use-join
"""构建动作选项块"""
if not current_available_actions:
return ""

View File

@@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
@@ -133,7 +133,13 @@ class DefaultReplyer:
try:
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
logger.debug(f"replyer生成内容: {content}")
# logger.debug(f"replyer生成内容: {content}")
logger.info(f"replyer生成内容: {content}")
if global_config.debug.show_replyer_reasoning:
logger.info(f"replyer生成推理:\n{reasoning_content}")
logger.info(f"replyer生成模型: {model_name}")
llm_response.content = content
llm_response.reasoning = reasoning_content
llm_response.model = model_name
@@ -238,8 +244,8 @@ class DefaultReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)
@@ -480,6 +486,31 @@ class DefaultReplyer:
) -> Tuple[str, str]:
"""
Args:
message_list_before_now: 历史消息列表
target_user_id: 目标用户ID当前对话对象
Returns:
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt = build_readable_messages(
latest_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
)
return all_dialogue_prompt
def core_background_build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
Args:
message_list_before_now: 历史消息列表
target_user_id: 目标用户ID当前对话对象
@@ -529,7 +560,7 @@ class DefaultReplyer:
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解
{core_dialogue_prompt_str}
--------------------------------
"""
@@ -594,7 +625,18 @@ class DefaultReplyer:
else:
bot_nickname = ""
prompt_personality = f"{global_config.personality.personality};"
# 获取基础personality
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (global_config.personality.states and
global_config.personality.state_probability > 0 and
random.random() < global_config.personality.state_probability):
# 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state
prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def build_prompt_reply_context(
@@ -731,7 +773,7 @@ class DefaultReplyer:
continue
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
if duration > 8:
if duration > 12:
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s请使用更快的模型")
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
@@ -776,9 +818,7 @@ class DefaultReplyer:
reply_target_block = ""
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
message_list_before_now_long, user_id, sender
)
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
return await global_prompt_manager.format_prompt(
"replyer_prompt",
@@ -793,9 +833,8 @@ class DefaultReplyer:
identity=personality_prompt,
action_descriptions=actions_info,
sender_name=sender,
background_dialogue_prompt=background_dialogue_prompt,
dialogue_prompt=dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
@@ -960,9 +999,9 @@ class DefaultReplyer:
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 直接使用已初始化的模型实例
logger.info(f"\n{prompt}\n")
# logger.info(f"\n{prompt}\n")
if global_config.debug.show_prompt:
if global_config.debug.show_replyer_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
@@ -971,6 +1010,9 @@ class DefaultReplyer:
prompt
)
# 移除 content 前后的换行符和空格
content = content.strip()
logger.info(f"使用 {model_name} 生成回复内容: {content}")
return content, reasoning_content, model_name, tool_calls

View File

@@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.mood.mood_manager import mood_manager
@@ -256,8 +256,8 @@ class PrivateReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)
@@ -522,7 +522,18 @@ class PrivateReplyer:
else:
bot_nickname = ""
prompt_personality = f"{global_config.personality.personality};"
# 获取基础personality
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (global_config.personality.states and
global_config.personality.state_probability > 0 and
random.random() < global_config.personality.state_probability):
# 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state
prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def build_prompt_reply_context(
@@ -668,7 +679,7 @@ class PrivateReplyer:
continue
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
if duration > 8:
if duration > 12:
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s请使用更快的模型")
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
@@ -911,7 +922,7 @@ class PrivateReplyer:
# 直接使用已初始化的模型实例
logger.info(f"\n{prompt}\n")
if global_config.debug.show_prompt:
if global_config.debug.show_replyer_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
@@ -919,8 +930,12 @@ class PrivateReplyer:
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
content = content.strip()
logger.info(f"使用 {model_name} 生成回复内容: {content}")
if global_config.debug.show_replyer_reasoning:
logger.info(f"使用 {model_name} 生成回复推理:\n{reasoning_content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):

View File

@@ -17,8 +17,7 @@ def init_replyer_prompt():
你正在qq群里聊天下面是群里正在聊的内容:
{time_block}
{background_dialogue_prompt}
{core_dialogue_prompt}
{dialogue_prompt}
{reply_target_block}
{identity}
@@ -26,7 +25,7 @@ def init_replyer_prompt():
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。请不要思考太长
现在,你说:""",
"replyer_prompt",
)

View File

@@ -24,7 +24,7 @@ def init_rewrite_prompt():
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
{keywords_reaction_prompt}
{moderation_prompt}
不要输出多余内容(包括冒号和引号表情包emoji,at或 @等 ),只输出一条回复就好。
不要输出多余内容(包括冒号和引号表情包emoji,at或 @等 ),只输出一条回复就好。不要思考的太长。
改写后的回复:
""",
"default_expressor_prompt",

View File

@@ -43,9 +43,12 @@ def replace_user_references(
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
# 检查是否是机器人自己(支持多平台)
if replace_bot_name:
if platform == "qq" and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_id # type: ignore
@@ -92,6 +95,8 @@ def replace_user_references(
new_content += content[last_end:]
content = new_content
# Telegram 文本 @username 的显示映射交由适配器或平台层处理;此处不做硬编码替换
return content
@@ -432,7 +437,10 @@ def _build_readable_messages_internal(
person_name = (
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
)
if replace_bot_name and user_id == global_config.bot.qq_account:
if replace_bot_name and (
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
):
person_name = f"{global_config.bot.nickname}(你)"
# 使用独立函数处理用户引用格式
@@ -866,7 +874,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
if user_id == global_config.bot.qq_account:
if (platform == "qq" and user_id == global_config.bot.qq_account) or (
platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
):
# print("SELF11111111111111")
return "SELF"
try:

View File

@@ -334,7 +334,7 @@ class StatisticOutputTask(AsyncTask):
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
model_name = record.model_assign_name or record.model_name or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
@@ -492,10 +492,15 @@ class StatisticOutputTask(AsyncTask):
continue
# Update name_mapping
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
try:
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
except (IndexError, TypeError) as e:
logger.warning(f"更新 name_mapping 时发生错误chat_id: {chat_id}, 错误: {e}")
# 重置为正确的格式
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start_dt) in enumerate(collect_period):
@@ -514,15 +519,32 @@ class StatisticOutputTask(AsyncTask):
last_all_time_stat = None
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
try:
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
# 修复 name_mapping 数据类型不匹配问题
# JSON 中存储为列表,但代码期望为元组
raw_name_mapping = last_stat["name_mapping"]
self.name_mapping = {}
for chat_id, value in raw_name_mapping.items():
if isinstance(value, list) and len(value) == 2:
# 将列表转换为元组
self.name_mapping[chat_id] = (value[0], value[1])
elif isinstance(value, tuple) and len(value) == 2:
# 已经是元组,直接使用
self.name_mapping[chat_id] = value
else:
# 数据格式不正确,跳过或使用默认值
logger.warning(f"name_mapping 中 chat_id {chat_id} 的数据格式不正确: {value}")
continue
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
except Exception as e:
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
stat_start_timestamp = [(period[0], now - period[1]) for period in self.stat_period]
@@ -571,8 +593,14 @@ class StatisticOutputTask(AsyncTask):
# 更新上次完整统计数据的时间戳
# 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
# 将 name_mapping 中的元组转换为列表因为JSON不支持元组
json_safe_name_mapping = {}
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
local_storage["last_full_statistics"] = {
"name_mapping": self.name_mapping,
"name_mapping": json_safe_name_mapping,
"stat_data": clean_stat_data,
"timestamp": now.timestamp(),
}
@@ -651,10 +679,13 @@ class StatisticOutputTask(AsyncTask):
if stats[TOTAL_MSG_CNT] <= 0:
return ""
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
output.extend(
f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}"
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items())
)
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
try:
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
output.append(f"{chat_name[:32]:<32} {count:>10}")
except (IndexError, TypeError) as e:
logger.warning(f"格式化聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
output.append(f"{'未知聊天':<32} {count:>10}")
output.append("")
return "\n".join(output)
@@ -770,14 +801,16 @@ class StatisticOutputTask(AsyncTask):
)
# 聊天消息统计
chat_rows = "\n".join(
[
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
]
if stat_data[MSG_CNT_BY_CHAT]
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
chat_rows = []
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()):
try:
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
chat_rows.append(f"<tr><td>{chat_name}</td><td>{count}</td></tr>")
except (IndexError, TypeError) as e:
logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
chat_rows.append(f"<tr><td>未知聊天</td><td>{count}</td></tr>")
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
# 生成HTML
return f"""
<div id=\"{div_id}\" class=\"tab-content\">
@@ -824,7 +857,7 @@ class StatisticOutputTask(AsyncTask):
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
</thead>
<tbody>
{chat_rows}
{chat_rows_html}
</tbody>
</table>
@@ -975,7 +1008,7 @@ class StatisticOutputTask(AsyncTask):
}}
// 聊天消息分布饼图
const chatLabels = {[self.name_mapping[chat_id][0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
const chatLabels = {[self.name_mapping.get(chat_id, ("未知聊天", 0))[0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
if (chatLabels.length > 0) {{
const chatData = {{
labels: chatLabels,
@@ -1233,7 +1266,7 @@ class StatisticOutputTask(AsyncTask):
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
model_name = record.model_assign_name or record.model_name or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost

View File

@@ -30,76 +30,146 @@ def is_english_letter(char: str) -> bool:
return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
logger.debug(f"result: {result}")
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射
Args:
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
Returns:
字典,键为平台名,值为账号
"""
result = {}
for platform_entry in platforms:
if ":" in platform_entry:
platform_name, account = platform_entry.split(":", 1)
result[platform_name.strip()] = account.strip()
return result
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号
Args:
platform: 当前消息的平台
platform_accounts: 从 platforms 列表解析的平台账号映射
qq_account: QQ 账号(兼容旧配置)
Returns:
当前平台对应的账号
"""
if platform == "qq":
return qq_account
elif platform == "telegram":
# 优先使用 tg其次使用 telegram
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
else:
# 其他平台直接使用平台名作为键
return platform_accounts.get(platform, "")
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.bot.nickname] + list(global_config.bot.alias_names)
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
nickname = str(global_config.bot.nickname or "")
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
keywords = [nickname] + alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
# 这部分怎么处理啊啊啊啊
# 我觉得可以给消息加一个 reply_probability_boost字段
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
# 1) 直接的 additional_config 标记
add_cfg = getattr(message.message_info, "additional_config", None) or {}
if isinstance(add_cfg, dict):
if add_cfg.get("at_bot") or add_cfg.get("is_mentioned"):
is_mentioned = True
# 当提供数值型 is_mentioned 时,当作概率提升
try:
if add_cfg.get("is_mentioned") not in (None, ""):
reply_probability = float(add_cfg.get("is_mentioned")) # type: ignore
except Exception:
pass
# 2) 已经在上游设置过的 message.is_mentioned
if getattr(message, "is_mentioned", False):
is_mentioned = True
# 3) 扫描分段:是否包含 mention_bot适配器插入
def _has_mention_bot(seg) -> bool:
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, is_at, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
if seg is None:
return False
if getattr(seg, "type", None) == "mention_bot":
return True
if getattr(seg, "type", None) == "seglist":
for s in getattr(seg, "data", []) or []:
if _has_mention_bot(s):
return True
return False
except Exception:
return False
for keyword in keywords:
if keyword in message.processed_plain_text:
is_mentioned = True
# 判断是否被@
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
if _has_mention_bot(getattr(message, "message_segment", None)):
is_at = True
is_mentioned = True
if is_at and global_config.chat.at_bot_inevitable_reply:
# 4) 统一的 @ 检测逻辑
if current_account and not is_at and not is_mentioned:
if platform == "qq":
# QQ 格式: @<name:qq_id>
if re.search(rf"@<(.+?):{re.escape(current_account)}>", text):
is_at = True
is_mentioned = True
else:
# 其他平台格式: @username 或 @account
if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE):
is_at = True
is_mentioned = True
# 5) 统一的回复检测逻辑
if not is_mentioned:
# 通用回复格式:包含 "(你)" 或 "(你)"
if re.search(r"\[回复 .*?\(你\)", text) or re.search(r"\[回复 .*?(你):", text):
is_mentioned = True
# ID 形式的回复检测
elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text):
is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
if not is_mentioned and keywords:
msg_content = text
# 去除各种 @ 与 回复标记,避免误判
msg_content = re.sub(r"@(.+?)(\d+)", "", msg_content)
msg_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", msg_content)
msg_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id|你)\)(.+?)\],说:", "", msg_content)
msg_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", msg_content)
for kw in keywords:
if kw and kw in msg_content:
is_mentioned = True
break
# 7) 概率设置
if is_at and getattr(global_config.chat, "at_bot_inevitable_reply", 1):
reply_probability = 1.0
logger.debug("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\)(.+?)\],说:", message.processed_plain_text
) or re.match(
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>(.+?)\],说:",
message.processed_plain_text,
):
is_mentioned = True
else:
# 判断内容中是否被提及
message_content = re.sub(r"@(.+?)(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
if is_mentioned and global_config.chat.mentioned_bot_reply:
reply_probability = 1.0
logger.debug("被提及回复概率设置为100%")
elif is_mentioned and getattr(global_config.chat, "mentioned_bot_reply", 1):
reply_probability = max(reply_probability, 1.0)
logger.debug("被提及回复概率设置为100%")
return is_mentioned, is_at, reply_probability
@@ -115,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
who_chat_in_group = []
for db_msg in recent_messages:
# user_info = UserInfo.from_dict(
# {
# "platform": msg_db_data["user_platform"],
# "user_id": msg_db_data["user_id"],
# "user_nickname": msg_db_data["user_nickname"],
# "user_cardname": msg_db_data.get("user_cardname", ""),
# }
# )
# if (
# (user_info.platform, user_info.user_id) != sender
# and user_info.user_id != global_config.bot.qq_account
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
# and len(who_chat_in_group) < 5
# ): # 排除重复排除消息发送者排除bot限制加载的关系数目
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
if (
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append(
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
)
return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
@@ -410,42 +441,6 @@ def calculate_typing_time(
return total_time # 加上回车时间
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量
text_vector = text_to_vector(text)
# 计算每个主题的相似度
similarities = []
for topic in topics:
topic_vector = text_to_vector(topic)
# 获取所有唯一词
all_words = set(text_vector.keys()) | set(topic_vector.keys())
# 构建向量
v1 = [text_vector.get(word, 0) for word in all_words]
v2 = [topic_vector.get(word, 0) for word in all_words]
# 计算相似度
similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
@@ -523,47 +518,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars)
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
"""计算两个时间点之间的消息数量和文本总长度
Args:
start_time (float): 起始时间戳 (不包含)
end_time (float): 结束时间戳 (包含)
stream_id (str): 聊天流ID
Returns:
tuple[int, int]: (消息数量, 文本总长度)
"""
count = 0
total_length = 0
# 参数校验 (可选但推荐)
if start_time >= end_time:
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
return 0, 0
if not stream_id:
logger.error("stream_id 不能为空")
return 0, 0
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = count_messages(filter_query)
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
return count, total_length
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
@@ -698,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data
return result
# def assign_message_ids_flexible(
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
# ) -> list:
# """
# 为消息列表中的每个消息分配唯一的简短随机ID增强版
# Args:
# messages: 消息列表
# prefix: ID前缀默认为"msg"
# id_length: ID的总长度不包括前缀默认为6
# use_timestamp: 是否在ID中包含时间戳默认为False
# Returns:
# 包含 {'id': str, 'message': any} 格式的字典列表
# """
# result = []
# used_ids = set()
# for i, message in enumerate(messages):
# # 生成唯一的ID
# while True:
# if use_timestamp:
# # 使用时间戳的后几位 + 随机字符
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
# remaining_length = id_length - 3
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
# else:
# # 使用索引 + 随机字符
# index_str = str(i + 1)
# remaining_length = max(1, id_length - len(index_str))
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{index_str}{random_chars}"
# if message_id not in used_ids:
# used_ids.add(message_id)
# break
# result.append({"id": message_id, "message": message})
# return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""

View File

@@ -44,6 +44,11 @@ class ImageManager:
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
try:
self._cleanup_invalid_descriptions()
except Exception as e:
logger.warning(f"数据库清理失败: {e}")
self._initialized = True
def _ensure_image_dir(self):
@@ -92,6 +97,26 @@ class ImageManager:
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
@staticmethod
def _cleanup_invalid_descriptions():
"""清理数据库中 description 为空或为 'None' 的记录"""
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = Images.delete().where(
(Images.description >> None) | (Images.description << invalid_values)
).execute()
# 清理 ImageDescriptions 表
deleted_descriptions = ImageDescriptions.delete().where(
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
).execute()
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")
else:
logger.info("[清理完成] 未发现无效描述记录")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -273,7 +298,7 @@ class ImageManager:
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
prompt, image_base64, image_format, temperature=0.4
)
if description is None:
@@ -570,7 +595,7 @@ class ImageManager:
# 获取VLM描述
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
prompt, image_base64, image_format, temperature=0.4
)
if description is None:

View File

@@ -303,15 +303,13 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
count = FloatField()
# new mode fields
context = TextField(null=True)
context_words = TextField(null=True)
up_content = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)
type = TextField()
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
class Meta:

View File

@@ -55,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.0-snapshot.3"
MMC_VERSION = "0.11.1-snapshot.1"
def get_key_comment(toml_table, key):

View File

@@ -27,6 +27,9 @@ class BotConfig(ConfigBase):
nickname: str
"""昵称"""
platforms: list[str] = field(default_factory=lambda: [])
"""其他平台列表"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""
@@ -54,6 +57,12 @@ class PersonalityConfig(ConfigBase):
private_plan_style: str = ""
"""私聊说话规则,行为风格"""
states: list[str] = field(default_factory=lambda: [])
"""状态列表用于随机替换personality"""
state_probability: float = 0.0
"""状态概率每次构建人格时替换personality的概率"""
@dataclass
class RelationshipConfig(ConfigBase):
@@ -82,6 +91,9 @@ class ChatConfig(ConfigBase):
auto_chat_value: float = 1
"""自动聊天,越小,麦麦主动聊天的概率越低"""
enable_auto_chat_value_rules: bool = True
"""是否启用动态自动聊天频率规则"""
at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅"""
@@ -91,6 +103,9 @@ class ChatConfig(ConfigBase):
talk_value: float = 1
"""思考频率"""
enable_talk_value_rules: bool = True
"""是否启用动态发言频率规则"""
talk_value_rules: list[dict] = field(default_factory=lambda: [])
"""
思考频率规则列表,支持按聊天流/按日内时段配置。
@@ -177,7 +192,7 @@ class ChatConfig(ConfigBase):
def get_talk_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 talk_value未匹配则回退到基础值。"""
if not self.talk_value_rules:
if not self.enable_talk_value_rules or not self.talk_value_rules:
return self.talk_value
now_min = self._now_minutes()
@@ -232,7 +247,7 @@ class ChatConfig(ConfigBase):
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 auto_chat_value未匹配则回退到基础值。"""
if not self.auto_chat_value_rules:
if not self.enable_auto_chat_value_rules or not self.auto_chat_value_rules:
return self.auto_chat_value
now_min = self._now_minutes()
@@ -310,8 +325,8 @@ class MemoryConfig(ConfigBase):
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: Literal["llm", "context", "full-context"] = "context"
"""表达方式模式,可选:llm模式context上下文模式full-context 完整上下文嵌入模式"""
mode: str = "classic"
"""表达方式模式,可选:classic经典模式exp_model 表达模型模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""
@@ -626,6 +641,12 @@ class DebugConfig(ConfigBase):
show_prompt: bool = False
"""是否显示prompt"""
show_replyer_prompt: bool = True
"""是否显示回复器prompt"""
show_replyer_reasoning: bool = True
"""是否显示回复器推理"""
@dataclass

View File

@@ -0,0 +1,93 @@
import re
import difflib
import random
from datetime import datetime
from typing import Optional, List, Dict
from collections import defaultdict
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
# 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content)
# 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content)
# 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
return selected

View File

@@ -1,10 +1,8 @@
import time
import random
import json
import os
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
from typing import List, Optional, Tuple
import traceback
from src.common.logger import get_logger
from src.common.database.database_model import Expression
@@ -17,26 +15,16 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, calculate_similarity
from json_repair import repair_json
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 15 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
@@ -105,8 +93,8 @@ class ExpressionLearner:
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 150 / self.learning_intensity
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 300 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
@@ -139,7 +127,7 @@ class ExpressionLearner:
return True
async def trigger_learning_for_chat(self) -> bool:
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
@@ -150,11 +138,10 @@ class ExpressionLearner:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
return
try:
logger.info(f"聊天流 {self.chat_name} 触发表达学习")
logger.info(f"聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
@@ -163,161 +150,105 @@ class ExpressionLearner:
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return False
return
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0最近活跃的不衰减
当时间差为7天时衰减值为0.002中等衰减
当时间差为30天或更长时衰减值为0.01高衰减
使用二次函数进行曲线插值
"""
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS**2)
decay = a * (time_diff_days**2)
return min(0.01, decay)
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
"""
res = await self.learn_expression(num)
learnt_expressions = await self.learn_expression(num)
if res is None:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
learnt_expressions = res
# 展示学到的表达方式
learnt_expressions_str = ""
for (
_chat_id,
situation,
style,
_context,
_context_words,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for (
chat_id,
situation,
style,
context,
context_words,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
}
)
current_time = time.time()
# 存储到数据库 Expression 表
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
)
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
situation,
style,
context,
up_content,
) in learnt_expressions:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
continue
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在)
try:
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
@@ -339,8 +270,8 @@ class ExpressionLearner:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
@@ -393,24 +324,44 @@ class ExpressionLearner:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
# 确保 match_responses 是一个列表
if not isinstance(match_responses, list):
if isinstance(match_responses, dict):
match_responses = [match_responses]
else:
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
return []
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}")
for match_response in match_responses:
try:
# 检查 match_response 的类型
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response["context"]
context = match_response.get("context", "")
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError) as e:
except (ValueError, KeyError, IndexError, TypeError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
@@ -418,15 +369,12 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
type_str = "语言风格"
prompt = "learn_style_prompt"
current_time = time.time()
# 获取上次学习之后的消息
@@ -439,14 +387,14 @@ class ExpressionLearner:
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
"learn_style_prompt",
chat_str=random_msg_str,
)
@@ -456,49 +404,51 @@ class ExpressionLearner:
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
split_matched_expressions_w_emb = []
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words)
)
return split_matched_expressions_w_emb
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词
Args:
matched_expressions: 匹配到的表达方式列表每个元素为(situation, style, context)
Returns:
添加了分词结果的表达方式列表每个元素为(situation, style, context, context_words)
"""
result = []
# 为每条消息构建精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 使用jieba进行分词
context_words = list(jieba.cut(context))
result.append((situation, style, context, context_words))
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
# 上一句为空,跳过该表达
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
return filtered_with_up
return result
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
@@ -530,6 +480,26 @@ class ExpressionLearner:
expressions.append((situation, style))
return expressions
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines
init_prompt()

View File

@@ -0,0 +1,446 @@
import json
import time
import random
import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, weighted_sample
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, filtered_target_message, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.situation == situation) &
(Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message
})
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, target_message, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
else:
logger.warning(f"未知的表达模式: {expression_mode}回退到classic模式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_model_only(
self,
chat_id: str,
target_message: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用随机抽样选择表达方式
style_exprs = self._random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式更新last_active_time
if valid_expressions:
self.update_expressions_last_active_time(valid_expressions)
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -0,0 +1,141 @@
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排(可在线学习)
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
# 根据k参数限制返回的候选数量
if k is not None and k > 0:
# 按分数降序排序取前k个
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
# 如果没有指定k返回所有分数
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer

View File

@@ -0,0 +1,60 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)

View File

@@ -0,0 +1,31 @@
import re
from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
self.use_jieba = use_jieba and _HAS_JIEBA
def tokenize(self, text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
if self.use_jieba:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else:
toks = simple_en_tokenize(text)
# 过滤掉纯符号和停用词
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]

View File

@@ -0,0 +1,628 @@
"""
多聊天室表达风格学习系统
支持为每个chat_id维护独立的表达模型学习从up_content到style的映射
"""
import os
import pickle
import traceback
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from .expressor_model.model import ExpressorModel
logger = get_logger("style_learner")
class StyleLearner:
"""
单个聊天室的表达风格学习器
学习从up_content到style的映射关系
支持动态管理风格集合最多2000个
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
(f", situation: '{situation}'" if situation else ""))
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
try:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
try:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
success_count = 0
for i, style in enumerate(styles):
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
result = {}
for style, style_id in self.style_to_id.items():
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减(遗忘)
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
"chat_id": self.chat_id,
"total_samples": self.learning_stats["total_samples"],
"style_count": len(self.style_to_id),
"max_styles": self.max_styles,
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys())
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"max_styles": self.max_styles,
"learning_stats": self.learning_stats
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
self.id_to_style = save_data["id_to_style"]
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.max_styles = save_data.get("max_styles", 2000)
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
class StyleLearnerManager:
"""
多聊天室表达风格学习管理器
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合最多2000个
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
chat_id = filename.replace("_style_model.pkl", "")
learner = StyleLearner(chat_id)
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
self._auto_save_task.cancel()
try:
await self._auto_save_task
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
self.save_all_models()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动保存失败: {e}")
# 全局管理器实例
style_learner_manager = StyleLearnerManager()

View File

@@ -13,6 +13,7 @@ from google.genai.types import (
ContentUnion,
ThinkingConfig,
Tool,
GoogleSearch,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
@@ -176,19 +177,21 @@ def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
resp: APIResponse | None = None,
):
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
# 处理 thoughtGemini 的特殊字段)
for c in getattr(delta, "candidates", []):
if c.content and getattr(c.content, "parts", None):
for p in c.content.parts:
if getattr(p, "thought", False) and getattr(p, "text", None):
# 把 thought 写入 buffer避免 resp.content 永远为空
# 保存到 reasoning_content
if resp is not None:
resp.reasoning_content = (resp.reasoning_content or "") + p.text
elif getattr(p, "text", None):
# 正常输出写入 buffer
fc_delta_buffer.write(p.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
@@ -213,9 +216,11 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
resp: APIResponse | None = None,
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if resp is None:
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
@@ -240,11 +245,15 @@ def _build_stream_api_resp(
# 检查是否因为 max_tokens 截断
reason = None
if last_resp and getattr(last_resp, "candidates", None):
c0 = last_resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
for c in last_resp.candidates:
fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None)
if fr:
reason = str(fr)
break
if str(reason).endswith("MAX_TOKENS"):
if resp.content and resp.content.strip():
has_visible_output = bool(resp.content and resp.content.strip())
if has_visible_output:
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
@@ -253,7 +262,8 @@ def _build_stream_api_resp(
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
if not getattr(resp, "reasoning_content", None):
raise EmptyResponseException()
return resp
@@ -271,7 +281,8 @@ async def _default_stream_response_handler(
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
resp = APIResponse()
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
@@ -287,6 +298,7 @@ async def _default_stream_response_handler(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
resp=resp,
)
if chunk.usage_metadata:
@@ -302,6 +314,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer,
_tool_calls_buffer,
last_resp=last_resp,
resp=resp,
), _usage_record
except Exception:
# 确保缓冲区被关闭
@@ -526,6 +539,15 @@ class GeminiClient(BaseClient):
tools = _convert_tool_options(tool_options) if tool_options else None
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 检测是否为带 -search 的模型
enable_google_search = False
model_identifier = model_info.model_identifier
if model_identifier.endswith("-search"):
enable_google_search = True
# 去掉后缀并更新模型ID
model_identifier = model_identifier.removesuffix("-search")
model_info.model_identifier = model_identifier
logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}")
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
@@ -548,6 +570,17 @@ class GeminiClient(BaseClient):
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
# 自动启用 GoogleSearch grounding_tool
if enable_google_search:
grounding_tool = Tool(google_search=GoogleSearch())
if "tools" in generation_config_dict:
existing = generation_config_dict["tools"]
if isinstance(existing, list):
existing.append(grounding_tool)
else:
generation_config_dict["tools"] = [existing, grounding_tool]
else:
generation_config_dict["tools"] = [grounding_tool]
generation_config = GenerateContentConfig(**generation_config_dict)

View File

@@ -199,6 +199,7 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
finish_reason: str | None = None,
) -> APIResponse:
resp = APIResponse()
@@ -236,6 +237,9 @@ def _build_stream_api_resp(
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
# 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出)
# 保留 finish_reason 仅用于上层判断
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
@@ -258,6 +262,8 @@ async def _default_stream_response_handler(
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
finish_reason: str | None = None # 记录最后的 finish_reason
_model_name: str | None = None # 记录模型名
def _insure_buffer_closed():
# 确保缓冲区被关闭
@@ -285,6 +291,12 @@ async def _default_stream_response_handler(
continue # 跳过本帧,避免访问 choices[0]
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
finish_reason = event.choices[0].finish_reason
if hasattr(event, "model") and event.model and not _model_name:
_model_name = event.model # 记录模型名
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
@@ -307,11 +319,34 @@ async def _default_stream_response_handler(
)
try:
return _build_stream_api_resp(
resp = _build_stream_api_resp(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
), _usage_record
finish_reason=finish_reason,
)
# 统一在这里输出 max_tokens 截断的警告,并从 resp 中读取
if finish_reason == "length":
# 把模型名塞到 resp.raw_data后续严格“从 resp 提取”
try:
if _model_name:
resp.raw_data = {"model": _model_name}
except Exception:
pass
model_dbg = None
try:
if isinstance(resp.raw_data, dict):
model_dbg = resp.raw_data.get("model")
except Exception:
model_dbg = None
# 统一日志格式
logger.info(
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (model_dbg or "")
)
return resp, _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
@@ -335,9 +370,32 @@ def _default_normal_response_parser(
"""
api_response = APIResponse()
if not hasattr(resp, "choices") or len(resp.choices) == 0:
raise EmptyResponseException("响应解析失败缺失choices字段或choices列表为空")
message_part = resp.choices[0].message
# 兼容部分 OpenAI 兼容服务在空回复时返回 choices=None 的情况
choices = getattr(resp, "choices", None)
if not choices:
try:
model_dbg = getattr(resp, "model", None)
id_dbg = getattr(resp, "id", None)
usage_dbg = None
if hasattr(resp, "usage") and resp.usage:
usage_dbg = {
"prompt": getattr(resp.usage, "prompt_tokens", None),
"completion": getattr(resp.usage, "completion_tokens", None),
"total": getattr(resp.usage, "total_tokens", None),
}
try:
raw_snippet = str(resp)[:300]
except Exception:
raw_snippet = "<unserializable>"
logger.debug(
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
)
except Exception:
# 日志采集失败不应影响控制流
pass
# 统一抛出可重试的 EmptyResponseException触发上层重试逻辑
raise EmptyResponseException("响应解析失败choices 为空或缺失")
message_part = choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
# 有有效的推理字段
@@ -381,6 +439,22 @@ def _default_normal_response_parser(
# 将原始响应存储在原始数据中
api_response.raw_data = resp
# 检查 max_tokens 截断
try:
choice0 = resp.choices[0]
reason = getattr(choice0, "finish_reason", None)
if reason and reason == "length":
print(resp)
_model_name = resp.model
# 统一日志格式
logger.info(
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (_model_name or "")
)
return api_response, _usage_record
except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException()
@@ -489,7 +563,7 @@ class OpenaiClient(BaseClient):
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
# logger.
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
# logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")

View File

@@ -29,12 +29,19 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
:return: 转换后的图片数据
"""
try:
image = Image.open(image_data)
image = Image.open(io.BytesIO(image_data))
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
# 静态图像转换为JPEG格式
# 仅在非动图时进行格式转换
if (
not getattr(image, "is_animated", False)
and image.format
and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"])
):
reformated_image_data = io.BytesIO()
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
img_to_save = image
if img_to_save.mode in ("RGBA", "LA", "P"):
img_to_save = img_to_save.convert("RGB")
img_to_save.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
image_data = reformated_image_data.getvalue()
return image_data
@@ -50,20 +57,22 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
:return: 缩放后的图片数据
"""
try:
image = Image.open(image_data)
image = Image.open(io.BytesIO(image_data))
# 原始尺寸
original_size = (image.width, image.height)
# 计算新的尺寸
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
# 计算新的尺寸防止为0
new_w = max(1, int(original_size[0] * scale))
new_h = max(1, int(original_size[1] * scale))
new_size = (new_w, new_h)
output_buffer = io.BytesIO()
if getattr(image, "is_animated", False):
# 动态图片,处理所有帧
frames = []
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
new_size = (max(1, new_size[0] // 2), max(1, new_size[1] // 2)) # 动图,缩放尺寸再打折
for frame_idx in range(getattr(image, "n_frames", 1)):
image.seek(frame_idx)
new_frame = image.copy()
@@ -83,6 +92,8 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
if resized_image.mode in ("RGBA", "LA", "P"):
resized_image = resized_image.convert("RGB")
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
return output_buffer.getvalue(), original_size, new_size

View File

@@ -270,13 +270,28 @@ class LLMRequest:
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except (EmptyResponseException, NetworkConnectionError) as e:
except EmptyResponseException as e:
# 空回复:通常为临时问题,单独记录并重试
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}'用尽对临时错误的重试次数后仍然失败。")
logger.error(f"模型 '{model_info.name}'多次出现空回复后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 遇到可重试错误: {str(e)}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
# 网络错误:单独记录并重试
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except RespNotOkException as e:
@@ -369,8 +384,8 @@ class LLMRequest:
failed_models_this_request.add(model_info.name)
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
logger.error("收到不可恢复的客户端错误 (400)中止所有尝试")
raise last_exception from e
logger.warning("收到客户端错误 (400)跳过当前模型并继续尝试其他模型")
continue
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
if last_exception:

View File

@@ -18,7 +18,7 @@ from .memory_utils import (
find_best_matching_memory,
check_title_exists_fuzzy,
get_all_titles,
get_memory_titles_by_chat_id_weighted,
find_most_similar_memory_by_chat_id,
)
@@ -41,6 +41,80 @@ class MemoryChest:
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
def remove_one_memory_by_age_weight(self) -> bool:
"""
删除一条记忆:按“越老/越新更易被删”的权重随机选择(老=较小id新=较大id
返回:是否删除成功
"""
try:
memories = list(MemoryChestModel.select())
if not memories:
return False
# 排除锁定项
candidates = [m for m in memories if not getattr(m, "locked", False)]
if not candidates:
return False
# 按 id 排序,使用 id 近似时间顺序(小 -> 老,大 -> 新)
candidates.sort(key=lambda m: m.id)
n = len(candidates)
if n == 1:
MemoryChestModel.delete().where(MemoryChestModel.id == candidates[0].id).execute()
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样){candidates[0].title}")
return True
# 计算U型权重中间最低两端最高
# r ∈ [0,1] 为位置归一化w = 0.1 + 0.9 * (abs(r-0.5)*2)**1.5
weights = []
for idx, _m in enumerate(candidates):
r = idx / (n - 1)
w = 0.1 + 0.9 * (abs(r - 0.5) * 2) ** 1.5
weights.append(w)
import random as _random
selected = _random.choices(candidates, weights=weights, k=1)[0]
MemoryChestModel.delete().where(MemoryChestModel.id == selected.id).execute()
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样){selected.title}")
return True
except Exception as e:
logger.error(f"[记忆管理] 按年龄权重删除记忆时出错: {e}")
return False
def _compute_merge_similarity_threshold(self) -> float:
"""
根据当前记忆数量占比动态计算合并相似度阈值。
规则:占比越高,阈值越低。
- < 60%: 0.80(更严格,避免早期误合并)
- < 80%: 0.70
- < 100%: 0.60
- < 120%: 0.50
- >= 120%: 0.45(最宽松,加速收敛)
"""
try:
current_count = MemoryChestModel.select().count()
max_count = max(1, int(global_config.memory.max_memory_number))
percentage = current_count / max_count
if percentage < 0.6:
return 0.70
elif percentage < 0.8:
return 0.60
elif percentage < 1.0:
return 0.50
elif percentage < 1.5:
return 0.40
elif percentage < 2:
return 0.30
else:
return 0.25
except Exception:
# 发生异常时使用保守阈值
return 0.70
async def build_running_content(self, chat_id: str = None) -> str:
"""
构建记忆仓库的运行内容
@@ -430,75 +504,43 @@ class MemoryChest:
except Exception as e:
logger.error(f"保存记忆仓库内容时出错: {e}")
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> list[str]:
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> tuple[list[str], list[str]]:
"""
选择与给定记忆标题相关的记忆目标
选择与给定记忆标题相关的记忆目标(基于文本相似度)
Args:
memory_title: 要匹配的记忆标题
chat_id: 聊天ID用于加权抽样
chat_id: 聊天ID用于筛选同chat_id的记忆
Returns:
list[str]: 选中的记忆内容列表
tuple[list[str], list[str]]: (选中的记忆标题列表, 选中的记忆内容列表)
"""
try:
# 如果提供了chat_id使用加权抽样
all_titles = get_memory_titles_by_chat_id_weighted(chat_id)
# 剔除掉输入的 memory_title 本身
all_titles = [title for title in all_titles if title and title.strip() != (memory_title or "").strip()]
if not chat_id:
logger.warning("未提供chat_id无法进行记忆匹配")
return [], []
content = ""
display_index = 1
for title in all_titles:
content += f"{display_index}. {title}\n"
display_index += 1
# 动态计算相似度阈值(占比越高阈值越低)
dynamic_threshold = self._compute_merge_similarity_threshold()
# 使用相似度匹配查找最相似的记忆(基于动态阈值)
similar_memory = find_most_similar_memory_by_chat_id(
target_title=memory_title,
target_chat_id=chat_id,
similarity_threshold=dynamic_threshold
)
prompt = f"""
所有记忆列表
{content}
请根据以上记忆列表,选择一个与"{memory_title}"相关的记忆用json输出
如果没有相关记忆,输出:
{{
"selected_title": ""
}}
可以选择多个相关的记忆但最多不超过5个
例如:
{{
"selected_title": "选择的相关记忆标题"
}},
{{
"selected_title": "选择的相关记忆标题"
}},
{{
"selected_title": "选择的相关记忆标题"
}}
...
注意:请返回原始标题本身作为 selected_title不要包含前面的序号或多余字符。
请输出JSON格式不要输出其他内容
"""
# logger.info(f"选择合并目标 prompt: {prompt}")
if global_config.debug.show_prompt:
logger.info(f"选择合并目标 prompt: {prompt}")
if similar_memory:
selected_title, selected_content, similarity = similar_memory
logger.info(f"'{memory_title}' 找到相似记忆: '{selected_title}' (相似度: {similarity:.3f} 阈值: {dynamic_threshold:.2f})")
return [selected_title], [selected_content]
else:
logger.debug(f"选择合并目标 prompt: {prompt}")
merge_target_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
# 解析JSON响应
selected_titles = self._parse_merge_target_json(merge_target_response)
# 根据标题查找对应的内容
selected_contents = self._get_memories_by_titles(selected_titles)
logger.info(f"选择合并目标结果: {len(selected_contents)} 条记忆:{selected_titles}")
return selected_titles,selected_contents
logger.info(f"'{memory_title}' 未找到相似度 >= {dynamic_threshold:.2f} 的记忆")
return [], []
except Exception as e:
logger.error(f"选择合并目标时出错: {e}")
return []
return [], []
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
"""
@@ -659,6 +701,11 @@ class MemoryChest:
合并记忆
"""
try:
# 在记忆整合前先清理空chat_id的记忆
cleaned_count = self.cleanup_empty_chat_id_memories()
if cleaned_count > 0:
logger.info(f"记忆整合前清理了 {cleaned_count} 条空chat_id记忆")
content = ""
for memory in memory_list:
content += f"{memory}\n"
@@ -793,5 +840,37 @@ MutePlugin 是禁言插件的名称
logger.error(f"生成合并记忆标题时出错: {e}")
return f"合并记忆_{int(time.time())}"
def cleanup_empty_chat_id_memories(self) -> int:
"""
清理chat_id为空的记忆记录
Returns:
int: 被清理的记忆数量
"""
try:
# 查找所有chat_id为空的记忆
empty_chat_id_memories = MemoryChestModel.select().where(
(MemoryChestModel.chat_id.is_null()) |
(MemoryChestModel.chat_id == "") |
(MemoryChestModel.chat_id == "None")
)
count = 0
for memory in empty_chat_id_memories:
logger.info(f"清理空chat_id记忆: 标题='{memory.title}', ID={memory.id}")
memory.delete_instance()
count += 1
if count > 0:
logger.info(f"已清理 {count} 条chat_id为空的记忆记录")
else:
logger.debug("未发现需要清理的空chat_id记忆记录")
return count
except Exception as e:
logger.error(f"清理空chat_id记忆时出错: {e}")
return 0
global_memory_chest = MemoryChest()

View File

@@ -0,0 +1,185 @@
import time
import asyncio
from typing import List, Optional, Tuple
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.memory_system.questions import global_conflict_tracker
from src.memory_system.memory_utils import parse_md_json
logger = get_logger("curious")
class CuriousDetector:
"""
好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容
"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.llm_request = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="curious_detector",
)
async def detect_questions(self, recent_messages: List) -> Optional[str]:
"""
检测最近消息中是否有需要提问的内容
Args:
recent_messages: 最近的消息列表
Returns:
Optional[str]: 如果检测到需要提问的内容返回问题文本否则返回None
"""
try:
if not recent_messages or len(recent_messages) < 2:
return None
# 构建聊天内容
chat_content_block, _ = build_readable_messages_with_id(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
# 检查是否已经有问题在跟踪中
existing_questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_id)
if len(existing_questions) > 0:
logger.debug(f"当前已有{len(existing_questions)}个问题在跟踪中,跳过检测")
return None
# 构建检测提示词
prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。
检测条件:
1. 聊天中存在逻辑矛盾或冲突的信息
2. 有人反对或否定之前提出的信息
3. 存在观点不一致的情况
4. 有模糊不清或需要澄清的概念
5. 有人提出了质疑或反驳
**重要限制:**
- 忽略涉及违法、暴力、色情、政治等敏感话题的内容
- 不要对敏感话题提问
- 只有在确实存在矛盾或冲突时才提问
- 如果聊天内容正常没有矛盾请输出NO
**聊天记录**
{chat_content_block}
请分析上述聊天记录如果发现需要提问的内容请用JSON格式输出
```json
{{
"question": "具体的问题描述,要完整描述涉及的概念和问题",
"reason": "为什么需要提问这个问题的理由"
}}
```
如果没有需要提问的内容请只输出NO"""
if global_config.debug.show_prompt:
logger.info(f"好奇心检测提示词: {prompt}")
else:
logger.debug("已发送好奇心检测提示词")
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3)
if not result_text:
return None
result_text = result_text.strip()
# 检查是否输出NO
if result_text.upper() == "NO":
logger.debug("未检测到需要提问的内容")
return None
# 尝试解析JSON
try:
questions, reasoning = parse_md_json(result_text)
if questions and len(questions) > 0:
question_data = questions[0]
question = question_data.get("question", "")
reason = question_data.get("reason", "")
if question and question.strip():
logger.info(f"检测到需要提问的内容: {question}")
logger.info(f"提问理由: {reason}")
return question
except Exception as e:
logger.warning(f"解析问题JSON失败: {e}")
logger.debug(f"原始响应: {result_text}")
return None
except Exception as e:
logger.error(f"好奇心检测失败: {e}")
return None
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
"""
将检测到的问题记录到冲突追踪器中
Args:
question: 检测到的问题
context: 问题上下文
Returns:
bool: 是否成功记录
"""
try:
if not question or not question.strip():
return False
# 记录问题到冲突追踪器,并开始跟踪
await global_conflict_tracker.track_conflict(
question=question.strip(),
context=context,
start_following=False,
chat_id=self.chat_id
)
logger.info(f"已记录问题到冲突追踪器: {question}")
return True
except Exception as e:
logger.error(f"记录问题失败: {e}")
return False
async def check_and_make_question(chat_id: str, recent_messages: List) -> bool:
"""
检查聊天记录并生成问题(如果检测到需要提问的内容)
Args:
chat_id: 聊天ID
recent_messages: 最近的消息列表
Returns:
bool: 是否检测到并记录了问题
"""
try:
detector = CuriousDetector(chat_id)
# 检测是否需要提问
question = await detector.detect_questions(recent_messages)
if question:
# 记录问题
success = await detector.make_question_from_detection(question)
if success:
logger.info(f"成功检测并记录问题: {question}")
return True
return False
except Exception as e:
logger.error(f"检查并生成问题失败: {e}")
return False

View File

@@ -8,7 +8,6 @@ from src.memory_system.Memory_chest import global_memory_chest
from src.common.logger import get_logger
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.config.config import global_config
from src.memory_system.memory_utils import get_all_titles
logger = get_logger("memory")
@@ -56,19 +55,19 @@ class MemoryManagementTask(AsyncTask):
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
if percentage < 0.5:
if percentage < 0.6:
# 小于50%每600秒执行一次
return 3600
elif percentage < 0.7:
elif percentage < 1:
# 大于等于50%每300秒执行一次
return 1800
elif percentage < 0.9:
# 大于等于70%每120秒执行一次
return 300
elif percentage < 1.2:
return 30
elif percentage < 1.5:
# 大于等于100%每120秒执行一次
return 600
elif percentage < 1.8:
return 120
else:
return 10
return 30
except Exception as e:
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
@@ -93,6 +92,22 @@ class MemoryManagementTask(AsyncTask):
percentage = current_count / self.max_memory_number
logger.info(f"当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
# 当占比 > 1.6 时,持续删除直到占比 <= 1.6(越老/越新更易被删)
if percentage > 2:
logger.info("记忆过多,开始遗忘记忆")
while True:
if percentage <= 1.8:
break
removed = global_memory_chest.remove_one_memory_by_age_weight()
if not removed:
logger.warning("没有可删除的记忆,停止连续删除")
break
# 重新计算占比
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
logger.info(f"遗忘进度: 当前 {current_count}/{self.max_memory_number} ({percentage:.1%})")
logger.info("遗忘记忆结束")
# 如果记忆数量为0跳过执行
if current_count < 10:
return
@@ -109,7 +124,7 @@ class MemoryManagementTask(AsyncTask):
logger.info("无合适合并内容,跳过本次合并")
return
logger.info(f"为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
logger.info(f"{selected_chat_id} 为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
# 执行merge_memory合并记忆
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents,selected_chat_id)

View File

@@ -303,4 +303,55 @@ def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight:
except Exception as e:
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
return []
return []
def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str, similarity_threshold: float = 0.5) -> Optional[Tuple[str, str, float]]:
"""
在指定chat_id的记忆中查找最相似的记忆
Args:
target_title: 目标标题
target_chat_id: 目标聊天ID
similarity_threshold: 相似度阈值默认0.7
Returns:
Optional[Tuple[str, str, float]]: 最相似的记忆(title, content, similarity)或None
"""
try:
# 获取指定chat_id的所有记忆
same_chat_memories = []
for memory in MemoryChestModel.select():
if memory.title and not memory.locked and memory.chat_id == target_chat_id:
same_chat_memories.append((memory.title, memory.content))
if not same_chat_memories:
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
return None
# 计算相似度并找到最佳匹配
best_match = None
best_similarity = 0.0
for title, content in same_chat_memories:
# 跳过目标标题本身
if title.strip() == target_title.strip():
continue
similarity = calculate_similarity(target_title, title)
if similarity > best_similarity:
best_similarity = similarity
best_match = (title, content, similarity)
# 检查是否超过阈值
if best_match and best_similarity >= similarity_threshold:
logger.info(f"找到最相似记忆: '{best_match[0]}' (相似度: {best_similarity:.3f})")
return best_match
else:
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆,最高相似度: {best_similarity:.3f}")
return None
except Exception as e:
logger.error(f"查找最相似记忆时出错: {e}")
return None

View File

@@ -50,41 +50,34 @@ class QuestionMaker:
"""按权重随机选取一个未回答的冲突并自增 raise_time。
选择规则:
- 若存在 `raise_time == 0` 的项按权重抽样0 次权重 1.0≥1 次权重 0.05)。
- 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。
- 若存在 `raise_time == 0` 的项按权重抽样0 次权重 1.0≥1 次权重 0.01)。
- 若不存在返回 None。
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
"""
conflicts = await self.get_un_answered_conflict()
if not conflicts:
return None
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
if not conflicts_with_zero:
if random.random() >= 0.05:
return None
# 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重)
chosen_conflict = random.choice(conflicts)
else:
# 权重规则raise_time == 0 -> 1.0raise_time >= 1 -> 0.05
if conflicts_with_zero:
# 权重规则raise_time == 0 -> 1.0raise_time >= 1 -> 0.01
weights = []
for conflict in conflicts:
current_raise_time = getattr(conflict, "raise_time", 0) or 0
weight = 1.0 if current_raise_time == 0 else 0.05
weight = 1.0 if current_raise_time == 0 else 0.01
weights.append(weight)
# 按权重随机选择
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
# 选中后,自增 raise_time 并保存
try:
# 选中后,自增 raise_time 并保存
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
chosen_conflict.save()
except Exception:
# 静默失败不影响流程
pass
return chosen_conflict
return chosen_conflict
else:
# 如果没有 raise_time == 0 的冲突,返回 None
return None
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""生成一条用于询问用户的冲突问题与上下文。

View File

@@ -278,15 +278,41 @@ class ConflictTracker:
# 无新消息时稍作等待
await asyncio.sleep(poll_interval)
# 未获取到答案,仅存储问题
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=time.time(),
update_time=time.time(),
answer="",
chat_id=tracker.chat_id,
# 未获取到答案,检查是否需要删除记录
# 查找现有的冲突记录
existing_conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == original_question,
MemoryConflict.chat_id == tracker.chat_id
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
if existing_conflict:
# 检查raise_time是否大于3且没有答案
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
if current_raise_time > 0 and not existing_conflict.answer:
# 删除该条目
await self.delete_conflict(original_question, tracker.chat_id)
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
else:
# 更新记录但不删除
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=existing_conflict.create_time,
update_time=time.time(),
answer="",
chat_id=tracker.chat_id,
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
else:
# 如果没有现有记录,创建新记录
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=time.time(),
update_time=time.time(),
answer="",
chat_id=tracker.chat_id,
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
logger.info(f"问题跟踪结束:{original_question}")
except Exception as e:
logger.error(f"后台问题跟踪任务异常: {e}")
@@ -374,6 +400,7 @@ class ConflictTracker:
1.提问必须具体,明确
2.提问最好涉及指向明确的事物,而不是代称
3.如果缺少上下文,不要强行提问,可以忽略
4.请忽略涉及违法,暴力,色情,政治等敏感话题的内容
请用json格式输出不要输出其他内容仅输出提问理由和具体提的提问:
**示例**
@@ -420,5 +447,33 @@ class ConflictTracker:
logger.error(f"获取冲突记录数量时出错: {e}")
return 0
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
"""
删除指定的冲突记录
Args:
conflict_content: 冲突内容
chat_id: 聊天ID
Returns:
bool: 是否成功删除
"""
try:
conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == conflict_content,
MemoryConflict.chat_id == chat_id
)
if conflict:
conflict.delete_instance()
logger.info(f"已删除冲突记录: {conflict_content}")
return True
else:
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
return False
except Exception as e:
logger.error(f"删除冲突记录时出错: {e}")
return False
# 全局冲突追踪器实例
global_conflict_tracker = ConflictTracker()

View File

@@ -1,11 +1,12 @@
from src.common.logger import get_logger
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.config.config import global_config
logger = get_logger("frequency_api")
def get_current_talk_frequency(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:

View File

@@ -1,14 +1,25 @@
from typing import Optional, Type
from typing import Optional, Type, TYPE_CHECKING
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象,用于传递聊天上下文信息
Returns:
Optional[BaseTool]: 工具实例如果未找到则返回None
"""
from src.plugin_system.core import component_registry
# 获取插件配置
@@ -19,7 +30,7 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config) if tool_class else None
return tool_class(plugin_config, chat_stream) if tool_class else None
def get_llm_available_tool_definitions():

View File

@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3)
logger = get_logger("base_tool")
@@ -29,8 +32,23 @@ class BaseTool(ABC):
available_for_llm: bool = False
"""是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None):
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
"""初始化工具基类
Args:
plugin_config: 插件配置字典
chat_stream: 聊天流对象,用于获取聊天上下文信息
"""
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:

View File

@@ -223,7 +223,7 @@ class ToolExecutor:
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
tool_instance = tool_instance or get_tool_instance(function_name, self.chat_stream)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None

View File

@@ -1,34 +0,0 @@
{
"manifest_version": 1,
"name": "MaiCurious插件 (MaiCurious Actions)",
"version": "1.0.0",
"description": "可以好奇",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.11.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["curious", "action", "built-in"],
"categories": ["Deep Think"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": true,
"plugin_type": "action_provider",
"components": [
{
"type": "action",
"name": "maicurious",
"description": "发送好奇"
}
]
}
}

View File

@@ -1,117 +0,0 @@
from typing import List, Tuple, Type, Any
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
from src.person_info.person_info import Person
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.relation.relation import BuildRelationAction
from src.plugin_system.apis import llm_api
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.component_types import ActionActivationType
from src.plugin_system.apis import config_api
from src.plugin_system.apis import frequency_api
from src.plugin_system.apis import generator_api
from src.memory_system.questions import global_conflict_tracker
logger = get_logger("question_actions")
class CuriousAction(BaseAction):
"""频率调节动作 - 调整聊天发言频率"""
activation_type = ActionActivationType.ALWAYS
parallel_action = True
# 动作基本信息
action_name = "make_question"
action_description = "提出一个问题,当有人反驳你的观点,或其他人之间有观点冲突时使用"
# 动作参数定义
action_parameters = {
"question": "对存在疑问的信息提出一个问题,描述全面,完整的描述涉及的概念和问题",
}
action_require = [
f"当聊天记录中的概念存在逻辑上的矛盾时使用",
f"当有人反对或否定你提出的信息时使用",
f"或当你对现有的概念或事物存在疑问时使用",
f"有人认为你的观点是错误的请选择question动作",
f"有人与你观点不一致请选择question动作",
]
# 关联类型
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行频率调节动作"""
try:
if len(global_conflict_tracker.question_tracker_list) > 1:
return False, "当前已有问题请先解答完再提问不要再使用make_question动作"
question = self.action_data.get("question", "")
# 存储问题到冲突追踪器
if question:
await global_conflict_tracker.record_conflict(conflict_content=question, start_following=True,chat_id=self.chat_id)
logger.info(f"已存储问题到冲突追踪器: {question}")
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你产生了一个问题:{question}",
action_done=True,
)
return True, f"问题{question}已记录,不要重复提问该问题"
except Exception as e:
error_msg = f"问题生成失败: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}", exc_info=True)
await self.send_text("问题生成失败")
return False, error_msg
@register_plugin
class CuriousPlugin(BasePlugin):
"""关系动作插件
系统内置插件,提供基础的聊天交互功能:
- Reply: 回复动作
- NoReply: 不回复动作
- Emoji: 表情动作
注意插件基本信息优先从_manifest.json文件中读取
"""
# 插件基本信息
plugin_name: str = "maicurious" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件启用配置",
"components": "核心组件启用配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="3.0.0", description="配置文件版本"),
}
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# --- 根据配置注册组件 ---
components = []
components.append((CuriousAction.get_action_info(), CuriousAction))
return components

View File

@@ -1,29 +1,77 @@
from typing import Tuple
import asyncio
from datetime import datetime
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.prompt_builder import Prompt
from src.llm_models.payload_content.tool_option import ToolParamType
from src.plugin_system import BaseAction, ActionActivationType
from src.chat.utils.utils import cut_key_words
from src.memory_system.Memory_chest import global_memory_chest
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.apis.message_api import get_messages_by_time_in_chat, build_readable_messages
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from typing import Any
logger = get_logger("memory")
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e:
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def parse_time_range(time_range: str) -> tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
格式: "YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
"""
if " - " not in time_range:
raise ValueError("时间范围格式错误,应使用 ' - ' 分隔开始和结束时间")
start_str, end_str = time_range.split(" - ", 1)
start_timestamp = parse_datetime_to_timestamp(start_str.strip())
end_timestamp = parse_datetime_to_timestamp(end_str.strip())
if start_timestamp > end_timestamp:
raise ValueError("开始时间不能晚于结束时间")
return start_timestamp, end_timestamp
class GetMemoryTool(BaseTool):
"""获取用户信息"""
name = "get_memory"
description = "在记忆中搜索,获取某个问题的答案"
description = "在记忆中搜索,获取某个问题的答案,可以指定搜索的时间范围或时间点"
parameters = [
("question", ToolParamType.STRING, "需要获取答案的问题", True, None)
("question", ToolParamType.STRING, "需要获取答案的问题", True, None),
("time_point", ToolParamType.STRING, "需要获取记忆的时间点格式为YYYY-MM-DD HH:MM:SS", False, None),
("time_range", ToolParamType.STRING, "需要获取记忆的时间范围格式为YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS", False, None)
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
"""执行记忆搜索
Args:
function_args: 工具参数
@@ -32,59 +80,166 @@ class GetMemoryTool(BaseTool):
dict: 工具执行结果
"""
question: str = function_args.get("question") # type: ignore
time_point: str = function_args.get("time_point") # type: ignore
time_range: str = function_args.get("time_range") # type: ignore
answer = await global_memory_chest.get_answer_by_question(question=question)
if not answer:
return {"content": f"问题:{question},没有找到相关记忆"}
# 检查是否指定了时间参数
has_time_params = bool(time_point or time_range)
return {"content": f"问题:{question},答案:{answer}"}
class GetMemoryAction(BaseAction):
"""关系动作 - 获取记忆"""
activation_type = ActionActivationType.LLM_JUDGE
parallel_action = True
# 动作基本信息
action_name = "get_memory"
action_description = (
"在记忆中搜寻某个问题的答案"
)
# 动作参数定义
action_parameters = {
"question": "需要搜寻或回答的问题",
}
# 动作使用场景
action_require = [
"在记忆中搜寻某个问题的答案",
"有你不了解的概念",
"有人提问关于过去的事情"
"你需要根据记忆回答某个问题",
]
# 关联类型
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行关系动作"""
if has_time_params and not self.chat_id:
return {"content": f"问题:{question}无法获取聊天记录缺少chat_id"}
question = self.action_data.get("question", "")
answer = await global_memory_chest.get_answer_by_question(self.chat_id, question)
if not answer:
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你回忆了有关问题:{question}的记忆,但是没有找到相关记忆",
action_done=True,
)
return False, f"问题:{question},没有找到相关记忆"
# 创建并行任务
tasks = []
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你回忆了有关问题:{question}的记忆,答案是:{answer}",
action_done=True,
# 原任务:从记忆仓库获取答案
memory_task = asyncio.create_task(
global_memory_chest.get_answer_by_question(question=question)
)
tasks.append(("memory", memory_task))
return True, f"成功获取记忆: {answer}"
# 新任务:从聊天记录获取答案(如果指定了时间参数)
chat_task = None
if has_time_params:
chat_task = asyncio.create_task(
self._get_answer_from_chat_history(question, time_point, time_range)
)
tasks.append(("chat", chat_task))
# 等待所有任务完成
results = {}
for task_name, task in tasks:
try:
results[task_name] = await task
except Exception as e:
logger.error(f"任务 {task_name} 执行失败: {e}")
results[task_name] = None
# 处理结果
memory_answer = results.get("memory")
chat_answer = results.get("chat")
# 构建返回内容
content_parts = []
if memory_answer:
content_parts.append(f"对问题'{question}',你回忆的信息是:{memory_answer}")
if chat_answer:
content_parts.append(f"对问题'{question}',基于聊天记录的回答:{chat_answer}")
elif has_time_params:
if time_point:
content_parts.append(f"{time_point} 的时间点,你没有参与聊天")
elif time_range:
content_parts.append(f"{time_range} 的时间范围内,你没有参与聊天")
if content_parts:
retrieval_content = f"问题:{question}" + "\n".join(content_parts)
return {"content": retrieval_content}
else:
return {"content": ""}
async def _get_answer_from_chat_history(self, question: str, time_point: str = None, time_range: str = None) -> str:
"""从聊天记录中获取问题的答案"""
try:
# 确定时间范围
print(f"time_point: {time_point}, time_range: {time_range}")
# 检查time_range的两个时间值是否相同如果相同则按照time_point处理
if time_range and not time_point:
try:
start_timestamp, end_timestamp = parse_time_range(time_range)
if start_timestamp == end_timestamp:
# 两个时间值相同按照time_point处理
time_point = time_range.split(" - ")[0].strip()
time_range = None
print(f"time_range两个值相同按照time_point处理: {time_point}")
except Exception as e:
logger.warning(f"解析time_range失败: {e}")
if time_point:
# 时间点搜索前后25条记录
target_timestamp = parse_datetime_to_timestamp(time_point)
# 获取前后各25条记录总共50条
messages_before = get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=0,
end_time=target_timestamp,
limit=25,
limit_mode="latest"
)
messages_after = get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=target_timestamp,
end_time=float('inf'),
limit=25,
limit_mode="earliest"
)
messages = messages_before + messages_after
elif time_range:
# 时间范围搜索范围内最多50条记录
start_timestamp, end_timestamp = parse_time_range(time_range)
messages = get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=start_timestamp,
end_time=end_timestamp,
limit=50,
limit_mode="latest"
)
else:
return "未指定时间参数"
if not messages:
return "没有找到相关聊天记录"
# 将消息转换为可读格式
chat_content = build_readable_messages(messages, timestamp_mode="relative")
if not chat_content.strip():
return "聊天记录为空"
# 使用LLM分析聊天内容并回答问题
try:
llm_request = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="chat_history_analysis"
)
analysis_prompt = f"""请根据以下聊天记录内容,回答用户的问题。请输出一段平文本,不要有特殊格式。
聊天记录:
{chat_content}
用户问题:{question}
请仔细分析聊天记录,提取与问题相关的信息,并给出准确的答案。如果聊天记录中没有相关信息,无法回答问题,输出"无有效信息"即可,不要输出其他内容。
答案:"""
print(f"analysis_prompt: {analysis_prompt}")
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
prompt=analysis_prompt,
temperature=0.3,
max_tokens=256
)
print(f"response: {response}")
if "无有效信息" in response:
return ""
return response
except Exception as llm_error:
logger.error(f"LLM分析聊天记录失败: {llm_error}")
# 如果LLM分析失败返回聊天内容的摘要
if len(chat_content) > 300:
chat_content = chat_content[:300] + "..."
return chat_content
except Exception as e:
logger.error(f"从聊天记录获取答案失败: {e}")
return ""

View File

@@ -7,7 +7,7 @@ from src.plugin_system.base.config_types import ConfigField
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.memory.build_memory import GetMemoryAction, GetMemoryTool
from src.plugins.built_in.memory.build_memory import GetMemoryTool
logger = get_logger("memory_build")
@@ -48,7 +48,6 @@ class MemoryBuildPlugin(BasePlugin):
# --- 根据配置注册组件 ---
components = []
# components.append((GetMemoryAction.get_action_info(), GetMemoryAction))
components.append((GetMemoryTool.get_tool_info(), GetMemoryTool))
return components

View File

@@ -78,7 +78,7 @@ class RelationActionsPlugin(BasePlugin):
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
"config_version": ConfigField(type=str, default="1.0.2", description="配置文件版本"),
},
"components": {
"relation_max_memory_num": ConfigField(type=int, default=10, description="关系记忆最大数量"),
@@ -90,7 +90,7 @@ class RelationActionsPlugin(BasePlugin):
# --- 根据配置注册组件 ---
components = []
components.append((BuildRelationAction.get_action_info(), BuildRelationAction))
components.append((GetPersonInfoTool.get_tool_info(), GetPersonInfoTool))
# components.append((BuildRelationAction.get_action_info(), BuildRelationAction))
# components.append((GetPersonInfoTool.get_tool_info(), GetPersonInfoTool))
return components

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.18.3"
version = "6.19.2"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -14,6 +14,9 @@ version = "6.18.3"
[bot]
platform = "qq"
qq_account = "1145141919810" # 麦麦的QQ账号
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
nickname = "麦麦" # 麦麦的昵称
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
@@ -44,10 +47,20 @@ private_plan_style = """
2.如果相同的内容已经被执行,请不要重复执行
3.某句话如果已经被回复过,不要重复回复"""
# 状态,可以理解为人格多样性,会随机替换人格
states = [
"是一个女大学生,喜欢上网聊天,会刷小红书。" ,
"是一个大二心理学生,会刷贴吧和中国知网。" ,
"是一个赛博网友,最近很想吐槽人。"
]
# 替换概率每次构建人格时替换personality的概率0.0-1.0
state_probability = 0.3
[expression]
# 表达方式模式(此选项暂未使用)
mode = "context"
# 可选:llm模式context上下文模式
# 表达方式模式
mode = "classic"
# 可选:classic经典模式exp_model 表达模型模式,这个模式需要一定时间学习才会有比较好的效果
# 表达学习配置
learning_list = [ # 表达学习配置列表,支持按聊天流配置
@@ -79,6 +92,9 @@ max_context_size = 30 # 上下文长度
auto_chat_value = 1 # 自动聊天,越小,麦麦主动聊天的概率越低
planner_smooth = 5 #规划器平滑增大数值会减小planner负荷略微降低反应速度推荐2-80为关闭必须大于等于0
enable_talk_value_rules = true # 是否启用动态发言频率规则
enable_auto_chat_value_rules = false # 是否启用动态自动聊天频率规则
# 动态发言频率规则:按时段/按chat_id调整 talk_value优先匹配具体chat再匹配全局
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
# 说明:
@@ -205,6 +221,8 @@ library_log_levels = { aiohttp = "WARNING"} # 设置特定库的日志级别
[debug]
show_prompt = false # 是否显示prompt
show_replyer_prompt = false # 是否显示回复器prompt
show_replyer_reasoning = false # 是否显示回复器推理
[maim_message]
auth_token = [] # 认证令牌用于API验证为空则不启用验证

View File

@@ -1,5 +1,5 @@
[inner]
version = "1.7.4"
version = "1.7.7"
# 配置文件版本号迭代规则同bot_config.toml
@@ -35,9 +35,9 @@ name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
max_retry = 2
max_retry = 3
timeout = 120
retry_interval = 10
retry_interval = 5
[[models]] # 模型(可以配置多个)
@@ -49,21 +49,31 @@ price_out = 8.0 # 输出价格用于API调用统计
#force_stream_mode = true # 强制流式输出模式若模型不支持非流式输出请取消该注释启用强制流式输出若无该字段默认值为false
[[models]]
model_identifier = "deepseek-ai/DeepSeek-V3"
name = "siliconflow-deepseek-v3"
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
name = "siliconflow-deepseek-v3.2"
api_provider = "SiliconFlow"
price_in = 2.0
price_out = 8.0
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
price_in = 0
price_out = 0
price_out = 3.0
[models.extra_params] # 可选的额外参数配置
enable_thinking = false # 不启用思考
[[models]]
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
name = "siliconflow-deepseek-v3.2-think"
api_provider = "SiliconFlow"
price_in = 2.0
price_out = 3.0
[models.extra_params] # 可选的额外参数配置
enable_thinking = true # 不启用思考
[[models]]
model_identifier = "deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow"
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Qwen/Qwen3-30B-A3B-Instruct-2507"
name = "qwen3-30b"
@@ -72,8 +82,8 @@ price_in = 0.7
price_out = 2.8
[[models]]
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
name = "qwen2.5-vl-72b"
model_identifier = "Qwen/Qwen3-VL-30B-A3B-Instruct"
name = "qwen3-vl-30"
api_provider = "SiliconFlow"
price_in = 4.13
price_out = 4.13
@@ -94,12 +104,12 @@ price_out = 0
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,麦麦的情绪变化等,是麦麦必须的模型
model_list = ["siliconflow-deepseek-v3"] # 使用的模型列表,每个子项对应上面的模型名称(name)
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 2048 # 最大输出token数
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
model_list = ["qwen3-8b","qwen3-30b"]
model_list = ["qwen3-30b"]
temperature = 0.7
max_tokens = 2048
@@ -109,18 +119,18 @@ temperature = 0.7
max_tokens = 800
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
model_list = ["siliconflow-deepseek-v3"]
model_list = ["siliconflow-deepseek-v3.2-think","siliconflow-deepseek-r1","siliconflow-deepseek-v3.2"]
temperature = 0.3 # 模型温度新V3建议0.1-0.3
max_tokens = 800
max_tokens = 2048
[model_task_config.planner] #决策:负责决定麦麦该什么时候回复的模型
model_list = ["siliconflow-deepseek-v3"]
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.3
max_tokens = 800
[model_task_config.vlm] # 图像识别模型
model_list = ["qwen2.5-vl-72b"]
max_tokens = 800
model_list = ["qwen3-vl-30"]
max_tokens = 256
[model_task_config.voice] # 语音识别模型
model_list = ["sensevoice-small"]
@@ -132,16 +142,16 @@ model_list = ["bge-m3"]
#------------LPMM知识库模型------------
[model_task_config.lpmm_entity_extract] # 实体提取模型
model_list = ["siliconflow-deepseek-v3"]
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.2
max_tokens = 800
[model_task_config.lpmm_rdf_build] # RDF构建模型
model_list = ["siliconflow-deepseek-v3"]
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.2
max_tokens = 800
[model_task_config.lpmm_qa] # 问答模型
model_list = ["qwen3-30b"]
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.7
max_tokens = 800

391
test_style_learner_db.py Normal file
View File

@@ -0,0 +1,391 @@
"""
StyleLearner 数据库测试脚本
使用数据库中的expression数据测试style_learner功能
"""
import os
import sys
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.common.database.database_model import Expression, db
from src.express.style_learner import StyleLearnerManager
from src.common.logger import get_logger
logger = get_logger("style_learner_test")
class StyleLearnerDatabaseTest:
"""使用数据库数据测试StyleLearner"""
def __init__(self, random_state: int = 42):
self.random_state = random_state
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
# 测试结果
self.test_results = {
"total_samples": 0,
"train_samples": 0,
"test_samples": 0,
"unique_styles": 0,
"unique_chat_ids": 0,
"accuracy": 0.0,
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"predictions": [],
"ground_truth": [],
"model_save_success": False,
"model_save_path": self.manager.model_save_path
}
def load_data_from_database(self) -> List[Dict]:
"""
从数据库加载expression数据
Returns:
List[Dict]: 包含up_content, style, chat_id的数据列表
"""
try:
# 连接数据库
db.connect(reuse_if_open=True)
# 查询所有expression数据
expressions = Expression.select().where(
(Expression.up_content.is_null(False)) &
(Expression.style.is_null(False)) &
(Expression.chat_id.is_null(False)) &
(Expression.type == "style")
)
data = []
for expr in expressions:
if expr.up_content and expr.style and expr.chat_id:
data.append({
"up_content": expr.up_content,
"style": expr.style,
"chat_id": expr.chat_id,
"last_active_time": expr.last_active_time,
"context": expr.context,
"situation": expr.situation
})
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
return data
except Exception as e:
logger.error(f"从数据库加载数据失败: {e}")
return []
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
"""
数据预处理
Args:
data: 原始数据
Returns:
List[Dict]: 预处理后的数据
"""
# 过滤掉空值或过短的数据
filtered_data = []
for item in data:
up_content = item["up_content"].strip()
style = item["style"].strip()
if len(up_content) >= 2 and len(style) >= 2:
filtered_data.append({
"up_content": up_content,
"style": style,
"chat_id": item["chat_id"],
"last_active_time": item["last_active_time"],
"context": item["context"],
"situation": item["situation"]
})
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
return filtered_data
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
分割训练集和测试集
训练集使用所有数据测试集从训练集中随机选择5%
Args:
data: 预处理后的数据
Returns:
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
"""
# 训练集使用所有数据
train_data = data.copy()
# 测试集从训练集中随机选择5%
test_size = 0.05 # 5%
test_data = train_test_split(
train_data, test_size=test_size, random_state=self.random_state
)[1] # 只取测试集部分
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)}")
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
return train_data, test_data
def train_model(self, train_data: List[Dict]) -> None:
"""
训练模型
Args:
train_data: 训练数据
"""
logger.info("开始训练模型...")
# 统计信息
chat_ids = set()
styles = set()
for item in train_data:
chat_id = item["chat_id"]
up_content = item["up_content"]
style = item["style"]
chat_ids.add(chat_id)
styles.add(style)
# 学习映射关系
success = self.manager.learn_mapping(chat_id, up_content, style)
if not success:
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
self.test_results["train_samples"] = len(train_data)
self.test_results["unique_styles"] = len(styles)
self.test_results["unique_chat_ids"] = len(chat_ids)
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
# 保存训练好的模型
logger.info("开始保存训练好的模型...")
save_success = self.manager.save_all_models()
self.test_results["model_save_success"] = save_success
if save_success:
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
else:
logger.warning("部分模型保存失败")
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
def test_model(self, test_data: List[Dict]) -> None:
"""
测试模型
Args:
test_data: 测试数据
"""
logger.info("开始测试模型...")
predictions = []
ground_truth = []
correct_predictions = 0
for item in test_data:
chat_id = item["chat_id"]
up_content = item["up_content"]
true_style = item["style"]
# 预测风格
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
predictions.append(predicted_style)
ground_truth.append(true_style)
# 检查预测是否正确
if predicted_style == true_style:
correct_predictions += 1
# 记录详细预测结果
self.test_results["predictions"].append({
"chat_id": chat_id,
"up_content": up_content,
"true_style": true_style,
"predicted_style": predicted_style,
"scores": scores
})
# 计算准确率
accuracy = correct_predictions / len(test_data) if test_data else 0
# 计算其他指标需要处理None值
valid_predictions = [p for p in predictions if p is not None]
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
if valid_predictions:
precision, recall, f1, _ = precision_recall_fscore_support(
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
)
else:
precision = recall = f1 = 0.0
self.test_results["test_samples"] = len(test_data)
self.test_results["accuracy"] = accuracy
self.test_results["precision"] = precision
self.test_results["recall"] = recall
self.test_results["f1_score"] = f1
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
def analyze_results(self) -> None:
"""分析测试结果"""
logger.info("=== 测试结果分析 ===")
print("\n📊 数据统计:")
print(f" 总样本数: {self.test_results['total_samples']}")
print(f" 训练样本数: {self.test_results['train_samples']}")
print(f" 测试样本数: {self.test_results['test_samples']}")
print(f" 唯一风格数: {self.test_results['unique_styles']}")
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
print("\n🎯 模型性能:")
print(f" 准确率: {self.test_results['accuracy']:.4f}")
print(f" 精确率: {self.test_results['precision']:.4f}")
print(f" 召回率: {self.test_results['recall']:.4f}")
print(f" F1分数: {self.test_results['f1_score']:.4f}")
print("\n💾 模型保存:")
save_status = "成功" if self.test_results['model_save_success'] else "失败"
print(f" 保存状态: {save_status}")
print(f" 保存路径: {self.test_results['model_save_path']}")
# 分析各聊天室的性能
chat_performance = {}
for pred in self.test_results["predictions"]:
chat_id = pred["chat_id"]
if chat_id not in chat_performance:
chat_performance[chat_id] = {"correct": 0, "total": 0}
chat_performance[chat_id]["total"] += 1
if pred["predicted_style"] == pred["true_style"]:
chat_performance[chat_id]["correct"] += 1
print("\n📈 各聊天室性能:")
for chat_id, perf in chat_performance.items():
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
# 分析风格分布
style_counts = {}
for pred in self.test_results["predictions"]:
style = pred["true_style"]
style_counts[style] = style_counts.get(style, 0) + 1
print("\n🎨 风格分布 (前10个):")
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
for style, count in sorted_styles[:10]:
print(f" {style}: {count}")
def show_sample_predictions(self, num_samples: int = 10) -> None:
"""显示样本预测结果"""
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
status = "" if pred["predicted_style"] == pred["true_style"] else ""
print(f"\n {i+1}. {status}")
print(f" 聊天室: {pred['chat_id']}")
print(f" 输入内容: {pred['up_content']}")
print(f" 真实风格: {pred['true_style']}")
print(f" 预测风格: {pred['predicted_style']}")
if pred["scores"]:
top_scores = dict(list(pred["scores"].items())[:3])
print(f" 分数: {top_scores}")
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
"""保存测试结果到文件"""
try:
with open(output_file, "w", encoding="utf-8") as f:
f.write("StyleLearner 数据库测试结果\n")
f.write("=" * 50 + "\n\n")
f.write("数据统计:\n")
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
f.write("模型性能:\n")
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
f.write("模型保存:\n")
save_status = "成功" if self.test_results['model_save_success'] else "失败"
f.write(f" 保存状态: {save_status}\n")
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
f.write("详细预测结果:\n")
for i, pred in enumerate(self.test_results["predictions"]):
status = "" if pred["predicted_style"] == pred["true_style"] else ""
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
logger.info(f"测试结果已保存到 {output_file}")
except Exception as e:
logger.error(f"保存测试结果失败: {e}")
def run_test(self) -> None:
"""运行完整测试"""
logger.info("开始StyleLearner数据库测试...")
# 1. 加载数据
raw_data = self.load_data_from_database()
if not raw_data:
logger.error("没有加载到数据,测试终止")
return
# 2. 数据预处理
processed_data = self.preprocess_data(raw_data)
if not processed_data:
logger.error("预处理后没有数据,测试终止")
return
self.test_results["total_samples"] = len(processed_data)
# 3. 分割数据
train_data, test_data = self.split_data(processed_data)
# 4. 训练模型
self.train_model(train_data)
# 5. 测试模型
self.test_model(test_data)
# 6. 分析结果
self.analyze_results()
# 7. 显示样本预测
self.show_sample_predictions(10)
# 8. 保存结果
self.save_results()
logger.info("StyleLearner数据库测试完成!")
def main():
"""主函数"""
print("StyleLearner 数据库测试脚本")
print("=" * 50)
# 创建测试实例
test = StyleLearnerDatabaseTest(random_state=42)
# 运行测试
test.run_test()
if __name__ == "__main__":
main()

76
view_pkl.py Normal file
View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
"""
查看 .pkl 文件内容的工具脚本
"""
import pickle
import sys
import os
from pprint import pprint
def view_pkl_file(file_path):
"""查看 pkl 文件内容"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print(f"📊 数据类型: {type(data)}")
print("=" * 50)
if isinstance(data, dict):
print("🔑 字典键:")
for key in data.keys():
print(f" - {key}: {type(data[key])}")
print()
print("📋 详细内容:")
pprint(data, width=120, depth=10)
elif isinstance(data, list):
print(f"📝 列表长度: {len(data)}")
if data:
print(f"📊 第一个元素类型: {type(data[0])}")
print("📋 前几个元素:")
for i, item in enumerate(data[:3]):
print(f" [{i}]: {item}")
else:
print("📋 内容:")
pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
print("\n" + "="*50)
print("🔍 详细词汇统计 (token_counts):")
token_counts = data['nb']['token_counts']
for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:")
if tokens:
# 按词频排序显示前10个词
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
for word, count in sorted_tokens[:10]:
print(f" '{word}': {count}")
if len(sorted_tokens) > 10:
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
else:
print(" (无词汇数据)")
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>")
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
return
file_path = sys.argv[1]
view_pkl_file(file_path)
if __name__ == "__main__":
main()

63
view_tokens.py Normal file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
专门查看 expressor.pkl 文件中 token_counts 的脚本
"""
import pickle
import sys
import os
def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print("=" * 60)
if 'nb' not in data or 'token_counts' not in data['nb']:
print("❌ 这不是一个 expressor 模型文件")
return
token_counts = data['nb']['token_counts']
candidates = data.get('candidates', {})
print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60)
for style_id, tokens in token_counts.items():
style_text = candidates.get(style_id, "未知风格")
print(f"\n📝 {style_id}: {style_text}")
print(f"📊 词汇数量: {len(tokens)}")
if tokens:
# 按词频排序
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens):
print(f" {i+1:2d}. '{word}': {count}")
else:
print(" (无词汇数据)")
print("-" * 40)
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>")
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
return
file_path = sys.argv[1]
view_token_counts(file_path)
if __name__ == "__main__":
main()