Merge branch 'MaiM-with-u:dev' into dev

This commit is contained in:
infinitycat
2025-10-09 14:44:55 +08:00
committed by GitHub
151 changed files with 10610 additions and 14541 deletions

1
.envrc
View File

@@ -1 +0,0 @@
use flake

3
.gitattributes vendored
View File

@@ -1,3 +1,2 @@
*.bat text eol=crlf
*.cmd text eol=crlf
MaiLauncher.bat text eol=crlf working-tree-encoding=GBK
*.cmd text eol=crlf

12
.gitignore vendored
View File

@@ -20,6 +20,8 @@ MaiBot-Napcat-Adapter
nonebot-maibot-adapter/
MaiMBot-LPMM
*.zip
run_bot.bat
run_na.bat
run.bat
log_debug/
run_amds.bat
@@ -41,16 +43,13 @@ config/bot_config.toml
config/bot_config.toml.bak
config/lpmm_config.toml
config/lpmm_config.toml.bak
src/mais4u/config/s4u_config.toml
src/mais4u/config/old
template/compare/bot_config_template.toml
template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py
CLAUDE.md
s4u.s4u
s4u.s4u1
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -321,9 +320,14 @@ run_pet.bat
/plugins/*
!/plugins
!/plugins/hello_world_plugin
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin
!/plugins/deep_think
!/plugins/MaiFrequencyControl
!/plugins/__init__.py
config.toml
interested_rates.txt
MaiBot.code-workspace
*.lock

View File

@@ -26,12 +26,10 @@
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
- 🔌 **强大插件系统**全面重构的插件架构更多API。
- 🤔 **实时思维系统**:模拟人类思考过程。
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
- 💝 **情感表达系统**:情绪系统和表情包系统。
- 🧠 **持久记忆系统**基于图的长期记忆存储
- 🔄 **动态人格系统**:自适应的性格特征和表达方式。
- 🔌 **强大插件系统**提供API和事件系统可编写强大插件
<div style="text-align: center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
@@ -46,7 +44,7 @@
## 🔥 更新和安装
**最新版本: v0.10.1** ([更新日志](changelogs/changelog.md))
**最新版本: v0.10.3** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
@@ -64,7 +62,7 @@
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于程序处于开发中,可能消耗较多 token。
## 麦麦MC项目早期开发
## 麦麦MC项目MaiCraft(早期开发)
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
交流群1058573197
@@ -72,13 +70,13 @@
## 💬 讨论
**技术交流群:**
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
[三群](https://qm.qq.com/q/wlH5eT8OmQ) |
[四群](https://qm.qq.com/q/wGePTl1UyY)
[麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) |
[麦麦脑磁图](https://qm.qq.com/q/wlH5eT8OmQ) |
[麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) |
[麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY)
**聊天吹水群:**
- [](https://qm.qq.com/q/JxvHZnxyec)
- [麦麦之闲聊](https://qm.qq.com/q/JxvHZnxyec)
**插件开发测试版群:**
- [插件开发群](https://qm.qq.com/q/1036092828)

24
bot.py
View File

@@ -5,16 +5,29 @@ import sys
import time
import platform
import traceback
import shutil
from dotenv import load_dotenv
from pathlib import Path
from rich.traceback import install
if os.path.exists(".env"):
load_dotenv(".env", override=True)
env_path = Path(__file__).parent / ".env"
template_env_path = Path(__file__).parent / "template" / "template.env"
if env_path.exists():
load_dotenv(str(env_path), override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
try:
if template_env_path.exists():
shutil.copyfile(template_env_path, env_path)
print("未找到.env已从 template/template.env 自动创建")
load_dotenv(str(env_path), override=True)
else:
print("未找到.env文件也未找到模板 template/template.env")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
except Exception as e:
print(f"自动创建 .env 失败: {e}")
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging
@@ -62,9 +75,10 @@ def easter_egg():
async def graceful_shutdown(): # sourcery skip: use-named-expression
try:
logger.info("正在优雅关闭麦麦...")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
# 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)

View File

@@ -1,5 +1,41 @@
# Changelog
## [0.11.0] - 2025-9-22
### 🌟 主要功能更改
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
- 麦麦好奇功能,麦麦会自主提出问题
- 添加deepthink插件默认关闭让麦麦可以深度思考一些问题
- 添加表情包管理插件
### 细节功能更改
- 修复配置文件转义问题
- 情绪系统现在可以由配置文件控制开关
- 修复平行动作控制失效的问题
- 添加planner防抖防止短时间快速消耗token
- 修复吞字问题
- 更新依赖表
- 修复负载均衡
- 优化了对gemini和不同模型的支持
## [0.10.3] - 2025-9-22
### 🌟 主要功能更改
- planner支持多动作移除Sub_planner
- 移除激活度系统现在回复完全由planner控制
- 现可自定义planner行为更优化的聊天频率控制
- 支持发送转发和合并转发
- 关系现在支持多人的信息
- 更好的event系统正式建立
### 细节功能更改
- 支持所有表达方式互通
- 现可使用付费嵌入模型
- 添加多种发送类型
- 优化识图token限制
- 为空回复添加重试机制
- 加入brainchat模式为私聊支持做准备
- 修复qq号格式
## [0.10.2] - 2025-8-31
### 🌟 主要功能更改

View File

@@ -1,51 +0,0 @@
# Changelog
## [1.0.3] - 2025-3-31
### Added
- 新增了心流相关配置项:
- `heartflow` 配置项,用于控制心流功能
### Removed
- 移除了 `response` 配置项中的 `model_r1_probability``model_v3_probability` 选项
- 移除了次级推理模型相关配置
## [1.0.1] - 2025-3-30
### Added
- 增加了流式输出控制项 `stream`
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
## [1.0.0] - 2025-3-30
### Added
- 修复了错误的版本命名
- 杀掉了所有无关文件
## [0.0.11] - 2025-3-12
### Added
- 新增了 `schedule` 配置项,用于配置日程表生成功能
- 新增了 `response_splitter` 配置项,用于控制回复分割
- 新增了 `experimental` 配置项,用于实验性功能开关
- 新增了 `llm_observation``llm_sub_heartflow` 模型配置
- 新增了 `llm_heartflow` 模型配置
-`personality` 配置项中新增了 `prompt_schedule_gen` 参数
### Changed
- 优化了模型配置的组织结构
- 调整了部分配置项的默认值
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
-`message` 配置项中:
- 新增了 `model_max_output_length` 参数
-`willing` 配置项中新增了 `emoji_response_penalty` 参数
-`personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
### Removed
- 移除了 `min_text_length` 配置项
- 移除了 `cq_code` 配置项
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
## [0.0.5] - 2025-3-11
### Added
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
## [0.0.4] - 2025-3-9
### Added
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。

View File

@@ -28,7 +28,7 @@ version = "1.1.1"
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
@@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式,现在支持不良好 | `openai` |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效**
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.cn/v1"
base_url = "https://api.deepseek.com/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
@@ -73,7 +73,7 @@ client_type = "openai"
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
@@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
而对于`gemini`需要单独进行配置
```toml
[[models]]
model_identifier = "gemini-2.5-flash"
name = "gemini-2.5-flash"
api_provider = "Google"
[models.extra_params]
thinking_budget = 0 # 禁用思考
# thinking_budget = -1 由模型自己决定
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |

57
flake.lock generated
View File

@@ -1,57 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 0,
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
"type": "path"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs",
"utils": "utils"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

View File

@@ -1,39 +0,0 @@
{
description = "MaiMBot Nix Dev Env";
inputs = {
utils.url = "github:numtide/flake-utils";
};
outputs = {
self,
nixpkgs,
utils,
...
}:
utils.lib.eachDefaultSystem (system: let
pkgs = import nixpkgs {inherit system;};
pythonPackages = pkgs.python3Packages;
in {
devShells.default = pkgs.mkShell {
name = "python-venv";
venvDir = "./.venv";
buildInputs = with pythonPackages; [
python
venvShellHook
scipy
numpy
];
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
pip install -r requirements.txt
'';
postShellHook = ''
# allow pip to install wheels
unset SOURCE_DATE_EPOCH
'';
};
});
}

View File

@@ -828,7 +828,7 @@ class LogViewer:
parts, tags = self.formatter.format_log_entry(log_entry)
line_text = " ".join(parts)
log_lines.append(line_text)
with open(filename, "w", encoding="utf-8") as f:
f.write("\n".join(log_lines))
messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
@@ -1188,15 +1188,16 @@ class LogViewer:
line_count += 1
except json.JSONDecodeError:
continue
# 如果发现了新模块,在主线程中更新模块集合
if new_modules:
def update_modules():
self.modules.update(new_modules)
self.update_module_list()
self.root.after(0, update_modules)
return new_entries
def append_new_logs(self, new_entries):
@@ -1424,4 +1425,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,34 @@
{
"manifest_version": 1,
"name": "Deep Think插件 (Deep Think Actions)",
"version": "1.0.0",
"description": "可以深度思考",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.11.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["deep", "think", "action", "built-in"],
"categories": ["Deep Think"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": true,
"plugin_type": "action_provider",
"components": [
{
"type": "action",
"name": "deep_think",
"description": "发送深度思考"
}
]
}
}

View File

@@ -0,0 +1,102 @@
from typing import List, Tuple, Type, Any
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
from src.person_info.person_info import Person
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.relation.relation import BuildRelationAction
from src.plugin_system.apis import llm_api
logger = get_logger("relation_actions")
class DeepThinkTool(BaseTool):
"""获取用户信息"""
name = "deep_think"
description = "深度思考,对某个知识,概念或逻辑问题进行全面且深入的思考,当面临复杂环境或重要问题时,使用此获得更好的解决方案。"
parameters = [
("question", ToolParamType.STRING, "需要思考的问题,越具体越好(从上下文中总结)", True, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
question: str = function_args.get("question") # type: ignore
print(f"question: {question}")
prompt = f"""
请你思考以下问题,以简洁的一段话回答:
{question}
"""
models = llm_api.get_available_models()
chat_model_config = models.get("replyer") # 使用字典访问方式
success, thinking_result, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="deep_think"
)
logger.info(f"{question}: {thinking_result}")
thinking_result =f"思考结果:{thinking_result}\n**注意** 因为你进行了深度思考,最后的回复内容可以回复的长一些,更加详细一些,不用太简洁。\n"
return {"content": thinking_result}
@register_plugin
class DeepThinkPlugin(BasePlugin):
"""关系动作插件
系统内置插件,提供基础的聊天交互功能:
- Reply: 回复动作
- NoReply: 不回复动作
- Emoji: 表情动作
注意插件基本信息优先从_manifest.json文件中读取
"""
# 插件基本信息
plugin_name: str = "deep_think" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件启用配置",
"components": "核心组件启用配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
}
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# --- 根据配置注册组件 ---
components = []
components.append((DeepThinkTool.get_tool_info(), DeepThinkTool))
return components

View File

@@ -0,0 +1,53 @@
{
"manifest_version": 1,
"name": "BetterEmoji",
"version": "1.0.0",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.4"
},
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
"keywords": ["emoji", "manage", "plugin"],
"categories": ["Examples", "Tutorial"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "emoji_manage",
"components": [
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
}
}

View File

@@ -0,0 +1,399 @@
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField,
ReplyContentType,
emoji_api,
)
from maim_message import Seg
from src.common.logger import get_logger
logger = get_logger("emoji_manage_plugin")
class AddEmojiCommand(BaseCommand):
command_name = "add_emoji"
command_description = "添加表情包"
command_pattern = r".*/emoji add.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False
# 注册找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 使用emoji_api注册表情包让API自动生成唯一文件名
result = await emoji_api.register_emoji(emoji_base64)
if result["success"]:
success_count += 1
description = result.get("description", "未知描述")
emotions = result.get("emotions", [])
replaced = result.get("replaced", False)
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
if description:
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "注册失败")
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"注册了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[add_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class ListEmojiCommand(BaseCommand):
"""列表表情包Command - 响应/emoji list命令"""
command_name = "emoji_list"
command_description = "列表表情包"
# === 命令设置(必须填写)===
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
async def execute(self) -> Tuple[bool, str, bool]:
"""执行列表表情包"""
from src.plugin_system.apis import emoji_api
import datetime
# 解析命令参数
import re
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
max_count = 10 # 默认显示10个
if match and match.group(1):
max_count = min(int(match.group(1)), 50) # 最多显示50个
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 获取表情包信息
emoji_count = emoji_api.get_count()
emoji_info = emoji_api.get_info()
# 构建返回消息
message_lines = [
f"📊 表情包统计信息 ({time_str})",
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
f"• 可用: {emoji_info['available_emojis']}",
]
if emoji_count == 0:
message_lines.append("\n❌ 暂无表情包")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return True, final_message, True
# 获取所有表情包
all_emojis = await emoji_api.get_all()
if not all_emojis:
message_lines.append("\n❌ 无法获取表情包列表")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return False, final_message, True
# 显示前N个表情包
display_emojis = all_emojis[:max_count]
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
for i, (_, description, emotion) in enumerate(display_emojis, 1):
# 截断过长的描述
short_desc = description[:50] + "..." if len(description) > 50 else description
message_lines.append(f"{i}. {short_desc} [{emotion}]")
# 如果还有更多表情包,显示总数
if len(all_emojis) > max_count:
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
final_message = "\n".join(message_lines)
# 直接发送文本消息
await self.send_text(final_message)
return True, final_message, True
class DeleteEmojiCommand(BaseCommand):
command_name = "delete_emoji"
command_description = "删除表情包"
command_pattern = r".*/emoji delete.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包图片
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False
# 删除找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 计算图片的哈希值来查找对应的表情包
import base64
import hashlib
# 确保base64字符串只包含ASCII字符
if isinstance(emoji_base64, str):
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
else:
emoji_base64_clean = str(emoji_base64)
# 计算哈希值
image_bytes = base64.b64decode(emoji_base64_clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest()
# 使用emoji_api删除表情包
result = await emoji_api.delete_emoji(emoji_hash)
if result["success"]:
success_count += 1
description = result.get("description", "未知描述")
count_before = result.get("count_before", 0)
count_after = result.get("count_after", 0)
emotions = result.get("emotions", [])
result_msg = f"表情包 {i + 1} 删除成功"
if description:
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
result_msg += f"\n表情包数量: {count_before}{count_after}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "删除失败")
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"删除了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@register_plugin
class EmojiManagePlugin(BasePlugin):
"""表情包管理插件 - 管理表情包"""
# 插件基本信息
plugin_name: str = "emoji_manage_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(RandomEmojis.get_command_info(), RandomEmojis),
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
]

View File

@@ -1,3 +1,4 @@
import random
from typing import List, Tuple, Type, Any
from src.plugin_system import (
BasePlugin,
@@ -12,7 +13,10 @@ from src.plugin_system import (
EventType,
MaiMessages,
ToolParamType,
ReplyContentType,
emoji_api,
)
from src.config.config import global_config
class CompareNumbersTool(BaseTool):
@@ -24,6 +28,7 @@ class CompareNumbersTool(BaseTool):
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
@@ -136,12 +141,80 @@ class PrintMessage(BaseEventHandler):
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None]:
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]:
"""执行打印消息事件处理"""
# 打印接收到的消息
if self.get_config("print_message.enabled", False):
print(f"接收到消息: {message.raw_message if message else '无效消息'}")
return True, True, "消息已打印", None
return True, True, "消息已打印", None, None
class ForwardMessages(BaseEventHandler):
"""
把接收到的消息转发到指定聊天ID
此组件是HYBRID消息和FORWARD消息的使用示例。
每收到10条消息就会以1%的概率使用HYBRID消息转发否则使用FORWARD消息转发。
"""
event_type = EventType.ON_MESSAGE
handler_name = "forward_messages_handler"
handler_description = "把接收到的消息转发到指定聊天ID"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0 # 用于计数转发的消息数量
self.messages: List[str] = []
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]:
if not message:
return True, True, None, None, None
stream_id = message.stream_id or ""
if message.plain_text:
self.messages.append(message.plain_text)
self.counter += 1
if self.counter % 10 == 0:
if random.random() < 0.01:
success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages])
else:
success = await self.send_forward(
stream_id,
[
(
str(global_config.bot.qq_account),
str(global_config.bot.nickname),
[(ReplyContentType.TEXT, msg)],
)
for msg in self.messages
],
)
if not success:
raise ValueError("转发消息失败")
self.messages = []
return True, True, None, None, None
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@@ -153,7 +226,7 @@ class HelloWorldPlugin(BasePlugin):
# 插件基本信息
plugin_name: str = "hello_world_plugin" # 内部标识符
enable_plugin: bool = True
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
@@ -164,8 +237,7 @@ class HelloWorldPlugin(BasePlugin):
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
@@ -185,6 +257,8 @@ class HelloWorldPlugin(BasePlugin):
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
]

View File

@@ -1,56 +1,37 @@
[project]
name = "MaiBot"
version = "0.8.1"
version = "0.11.0"
description = "MaiCore 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.12.14",
"apscheduler>=3.11.0",
"aiohttp-cors>=0.8.1",
"colorama>=0.4.6",
"cryptography>=45.0.5",
"customtkinter>=5.2.2",
"dotenv>=0.9.9",
"faiss-cpu>=1.11.0",
"fastapi>=0.116.0",
"google-genai>=1.39.1",
"jieba>=0.42.1",
"json-repair>=0.47.6",
"jsonlines>=4.0.0",
"maim-message>=0.3.8",
"maim-message",
"matplotlib>=3.10.3",
"networkx>=3.4.2",
"numpy>=2.2.6",
"openai>=1.95.0",
"packaging>=25.0",
"pandas>=2.3.1",
"peewee>=3.18.2",
"pillow>=11.3.0",
"psutil>=7.0.0",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
"pymongo>=4.13.2",
"pypinyin>=0.54.0",
"python-dateutil>=2.9.0.post0",
"python-dotenv>=1.1.1",
"python-igraph>=0.11.9",
"quick-algo>=0.1.3",
"reportportal-client>=5.6.5",
"requests>=2.32.4",
"rich>=14.0.0",
"ruff>=0.12.2",
"scikit-learn>=1.7.0",
"scipy>=1.15.3",
"seaborn>=0.13.2",
"setuptools>=80.9.0",
"strawberry-graphql[fastapi]>=0.275.5",
"structlog>=25.4.0",
"toml>=0.10.2",
"tomli>=2.2.1",
"tomli-w>=1.2.0",
"tomlkit>=0.13.3",
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"websockets>=15.0.1",
]

View File

@@ -1,271 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.txt -o requirements.lock
aenum==3.1.16
# via reportportal-client
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.12.14
# via
# -r requirements.txt
# maim-message
# reportportal-client
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.9.0
# via
# httpx
# openai
# starlette
apscheduler==3.11.0
# via -r requirements.txt
attrs==25.3.0
# via
# aiohttp
# jsonlines
certifi==2025.7.9
# via
# httpcore
# httpx
# reportportal-client
# requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.2
# via requests
click==8.2.1
# via uvicorn
colorama==0.4.6
# via
# -r requirements.txt
# click
# tqdm
contourpy==1.3.2
# via matplotlib
cryptography==45.0.5
# via
# -r requirements.txt
# maim-message
customtkinter==5.2.2
# via -r requirements.txt
cycler==0.12.1
# via matplotlib
darkdetect==0.8.0
# via customtkinter
distro==1.9.0
# via openai
dnspython==2.7.0
# via pymongo
dotenv==0.9.9
# via -r requirements.txt
faiss-cpu==1.11.0
# via -r requirements.txt
fastapi==0.116.0
# via
# -r requirements.txt
# maim-message
# strawberry-graphql
fonttools==4.58.5
# via matplotlib
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
graphql-core==3.2.6
# via strawberry-graphql
h11==0.16.0
# via
# httpcore
# uvicorn
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via openai
idna==3.10
# via
# anyio
# httpx
# requests
# yarl
igraph==0.11.9
# via python-igraph
jieba==0.42.1
# via -r requirements.txt
jiter==0.10.0
# via openai
joblib==1.5.1
# via scikit-learn
json-repair==0.47.6
# via -r requirements.txt
jsonlines==4.0.0
# via -r requirements.txt
kiwisolver==1.4.8
# via matplotlib
maim-message==0.3.8
# via -r requirements.txt
markdown-it-py==3.0.0
# via rich
matplotlib==3.10.3
# via
# -r requirements.txt
# seaborn
mdurl==0.1.2
# via markdown-it-py
multidict==6.6.3
# via
# aiohttp
# yarl
networkx==3.5
# via -r requirements.txt
numpy==2.3.1
# via
# -r requirements.txt
# contourpy
# faiss-cpu
# matplotlib
# pandas
# scikit-learn
# scipy
# seaborn
openai==1.95.0
# via -r requirements.txt
packaging==25.0
# via
# -r requirements.txt
# customtkinter
# faiss-cpu
# matplotlib
# strawberry-graphql
pandas==2.3.1
# via
# -r requirements.txt
# seaborn
peewee==3.18.2
# via -r requirements.txt
pillow==11.3.0
# via
# -r requirements.txt
# matplotlib
propcache==0.3.2
# via
# aiohttp
# yarl
psutil==7.0.0
# via -r requirements.txt
pyarrow==20.0.0
# via -r requirements.txt
pycparser==2.22
# via cffi
pydantic==2.11.7
# via
# -r requirements.txt
# fastapi
# maim-message
# openai
pydantic-core==2.33.2
# via pydantic
pygments==2.19.2
# via rich
pymongo==4.13.2
# via -r requirements.txt
pyparsing==3.2.3
# via matplotlib
pypinyin==0.54.0
# via -r requirements.txt
python-dateutil==2.9.0.post0
# via
# -r requirements.txt
# matplotlib
# pandas
# strawberry-graphql
python-dotenv==1.1.1
# via
# -r requirements.txt
# dotenv
python-igraph==0.11.9
# via -r requirements.txt
python-multipart==0.0.20
# via strawberry-graphql
pytz==2025.2
# via pandas
quick-algo==0.1.3
# via -r requirements.txt
reportportal-client==5.6.5
# via -r requirements.txt
requests==2.32.4
# via
# -r requirements.txt
# reportportal-client
rich==14.0.0
# via -r requirements.txt
ruff==0.12.2
# via -r requirements.txt
scikit-learn==1.7.0
# via -r requirements.txt
scipy==1.16.0
# via
# -r requirements.txt
# scikit-learn
seaborn==0.13.2
# via -r requirements.txt
setuptools==80.9.0
# via -r requirements.txt
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# openai
starlette==0.46.2
# via fastapi
strawberry-graphql==0.275.5
# via -r requirements.txt
structlog==25.4.0
# via -r requirements.txt
texttable==1.7.0
# via igraph
threadpoolctl==3.6.0
# via scikit-learn
toml==0.10.2
# via -r requirements.txt
tomli==2.2.1
# via -r requirements.txt
tomli-w==1.2.0
# via -r requirements.txt
tomlkit==0.13.3
# via -r requirements.txt
tqdm==4.67.1
# via
# -r requirements.txt
# openai
typing-extensions==4.14.1
# via
# fastapi
# openai
# pydantic
# pydantic-core
# strawberry-graphql
# typing-inspection
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via
# pandas
# tzlocal
tzlocal==5.3.1
# via apscheduler
urllib3==2.5.0
# via
# -r requirements.txt
# requests
uvicorn==0.35.0
# via
# -r requirements.txt
# maim-message
websockets==15.0.1
# via
# -r requirements.txt
# maim-message
yarl==1.20.1
# via aiohttp

View File

@@ -1,49 +1,28 @@
APScheduler
Pillow
aiohttp
aiohttp-cors
colorama
customtkinter
dotenv
faiss-cpu
fastapi
jieba
jsonlines
maim_message
quick_algo
matplotlib
networkx
numpy
openai
pandas
peewee
pyarrow
pydantic
pypinyin
python-dateutil
python-dotenv
python-igraph
pymongo
requests
ruff
scipy
setuptools
toml
tomli
tomli_w
tomlkit
tqdm
urllib3
uvicorn
websockets
strawberry-graphql[fastapi]
packaging
rich
psutil
cryptography
json-repair
reportportal-client
scikit-learn
seaborn
structlog
google.genai
aiohttp>=3.12.14
aiohttp-cors>=0.8.1
colorama>=0.4.6
faiss-cpu>=1.11.0
fastapi>=0.116.0
google-genai>=1.39.1
jieba>=0.42.1
json-repair>=0.47.6
maim-message
matplotlib>=3.10.3
numpy>=2.2.6
openai>=1.95.0
pandas>=2.3.1
peewee>=3.18.2
pillow>=11.3.0
pyarrow>=20.0.0
pydantic>=2.11.7
pypinyin>=0.54.0
python-dotenv>=1.1.1
quick-algo>=0.1.3
rich>=14.0.0
ruff>=0.12.2
setuptools>=80.9.0
structlog>=25.4.0
toml>=0.10.2
tomlkit>=0.13.3
urllib3>=2.5.0
uvicorn>=0.35.0

389
scripts/build_io_pairs.py Normal file
View File

@@ -0,0 +1,389 @@
import argparse
import json
import random
import re
import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60
def clean_output_text(text: str) -> str:
"""
清理输出文本,移除表情包和回复内容
- 移除 [表情包:...] 格式的内容
- 移除 [回复...] 格式的内容
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
# 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip()
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err: Optional[Exception] = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e: # noqa: BLE001
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def fetch_messages_between(
start_ts: float,
end_ts: float,
platform: Optional[str] = None,
) -> List[DatabaseMessages]:
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
if platform:
filter_query["chat_info_platform"] = platform
# 当 limit==0 时sort 生效,这里按时间升序
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
groups: Dict[str, List[DatabaseMessages]] = {}
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
"""
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
"""
if not bucket:
raise ValueError("bucket 为空,无法合并")
latest = bucket[-1]
merged_texts: List[str] = []
for m in bucket:
text = m.processed_plain_text or ""
if text:
merged_texts.append(text)
merged = DatabaseMessages(
# 其他信息采用最新消息
message_id=latest.message_id,
time=latest.time,
chat_id=latest.chat_id,
reply_to=latest.reply_to,
interest_value=latest.interest_value,
key_words=latest.key_words,
key_words_lite=latest.key_words_lite,
is_mentioned=latest.is_mentioned,
is_at=latest.is_at,
reply_probability_boost=latest.reply_probability_boost,
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
display_message=latest.display_message,
priority_mode=latest.priority_mode,
priority_info=latest.priority_info,
additional_config=latest.additional_config,
is_emoji=latest.is_emoji,
is_picid=latest.is_picid,
is_command=latest.is_command,
is_notify=latest.is_notify,
selected_expressions=latest.selected_expressions,
user_id=latest.user_info.user_id,
user_nickname=latest.user_info.user_nickname,
user_cardname=latest.user_info.user_cardname,
user_platform=latest.user_info.platform,
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
chat_info_user_id=latest.chat_info.user_info.user_id,
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
chat_info_user_platform=latest.chat_info.user_info.platform,
chat_info_stream_id=latest.chat_info.stream_id,
chat_info_platform=latest.chat_info.platform,
chat_info_create_time=latest.chat_info.create_time,
chat_info_last_active_time=latest.chat_info.last_active_time,
)
return merged
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
if not messages:
return []
merged: List[DatabaseMessages] = []
bucket: List[DatabaseMessages] = []
def flush_bucket() -> None:
nonlocal bucket
if bucket:
merged.append(_merge_bucket_to_message(bucket))
bucket = []
for msg in messages:
if not bucket:
bucket = [msg]
continue
last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id)
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
if same_user and close_enough:
bucket.append(msg)
else:
flush_bucket()
bucket = [msg]
flush_bucket()
return merged
def build_pairs_for_chat(
original_messages: List[DatabaseMessages],
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
target_user_id: Optional[str] = None,
) -> List[Tuple[str, str, str]]:
"""
对每条合并后的消息作为 output从其前面取 20-30 条(可配置)的原始消息作为 input。
input 使用原始未合并的消息构建上下文。
output 使用合并后消息的 processed_plain_text。
如果指定了 target_user_id则只处理该用户的消息作为 output。
"""
pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and
original_messages[original_idx].time < merged_msg.time):
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and
original_messages[original_idx].time == merged_msg.time):
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window)
context_msgs = original_messages[start:original_idx]
# 使用原始未合并消息构建 input
input_str = build_readable_messages(
messages=context_msgs,
timestamp_mode="normal_no_YMD",
show_actions=False,
show_pic=True,
)
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
output_text = merged_msg.processed_plain_text or ""
output_text = clean_output_text(output_text)
output_id = merged_msg.message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
user_id: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
# 获取所有消息不按user_id过滤这样input上下文可以包含所有用户的消息
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
all_pairs.extend(pairs)
return all_pairs
def main(argv: Optional[List[str]] = None) -> int:
# 若未提供参数,则进入交互模式
if argv is None:
argv = sys.argv[1:]
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表支持按用户ID筛选消息")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
"--output",
default=None,
help="输出保存路径,支持 .jsonl每行 {input, output}若不指定则打印到stdout",
)
args = parser.parse_args(argv)
start_ts = parse_datetime_to_timestamp(args.start)
end_ts = parse_datetime_to_timestamp(args.end)
if end_ts <= start_ts:
raise ValueError("结束时间必须大于起始时间")
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
with open(args.output, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {args.output}")
else:
# 打印到 stdout
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
return 0
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
suffix = f"[{default}]" if default not in (None, "") else ""
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
if value == "" and default is not None:
return default
return value
def run_interactive() -> int:
print("进入交互模式直接回车采用默认值。时间格式例如2025-09-28 00:00:00 或 2025-09-28")
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
user_id = _prompt_with_default("用户ID可留空表示不限", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
except Exception:
print("上下文条数输入有误,使用默认 20/30")
min_ctx, max_ctx = 20, 30
output_path = _prompt_with_default("输出路径(.jsonl可留空打印到控制台", "")
if not start_str or not end_str:
print("必须提供起始与结束时间。")
return 2
try:
start_ts = parse_datetime_to_timestamp(start_str)
end_ts = parse_datetime_to_timestamp(end_str)
except Exception as e: # noqa: BLE001
print(f"时间解析失败:{e}")
return 2
if end_ts <= start_ts:
print("结束时间必须大于起始时间。")
return 2
if max_ctx < min_ctx:
print("最多条数不能小于最少条数。")
return 2
platform_val = platform if platform != "" else None
user_id_val = user_id if user_id != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {output_path}")
else:
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
print(f"总计 {len(pairs)} 条。")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,334 @@
import time
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
from typing import List, Tuple
import numpy as np
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression, ChatStreams
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def get_expression_data() -> List[Tuple[float, float, str, str]]:
"""获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表"""
expressions = Expression.select()
data = []
for expr in expressions:
# 如果create_date为空跳过该记录
if expr.create_date is None:
continue
data.append((
expr.create_date,
expr.count,
expr.chat_id,
expr.type
))
return data
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 分离数据
create_dates = [item[0] for item in data]
counts = [item[1] for item in data]
chat_ids = [item[2] for item in data]
expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
# 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加网格
ax.grid(True, alpha=0.3)
# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label('数据点顺序', fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
print(f"平均使用次数: {np.mean(counts):.2f}")
print(f"中位数使用次数: {np.median(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按聊天分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按chat_id分组
chat_groups = {}
for item in data:
chat_id = item[2]
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(14, 10))
# 为每个聊天分配不同颜色
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
create_dates = [item[0] for item in chat_data]
counts = [item[1] for item in chat_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
chat_name = get_chat_name(chat_id)
# 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
edgecolors='black', linewidth=0.5)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id)
counts = [item[1] for item in chat_data]
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n分组散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按表达式类型分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按type分组
type_groups = {}
for item in data:
expr_type = item[3]
if expr_type not in type_groups:
type_groups[expr_type] = []
type_groups[expr_type].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 为每个类型分配不同颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
for i, (expr_type, type_data) in enumerate(type_groups.items()):
create_dates = [item[0] for item in type_data]
counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
edgecolors='black', linewidth=0.5)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n类型散点图已保存到: {save_path}")
# 显示图片
plt.show()
def main():
"""主函数"""
print("开始分析表达式数据...")
# 获取数据
data = get_expression_data()
if not data:
print("没有找到有效的表达式数据create_date不为空的数据")
return
print(f"找到 {len(data)} 条有效数据")
# 创建输出目录
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
# 生成时间戳用于文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 创建基础散点图
print("\n1. 创建基础散点图...")
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
# 2. 创建按聊天分组的散点图
print("\n2. 创建按聊天分组的散点图...")
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
# 3. 创建按类型分组的散点图
print("\n3. 创建按类型分组的散点图...")
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
print("\n分析完成!")
if __name__ == "__main__":
main()

View File

@@ -5,12 +5,11 @@ from typing import Dict, List
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
@@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
@@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
'0-1天': 0,
'1-3天': 0,
'3-7天': 0,
'7-14天': 0,
'14-30天': 0,
'30-60天': 0,
'60-90天': 0,
'90+天': 0
"0-1天": 0,
"1-3天": 0,
"3-7天": 0,
"7-14天": 0,
"14-30天": 0,
"30-60天": 0,
"60-90天": 0,
"90+天": 0,
}
for expr in expressions:
diff_days = (now - expr.last_active_time) / (24*3600)
diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1:
distribution['0-1天'] += 1
distribution["0-1天"] += 1
elif diff_days < 3:
distribution['1-3天'] += 1
distribution["1-3天"] += 1
elif diff_days < 7:
distribution['3-7天'] += 1
distribution["3-7天"] += 1
elif diff_days < 14:
distribution['7-14天'] += 1
distribution["7-14天"] += 1
elif diff_days < 30:
distribution['14-30天'] += 1
distribution["14-30天"] += 1
elif diff_days < 60:
distribution['30-60天'] += 1
distribution["30-60天"] += 1
elif diff_days < 90:
distribution['60-90天'] += 1
distribution["60-90天"] += 1
else:
distribution['90+天'] += 1
distribution["90+天"] += 1
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values"""
distribution = {
'0-1': 0,
'1-2': 0,
'2-3': 0,
'3-4': 0,
'4-5': 0,
'5-10': 0,
'10+': 0
}
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
cnt = expr.count
if cnt < 1:
distribution['0-1'] += 1
distribution["0-1"] += 1
elif cnt < 2:
distribution['1-2'] += 1
distribution["1-2"] += 1
elif cnt < 3:
distribution['2-3'] += 1
distribution["2-3"] += 1
elif cnt < 4:
distribution['3-4'] += 1
distribution["3-4"] += 1
elif cnt < 5:
distribution['4-5'] += 1
distribution["4-5"] += 1
elif cnt < 10:
distribution['5-10'] += 1
distribution["5-10"] += 1
else:
distribution['10+'] += 1
distribution["10+"] += 1
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return (Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.count.desc())
.limit(top_n))
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics"""
time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}")
print("\n上次激活时间分布:")
for period, count in time_dist.items():
print(f"{period}: {count} ({count/total*100:.2f}%)")
print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:")
for range_, count in count_dist.items():
print(f"{range_}: {count} ({count/total*100:.2f}%)")
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}")
if chat_total == 0:
print("该聊天没有表达式数据")
return
# Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:")
for period, count in time_dist.items():
if count > 0:
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:")
for range_, count in count_dist.items():
if count > 0:
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions
print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10)
@@ -163,32 +151,32 @@ def interactive_menu() -> None:
if not expressions:
print("数据库中没有找到表达式")
return
total = len(expressions)
# Get unique chat_ids and their names
chat_ids = list(set(expr.chat_id for expr in expressions))
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True:
print("\n" + "="*50)
print("\n" + "=" * 50)
print("表达式统计分析")
print("="*50)
print("=" * 50)
print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == 'q':
if choice.lower() == "q":
print("再见!")
break
try:
choice_num = int(choice)
if choice_num == 0:
@@ -200,9 +188,9 @@ def interactive_menu() -> None:
print("无效的选择,请重新输入")
except ValueError:
print("请输入有效的数字")
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()
interactive_menu()

View File

@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入")
def ensure_openie_dir():
"""确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR):
@@ -253,7 +254,7 @@ def main():
# 没有运行的事件循环,创建新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步主函数
loop.run_until_complete(main_async())

View File

@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR):
@@ -48,6 +50,7 @@ def ensure_dirs():
os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"

View File

@@ -1,287 +0,0 @@
import time
import sys
import os
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
'0.000-0.010': 0,
'0.010-0.050': 0,
'0.050-0.100': 0,
'0.100-0.500': 0,
'0.500-1.000': 0,
'1.000-2.000': 0,
'2.000-5.000': 0,
'5.000-10.000': 0,
'10.000+': 0
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution['0.000-0.010'] += 1
elif value < 0.050:
distribution['0.010-0.050'] += 1
elif value < 0.100:
distribution['0.050-0.100'] += 1
elif value < 0.500:
distribution['0.100-0.500'] += 1
elif value < 1.000:
distribution['0.500-1.000'] += 1
elif value < 2.000:
distribution['1.000-2.000'] += 1
elif value < 5.000:
distribution['2.000-5.000'] += 1
elif value < 10.000:
distribution['5.000-10.000'] += 1
else:
distribution['10.000+'] += 1
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
if not values:
return {
'count': 0,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
values.sort()
count = len(values)
return {
'count': count,
'min': min(values),
'max': max(values),
'avg': sum(values) / count,
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
).count()
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
elif choice == "2":
return now - 3*24*3600, now
elif choice == "3":
return now - 7*24*3600, now
elif choice == "4":
return now - 30*24*3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where(
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
)
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats['count']
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name}: {count} ({percentage:.2f}%)")
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "="*50)
print("Interest Value 分析工具")
print("="*50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@@ -2,6 +2,7 @@ import os
from pathlib import Path
import sys # 新增系统模块导入
from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
all_paragraphs.extend(paragraphs)
return all_paragraphs
def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件
@@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]:
raw_data.append(item)
logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data
return sha256_list, raw_data

View File

@@ -1,394 +0,0 @@
import time
import sys
import os
import re
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
def contains_emoji_or_image_tags(text: str) -> bool:
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
if not text:
return False
# 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r'\[表情包[^\]]*\]'
image_pattern = r'\[图片[^\]]*\]'
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
def clean_reply_text(text: str) -> str:
"""Remove reply references like [回复 xxxx...] from text"""
if not text:
return text
# 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
# 去除多余的空白字符
cleaned_text = cleaned_text.strip()
return cleaned_text
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
'0': 0, # 空文本
'1-5': 0, # 极短文本
'6-10': 0, # 很短文本
'11-20': 0, # 短文本
'21-30': 0, # 较短文本
'31-50': 0, # 中短文本
'51-70': 0, # 中等文本
'71-100': 0, # 较长文本
'101-150': 0, # 长文本
'151-200': 0, # 很长文本
'201-300': 0, # 超长文本
'301-500': 0, # 极长文本
'501-1000': 0, # 巨长文本
'1000+': 0 # 超巨长文本
}
for msg in messages:
if msg.processed_plain_text is None:
continue
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
if length == 0:
distribution['0'] += 1
elif length <= 5:
distribution['1-5'] += 1
elif length <= 10:
distribution['6-10'] += 1
elif length <= 20:
distribution['11-20'] += 1
elif length <= 30:
distribution['21-30'] += 1
elif length <= 50:
distribution['31-50'] += 1
elif length <= 70:
distribution['51-70'] += 1
elif length <= 100:
distribution['71-100'] += 1
elif length <= 150:
distribution['101-150'] += 1
elif length <= 200:
distribution['151-200'] += 1
elif length <= 300:
distribution['201-300'] += 1
elif length <= 500:
distribution['301-500'] += 1
elif length <= 1000:
distribution['501-1000'] += 1
else:
distribution['1000+'] += 1
return distribution
def get_text_length_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for processed_plain_text length"""
lengths = []
null_count = 0
excluded_count = 0 # 被排除的消息数量
for msg in messages:
if msg.processed_plain_text is None:
null_count += 1
elif contains_emoji_or_image_tags(msg.processed_plain_text):
# 排除包含表情包或图片标记的消息
excluded_count += 1
else:
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
lengths.append(len(cleaned_text))
if not lengths:
return {
'count': 0,
'null_count': null_count,
'excluded_count': excluded_count,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
lengths.sort()
count = len(lengths)
return {
'count': count,
'null_count': null_count,
'excluded_count': excluded_count,
'min': min(lengths),
'max': max(lengths),
'avg': sum(lengths) / count,
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id排除特殊类型消息
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
).count()
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
elif choice == "2":
return now - 3*24*3600, now
elif choice == "3":
return now - 7*24*3600, now
elif choice == "4":
return now - 30*24*3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
for msg in messages:
if msg.processed_plain_text is not None:
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
chat_name = get_chat_name(msg.chat_id)
time_str = format_timestamp(msg.time)
# 截取前100个字符作为预览
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
message_lengths.append((chat_name, length, time_str, preview))
# 按长度排序取前N个
message_lengths.sort(key=lambda x: x[1], reverse=True)
return message_lengths[:top_n]
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
"""Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息
query = Messages.select().where(
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
)
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_text_length_distribution(messages)
stats = get_text_length_stats(messages)
top_longest = get_top_longest_messages(messages, 10)
# 显示结果
print("\n=== Processed Plain Text 长度分析结果 ===")
print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"总消息数量: {len(messages)}")
print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}")
if stats['count'] > 0:
print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:")
total = stats['count']
if total > 0:
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
# 显示最长的消息
if top_longest:
print(f"\n最长的 {len(top_longest)} 条消息:")
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
print(f"{i}. [{chat_name}] {time_str}")
print(f" 长度: {length} 字符")
print(f" 预览: {preview}")
print()
def interactive_menu() -> None:
"""Interactive menu for text length analysis"""
while True:
print("\n" + "="*50)
print("Processed Plain Text 长度分析工具")
print("="*50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到聊天数据")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_text_lengths(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@@ -0,0 +1,570 @@
import asyncio
import time
import traceback
import random
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel
ERROR_LOOP_INFO = {
"loop_plan_info": {
"action_result": {
"action_type": "error",
"action_data": {},
"reasoning": "循环处理失败",
},
},
"loop_action_info": {
"action_taken": False,
"reply_text": "",
"command": "",
"taken_time": time.time(),
},
}
install(extra_lines=3)
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
logger = get_logger("bc") # Logger Name Changed
class BrainChatting:
"""
管理一个连续的私聊Brain Chat循环
用于在特定聊天流中生成回复。
"""
def __init__(self, chat_id: str):
"""
BrainChatting 初始化函数
参数:
chat_id: 聊天流唯一标识符(如stream_id)
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
performance_version: 性能记录版本号,用于区分不同启动版本
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.action_manager = ActionManager()
self.action_planner = BrainPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
# 循环控制内部状态
self.running: bool = False
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
# 添加循环信息管理相关的属性
self.history_loop: List[CycleDetail] = []
self._cycle_counter = 0
self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 2
self.more_plan = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
# 如果循环已经激活,直接返回
if self.running:
logger.debug(f"{self.log_prefix} BrainChatting 已激活,无需重复启动")
return
try:
# 标记为活动状态,防止重复启动
self.running = True
self._loop_task = asyncio.create_task(self._main_chat_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
logger.info(f"{self.log_prefix} BrainChatting 启动完成")
except Exception as e:
# 启动失败时重置状态
self.running = False
self._loop_task = None
logger.error(f"{self.log_prefix} BrainChatting 启动失败: {e}")
raise
def _handle_loop_completion(self, task: asyncio.Task):
"""当 _hfc_loop 任务完成时执行的回调。"""
try:
if exception := task.exception():
logger.error(f"{self.log_prefix} BrainChatting: 脱离了聊天(异常): {exception}")
logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
logger.info(f"{self.log_prefix} BrainChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} BrainChatting: 结束了聊天")
def start_cycle(self) -> Tuple[Dict[str, float], str]:
self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(self._cycle_counter)
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
cycle_timers = {}
return cycle_timers, self._current_cycle_detail.thinking_id
def end_cycle(self, loop_info, cycle_timers):
self._current_cycle_detail.set_loop_info(loop_info)
self.history_loop.append(self._current_cycle_detail)
self._current_cycle_detail.timers = cycle_timers
self._current_cycle_detail.end_time = time.time()
def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}" # type: ignore
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def _loopbody(self): # sourcery skip: hoist-if-from-if
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
limit=20,
limit_mode="latest",
filter_mai=True,
filter_command=True,
)
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
await self._observe(recent_messages_list=recent_messages_list)
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
return True
return True
async def _send_and_store_reply(
self,
response_set: "ReplySetModel",
action_message: "DatabaseMessages",
cycle_timers: Dict[str, float],
thinking_id,
actions,
selected_expressions: Optional[List[int]] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.chat_info.platform
if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform=platform, user_id=action_message.user_info.user_id)
person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=action_prompt_display,
action_done=True,
thinking_id=thinking_id,
action_data={"reply_text": reply_text},
action_name="reply",
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": actions,
},
"loop_action_info": {
"action_taken": True,
"reply_text": reply_text,
"command": "",
"taken_time": time.time(),
},
}
return loop_info, reply_text, cycle_timers
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
# 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {}
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
)
for action in action_to_use_info
]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
for result in results:
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["reply_text"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"taken_time": time.time(),
}
)
_reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"taken_time": time.time(),
},
}
_reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
return True
async def _main_chat_loop(self):
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
try:
while self.running:
# 主循环
success = await self._loopbody()
await asyncio.sleep(0.1)
if not success:
break
except asyncio.CancelledError:
# 设置了关闭标志位后被取消是正常流程
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
except Exception:
logger.error(f"{self.log_prefix} 麦麦聊天意外错误将于3s后尝试重新启动")
print(traceback.format_exc())
await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _handle_action(
self,
action: str,
reasoning: str,
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
action_message: Optional["DatabaseMessages"] = None,
) -> tuple[bool, str, str]:
"""
处理规划动作,使用动作工厂创建相应的动作处理器
参数:
action: 动作类型
reasoning: 决策理由
action_data: 动作数据,包含不同动作需要的参数
cycle_timers: 计时器字典
thinking_id: 思考ID
返回:
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
"""
try:
# 使用工厂创建动作处理器实例
try:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
action_reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
log_prefix=self.log_prefix,
action_message=action_message,
)
except Exception as e:
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
traceback.print_exc()
return False, "", ""
if not action_handler:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, "", ""
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.run()
success, action_text = result
command = ""
return success, action_text, command
except Exception as e:
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, "", ""
async def _send_response(
self,
reply_set: "ReplySetModel",
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 4)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
reply_text = ""
first_replied = False
for reply_content in reply_set.reply_data:
if reply_content.content_type != ReplyContentType.TEXT:
continue
data: str = reply_content.content # type: ignore
if not first_replied:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message=message_data,
set_reply=need_reply,
typing=False,
selected_expressions=selected_expressions,
)
first_replied = True
else:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message=message_data,
set_reply=False,
typing=True,
selected_expressions=selected_expressions,
)
reply_text += data
return reply_text
async def _execute_action(
self,
action_planner_info: ActionPlannerInfo,
chosen_action_plan_infos: List[ActionPlannerInfo],
thinking_id: str,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
):
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
if action_planner_info.action_type == "no_reply":
# 直接处理no_reply逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_reply信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_reply",
)
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type == "reply":
try:
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=action_planner_info.reasoning or "",
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
# 其他动作
else:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
)
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_planner_info.action_type,
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e),
}

View File

@@ -0,0 +1,538 @@
import json
import time
import traceback
import random
import re
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
get_actions_by_timestamp_with_chat,
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("planner")
install(extra_lines=3)
def init_prompt():
Prompt(
"""
{time_block}
{name_block}
你的兴趣是:{interest}
{chat_context_description},以下是具体的聊天内容
**聊天内容**
{chat_content_block}
**动作记录**
{actions_before_now_block}
**可用的action**
reply
动作描述:
进行回复,你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
{{
"action": "reply",
"target_message_id":"想要回复的消息id",
"reason":"回复的原因"
}}
no_reply
动作描述:
等待,保持沉默,等待对方发言
{{
"action": "no_reply",
}}
{action_options_text}
请选择合适的action并说明触发action的消息id和选择该action的原因。消息id格式:m+数字
先输出你的选择思考理由再输出你选择的action理由是一段平文本不要分点精简。
**动作选择要求**
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
{plan_style}
{moderation_prompt}
请选择所有符合使用要求的action动作用json格式输出如果输出多个json每个json都要单独用```json包裹你可以重复使用同一个动作或不同动作:
**示例**
// 理由文本
```json
{{
"action":"动作名",
"target_message_id":"触发动作的消息id",
//对应参数
}}
```
```json
{{
"action":"动作名",
"target_message_id":"触发动作的消息id",
//对应参数
}}
```
""",
"brain_planner_prompt",
)
Prompt(
"""
{action_name}
动作描述:{action_description}
使用条件:
{action_require}
{{
"action": "{action_name}",{action_parameters},
"target_message_id":"触发action的消息id",
"reason":"触发action的原因"
}}
""",
"brain_action_prompt",
)
class BrainPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="planner"
) # 用于动作规划
self.last_obs_time_mark = 0.0
def find_message_by_id(
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
) -> Optional["DatabaseMessages"]:
# sourcery skip: use-next
"""
根据message_id从message_id_list中查找对应的原始消息
Args:
message_id: 要查找的消息ID
message_id_list: 消息ID列表格式为[{'id': str, 'message': dict}, ...]
Returns:
找到的原始消息字典如果未找到则返回None
"""
for item in message_id_list:
if item[0] == message_id:
return item[1]
return None
def _parse_single_action(
self,
action_json: dict,
message_id_list: List[Tuple[str, "DatabaseMessages"]],
current_available_actions: List[Tuple[str, ActionInfo]],
) -> List[ActionPlannerInfo]:
"""解析单个action JSON并返回ActionPlannerInfo列表"""
action_planner_infos = []
try:
action = action_json.get("action", "no_reply")
reasoning = action_json.get("reason", "未提供原因")
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_reply动作需要target_message_id
target_message = None
if target_message_id := action_json.get("target_message_id"):
# 根据target_message_id查找原始消息
target_message = self.find_message_by_id(target_message_id, message_id_list)
if target_message is None:
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
# 选择最新消息作为target_message
target_message = message_id_list[-1][1]
else:
target_message = message_id_list[-1][1]
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id使用最新消息作为target_message")
# 验证action是否可用
available_action_names = [action_name for action_name, _ in current_available_actions]
internal_action_names = ["no_reply", "reply", "wait_time"]
if action not in internal_action_names and action not in available_action_names:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
)
reasoning = (
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
)
action = "no_reply"
# 创建ActionPlannerInfo对象
# 将列表转换为字典格式
available_actions_dict = dict(current_available_actions)
action_planner_infos.append(
ActionPlannerInfo(
action_type=action,
reasoning=reasoning,
action_data=action_data,
action_message=target_message,
available_actions=available_actions_dict,
)
)
except Exception as e:
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
# 将列表转换为字典格式
available_actions_dict = dict(current_available_actions)
action_planner_infos.append(
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"解析单个action时出错: {e}",
action_data={},
action_message=None,
available_actions=available_actions_dict,
)
)
return action_planner_infos
async def plan(
self,
available_actions: Dict[str, ActionInfo],
loop_start_time: float = 0.0,
) -> List[ActionPlannerInfo]:
# sourcery skip: use-named-expression
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
"""
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.last_obs_time_mark,
truncate=True,
show_actions=True,
)
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
messages=message_list_before_now_short,
timestamp_mode="normal_no_YMD",
truncate=False,
show_actions=False,
)
self.last_obs_time_mark = time.time()
# 获取必要信息
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
# 应用激活类型过滤
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
# 构建包含所有动作的提示词
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=filtered_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
)
# 调用LLM获取决策
actions = await self._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
available_actions=available_actions,
loop_start_time=loop_start_time,
)
return actions
async def build_planner_prompt(
self,
is_group_chat: bool,
chat_target_info: Optional["TargetPersonInfo"],
current_available_actions: Dict[str, ActionInfo],
message_id_list: List[Tuple[str, "DatabaseMessages"]],
chat_content_block: str = "",
interest: str = "",
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
# 获取最近执行过的动作
actions_before_now = get_actions_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=time.time() - 600,
timestamp_end=time.time(),
limit=6,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
if actions_before_now_block:
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
else:
actions_before_now_block = ""
if chat_target_info:
# 构建聊天上下文描述
chat_context_description = (
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
)
# 构建动作选项块
action_options_block = await self._build_action_options_block(current_available_actions)
# 其他信息
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
bot_name = global_config.bot.nickname
bot_nickname = (
f",也可以叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
)
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
# 获取主规划器模板并填充
planner_prompt_template = await global_prompt_manager.get_prompt_async("brain_planner_prompt")
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block,
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
name_block=name_block,
interest=interest,
plan_style=global_config.personality.private_plan_style,
)
return prompt, message_id_list
except Exception as e:
logger.error(f"构建 Planner 提示词时出错: {e}")
logger.error(traceback.format_exc())
return "构建 Planner Prompt 时出错", []
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
"""
获取 Planner 需要的必要信息
"""
is_group_chat = True
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[action_name]
else:
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
return is_group_chat, chat_target_info, current_available_actions
def _filter_actions_by_activation_type(
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
) -> Dict[str, ActionInfo]:
"""根据激活类型过滤动作"""
filtered_actions = {}
for action_name, action_info in available_actions.items():
if action_info.activation_type == ActionActivationType.NEVER:
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
continue
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
filtered_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.RANDOM:
if random.random() < action_info.random_activation_probability:
filtered_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.KEYWORD:
if action_info.activation_keywords:
for keyword in action_info.activation_keywords:
if keyword in chat_content_block:
filtered_actions[action_name] = action_info
break
else:
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
return filtered_actions
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
# sourcery skip: use-join
"""构建动作选项块"""
if not current_available_actions:
return ""
action_options_block = ""
for action_name, action_info in current_available_actions.items():
# 构建参数文本
param_text = ""
if action_info.action_parameters:
param_text = "\n"
for param_name, param_description in action_info.action_parameters.items():
param_text += f' "{param_name}":"{param_description}"\n'
param_text = param_text.rstrip("\n")
# 构建要求文本
require_text = ""
for require_item in action_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
# 获取动作提示模板并填充
using_action_prompt = await global_prompt_manager.get_prompt_async("brain_action_prompt")
using_action_prompt = using_action_prompt.format(
action_name=action_name,
action_description=action_info.description,
action_parameters=param_text,
action_require=require_text,
)
action_options_block += using_action_prompt
return action_options_block
async def _execute_main_planner(
self,
prompt: str,
message_id_list: List[Tuple[str, "DatabaseMessages"]],
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> List[ActionPlannerInfo]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
try:
# 调用LLM
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if reasoning_content:
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
else:
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
if reasoning_content:
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return [
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_data={},
action_message=None,
available_actions=available_actions,
)
]
# 解析LLM响应
if llm_content:
try:
if json_objects := self._extract_json_from_markdown(llm_content):
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
else:
# 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
traceback.print_exc()
else:
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
# 添加循环开始时间到所有非no_reply动作
for action in actions:
action.action_data = action.action_data or {}
action.action_data["loop_start_time"] = loop_start_time
logger.debug(
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
return actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply"""
return [
ActionPlannerInfo(
action_type="no_reply",
reasoning=reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
)
]
def _extract_json_from_markdown(self, content: str) -> List[dict]:
# sourcery skip: for-append-to-extend
"""从Markdown格式的内容中提取JSON对象"""
json_objects = []
# 使用正则表达式查找```json包裹的JSON内容
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, content, re.DOTALL)
for match in matches:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
json_obj = json.loads(repair_json(json_str))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except Exception as e:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects
init_prompt()

View File

@@ -708,7 +708,7 @@ class EmojiManager:
if not emoji.is_deleted and emoji.hash == emoji_hash:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
"""根据哈希值获取已注册表情包的情感标签列表
@@ -731,7 +731,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',')
return emoji_record.emotion.split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")

View File

@@ -3,21 +3,25 @@ import random
import json
import os
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
import traceback
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from json_repair import repair_json
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_DAYS = 15 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
@@ -46,48 +50,63 @@ def init_prompt() -> None:
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。
例如:
"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"时,使用"对对对"
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契"使用"懂的都懂"
"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"时,使用"对对对"
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
match_expression_context_prompt = """
**聊天内容**
{chat_str}
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context。
如果找不到原句,就不输出该句的匹配结果。
以json格式输出
格式如下:
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
...
现在请你输出匹配结果:
"""
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
def can_learn_for_chat(self) -> bool:
"""
检查指定聊天流是否允许学习表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许学习
"""
try:
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
return enable_learning
except Exception as e:
logger.error(f"检查学习权限失败: {e}")
return False
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 150 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
@@ -99,27 +118,13 @@ class ExpressionLearner:
Returns:
bool: 是否应该触发学习
"""
current_time = time.time()
# 获取该聊天流的学习强度
try:
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
except Exception as e:
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
return False
# 检查是否允许学习
if not enable_learning:
if not self.enable_learning:
return False
# 根据学习强度计算最短学习时间间隔
min_interval = self.min_learning_interval / learning_intensity
# 检查时间间隔
time_diff = current_time - self.last_learning_time
if time_diff < min_interval:
time_diff = time.time() - self.last_learning_time
if time_diff < self.min_learning_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
@@ -165,6 +170,7 @@ class ExpressionLearner:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return False
def _apply_global_decay_to_database(self, current_time: float) -> None:
@@ -229,39 +235,43 @@ class ExpressionLearner:
"""
学习并存储表达方式
"""
# 检查是否允许在此聊天流中学习(在函数最前面检查)
if not self.can_learn_for_chat():
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
return []
res = await self.learn_expression(num)
if res is None:
return []
learnt_expressions, chat_id = res
chat_stream = get_chat_manager().get_stream(chat_id)
if chat_stream is None:
group_name = f"聊天流 {chat_id}"
elif chat_stream.group_info:
group_name = chat_stream.group_info.group_name
else:
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
learnt_expressions_str = ""
for _chat_id, situation, style in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{group_name} 学习到表达风格:\n{learnt_expressions_str}")
if not learnt_expressions:
logger.info("没有学习到表达风格")
return []
learnt_expressions = res
learnt_expressions_str = ""
for (
_chat_id,
situation,
style,
_context,
_context_words,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
for (
chat_id,
situation,
style,
context,
context_words,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style})
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
}
)
current_time = time.time()
@@ -281,6 +291,8 @@ class ExpressionLearner:
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
@@ -293,6 +305,8 @@ class ExpressionLearner:
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
)
# 限制最大数量
exprs = list(
@@ -306,7 +320,105 @@ class ExpressionLearner:
expr.delete_instance()
return learnt_expressions
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
for match_response in match_responses:
try:
# 获取表达方式序号
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response["context"]
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
"""从指定聊天流学习表达方式
Args:
@@ -317,7 +429,7 @@ class ExpressionLearner:
current_time = time.time()
# 获取上次学习时间
# 获取上次学习之后的消息
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
@@ -328,17 +440,18 @@ class ExpressionLearner:
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0].chat_id
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
# print(f"random_msg_str:{random_msg_str}")
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
chat_str=random_msg_str,
)
logger.debug(f"学习{type_str}的prompt: {prompt}")
# print(f"random_msg_str:{random_msg_str}")
# logger.info(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
@@ -346,13 +459,48 @@ class ExpressionLearner:
logger.error(f"学习{type_str}失败: {e}")
return None
logger.debug(f"学习{type_str}的response: {response}")
# logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
return expressions, chat_id
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
split_matched_expressions_w_emb = []
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words)
)
return split_matched_expressions_w_emb
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词
Args:
matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context)
Returns:
添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words)
"""
result = []
for situation, style, context in matched_expressions:
# 使用jieba进行分词
context_words = list(jieba.cut(context))
result.append((situation, style, context, context_words))
return result
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
"""
@@ -379,7 +527,7 @@ class ExpressionLearner:
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((chat_id, situation, style))
expressions.append((situation, style))
return expressions
@@ -391,8 +539,6 @@ class ExpressionLearnerManager:
self.expression_learners = {}
self._ensure_expression_directories()
self._auto_migrate_json_to_db()
self._migrate_old_data_create_date()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
@@ -417,189 +563,5 @@ class ExpressionLearnerManager:
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
def _auto_migrate_json_to_db(self):
"""
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
迁移完成后在/data/expression/done.done写入标记文件存在则跳过。
然后检查done.done2如果没有就删除所有grammar表达并创建该标记文件。
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
done_flag2 = os.path.join(base_dir, "done.done2")
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"确保目录存在: {base_dir}")
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
else:
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
logger.info("开始删除所有grammar类型的表达...")
try:
deleted_count = self.delete_all_grammar_expressions()
logger.info(f"grammar表达删除完成共删除 {deleted_count} 个表达")
# 创建done.done2标记文件
with open(done_flag2, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info("已创建done.done2标记文件grammar表达删除标记完成")
except Exception as e:
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
else:
logger.info("grammar表达已删除跳过重复删除")
def _migrate_old_data_create_date(self):
"""
为没有create_date的老数据设置创建日期
使用last_active_time作为create_date的默认值
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:
"""
检查expression库中所有type为"grammar"的表达并全部删除
Returns:
int: 删除的grammar表达数量
"""
try:
# 查询所有type为"grammar"的表达
grammar_expressions = Expression.select().where(Expression.type == "grammar")
grammar_count = grammar_expressions.count()
if grammar_count == 0:
logger.info("expression库中没有找到grammar类型的表达")
return 0
logger.info(f"找到 {grammar_count} 个grammar类型的表达开始删除...")
# 删除所有grammar类型的表达
deleted_count = 0
for expr in grammar_expressions:
try:
expr.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"删除grammar表达失败: {e}")
continue
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
return deleted_count
except Exception as e:
logger.error(f"删除grammar表达过程中发生错误: {e}")
return 0
expression_learner_manager = ExpressionLearnerManager()

View File

@@ -77,10 +77,10 @@ class ExpressionSelector:
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
@@ -114,6 +114,20 @@ class ExpressionSelector:
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
@@ -123,9 +137,7 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int
) -> List[Dict[str, Any]]:
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -200,15 +212,15 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 10)
style_exprs = self.get_random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
@@ -248,7 +260,6 @@ class ExpressionSelector:
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
@@ -295,7 +306,6 @@ class ExpressionSelector:
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()

View File

@@ -1,122 +0,0 @@
from typing import Optional
from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
def get_config_base_focus_value(chat_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 focus_value
"""
if not global_config.chat.focus_value_adjust:
return global_config.chat.focus_value
if chat_id:
stream_focus_value = get_stream_specific_focus_value(chat_id)
if stream_focus_value is not None:
return stream_focus_value
global_focus_value = get_global_focus_value()
if global_focus_value is not None:
return global_focus_value
return global_config.chat.focus_value
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
"""
获取特定聊天流在当前时间的专注度
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 专注度值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in global_config.chat.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_id:
continue
# 使用通用的时间专注度解析方法
return get_time_based_focus_value(config_item[1:])
return None
def get_time_based_focus_value(time_focus_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的专注度
Args:
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
Returns:
float: 专注度值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间专注度配置
time_focus_pairs = []
for time_focus_str in time_focus_list:
try:
time_str, focus_str = time_focus_str.split(",")
hour, minute = map(int, time_str.split(":"))
focus_value = float(focus_str)
minutes = hour * 60 + minute
time_focus_pairs.append((minutes, focus_value))
except (ValueError, IndexError):
continue
if not time_focus_pairs:
return None
# 按时间排序
time_focus_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的专注度
current_focus_value = None
for minutes, focus_value in time_focus_pairs:
if current_minutes >= minutes:
current_focus_value = focus_value
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
if current_focus_value is None and time_focus_pairs:
current_focus_value = time_focus_pairs[-1][1]
return current_focus_value
def get_global_focus_value() -> Optional[float]:
"""
获取全局默认专注度配置
Returns:
float: 专注度值,如果没有配置则返回 None
"""
for config_item in global_config.chat.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return get_time_based_focus_value(config_item[1:])
return None

View File

@@ -1,501 +1,46 @@
import time
from typing import Optional, Dict, List
from src.plugin_system.apis import message_api
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value
logger = get_logger("frequency_control")
from typing import Dict
class FrequencyControl:
"""
频率控制类,可以根据最近时间段的发言数量和发言人数动态调整频率
特点:
- 发言频率调整基于最近10分钟的数据评估单位为"消息数/10分钟"
- 专注度调整基于最近10分钟的数据评估单位为"消息数/10分钟"
- 历史基准值基于最近一周的数据按小时统计每小时都有独立的基准值需要至少50条历史消息
- 统一标准两个调整都使用10分钟窗口确保逻辑一致性和响应速度
- 双向调整:根据活跃度高低,既能提高也能降低频率和专注度
- 数据充足性检查当历史数据不足50条时不更新基准值当基准值为默认值时不进行动态调整
- 基准值更新:直接使用新计算的周均值,无平滑更新
"""
"""简化的频率控制类仅管理不同chat_id的频率值"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {chat_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
# 发言频率调整值
self.talk_frequency_adjust: float = 1.0
self.talk_frequency_external_adjust: float = 1.0
# 专注度调整值
self.focus_value_adjust: float = 1.0
self.focus_value_external_adjust: float = 1.0
# 动态调整相关参数
self.last_update_time = time.time()
self.update_interval = 60 # 每60秒更新一次
# 历史数据缓存
self._message_count_cache = 0
self._user_count_cache = 0
self._last_cache_time = 0
self._cache_duration = 30 # 缓存30秒
# 调整参数
self.min_adjust = 0.3 # 最小调整值
self.max_adjust = 2.0 # 最大调整值
# 动态基准值(将根据历史数据计算)
self.base_message_count = 5 # 默认基准消息数量,将被动态更新
self.base_user_count = 3 # 默认基准用户数量,将被动态更新
# 平滑因子
self.smoothing_factor = 0.3
# 历史数据相关参数
self._last_historical_update = 0
self._historical_update_interval = 600 # 每十分钟更新一次历史基准值
self._historical_days = 7 # 使用最近7天的数据计算基准值
# 按小时统计的历史基准值
self._hourly_baseline = {
'messages': {}, # {0-23: 平均消息数}
'users': {} # {0-23: 平均用户数}
}
# 初始化24小时的默认基准值
for hour in range(24):
self._hourly_baseline['messages'][hour] = 0.0
self._hourly_baseline['users'][hour] = 0.0
def _update_historical_baseline(self):
"""
更新基于历史数据的基准值
使用最近一周的数据,按小时统计平均消息数量和用户数量
"""
current_time = time.time()
# 检查是否需要更新历史基准值
if current_time - self._last_historical_update < self._historical_update_interval:
return
try:
# 计算一周前的时间戳
week_ago = current_time - (self._historical_days * 24 * 3600)
# 获取最近一周的消息数据
historical_messages = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_stream.stream_id,
start_time=week_ago,
end_time=current_time,
filter_mai=True,
filter_command=True
)
if historical_messages and len(historical_messages) >= 50:
# 按小时统计消息数和用户数
hourly_stats = {hour: {'messages': [], 'users': set()} for hour in range(24)}
for msg in historical_messages:
# 获取消息的小时UTC时间
msg_time = time.localtime(msg.time)
msg_hour = msg_time.tm_hour
# 统计消息数
hourly_stats[msg_hour]['messages'].append(msg)
# 统计用户数
if msg.user_info and msg.user_info.user_id:
hourly_stats[msg_hour]['users'].add(msg.user_info.user_id)
# 计算每个小时的平均值(基于一周的数据)
for hour in range(24):
# 计算该小时的平均消息数(一周内该小时的总消息数 / 7天
total_messages = len(hourly_stats[hour]['messages'])
total_users = len(hourly_stats[hour]['users'])
# 只计算有消息的时段没有消息的时段设为0
if total_messages > 0:
avg_messages = total_messages / self._historical_days
avg_users = total_users / self._historical_days
self._hourly_baseline['messages'][hour] = avg_messages
self._hourly_baseline['users'][hour] = avg_users
else:
# 没有消息的时段设为0表示该时段不活跃
self._hourly_baseline['messages'][hour] = 0.0
self._hourly_baseline['users'][hour] = 0.0
# 更新整体基准值(用于兼容性)- 基于原始数据计算不受max(1.0)限制影响
overall_avg_messages = sum(len(hourly_stats[hour]['messages']) for hour in range(24)) / (24 * self._historical_days)
overall_avg_users = sum(len(hourly_stats[hour]['users']) for hour in range(24)) / (24 * self._historical_days)
self.base_message_count = overall_avg_messages
self.base_user_count = overall_avg_users
logger.info(
f"{self.log_prefix} 历史基准值更新完成: "
f"整体平均消息数={overall_avg_messages:.2f}, 整体平均用户数={overall_avg_users:.2f}"
)
# 记录几个关键时段的基准值
key_hours = [8, 12, 18, 22] # 早、中、晚、夜
for hour in key_hours:
# 计算该小时平均每10分钟的消息数和用户数
hourly_10min_messages = self._hourly_baseline['messages'][hour] / 6 # 1小时 = 6个10分钟
hourly_10min_users = self._hourly_baseline['users'][hour] / 6
logger.info(
f"{self.log_prefix} {hour}时基准值: "
f"消息数={self._hourly_baseline['messages'][hour]:.2f}/小时 "
f"({hourly_10min_messages:.2f}/10分钟), "
f"用户数={self._hourly_baseline['users'][hour]:.2f}/小时 "
f"({hourly_10min_users:.2f}/10分钟)"
)
elif historical_messages and len(historical_messages) < 50:
# 历史数据不足50条不更新基准值
logger.info(f"{self.log_prefix} 历史数据不足50条({len(historical_messages)}条),不更新基准值")
else:
# 如果没有历史数据,不更新基准值
logger.info(f"{self.log_prefix} 无历史数据,不更新基准值")
except Exception as e:
logger.error(f"{self.log_prefix} 更新历史基准值时出错: {e}")
# 出错时保持原有基准值不变
self._last_historical_update = current_time
def _get_current_hour_baseline(self) -> tuple[float, float]:
"""
获取当前小时的基准值
Returns:
tuple: (基准消息数, 基准用户数)
"""
current_hour = time.localtime().tm_hour
return (
self._hourly_baseline['messages'][current_hour],
self._hourly_baseline['users'][current_hour]
)
def get_dynamic_talk_frequency_adjust(self) -> float:
"""
获取纯动态调整值(不包含配置文件基础值)
Returns:
float: 动态调整值
"""
self._update_talk_frequency_adjust()
def get_talk_frequency_adjust(self) -> float:
"""获取发言频率调整值"""
return self.talk_frequency_adjust
def get_dynamic_focus_value_adjust(self) -> float:
"""
获取纯动态调整值(不包含配置文件基础值)
Returns:
float: 动态调整值
"""
self._update_focus_value_adjust()
return self.focus_value_adjust
def _update_talk_frequency_adjust(self):
"""
更新发言频率调整值
适合人少话多的时候:人少但消息多,提高回复频率
"""
current_time = time.time()
# 检查是否需要更新
if current_time - self.last_update_time < self.update_interval:
return
# 先更新历史基准值
self._update_historical_baseline()
try:
# 获取最近10分钟的数据发言频率更敏感
recent_messages = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_stream.stream_id,
start_time=current_time - 600, # 10分钟前
end_time=current_time,
filter_mai=True,
filter_command=True
)
# 计算消息数量和用户数量
message_count = len(recent_messages)
user_ids = set()
for msg in recent_messages:
if msg.user_info and msg.user_info.user_id:
user_ids.add(msg.user_info.user_id)
user_count = len(user_ids)
# 获取当前小时的基准值
current_hour_base_messages, current_hour_base_users = self._get_current_hour_baseline()
# 计算当前小时平均每10分钟的基准值
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
current_hour_10min_users = current_hour_base_users / 6
# 发言频率调整逻辑:根据活跃度双向调整
# 检查是否有足够的数据进行分析
if user_count > 0 and message_count >= 2: # 至少需要2条消息才能进行有意义的分析
# 检查历史基准值是否有效(该时段有活跃度)
if current_hour_base_messages > 0.0 and current_hour_base_users > 0.0:
# 计算人均消息数10分钟窗口
messages_per_user = message_count / user_count
# 使用当前小时每10分钟的基准人均消息数
base_messages_per_user = current_hour_10min_messages / current_hour_10min_users if current_hour_10min_users > 0 else 1.0
# 双向调整逻辑
if messages_per_user > base_messages_per_user * 1.2:
# 活跃度很高:提高回复频率
target_talk_adjust = min(self.max_adjust, messages_per_user / base_messages_per_user)
elif messages_per_user < base_messages_per_user * 0.8:
# 活跃度很低:降低回复频率
target_talk_adjust = max(self.min_adjust, messages_per_user / base_messages_per_user)
else:
# 活跃度正常:保持正常
target_talk_adjust = 1.0
else:
# 历史基准值不足,不调整
target_talk_adjust = 1.0
else:
# 数据不足:不调整
target_talk_adjust = 1.0
# 限制调整范围
target_talk_adjust = max(self.min_adjust, min(self.max_adjust, target_talk_adjust))
# 记录调整前的值
old_adjust = self.talk_frequency_adjust
# 平滑调整
self.talk_frequency_adjust = (
self.talk_frequency_adjust * (1 - self.smoothing_factor) +
target_talk_adjust * self.smoothing_factor
)
# 判断调整方向
if target_talk_adjust > 1.0:
adjust_direction = "提高"
elif target_talk_adjust < 1.0:
adjust_direction = "降低"
else:
if current_hour_base_messages <= 0.0 or current_hour_base_users <= 0.0:
adjust_direction = "不调整(该时段无活跃度)"
else:
adjust_direction = "保持"
# 计算实际变化方向
actual_change = ""
if self.talk_frequency_adjust > old_adjust:
actual_change = f"(实际提高: {old_adjust:.2f}{self.talk_frequency_adjust:.2f})"
elif self.talk_frequency_adjust < old_adjust:
actual_change = f"(实际降低: {old_adjust:.2f}{self.talk_frequency_adjust:.2f})"
else:
actual_change = f"(无变化: {self.talk_frequency_adjust:.2f})"
logger.info(
f"{self.log_prefix} 发言频率调整: "
f"当前: {message_count}消息/{user_count}用户, 人均: {message_count/user_count if user_count > 0 else 0:.2f}消息|"
f"基准: {current_hour_10min_messages:.2f}消息/{current_hour_10min_users:.2f}用户,人均:{current_hour_10min_messages/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}消息|"
f"目标调整: {adjust_direction}{target_talk_adjust:.2f}, 实际结果: {self.talk_frequency_adjust:.2f} {actual_change}"
)
except Exception as e:
logger.error(f"{self.log_prefix} 更新发言频率调整值时出错: {e}")
def _update_focus_value_adjust(self):
"""
更新专注度调整值
适合人多话多的时候人多且消息多提高专注度LLM消耗更多但回复更精准
"""
current_time = time.time()
# 检查是否需要更新
if current_time - self.last_update_time < self.update_interval:
return
try:
# 获取最近10分钟的数据与发言频率保持一致
recent_messages = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_stream.stream_id,
start_time=current_time - 600, # 10分钟前
end_time=current_time,
filter_mai=True,
filter_command=True
)
# 计算消息数量和用户数量
message_count = len(recent_messages)
user_ids = set()
for msg in recent_messages:
if msg.user_info and msg.user_info.user_id:
user_ids.add(msg.user_info.user_id)
user_count = len(user_ids)
# 获取当前小时的基准值
current_hour_base_messages, current_hour_base_users = self._get_current_hour_baseline()
# 计算当前小时平均每10分钟的基准值
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
current_hour_10min_users = current_hour_base_users / 6
# 专注度调整逻辑:根据活跃度双向调整
# 检查是否有足够的数据进行分析
if user_count > 0 and current_hour_10min_users > 0 and message_count >= 2:
# 检查历史基准值是否有效(该时段有活跃度)
if current_hour_base_messages > 0.0 and current_hour_base_users > 0.0:
# 计算用户活跃度比率基于10分钟数据
user_ratio = user_count / current_hour_10min_users
# 计算消息活跃度比率基于10分钟数据
message_ratio = message_count / current_hour_10min_messages if current_hour_10min_messages > 0 else 1.0
# 双向调整逻辑
if user_ratio > 1.3 and message_ratio > 1.3:
# 活跃度很高提高专注度消耗更多LLM资源但回复更精准
target_focus_adjust = min(self.max_adjust, (user_ratio + message_ratio) / 2)
elif user_ratio > 1.1 and message_ratio > 1.1:
# 活跃度较高:适度提高专注度
target_focus_adjust = min(self.max_adjust, 1.0 + (user_ratio + message_ratio - 2.0) * 0.2)
elif user_ratio < 0.7 or message_ratio < 0.7:
# 活跃度很低降低专注度节省LLM资源
target_focus_adjust = max(self.min_adjust, min(user_ratio, message_ratio))
else:
# 正常情况:保持默认专注度
target_focus_adjust = 1.0
else:
# 历史基准值不足,不调整
target_focus_adjust = 1.0
else:
# 数据不足:不调整
target_focus_adjust = 1.0
# 限制调整范围
target_focus_adjust = max(self.min_adjust, min(self.max_adjust, target_focus_adjust))
# 记录调整前的值
old_focus_adjust = self.focus_value_adjust
# 平滑调整
self.focus_value_adjust = (
self.focus_value_adjust * (1 - self.smoothing_factor) +
target_focus_adjust * self.smoothing_factor
)
# 计算当前小时平均每10分钟的基准值
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
current_hour_10min_users = current_hour_base_users / 6
# 判断调整方向
if target_focus_adjust > 1.0:
adjust_direction = "提高"
elif target_focus_adjust < 1.0:
adjust_direction = "降低"
else:
if current_hour_base_messages <= 0.0 or current_hour_base_users <= 0.0:
adjust_direction = "不调整(该时段无活跃度)"
else:
adjust_direction = "保持"
# 计算实际变化方向
actual_change = ""
if self.focus_value_adjust > old_focus_adjust:
actual_change = f"(实际提高: {old_focus_adjust:.2f}{self.focus_value_adjust:.2f})"
elif self.focus_value_adjust < old_focus_adjust:
actual_change = f"(实际降低: {old_focus_adjust:.2f}{self.focus_value_adjust:.2f})"
else:
actual_change = f"(无变化: {self.focus_value_adjust:.2f})"
logger.info(
f"{self.log_prefix} 专注度调整(10分钟): "
f"当前: {message_count}消息/{user_count}用户,人均:{message_count/user_count if user_count > 0 else 0:.2f}消息|"
f"基准: {current_hour_10min_messages:.2f}消息/{current_hour_10min_users:.2f}用户,人均:{current_hour_10min_messages/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}消息|"
f"比率: 用户{user_count/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}x, 消息{message_count/current_hour_10min_messages if current_hour_10min_messages > 0 else 0:.2f}x, "
f"目标调整: {adjust_direction}{target_focus_adjust:.2f}, 实际结果: {self.focus_value_adjust:.2f} {actual_change}"
)
except Exception as e:
logger.error(f"{self.log_prefix} 更新专注度调整值时出错: {e}")
def get_final_talk_frequency(self) -> float:
return get_config_base_talk_frequency(self.chat_stream.stream_id) * self.get_dynamic_talk_frequency_adjust() * self.talk_frequency_external_adjust
def get_final_focus_value(self) -> float:
return get_config_base_focus_value(self.chat_stream.stream_id) * self.get_dynamic_focus_value_adjust() * self.focus_value_external_adjust
def set_adjustment_parameters(
self,
min_adjust: Optional[float] = None,
max_adjust: Optional[float] = None,
base_message_count: Optional[int] = None,
base_user_count: Optional[int] = None,
smoothing_factor: Optional[float] = None,
update_interval: Optional[int] = None,
historical_update_interval: Optional[int] = None,
historical_days: Optional[int] = None
):
"""
设置调整参数
Args:
min_adjust: 最小调整值
max_adjust: 最大调整值
base_message_count: 基准消息数量
base_user_count: 基准用户数量
smoothing_factor: 平滑因子
update_interval: 更新间隔(秒)
"""
if min_adjust is not None:
self.min_adjust = max(0.1, min_adjust)
if max_adjust is not None:
self.max_adjust = max(1.0, max_adjust)
if base_message_count is not None:
self.base_message_count = max(1, base_message_count)
if base_user_count is not None:
self.base_user_count = max(1, base_user_count)
if smoothing_factor is not None:
self.smoothing_factor = max(0.0, min(1.0, smoothing_factor))
if update_interval is not None:
self.update_interval = max(10, update_interval)
if historical_update_interval is not None:
self._historical_update_interval = max(300, historical_update_interval) # 最少5分钟
if historical_days is not None:
self._historical_days = max(1, min(30, historical_days)) # 1-30天之间
def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value))
class FrequencyControlManager:
"""
频率控制管理器,管理多个聊天流的频率控制实例
"""
"""频率控制管理器,管理多个聊天流的频率控制实例"""
def __init__(self):
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
"""
获取或创建指定聊天流的频率控制实例
Args:
chat_id: 聊天流ID
Returns:
FrequencyControl: 频率控制实例
"""
"""获取或创建指定聊天流的频率控制实例"""
if chat_id not in self.frequency_control_dict:
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
return self.frequency_control_dict[chat_id]
def remove_frequency_control(self, chat_id: str) -> bool:
"""移除指定聊天流的频率控制实例"""
if chat_id in self.frequency_control_dict:
del self.frequency_control_dict[chat_id]
return True
return False
def get_all_chat_ids(self) -> list[str]:
"""获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys())
# 创建全局实例
frequency_control_manager = FrequencyControlManager()

View File

@@ -1,128 +0,0 @@
from typing import Optional
from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
def get_config_base_talk_frequency(chat_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
Args:
chat_stream_id: 聊天流ID格式为 "platform:chat_id:type"
Returns:
float: 对应的频率值
"""
if not global_config.chat.talk_frequency_adjust:
return global_config.chat.talk_frequency
# 优先检查聊天流特定的配置
if chat_id:
stream_frequency = get_stream_specific_frequency(chat_id)
if stream_frequency is not None:
return stream_frequency
# 检查全局时段配置(第一个元素为空字符串的配置)
global_frequency = get_global_frequency()
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
Args:
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
Returns:
float: 频率值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间频率配置
time_freq_pairs = []
for time_freq_str in time_freq_list:
try:
time_str, freq_str = time_freq_str.split(",")
hour, minute = map(int, time_str.split(":"))
frequency = float(freq_str)
minutes = hour * 60 + minute
time_freq_pairs.append((minutes, frequency))
except (ValueError, IndexError):
continue
if not time_freq_pairs:
return None
# 按时间排序
time_freq_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的频率
current_frequency = None
for minutes, frequency in time_freq_pairs:
if current_minutes >= minutes:
current_frequency = frequency
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
if current_frequency is None and time_freq_pairs:
current_frequency = time_freq_pairs[-1][1]
return current_frequency
def get_stream_specific_frequency(chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 频率值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in global_config.chat.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 使用通用的时间频率解析方法
return get_time_based_frequency(config_item[1:])
return None
def get_global_frequency() -> Optional[float]:
"""
获取全局默认频率配置
Returns:
float: 频率值,如果没有配置则返回 None
"""
for config_item in global_config.chat.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return get_time_based_frequency(config_item[1:])
return None

View File

@@ -1,37 +0,0 @@
from typing import Optional
import hashlib
def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
Args:
stream_config_str: 格式为 "platform:id:type" 的字符串
Returns:
str: 生成的 chat_id如果解析失败则返回 None
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
# 判断是否为群聊
is_group = stream_type == "group"
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None

View File

@@ -1,15 +1,15 @@
import asyncio
from multiprocessing import context
import time
import traceback
import math
import random
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install
from collections import deque
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
@@ -18,14 +18,15 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.chat.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.memory_system.question_maker import QuestionMaker
from src.memory_system.questions import global_conflict_tracker
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager
from src.mais4u.s4u_config import s4u_config
from src.memory_system.Memory_chest import global_memory_chest
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
@@ -33,6 +34,7 @@ from src.chat.utils.chat_message_builder import (
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel
ERROR_LOOP_INFO = {
@@ -84,8 +86,6 @@ class HeartFChatting:
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.frequency_control = frequency_control_manager.get_or_create_frequency_control(self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
@@ -99,8 +99,15 @@ class HeartFChatting:
self._cycle_counter = 0
self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 10
self.last_read_time = time.time() - 2
self.no_reply_until_call = False
self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间
self.questioned = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@@ -153,63 +160,117 @@ class HeartFChatting:
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
if elapsed < 0.1:
# 不显示小于0.1秒的计时器
continue
formatted_time = f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
# 获取动作类型,兼容新旧格式
action_type = "未知动作"
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
loop_plan_info = self._current_cycle_detail.loop_plan_info
if isinstance(loop_plan_info, dict):
action_result = loop_plan_info.get("action_result", {})
if isinstance(action_result, dict):
# 旧格式action_result是字典
action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
# TODO: 把这里写明白
action_type = action_result[0].action_type or "未知动作"
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get("action_type", "未知动作")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}, " # type: ignore
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f};" # type: ignore
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def caculate_interest_value(self, recent_messages_list: List["DatabaseMessages"]) -> float:
total_interest = 0.0
for msg in recent_messages_list:
interest_value = msg.interest_value
if interest_value is not None and msg.processed_plain_text:
total_interest += float(interest_value)
return total_interest / len(recent_messages_list)
async def _loopbody(self):
async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
limit=10,
limit=20,
limit_mode="latest",
filter_mai=True,
filter_command=True,
)
if recent_messages_list:
self.last_read_time = time.time()
await self._observe(interest_value=await self.caculate_interest_value(recent_messages_list),recent_messages_list=recent_messages_list)
question_probability = 0
if time.time() - self.last_active_time > 3600:
question_probability = 0.01
elif time.time() - self.last_active_time > 1200:
question_probability = 0.005
elif time.time() - self.last_active_time > 600:
question_probability = 0.001
else:
question_probability = 0.0003
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
# print(f"{self.log_prefix} questioned: {self.questioned},len: {len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id))}")
if question_probability > 0 and not self.questioned and len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id)) == 0: #长久没有回复,可以试试主动发言,提问概率随着时间增加
# logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,概率: {question_probability}")
if random.random() < question_probability: # 30%概率主动发言
try:
self.questioned = True
self.last_active_time = time.time()
# print(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
cycle_timers, thinking_id = self.start_cycle()
question_maker = QuestionMaker(self.stream_id)
question, context,conflict_context = await question_maker.make_question()
if question:
logger.info(f"{self.log_prefix} 问题: {question}")
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
else:
logger.info(f"{self.log_prefix} 无问题")
# self.end_cycle(cycle_timers, thinking_id)
except Exception as e:
logger.error(f"{self.log_prefix} 主动提问失败: {e}")
print(traceback.format_exc())
if len(recent_messages_list) >= 1:
# for message in recent_messages_list:
# print(message.processed_plain_text)
# !处理no_reply_until_call逻辑
if self.no_reply_until_call:
for message in recent_messages_list:
if (
message.is_mentioned
or message.is_at
or len(recent_messages_list) >= 8
or time.time() - self.last_read_time > 600
):
self.no_reply_until_call = False
self.last_read_time = time.time()
break
# 没有提到,继续保持沉默
if self.no_reply_until_call:
# logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
await asyncio.sleep(1)
return True
self.last_read_time = time.time()
# !此处使at或者提及必定回复
mentioned_message = None
for message in recent_messages_list:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# *控制频率用
if mentioned_message:
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
elif (
random.random()
< global_config.chat.get_talk_value(self.stream_id)
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
):
await self._observe(recent_messages_list=recent_messages_list)
else:
# 没有提到继续保持沉默等待5秒防止频繁触发
await asyncio.sleep(10)
return True
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
return True
return True
async def _send_and_store_reply(
self,
response_set,
response_set: "ReplySetModel",
action_message: "DatabaseMessages",
cycle_timers: Dict[str, float],
thinking_id,
@@ -257,191 +318,166 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _observe(self, interest_value: float = 0.0,recent_messages_list: List["DatabaseMessages"] = []) -> bool:
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
force_reply_message: Optional["DatabaseMessages"] = None,
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率
# 当interest_value为0时概率接近0使用Focus模式
# 当interest_value很高时概率接近1使用Normal模式
def calculate_normal_mode_probability(interest_val: float) -> float:
# 使用sigmoid函数调整参数使概率分布更合理
# 当interest_value = 0时概率约为0.1
# 当interest_value = 1时概率约为0.5
# 当interest_value = 2时概率约为0.8
# 当interest_value = 3时概率约为0.95
k = 2.0 # 控制曲线陡峭程度
x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
start_time = time.time()
normal_mode_probability = (
calculate_normal_mode_probability(interest_value)
* 2
* self.frequency_control.get_final_talk_frequency()
)
#对呼唤名字进行增幅
for msg in recent_messages_list:
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
normal_mode_probability += msg.reply_probability_boost
if global_config.chat.mentioned_bot_reply and msg.is_mentioned:
normal_mode_probability += global_config.chat.mentioned_bot_reply
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
# 根据概率决定使用直接回复
interest_triggerd = False
focus_triggerd = False
if random.random() < normal_mode_probability:
interest_triggerd = True
logger.info(
f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复"
)
if s4u_config.enable_s4u:
await send_typing()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
await global_memory_chest.build_running_content(chat_id=self.stream_id)
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
# 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {}
#如果兴趣度不足以激活
if not interest_triggerd:
#看看专注值够不够
if random.random() < self.frequency_control.get_final_focus_value():
#专注值足够,仍然进入正式思考
focus_triggerd = True #都没触发,路边
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 任意一种触发都行
if interest_triggerd or focus_triggerd:
# 进入正式思考模式
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
# 第一步:动作检查
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
# 执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
# current_available_actions=planner_info[2],
chat_content_block=chat_content_block,
# actions_before_now_block=actions_before_now_block,
message_id_list=message_id_list,
)
if not await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
):
return False
with Timer("规划器", cycle_timers):
# 根据不同触发进入不同plan
if focus_triggerd:
mode = ChatMode.FOCUS
else:
mode = ChatMode.NORMAL
action_to_use_info, _ = await self.action_planner.plan(
mode=mode,
loop_start_time=self.last_read_time,
has_reply = False
for action in action_to_use_info:
if action.action_type == "reply":
has_reply = True
break
if not has_reply and force_reply_message:
action_to_use_info.append(
ActionPlannerInfo(
action_type="reply",
reasoning="有人提到了你,进行回复",
action_data={},
action_message=force_reply_message,
available_actions=available_actions,
)
)
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
)
for action in action_to_use_info
]
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
)
for action in action_to_use_info
]
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
action_command = ""
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
_cur_action = action_to_use_info[i]
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
action_command = result.get("command", "")
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["reply_text"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
excute_result_str = ""
for result in results:
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"command": action_command,
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"command": action_command,
"taken_time": time.time(),
},
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["result"]
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["result"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
self.action_planner.add_plan_excute_log(result=excute_result_str)
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"taken_time": time.time(),
}
reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
)
reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"taken_time": time.time(),
},
}
reply_text = action_reply_text
"""S4U内容暂时保留"""
if s4u_config.enable_s4u:
await stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
"""S4U内容暂时保留"""
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
end_time = time.time()
if end_time - start_time < global_config.chat.planner_smooth:
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0.1)
return True
async def _main_chat_loop(self):
@@ -466,7 +502,7 @@ class HeartFChatting:
async def _handle_action(
self,
action: str,
reasoning: str,
action_reasoning: str,
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
@@ -477,11 +513,11 @@ class HeartFChatting:
参数:
action: 动作类型
reasoning: 决策理由
action_reasoning: 决策理由
action_data: 动作数据,包含不同动作需要的参数
cycle_timers: 计时器字典
thinking_id: 思考ID
action_message: 消息数据
返回:
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
"""
@@ -491,37 +527,105 @@ class HeartFChatting:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
log_prefix=self.log_prefix,
action_reasoning=action_reasoning,
action_message=action_message,
)
except Exception as e:
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
traceback.print_exc()
return False, "", ""
return False, ""
if not action_handler:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, "", ""
# 处理动作并获取结果
result = await action_handler.handle_action()
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.execute()
success, action_text = result
command = ""
return success, action_text, command
return success, action_text
except Exception as e:
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, "", ""
return False, ""
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
new_msg = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=1,
)
reply_action_info = ActionPlannerInfo(
action_type="reply",
reasoning= "",
action_data={},
action_message=new_msg[0],
available_actions=None,
loop_start_time=time.time(),
action_reasoning=reason)
self.action_planner.add_plan_log(reasoning=f"你对问题\"{question}\"感到好奇,想要和群友讨论", actions=[reply_action_info])
success, llm_response = await generator_api.rewrite_reply(
chat_stream=self.chat_stream,
reply_data={
"raw_reply": f"我对这个问题感到好奇:{question}",
"reason": reason,
},
)
if not success or not llm_response or not llm_response.reply_set:
logger.info("主动提问发言失败")
self.action_planner.add_plan_excute_log(result="主动回复生成失败")
return {"action_type": "reply", "success": False, "result": "主动回复生成失败", "loop_info": None}
if success:
for reply_seg in llm_response.reply_set.reply_data:
send_data = reply_seg.content
await send_api.text_to_stream(
text=send_data,
stream_id=self.stream_id,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reply_text": llm_response.reply_set.reply_data[0].content},
action_name="reply",
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": [reply_action_info],
},
"loop_action_info": {
"action_taken": True,
"reply_text": llm_response.reply_set.reply_data[0].content,
"command": "",
"taken_time": time.time(),
},
}
self.last_active_time = time.time()
self.action_planner.add_plan_excute_log(result=f"你提问:{question}")
return {
"action_type": "reply",
"success": True,
"result": f"你提问:{question}",
"loop_info": loop_info,
}
async def _send_response(
self,
reply_set,
reply_set: "ReplySetModel",
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
) -> str:
@@ -529,15 +633,17 @@ class HeartFChatting:
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 4)
need_reply = new_message_count >= random.randint(2, 3)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
reply_text = ""
first_replied = False
for reply_seg in reply_set:
data = reply_seg[1]
for reply_content in reply_set.reply_data:
if reply_content.content_type != ReplyContentType.TEXT:
continue
data: str = reply_content.content # type: ignore
if not first_replied:
await send_api.text_to_stream(
text=data,
@@ -571,86 +677,127 @@ class HeartFChatting:
):
"""执行单个动作的通用函数"""
try:
if action_planner_info.action_type == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
# 直接当场执行no_reply逻辑
if action_planner_info.action_type == "no_reply":
# 直接处理no_reply逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="no_reply",
action_reasoning=reason,
)
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try:
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call":
# 直接当场执行no_reply_until_call逻辑
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
reason = action_planner_info.reasoning or "选择不回复"
self.no_reply_until_call = True
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="no_reply_until_call",
action_reasoning=reason,
)
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑
self.questioned = False
# 刷新主动发言状态
reason = action_planner_info.reasoning or "选择回复"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="reply",
action_reasoning=reason,
)
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=action_planner_info.reasoning or "",
reply_reason=reason,
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"result": f"你回复内容{reply_text}",
"loop_info": loop_info,
}
else:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, result = await self._handle_action(
action = action_planner_info.action_type,
action_reasoning = action_planner_info.action_reasoning or "",
action_data = action_planner_info.action_data or {},
cycle_timers = cycle_timers,
thinking_id = thinking_id,
action_message= action_planner_info.action_message,
)
self.last_active_time = time.time()
return {
"action_type": action_planner_info.action_type,
"success": success,
"result": result,
}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_planner_info.action_type,
"success": False,
"reply_text": "",
"result": "",
"loop_info": None,
"error": str(e),
}

View File

@@ -1,24 +1,35 @@
import traceback
from typing import Any, Optional, Dict
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.chat.heart_flow.heartFC_chat import HeartFChatting
from src.chat.brain_chat.brain_chat import BrainChatting
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("heartflow")
class Heartflow:
"""主心流协调器,负责初始化并协调聊天"""
def __init__(self):
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
"""获取或创建一个新的HeartFChatting实例"""
try:
if chat_id in self.heartflow_chat_list:
if chat := self.heartflow_chat_list.get(chat_id):
return chat
else:
new_chat = HeartFChatting(chat_id = chat_id)
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
if not chat_stream:
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
if chat_stream.group_info:
new_chat = HeartFChatting(chat_id=chat_id)
else:
new_chat = BrainChatting(chat_id=chat_id)
await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat
return new_chat
@@ -27,4 +38,5 @@ class Heartflow:
traceback.print_exc()
return None
heartflow = Heartflow()

View File

@@ -1,20 +1,14 @@
import asyncio
import re
import math
import traceback
from typing import Tuple, TYPE_CHECKING
from src.config.config import global_config
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.chat_message_builder import replace_user_references
from src.common.logger import get_logger
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person
from src.common.database.database_model import Images
@@ -23,6 +17,7 @@ if TYPE_CHECKING:
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度
@@ -34,58 +29,17 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""
if message.is_picid or message.is_emoji:
return 0.0, []
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message)
interested_rate = 0.0
with Timer("记忆激活"):
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
max_depth= 4,
fast_retrieval=global_config.chat.interest_rate_mode == "fast",
)
message.key_words = keywords
message.key_words_lite = keywords_lite
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# interested_rate = 0.0
keywords = []
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
message.interest_value = base_interest
message.interest_value = 1
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
return base_interest, keywords
return 1, keywords
class HeartFCMessageReceiver:
@@ -114,17 +68,11 @@ class HeartFCMessageReceiver:
chat = message.chat_stream
# 2. 兴趣度计算与更新
interested_rate, keywords = await _calculate_interest(message)
_, keywords = await _calculate_interest(message)
await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
@@ -132,7 +80,7 @@ class HeartFCMessageReceiver:
# 用这个pattern截取出id部分picid是一个list并替换成对应的图片描述
picid_pattern = r"\[picid:([^\]]+)\]"
picid_list = re.findall(picid_pattern, message.processed_plain_text)
# 创建替换后的文本
processed_text = message.processed_plain_text
if picid_list:
@@ -145,18 +93,22 @@ class HeartFCMessageReceiver:
# 如果没有找到图片描述,则移除[picid:xxxx]标记
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references(
processed_text,
message.message_info.platform, # type: ignore
replace_bot_name=True
message.message_info.platform, # type: ignore
replace_bot_name=True,
)
# if not processed_plain_text:
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
_ = Person.register_person(
platform=message.message_info.platform, # type: ignore
user_id=message.message_info.user_info.user_id, # type: ignore
nickname=userinfo.user_nickname, # type: ignore
)
except Exception as e:
logger.error(f"消息处理失败: {e}")

View File

@@ -124,6 +124,7 @@ async def send_typing():
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)
async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
@@ -135,4 +136,4 @@ async def stop_typing():
await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
)
)

View File

@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:

View File

@@ -25,7 +25,6 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
@@ -33,11 +32,11 @@ install(extra_lines=3)
# 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
@@ -94,7 +93,13 @@ class EmbeddingStoreItem:
class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
def __init__(
self,
namespace: str,
dir_path: str,
max_workers: int = DEFAULT_MAX_WORKERS,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
self.namespace = namespace
self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
@@ -104,12 +109,16 @@ class EmbeddingStore:
# 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
# 如果配置值被调整,记录日志
if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
logger.warning(
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
)
if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
logger.warning(
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
)
self.store = {}
@@ -121,23 +130,23 @@ class EmbeddingStore:
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
@@ -148,43 +157,45 @@ class EmbeddingStore:
except Exception:
pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数
progress_callback: 进度回调函数,接收一个参数表示完成的数量
Returns:
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
"""
if not strs:
return []
# 分块
chunks = []
for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size]
chunk = strs[i : i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
try:
# 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 在线程中创建独立的事件循环
@@ -194,25 +205,25 @@ class EmbeddingStore:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
@@ -221,14 +232,14 @@ class EmbeddingStore:
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
return chunk_results
# 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk):
try:
@@ -242,7 +253,7 @@ class EmbeddingStore:
start_idx, chunk_strs = chunk
for i, s in enumerate(chunk_strs):
results[start_idx + i] = (s, [])
# 按原始顺序返回结果
ordered_results = []
for i in range(len(strs)):
@@ -251,7 +262,7 @@ class EmbeddingStore:
else:
# 防止遗漏
ordered_results.append((strs[i], []))
return ordered_results
def get_test_file_path(self):
@@ -260,14 +271,14 @@ class EmbeddingStore:
def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...")
# 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
)
# 构建测试向量字典
test_vectors = {}
for idx, (s, embedding) in enumerate(embedding_results):
@@ -277,10 +288,10 @@ class EmbeddingStore:
logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self):
@@ -298,35 +309,35 @@ class EmbeddingStore:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
# 检查本地向量完整性
for idx in range(len(EMBEDDING_TEST_STRINGS)):
if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
logger.info("开始检验嵌入模型一致性...")
# 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
)
# 检查一致性
for idx, (s, new_emb) in enumerate(embedding_results):
local_emb = local_vectors.get(str(idx))
if not new_emb:
logger.error(f"获取测试字符串嵌入失败: {s}")
return False
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False
logger.info("嵌入模型一致性校验通过。")
return True
@@ -334,22 +345,22 @@ class EmbeddingStore:
"""向库中存入字符串(使用多线程优化)"""
if not strs:
return
total = len(strs)
# 过滤已存在的字符串
new_strs = []
for s in strs:
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash not in self.store:
new_strs.append(s)
if not new_strs:
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
return
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
@@ -363,31 +374,39 @@ class EmbeddingStore:
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
# 首先更新已存在项的进度
already_processed = total - len(new_strs)
if already_processed > 0:
progress.update(task, advance=already_processed)
if new_strs:
# 使用实例配置的参数,智能调整分块和线程数
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
optimal_chunk_size = max(
MIN_CHUNK_SIZE,
min(
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
),
)
optimal_max_workers = min(
self.max_workers,
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
)
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
# 定义进度更新回调函数
def update_progress(count):
progress.update(task, advance=count)
# 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded(
new_strs,
chunk_size=optimal_chunk_size,
new_strs,
chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers,
progress_callback=update_progress
progress_callback=update_progress,
)
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s)
@@ -520,7 +539,7 @@ class EmbeddingManager:
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
"""
初始化EmbeddingManager
Args:
max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小

View File

@@ -426,9 +426,7 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
(node_key, score)
for node_key, score in ppr_res.items()
if node_key.startswith("paragraph")
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
]
del ppr_res

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
from .utils.dyn_topk import dyn_select_top_k
from .lpmmconfig import global_config # noqa
from .embedding_store import EmbeddingManager # noqa
from .llm_client import LLMClient # noqa
from .utils.dyn_topk import dyn_select_top_k # noqa
class MemoryActiveManager:

View File

@@ -8,7 +8,7 @@ def dyn_select_top_k(
# 检查输入列表是否为空
if not score:
return []
# 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)

File diff suppressed because it is too large Load Diff

View File

@@ -1,241 +0,0 @@
import json
import random
from json_repair import repair_json
from typing import List, Tuple
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Group Chat Prompt ---
memory_activator_prompt = """
你需要根据以下信息来挑选合适的记忆编号
以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
聊天记录:
{obs_info_text}
你想要回复的消息:
{target_message}
记忆:
{memory_info}
请输出一个json格式包含以下字段
{{
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
# 用于记忆选择的 LLM 模型
self.memory_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.selection",
)
async def activate_memory_with_chat_history(
self, target_message, chat_history: List[DatabaseMessages]
) -> List[Tuple[str, str]]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
keywords_list = set()
for msg in chat_history:
keywords = parse_keywords_string(msg.key_words)
if keywords:
if len(keywords_list) < 30:
# 最多容纳30个关键词
keywords_list.update(keywords)
logger.debug(f"提取关键词: {keywords_list}")
else:
break
if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
return []
# 从海马体获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
# logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
used_ids = set()
candidate_memories = []
# 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory:
keyword, content = memory
found = any(kw in content for kw in keywords_list)
if found:
# 随机分配一个不重复的2位数id
while True:
random_id = "{:02d}".format(random.randint(0, 99))
if random_id not in used_ids:
used_ids.add(random_id)
break
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
if not candidate_memories:
logger.info("没有找到相关的候选记忆")
return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
async def _select_memories_with_llm(
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
Returns:
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt, temperature=0.3, max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段并解析逗号分隔的编号
if memory_ids_str := result.get("memory_ids", ""):
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
selected_memory_ids = valid_memory_ids
else:
selected_memory_ids = []
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
selected_memories = [
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
]
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
init_prompt()

View File

@@ -3,19 +3,18 @@ import os
import re
from typing import Dict, Any, Optional
from maim_message import UserInfo
from maim_message import UserInfo, Seg, GroupInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.person_info.person_info import Person
# 定义日志配置
@@ -27,7 +26,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否包含过滤词
Args:
@@ -40,14 +39,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否匹配过滤正则表达式
Args:
@@ -58,9 +57,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
# 检查text是否为None或空字符串
if text is None or not text:
return False
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
@@ -74,8 +77,6 @@ class ChatBot:
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
self.s4u_message_processor = S4UMessageProcessor()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
@@ -149,33 +150,29 @@ class ChatBot:
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("notice消息")
# print(message)
print(message)
return True
async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
await self.s4u_message_processor.process_message(message)
return
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
"""
用于专门处理回送消息ID的函数
"""
message_data: Dict[str, Any] = raw_data.get("content", {})
if not message_data:
return
message_type = message_data.get("type")
if message_type != "echo":
return
mmc_message_id = message_data.get("echo")
actual_message_id = message_data.get("actual_id")
if MessageStorage.update_message(mmc_message_id, actual_message_id):
logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}")
else:
logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
@@ -194,11 +191,6 @@ class ChatBot:
# 确保所有任务已启动
await self._ensure_started()
platform = message_data["message_info"].get("platform")
if platform == "amaidesu_default":
await self.do_s4u(message_data)
return
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
@@ -211,18 +203,35 @@ class ChatBot:
# print(message_data)
# logger.debug(str(message_data))
message = MessageRecv(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_MESSAGE_PRE_PROCESS, message
)
if not continue_flag:
return
if modified_message and modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if await self.handle_notice_message(message):
# return
pass
group_info = message.message_info.group_info
user_info = message.message_info.user_info
if message.message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
await MessageStorage.update_message(message)
return
# 处理消息内容,生成纯文本
await message.process()
# 过滤检查
if _check_ban_words(
message.processed_plain_text,
user_info, # type: ignore
group_info,
) or _check_ban_regex(
message.raw_message, # type: ignore
user_info, # type: ignore
group_info,
):
return
get_chat_manager().register_message(message)
@@ -234,21 +243,10 @@ class ChatBot:
message.update_chat_stream(chat)
# 处理消息内容,生成纯文本
await message.process()
# if await self.check_ban_content(message):
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return
# 命令处理 - 使用新插件系统检查并处理命令
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
@@ -258,8 +256,11 @@ class ChatBot:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
if not continue_flag:
return
if modified_message and modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:

View File

@@ -8,6 +8,7 @@ from typing import Optional, Any, List
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream
@@ -79,6 +80,14 @@ class Message(MessageBase):
if processed:
segments_text.append(processed)
return " ".join(segments_text)
elif segment.type == "forward":
segments_text = []
for node_dict in segment.data:
message = MessageBase.from_dict(node_dict) # type: ignore
processed_text = await self._process_message_segments(message.message_segment)
if processed_text:
segments_text.append(f"{global_config.bot.nickname}: {processed_text}")
return "[合并消息]: " + "\n-- ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment) # type: ignore
@@ -129,6 +138,7 @@ class MessageRecv(Message):
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
# print(f"self.message_segment: {self.message_segment}")
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
@@ -199,129 +209,6 @@ class MessageRecv(Message):
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
super().__init__(message_dict)
self.is_gift = False
self.is_fake_gift = False
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: Optional[str] = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
self.is_internal = False
self.voice_done = None
self.chat_info = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_voice = False
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "voice_done":
msg_id = segment.data
logger.info(f"voice_done: {msg_id}")
self.voice_done = msg_id
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True
self.screen_info = segment.data
return "屏幕信息"
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""

View File

@@ -18,7 +18,7 @@ class MessageStorage:
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
@@ -33,7 +33,6 @@ class MessageStorage:
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# print(message)
@@ -85,7 +84,7 @@ class MessageStorage:
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
selected_expressions = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
@@ -143,31 +142,26 @@ class MessageStorage:
# 如果需要其他存储相关的函数,可以在这里添加
@staticmethod
async def update_message(
message: MessageRecv,
) -> None: # 用于实时更新数据库的自身发送消息ID目前能处理text,reply,image和emoji
"""更新最新一条匹配消息的message_id"""
def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool:
"""实时更新数据库的自身发送消息ID"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
return False
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
else:
logger.debug("未找到匹配的消息")
return False
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
return False
@staticmethod
def replace_image_descriptions(text: str) -> str:

View File

@@ -2,6 +2,7 @@ import asyncio
import traceback
from rich.traceback import install
from maim_message import Seg
from src.common.message.api import get_global_api
from src.common.logger import get_logger
@@ -15,7 +16,7 @@ install(extra_lines=3)
logger = get_logger("sender")
async def send_message(message: MessageSending, show_log=True) -> bool:
async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=200)
@@ -32,7 +33,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
raise e # 重新抛出其他异常
class HeartFCSender:
class UniversalMessageSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self):
@@ -66,8 +67,36 @@ class HeartFCSender:
message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
return False
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。")
message.processed_plain_text = modified_message.plain_text
await message.process()
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_SEND, message=message, stream_id=chat_id
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
return False
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
if typing:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
@@ -76,10 +105,22 @@ class HeartFCSender:
)
await asyncio.sleep(typing_time)
sent_msg = await send_message(message, show_log=show_log)
sent_msg = await _send_message(message, show_log=show_log)
if not sent_msg:
return False
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_SEND, message=message, stream_id=chat_id
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
return True
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
if storage_message:
await self.storage.store_message(message, message.chat_stream)

View File

@@ -32,7 +32,7 @@ class ActionManager:
self,
action_name: str,
action_data: dict,
reasoning: str,
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
@@ -46,7 +46,7 @@ class ActionManager:
Args:
action_name: 动作名称
action_data: 动作数据
reasoning: 执行理由
action_reasoning: 执行理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流
@@ -77,7 +77,7 @@ class ActionManager:
# 创建动作实例
instance = component_class(
action_data=action_data,
reasoning=reasoning,
action_reasoning=action_reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=chat_stream,
@@ -124,4 +124,4 @@ class ActionManager:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")

View File

@@ -103,25 +103,23 @@ class ActionModifier:
self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
# === 第三阶段:激活类型判定 ===
# if chat_content is not None:
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions()
# 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions()
# 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions,
# chat_content,
# )
# 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions,
# chat_content,
# )
# 应用第三阶段的移除
# for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# 应用第三阶段的移除
# for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# === 统一日志记录 ===
all_removals = removals_s1 + removals_s2
@@ -131,9 +129,7 @@ class ActionModifier:
available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug(
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
)
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions: List[Tuple[str, str]] = []

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,8 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.memory_system.Memory_chest import global_memory_chest
from src.memory_system.questions import global_conflict_tracker
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
@@ -15,124 +16,36 @@ from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.mood.mood_manager import mood_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.chat.memory_system.memory_activator import MemoryActivator
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person, is_person_known
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
init_lpmm_prompt()
init_replyer_prompt()
init_rewrite_prompt()
logger = get_logger("replyer")
def init_prompt():
Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt(
"""
{expression_habits_block}
{relation_info_block}
{chat_target}
{time_block}
{chat_info}
{identity}
你现在的心情是{mood_state}
你正在{chat_target_2},{reply_target_block}
你想要对上述的发言进行回复回复的具体内容原句{raw_reply}
原因是{reason}
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息
你需要使用合适的语法和句法参考聊天内容组织一条日常且口语化的回复请你修改你想表达的原句符合你的表达风格和语言习惯
{reply_style}
你可以完全重组回复保留最基本的表达含义就好但重组后保持语意通顺
{keywords_reaction_prompt}
{moderation_prompt}
不要输出多余内容(包括前后缀冒号和引号括号表情包emoji,at或 @等 )只输出一条回复就好
现在你说
""",
"default_expressor_prompt",
)
# s4u 风格的 prompt 模板
Prompt(
"""{identity}
你正在群聊中聊天你想要回复 {sender_name} 的发言同时也有其他用户会参与聊天你可以参考他们的回复内容但是你现在想回复{sender_name}的发言
{time_block}
{background_dialogue_prompt}
{core_dialogue_prompt}
{expression_habits_block}{tool_info_block}
{knowledge_prompt}{memory_block}{relation_info_block}
{extra_info_block}
{reply_target_block}
你的心情{mood_state}
{reply_style}
注意不要复读你说过的话
{keywords_reaction_prompt}
请注意不要输出多余内容(包括前后缀冒号和引号at或 @等 )只输出回复内容
{moderation_prompt}
不要输出多余内容(包括前后缀冒号和引号括号()表情包emoji,at或 @等 )只输出一条回复就好
现在你说""",
"replyer_prompt",
)
Prompt(
"""{identity}
{time_block}
你现在正在一个QQ群里聊天以下是正在进行的聊天内容
{background_dialogue_prompt}
{expression_habits_block}{tool_info_block}
{knowledge_prompt}{memory_block}{relation_info_block}
{extra_info_block}
你现在想补充说明你刚刚自己的发言内容{target}原因是{reason}
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复
注意保持上下文的连贯性
你现在的心情是{mood_state}
{reply_style}
{keywords_reaction_prompt}
请注意不要输出多余内容(包括前后缀冒号和引号at或 @等 )只输出回复内容
{moderation_prompt}
不要输出多余内容(包括前后缀冒号和引号括号()表情包emoji,at或 @等 )只输出一条回复就好
现在你说
""",
"replyer_self_prompt",
)
Prompt(
"""
你是一个专门获取知识的助手你的名字是{bot_name}现在是{time_now}
群里正在进行的聊天内容
{chat_history}
现在{sender}发送了内容:{target_message},你想要回复ta
请仔细分析聊天内容考虑以下几点
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的知识获取指令
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
""",
name="lpmm_get_knowledge_prompt",
)
class DefaultReplyer:
def __init__(
self,
@@ -142,8 +55,8 @@ class DefaultReplyer:
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator()
self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
@@ -159,6 +72,7 @@ class DefaultReplyer:
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@@ -192,6 +106,7 @@ class DefaultReplyer:
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
reply_time_point=reply_time_point,
)
llm_response.prompt = prompt
llm_response.selected_expressions = selected_expressions
@@ -202,10 +117,14 @@ class DefaultReplyer:
from src.plugin_system.core.events_manager import events_manager
if not from_plugin:
if not await events_manager.handle_mai_events(
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
):
)
if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成")
if modified_message and modified_message._modify_flags.modify_llm_prompt:
llm_response.prompt = modified_message.llm_prompt
prompt = str(modified_message.llm_prompt)
# 4. 调用 LLM 生成回复
content = None
@@ -219,10 +138,19 @@ class DefaultReplyer:
llm_response.reasoning = reasoning_content
llm_response.model = model_name
llm_response.tool_calls = tool_call
if not from_plugin and not await events_manager.handle_mai_events(
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
):
)
if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成")
if modified_message:
if modified_message._modify_flags.modify_llm_prompt:
logger.warning("警告插件在内容生成后才修改了prompt此修改不会生效")
llm_response.prompt = modified_message.llm_prompt # 虽然我不知道为什么在这里需要改prompt
if modified_message._modify_flags.modify_llm_response_content:
llm_response.content = modified_message.llm_response_content
if modified_message._modify_flags.modify_llm_response_reasoning:
llm_response.reasoning = modified_message.llm_response_reasoning
except UserWarning as e:
raise e
except Exception as llm_e:
@@ -293,24 +221,6 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
async def build_relation_info(self, sender: str, target: str):
if not global_config.relationship.enable_relationship:
return ""
if not sender:
return ""
if sender == global_config.bot.nickname:
return ""
# 获取用户ID
person = Person(person_name=sender)
if not is_person_known(person_name=sender):
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return person.build_relationship()
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@@ -348,46 +258,42 @@ class DefaultReplyer:
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
)
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_memory_block(self) -> str:
"""构建记忆块
Args:
chat_history: 聊天历史记录
target: 目标消息内容
Returns:
str: 记忆信息字符串
"""
# if not global_config.memory.enable_memory:
# return ""
if not global_config.memory.enable_memory:
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
else:
return ""
instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history=chat_history
)
running_memories = None
if not running_memories:
async def build_question_block(self) -> str:
"""构建问题块"""
# if not global_config.question.enable_question:
# return ""
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
questions_str = ""
for question in questions:
questions_str += f"- {question.question}\n"
if questions_str:
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
else:
return ""
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
keywords, content = running_memory
memory_str += f"- {keywords}{content}\n"
if instant_memory:
memory_str += f"- {instant_memory}\n"
return memory_str
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -417,7 +323,7 @@ class DefaultReplyer:
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
tool_info_str += f"- 【{tool_name}】: {content}\n"
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
@@ -453,6 +359,64 @@ class DefaultReplyer:
target = parts[1].strip()
return sender, target
def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述
Args:
text: 包含picid标记的文本
Returns:
替换后的文本
"""
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式
Args:
target: 目标消息内容包含[picid:xxx]格式
Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
"""
if not target or not target.strip():
return False, False, "", ""
# 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式)
pic_part = ""
if picid_matches:
pic_descriptions = []
for picid_match in picid_matches:
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分从第7个字符开始
description = translate_pid_to_description(pic_id)
logger.info(f"图片ID: {pic_id}, 描述: {description}")
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
if description == "[图片]":
pic_descriptions.append(description)
else:
pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@@ -511,11 +475,10 @@ class DefaultReplyer:
duration = end_time - start_time
return name, result, duration
def build_s4u_chat_history_prompts(
def build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
Args:
message_list_before_now: 历史消息列表
@@ -539,18 +502,6 @@ class DefaultReplyer:
except Exception as e:
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
@@ -583,60 +534,30 @@ class DefaultReplyer:
--------------------------------
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
)
if core_dialogue_prompt:
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
else:
all_dialogue_prompt = f"{all_dialogue_prompt_str}"
return core_dialogue_prompt, all_dialogue_prompt
def build_mai_think_context(
self,
chat_id: str,
memory_block: str,
relation_info: str,
time_block: str,
chat_target_1: str,
chat_target_2: str,
mood_prompt: str,
identity_block: str,
sender: str,
target: str,
chat_info: str,
) -> Any:
"""构建 mai_think 上下文信息
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
relation_info: 关系信息
time_block: 时间块内容
chat_target_1: 聊天目标1
chat_target_2: 聊天目标2
mood_prompt: 情绪提示
identity_block: 身份块内容
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
Returns:
Any: mai_think 实例
"""
mai_think = mai_thinking_manager.get_mai_think(chat_id)
mai_think.memory_block = memory_block
mai_think.relation_info_block = relation_info
mai_think.time_block = time_block
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = chat_info
mai_think.mood_state = mood_prompt
mai_think.identity = identity_block
mai_think.sender = sender
mai_think.target = target
return mai_think
async def build_actions_prompt(
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str:
"""构建动作提示"""
action_descriptions = ""
skip_names = ["emoji","build_memory","build_relation","reply"]
skip_names = ["emoji", "build_memory", "build_relation", "reply"]
if available_actions:
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
for action_name, action_info in available_actions.items():
@@ -673,19 +594,18 @@ class DefaultReplyer:
else:
bot_nickname = ""
prompt_personality = (
f"{global_config.personality.personality};"
)
prompt_personality = f"{global_config.personality.personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def build_prompt_reply_context(
self,
reply_message: DatabaseMessages,
reply_message: Optional[DatabaseMessages] = None,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[str, List[int]]:
"""
构建回复器上下文
@@ -720,26 +640,46 @@ class DefaultReplyer:
sender = person_name
target = reply_message.processed_plain_text
mood_prompt: str = ""
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
timestamp=reply_time_point,
limit=global_config.chat.max_context_size * 1,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
timestamp=reply_time_point,
limit=int(global_config.chat.max_context_size * 0.33),
)
person_list_short: List[Person] = []
for msg in message_list_before_short:
if (
global_config.bot.qq_account == msg.user_info.user_id
and global_config.bot.platform == msg.user_info.platform
):
continue
if (
reply_message
and reply_message.user_info.user_id == msg.user_info.user_id
and reply_message.user_info.platform == msg.user_info.platform
):
continue
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
if person.is_known:
person_list_short.append(person)
# for person in person_list_short:
# print(person.person_name)
chat_talking_prompt_short = build_readable_messages(
message_list_before_short,
replace_bot_name=True,
@@ -753,25 +693,29 @@ class DefaultReplyer:
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task(self.build_memory_block(), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
self._time_and_run_task(self.build_question_block(), "question_block"),
)
# 任务名称中英文映射
task_name_mapping = {
"expression_habits": "选取表达方式",
"relation_info": "感受关系",
"memory_block": "回忆",
# "memory_block": "回忆",
"memory_block": "记忆",
"tool_info": "使用工具",
"prompt_info": "获取知识",
"actions_info": "动作信息",
"personality_prompt": "人格信息",
"mood_state_prompt": "情绪状态",
"question_block": "问题",
}
# 处理结果
@@ -794,13 +738,16 @@ class DefaultReplyer:
expression_habits_block, selected_expressions = results_dict["expression_habits"]
expression_habits_block: str
selected_expressions: List[int]
relation_info: str = results_dict["relation_info"]
# relation_info: str = results_dict["relation_info"]
# memory_block: str = results_dict["memory_block"]
memory_block: str = results_dict["memory_block"]
tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
personality_prompt: str = results_dict["personality_prompt"]
question_block: str = results_dict["question_block"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
mood_state_prompt: str = results_dict["mood_state_prompt"]
if extra_info:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
@@ -811,69 +758,50 @@ class DefaultReplyer:
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
if sender:
if is_group_chat:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}"
)
else: # private chat
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}"
)
# 使用预先分析的内容类型结果
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
elif has_text:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意"
else:
# 其他情况(空内容等)
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
else:
reply_target_block = ""
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
message_list_before_now_long, user_id, sender
)
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
return await global_prompt_manager.format_prompt(
"replyer_self_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
target=target,
reason=reply_reason,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
), selected_expressions
else:
return await global_prompt_manager.format_prompt(
"replyer_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
sender_name=sender,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
), selected_expressions
return await global_prompt_manager.format_prompt(
"replyer_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
memory_block=memory_block,
knowledge_prompt=prompt_info,
mood_state=mood_state_prompt,
# memory_block=memory_block,
# relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
sender_name=sender,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
question_block=question_block,
), selected_expressions
async def build_prompt_rewrite_context(
self,
@@ -887,14 +815,12 @@ class DefaultReplyer:
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 添加情绪状态获取
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
else:
mood_prompt = ""
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
@@ -910,9 +836,8 @@ class DefaultReplyer:
)
# 并行执行2个构建任务
(expression_habits_block, _), relation_info, personality_prompt = await asyncio.gather(
(expression_habits_block, _), personality_prompt = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(sender, target),
self.build_personality_prompt(),
)
@@ -925,18 +850,39 @@ class DefaultReplyer:
)
if sender and target:
# 使用预先分析的内容类型结果
if is_group_chat:
if sender:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else:
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else: # private chat
if sender:
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
@@ -963,7 +909,7 @@ class DefaultReplyer:
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info,
# relation_info_block=relation_info,
chat_target=chat_target_1,
time_block=time_block,
chat_info=chat_talking_prompt_half,
@@ -972,7 +918,6 @@ class DefaultReplyer:
reply_target_block=reply_target_block,
raw_reply=raw_reply,
reason=reason,
mood_state=mood_prompt, # 添加情绪状态参数
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
@@ -1015,20 +960,18 @@ class DefaultReplyer:
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 直接使用已初始化的模型实例
logger.info(f"使用模型集生成回复: {', '.join(map(str, self.express_model.model_for_task.model_list))}")
logger.info(f"\n{prompt}\n")
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
logger.debug(f"replyer生成内容: {content}")
logger.info(f"使用 {model_name} 生成回复内容: {content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):
@@ -1115,6 +1058,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx)
break
return selected
init_prompt()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
from src.chat.utils.prompt_builder import Prompt
# from src.memory_system.memory_activator import MemoryActivator
def init_lpmm_prompt():
Prompt(
"""
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的知识获取指令
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
""",
name="lpmm_get_knowledge_prompt",
)

View File

@@ -0,0 +1,71 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
def init_replyer_prompt():
Prompt("你正在qq群里聊天下面是群里正在聊的内容:", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_block}{question_block}
你正在qq群里聊天下面是群里正在聊的内容:
{time_block}
{background_dialogue_prompt}
{core_dialogue_prompt}
{reply_target_block}
{identity}
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。
现在,你说:""",
"replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_block}
你正在和{sender_name}聊天,这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
{reply_target_block}
{identity}
你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_block}
你正在和{sender_name}聊天,这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
{identity}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。
""",
"private_replyer_self_prompt",
)

View File

@@ -0,0 +1,31 @@
from src.chat.utils.prompt_builder import Prompt
# from src.memory_system.memory_activator import MemoryActivator
def init_rewrite_prompt():
Prompt("你正在qq群里聊天下面是群里正在聊的内容:", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt(
"""
{expression_habits_block}
{chat_target}
{chat_info}
{identity}
你正在{chat_target_2},{reply_target_block}
现在请你对这句内容进行改写,请你参考上述内容进行改写,原句是:{raw_reply}
原因是:{reason}
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息。
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
{reply_style}
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
{keywords_reaction_prompt}
{moderation_prompt}
不要输出多余内容(包括冒号和引号表情包emoji,at或 @等 ),只输出一条回复就好。
改写后的回复:
""",
"default_expressor_prompt",
)

View File

@@ -2,21 +2,22 @@ from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
logger = get_logger("ReplyerManager")
class ReplyerManager:
def __init__(self):
self._repliers: Dict[str, DefaultReplyer] = {}
self._repliers: Dict[str, DefaultReplyer | PrivateReplyer] = {}
def get_replyer(
self,
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
) -> Optional[DefaultReplyer | PrivateReplyer]:
"""
获取或创建回复器实例。
@@ -46,10 +47,17 @@ class ReplyerManager:
return None
# model_configs 只在此时(初始化时)生效
replyer = DefaultReplyer(
chat_stream=target_stream,
request_type=request_type,
)
if target_stream.group_info:
replyer = DefaultReplyer(
chat_stream=target_stream,
request_type=request_type,
)
else:
replyer = PrivateReplyer(
chat_stream=target_stream,
request_type=request_type,
)
self._repliers[stream_id] = replyer
return replyer

View File

@@ -2,7 +2,7 @@ import time
import random
import re
from typing import List, Dict, Any, Tuple, Optional, Callable
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
from rich.traceback import install
from src.config.config import global_config
@@ -124,6 +124,7 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
# print(f"get_raw_msg_by_timestamp_with_chat: {chat_id}, {timestamp_start}, {timestamp_end}, {limit}, {limit_mode}, {filter_bot}, {filter_command}")
return find_messages(
message_filter=filter_query,
sort=sort_order,
@@ -215,6 +216,7 @@ def get_actions_by_timestamp_with_chat(
chat_id=action.chat_id,
chat_info_stream_id=action.chat_info_stream_id,
chat_info_platform=action.chat_info_platform,
action_reasoning=action.action_reasoning,
)
for action in actions
]
@@ -417,12 +419,6 @@ def _build_readable_messages_internal(
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
# 向下兼容
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
if show_pic:
content = process_pic_ids(content)
@@ -564,14 +560,12 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = []
current_time = time.time()
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
# sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True)
for action in actions:
action_time = action.time or current_time
action_name = action.action_name or "未知动作"
# action_reason = action.get(action_data")
if action_name in ["no_action", "no_action"]:
if action_name in ["no_reply", "no_reply"]:
continue
action_prompt_display = action.action_prompt_display or "无具体内容"
@@ -593,6 +587,7 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line)
return "\n".join(output_lines)
@@ -628,6 +623,7 @@ def build_readable_messages_with_id(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
remove_emoji_stickers: bool = False,
) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
@@ -644,6 +640,7 @@ def build_readable_messages_with_id(
show_pic=show_pic,
read_mark=read_mark,
message_id_list=message_id_list,
remove_emoji_stickers=remove_emoji_stickers,
)
return formatted_string, message_id_list
@@ -658,6 +655,7 @@ def build_readable_messages(
show_actions: bool = False,
show_pic: bool = True,
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
remove_emoji_stickers: bool = False,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
@@ -672,13 +670,40 @@ def build_readable_messages(
read_mark: 已读标记时间戳
truncate: 是否截断长消息
show_actions: 是否显示动作记录
remove_emoji_stickers: 是否移除表情包并过滤空消息
"""
# WIP HERE and BELOW ----------------------------------------------
# 创建messages的深拷贝避免修改原始列表
if not messages:
return ""
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
# 如果启用移除表情包,先过滤消息
if remove_emoji_stickers:
filtered_messages = []
for msg in messages:
# 获取消息内容
content = msg.processed_plain_text
# 移除表情包
emoji_pattern = r"\[表情包:[^\]]+\]"
content = re.sub(emoji_pattern, "", content)
# 如果移除表情包后内容不为空,则保留消息
if content.strip():
filtered_messages.append(msg)
messages = filtered_messages
copy_messages: List[MessageAndActionModel] = []
for msg in messages:
if remove_emoji_stickers:
# 创建 MessageAndActionModel 但移除表情包
model = MessageAndActionModel.from_DatabaseMessages(msg)
# 移除表情包
if model.processed_plain_text:
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
copy_messages.append(model)
else:
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
if show_actions and copy_messages:
# 获取所有消息的时间范围
@@ -862,17 +887,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
content = process_pic_ids(content)
# if not all([platform, user_id, timestamp is not None]):
# continue
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
@@ -909,6 +926,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
@@ -937,3 +955,45 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
"""
构建简化版消息字符串只包含processed_plain_text内容不考虑用户名和时间戳
Args:
messages: 消息列表
Returns:
只包含消息内容的字符串
"""
if not messages:
return ""
output_lines = []
for msg in messages:
# 获取纯文本内容
content = msg.processed_plain_text or ""
# 处理图片ID
pic_pattern = r"\[picid:[^\]]+\]"
def replace_pic_id(match):
return "[图片]"
content = re.sub(pic_pattern, replace_pic_id, content)
# 处理用户引用格式,移除回复和@标记
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
content = re.sub(reply_pattern, "回复[某人]", content)
at_pattern = r"@<[^:<>]+:[^:<>]+>"
content = re.sub(at_pattern, "@[某人]", content)
# 清理并添加到输出
content = content.strip()
if content:
output_lines.append(content)
return "\n".join(output_lines)

View File

@@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记""" # type: ignore
"""处理模板中的转义花括号,替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)

View File

@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
for item_name in stats[period_key][category]:
time_costs = stats[period_key][time_cost_key].get(item_name, [])
if time_costs:
# 计算平均耗时
avg_time_cost = sum(time_costs) / len(time_costs)
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
# 计算标准差
if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5
std_time_cost = variance**0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else:
stats[period_key][std_key][item_name] = 0.0
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
output.append(
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
)
output.append("")
return "\n".join(output)
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_MODEL]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按请求类型分类统计
type_rows = "\n".join(
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_TYPE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按模块分类统计
module_rows = "\n".join(
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_MODULE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 聊天消息统计
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
[
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[MSG_CNT_BY_CHAT]
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 生成HTML
return f"""

View File

@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
reply_probability = 0.0
is_at = False
is_mentioned = False
# 这部分怎么处理啊啊啊啊
#我觉得可以给消息加一个 reply_probability_boost字段
# 我觉得可以给消息加一个 reply_probability_boost字段
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
@@ -339,7 +339,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
else:
split_sentences = [cleaned_text]
sentences = []
sentences: List[str] = []
for sentence in split_sentences:
if global_config.chinese_typo.enable and enable_chinese_typo:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
@@ -383,10 +383,6 @@ def calculate_typing_time(
- 在所有输入结束后额外加上回车时间0.3秒
- 如果is_emoji为True将使用固定1秒的输入时间
"""
# # 将0-1的唤醒度映射到-1到1
# mood_arousal = mood_manager.current_mood.arousal
# # 映射到0.5到2倍的速度系数
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
@@ -826,20 +822,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [keywords_str] if keywords_str else []
def cut_key_words(concept_name: str) -> list[str]:
"""对概念名称进行jieba分词并过滤掉关键词列表中的关键词"""
concept_name_tokens = list(jieba.cut(concept_name))
# 定义常见连词、停用词与标点
conjunctions = {
"", "", "", "", "以及", "并且", "而且", "", "或者", ""
}
conjunctions = {"", "", "", "", "以及", "并且", "而且", "", "或者", ""}
stop_words = {
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "而且", "或者", "", "以及"
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"而且",
"或者",
"",
"以及",
}
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
@@ -864,11 +888,16 @@ def cut_key_words(concept_name: str) -> list[str]:
left = merged_tokens[-1]
right = cleaned_tokens[i + 1]
# 左右都需要是有效词
if left and right \
and left not in conjunctions and right not in conjunctions \
and left not in stop_words and right not in stop_words \
and not all(ch in chinese_punctuations for ch in left) \
and not all(ch in chinese_punctuations for ch in right):
if (
left
and right
and left not in conjunctions
and right not in conjunctions
and left not in stop_words
and right not in stop_words
and not all(ch in chinese_punctuations for ch in left)
and not all(ch in chinese_punctuations for ch in right)
):
# 合并为一个新词,并替换掉左侧与跳过右侧
combined = f"{left}{tok}{right}"
merged_tokens[-1] = combined
@@ -889,7 +918,7 @@ def cut_key_words(concept_name: str) -> list[str]:
if tok in stop_words:
continue
# if tok in ban_words:
# continue
# continue
if all(ch in chinese_punctuations for ch in tok):
continue
if tok.strip() == "":
@@ -899,4 +928,4 @@ def cut_key_words(concept_name: str) -> list[str]:
result_tokens.append(tok)
filtered_concept_name_tokens = result_tokens
return filtered_concept_name_tokens
return filtered_concept_name_tokens

View File

@@ -91,9 +91,10 @@ class ImageManager:
desc_obj.save()
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
@@ -120,6 +121,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
if tags:
@@ -144,14 +146,14 @@ class ImageManager:
return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
vlm_prompt, image_base64_processed, "jpg", temperature=0.4
)
else:
vlm_prompt = (
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
vlm_prompt, image_base64, image_format, temperature=0.4
)
if detailed_description is None:
@@ -172,9 +174,7 @@ class ImageManager:
# 使用较低温度确保输出稳定
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
emotion_result, _ = await emotion_llm.generate_response_async(
emotion_prompt, temperature=0.3, max_tokens=50
)
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt, temperature=0.3)
if not emotion_result:
logger.warning("LLM未能生成情感标签使用详细描述的前几个词")
@@ -220,11 +220,13 @@ class ImageManager:
img_obj.save()
except Images.DoesNotExist: # type: ignore
Images.create(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
vlm_processed=True,
)
except Exception as e:
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
@@ -268,7 +270,7 @@ class ImageManager:
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
@@ -564,7 +566,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
prompt = global_config.personality.visual_style
# 获取VLM描述
description, _ = await self.vlm.generate_response_for_image(
@@ -621,3 +623,41 @@ def image_path_to_base64(image_path: str) -> str:
return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}")
def base64_to_image(image_base64: str, output_path: str) -> bool:
"""将base64编码的图片保存为文件
Args:
image_base64: 图片的base64编码
output_path: 输出文件路径
Returns:
bool: 是否成功保存
Raises:
ValueError: 当base64编码无效时
IOError: 当保存文件失败时
"""
try:
# 确保base64字符串只包含ASCII字符
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
# 解码base64
image_bytes = base64.b64decode(image_base64)
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# 保存文件
with open(output_path, "wb") as f:
f.write(image_bytes)
return True
except Exception as e:
logger.error(f"保存base64图片失败: {e}")
return False

View File

@@ -6,7 +6,8 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
def transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例

View File

@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
@@ -219,6 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
chat_id: str,
chat_info_stream_id: str,
chat_info_platform: str,
action_reasoning:str
):
self.action_id = action_id
self.time = time
@@ -232,4 +234,5 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.chat_info_platform = chat_info_platform
self.action_reasoning = action_reasoning

View File

@@ -23,3 +23,5 @@ class ActionPlannerInfo(BaseDataModel):
action_data: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
loop_start_time: Optional[float] = None
action_reasoning: Optional[str] = None

View File

@@ -1,10 +1,13 @@
from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from typing import Optional, List, TYPE_CHECKING
from . import BaseDataModel
if TYPE_CHECKING:
from src.common.data_models.message_data_model import ReplySetModel
from src.llm_models.payload_content.tool_option import ToolCall
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
@@ -13,4 +16,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
reply_set: Optional["ReplySetModel"] = None

View File

@@ -1,5 +1,6 @@
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, List, Tuple, Union, Dict, Any
from dataclasses import dataclass, field
from enum import Enum
from . import BaseDataModel
@@ -34,3 +35,172 @@ class MessageAndActionModel(BaseDataModel):
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
)
class ReplyContentType(Enum):
TEXT = "text"
IMAGE = "image"
EMOJI = "emoji"
COMMAND = "command"
VOICE = "voice"
FORWARD = "forward"
HYBRID = "hybrid" # 混合类型,包含多种内容
def __repr__(self) -> str:
return self.value
@dataclass
class ForwardNode(BaseDataModel):
user_id: Optional[str] = None
user_nickname: Optional[str] = None
content: Union[List["ReplyContent"], str] = field(default_factory=list)
@classmethod
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
return cls(user_id="", user_nickname="", content=message_id)
@classmethod
def construct_as_created_node(
cls, user_id: str, user_nickname: str, content: List["ReplyContent"]
) -> "ForwardNode":
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
@dataclass
class ReplyContent(BaseDataModel):
content_type: ReplyContentType | str
content: Union[str, Dict, List[ForwardNode], List["ReplyContent"]] # 支持嵌套的 ReplyContent
@classmethod
def construct_as_text(cls, text: str):
return cls(content_type=ReplyContentType.TEXT, content=text)
@classmethod
def construct_as_image(cls, image_base64: str):
return cls(content_type=ReplyContentType.IMAGE, content=image_base64)
@classmethod
def construct_as_voice(cls, voice_base64: str):
return cls(content_type=ReplyContentType.VOICE, content=voice_base64)
@classmethod
def construct_as_emoji(cls, emoji_str: str):
return cls(content_type=ReplyContentType.EMOJI, content=emoji_str)
@classmethod
def construct_as_command(cls, command_arg: Dict):
return cls(content_type=ReplyContentType.COMMAND, content=command_arg)
@classmethod
def construct_as_hybrid(cls, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
hybrid_content_list: List[ReplyContent] = []
for content_type, content in hybrid_content:
assert content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content, str), "混合内容的每个项必须是字符串"
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
return cls(content_type=ReplyContentType.HYBRID, content=hybrid_content_list)
@classmethod
def construct_as_forward(cls, forward_nodes: List[ForwardNode]):
return cls(content_type=ReplyContentType.FORWARD, content=forward_nodes)
def __post_init__(self):
if isinstance(self.content_type, ReplyContentType):
if self.content_type not in [ReplyContentType.HYBRID, ReplyContentType.FORWARD] and isinstance(
self.content, List
):
raise ValueError(
f"非混合类型/转发类型的内容不能是列表content_type: {self.content_type}, content: {self.content}"
)
elif self.content_type in [ReplyContentType.HYBRID, ReplyContentType.FORWARD]:
if not isinstance(self.content, List):
raise ValueError(
f"混合类型/转发类型的内容必须是列表content_type: {self.content_type}, content: {self.content}"
)
@dataclass
class ReplySetModel(BaseDataModel):
"""
回复集数据模型,用于多种回复类型的返回
"""
reply_data: List[ReplyContent] = field(default_factory=list)
def __len__(self):
return len(self.reply_data)
def add_text_content(self, text: str):
"""
添加文本内容
Args:
text: 文本内容
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
def add_image_content(self, image_base64: str):
"""
添加图片内容base64编码的图片数据
Args:
image_base64: base64编码的图片数据
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.IMAGE, content=image_base64))
def add_voice_content(self, voice_base64: str):
"""
添加语音内容base64编码的音频数据
Args:
voice_base64: base64编码的音频数据
"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
def add_hybrid_content_by_raw(self, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
"""
添加混合型内容可以包含text, image, emoji的任意组合
Args:
hybrid_content: 元组 (类型, 消息内容) 构成的列表,如[(ReplyContentType.TEXT, "Hello"), (ReplyContentType.IMAGE, "<base64")]
"""
hybrid_content_list: List[ReplyContent] = []
for content_type, content in hybrid_content:
assert content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content, str), "混合内容的每个项必须是字符串"
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content_list))
def add_hybrid_content(self, hybrid_content: List[ReplyContent]):
"""
添加混合型内容,使用已经构造好的 ReplyContent 列表
Args:
hybrid_content: ReplyContent 构成的列表,如[ReplyContent(ReplyContentType.TEXT, "Hello"), ReplyContent(ReplyContentType.IMAGE, "<base64")]
"""
for content in hybrid_content:
assert content.content_type not in [
ReplyContentType.HYBRID,
ReplyContentType.FORWARD,
ReplyContentType.VOICE,
ReplyContentType.COMMAND,
], "混合内容的每个项不能是混合、转发、语音或命令类型"
assert isinstance(content.content, str), "混合内容的每个项必须是字符串"
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content))
def add_custom_content(self, content_type: str, content: Any):
"""
添加自定义类型的内容"""
self.reply_data.append(ReplyContent(content_type=content_type, content=content))
def add_forward_content(self, forward_content: List[ForwardNode]):
"""添加转发内容可以是字符串或ReplyContent嵌套的转发内容需要自己构造放入"""
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_content))

View File

@@ -0,0 +1,36 @@
# 对于`message_data_model.py`中`class ReplyContent`的规划解读
分类讨论如下:
- `ReplyContent.TEXT`: 单独的文本,`_level = 0``content``str`类型。
- `ReplyContent.IMAGE`: 单独的图片,`_level = 0``content``str`类型图片base64
- `ReplyContent.EMOJI`: 单独的表情包,`_level = 0``content``str`类型图片base64
- `ReplyContent.VOICE`: 单独的语音,`_level = 0``content``str`类型语音base64
- `ReplyContent.HYBRID`: 混合内容,`_level = 0`
- 其应该是一个列表,列表内应该只接受`str`类型的内容(图片和文本混合体)
- `ReplyContent.FORWARD`: 转发消息,`_level = n`
- 其应该是一个列表,列表接受`str`类型(图片/文本),`ReplyContent`类型(嵌套转发,嵌套有最高层数限制)
- `ReplyContent.COMMAND`: 指令消息,`_level = 0`
- 其应该是一个列表,列表内应该只接受`Dict`类型的内容
未来规划:
- `ReplyContent.AT` 单独的艾特,`_level = 0``content``str`类型用户ID
内容构造方式:
- 对于`TEXT`, `IMAGE`, `EMOJI`, `VOICE`,直接传入对应类型的内容,且`content`应该为`str`
- 对于`COMMAND`,传入一个字典,字典内的内容类型应符合上述规定。
- 对于`HYBRID`, `FORWARD`,传入一个列表,列表内的内容类型应符合上述规定。
因此,我们的类型注解应该是:
```python
from typing import Union, List, Dict
ReplyContentType = Union[
str, # TEXT, IMAGE, EMOJI, VOICE
List[Union[str, 'ReplyContent']], # HYBRID, FORWARD
Dict # COMMAND
]
```
现在`_level`被移除了,在解析的时候显式地检查内容的类型和结构即可。
`send_api`的custom_reply_set_to_stream仅在特定的类型下提供reply)message

View File

@@ -0,0 +1,57 @@
# 有关转发消息和其他消息的构建类型说明
```mermaid
graph LR;
direction TB;
A[ReplySet] --- B[ReplyContent];
A --- C["ReplyContent"];
A --- K["ReplyContent"];
A --- L["ReplyContent"];
A --- N["ReplyContent"];
A --- D[...];
B --- E["Text (in str)"];
B --- F["Image (in base64)"];
C --- G["Voice (in base64)"];
B --- I["Emoji (in base64)"];
subgraph "可行内容(以下的任意组合)";
subgraph "转发消息(Forward)"
M["List[ForwardNode]"]
end
subgraph "混合消息(Hybrid)"
J["List[ReplyContent] (要求只能包含普通消息)"]
end
subgraph "命令消息(Command)"
H["Command (in Dict)"]
end
subgraph "语音消息"
G
end
subgraph "普通消息"
E
F
I
end
end
N --- H
K --- J
L --- M
subgraph ForwardNodes
O["ForwardNode"]
P["ForwardNode"]
Q["ForwardNode"]
end
M --- O
M --- P
M --- Q
subgraph "内容 (message_id引用法)"
P --- U["content: str, 引用已有消息的有效ID"];
end
subgraph "内容 (生成法)"
O --- R["user_id: str"];
O --- S["user_nickname: str"];
O --- T["content: List[ReplyContent], 为这个转发节点的消息内容"];
end
```
另外,自定义消息类型我们在这里不做讨论。
以上列出了所有可能的ReplySet构建方式下面我们来解释一下各个类型的含义。

View File

@@ -1,64 +1,9 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
install(extra_lines=3)
_client = None
_db = None
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串
if uri.startswith(("mongodb://", "mongodb+srv://")):
return MongoClient(uri)
else:
raise ValueError(
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
"For MongoDB Atlas, use 'mongodb+srv://' format. "
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
)
if username and password:
# 如果有用户名和密码,使用认证连接
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
# 否则使用无认证连接
return MongoClient(host, port)
def get_db():
"""获取数据库连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
return _db
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
def __getattr__(self, name):
return getattr(get_db(), name)
def __getitem__(self, key):
return get_db()[key] # type: ignore
# 全局数据库访问点
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))

View File

@@ -135,7 +135,7 @@ class Messages(BaseModel):
interest_value = DoubleField(null=True)
key_words = TextField(null=True)
key_words_lite = TextField(null=True)
is_mentioned = BooleanField(null=True)
is_at = BooleanField(null=True)
reply_probability_boost = DoubleField(null=True)
@@ -169,7 +169,7 @@ class Messages(BaseModel):
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
is_notify = BooleanField(default=False)
selected_expressions = TextField(null=True)
class Meta:
@@ -185,6 +185,8 @@ class ActionRecords(BaseModel):
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
action_reasoning = TextField(null=True)
action_name = TextField()
action_data = TextField()
action_done = BooleanField(default=False)
@@ -267,12 +269,6 @@ class PersonInfo(BaseModel):
know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间
attitude_to_me = TextField(null=True) # 对bot的态度
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
class Meta:
# database = db # 继承自 BaseModel
@@ -299,6 +295,7 @@ class GroupInfo(BaseModel):
# database = db # 继承自 BaseModel
table_name = "group_info"
class Expression(BaseModel):
"""
用于存储表达风格的模型。
@@ -307,6 +304,11 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
count = FloatField()
# new mode fields
context = TextField(null=True)
context_words = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)
type = TextField()
@@ -315,36 +317,35 @@ class Expression(BaseModel):
class Meta:
table_name = "expression"
class GraphNodes(BaseModel):
class MemoryChest(BaseModel):
"""
用于存储记忆图节点的模型
用于存储记忆仓库的模型
"""
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
weight = FloatField(default=0.0) # 节点权重
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
title = TextField() # 标题
content = TextField() # 内容
chat_id = TextField(null=True) # 聊天ID
locked = BooleanField(default=False) # 是否锁定
class Meta:
table_name = "graph_nodes"
table_name = "memory_chest"
class GraphEdges(BaseModel):
class MemoryConflict(BaseModel):
"""
用于存储记忆图边的模型
用于存储记忆整合过程中冲突内容的模型
"""
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
conflict_content = TextField() # 冲突内容
answer = TextField(null=True) # 回答内容
create_time = FloatField() # 创建时间
update_time = FloatField() # 更新时间
context = TextField(null=True) # 上下文
chat_id = TextField(null=True) # 聊天ID
raise_time = FloatField(null=True) # 触发次数
class Meta:
table_name = "graph_edges"
table_name = "memory_conflicts"
def create_tables():
@@ -363,9 +364,9 @@ def create_tables():
OnlineTime,
PersonInfo,
Expression,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
ActionRecords, # 添加 ActionRecords 到初始化列表
MemoryChest,
MemoryConflict, # 添加记忆冲突表
]
)
@@ -374,7 +375,7 @@ def initialize_database(sync_constraints=False):
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则自动添加。
Args:
sync_constraints (bool): 是否同步字段约束。默认为 False。
如果为 True会检查并修复字段的 NULL 约束不一致问题。
@@ -390,9 +391,9 @@ def initialize_database(sync_constraints=False):
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords, # 添加 ActionRecords 到初始化列表
MemoryChest,
MemoryConflict,
]
try:
@@ -456,13 +457,13 @@ def initialize_database(sync_constraints=False):
logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}")
# 如果启用了约束同步,执行约束检查和修复
if sync_constraints:
logger.debug("开始同步数据库字段约束...")
sync_field_constraints()
logger.debug("数据库字段约束同步完成")
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
@@ -476,7 +477,7 @@ def sync_field_constraints():
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
如果发现不一致,会自动修复字段约束。
"""
models = [
ChatStreams,
LLMUsage,
@@ -487,9 +488,9 @@ def sync_field_constraints():
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords,
MemoryChest,
MemoryConflict,
]
try:
@@ -501,50 +502,55 @@ def sync_field_constraints():
continue
logger.debug(f"检查表 '{table_name}' 的字段约束...")
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
current_schema = {
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
# 检查每个模型字段的约束
constraints_to_fix = []
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue # 字段不存在,跳过
current_notnull = current_schema[field_name]['notnull']
current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null
# 如果模型允许 null 但数据库字段不允许 null需要修复
if model_allows_null and current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'allow_null',
'current_constraint': 'NOT NULL',
'target_constraint': 'NULL'
})
constraints_to_fix.append(
{
"field_name": field_name,
"field_obj": field_obj,
"action": "allow_null",
"current_constraint": "NOT NULL",
"target_constraint": "NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL")
# 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心
elif not model_allows_null and not current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'disallow_null',
'current_constraint': 'NULL',
'target_constraint': 'NOT NULL'
})
constraints_to_fix.append(
{
"field_name": field_name,
"field_obj": field_obj,
"action": "disallow_null",
"current_constraint": "NULL",
"target_constraint": "NOT NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL")
# 修复约束不一致的字段
if constraints_to_fix:
logger.info(f"'{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
_fix_table_constraints(table_name, model, constraints_to_fix)
else:
logger.debug(f"'{table_name}' 的字段约束已同步")
except Exception as e:
logger.exception(f"同步字段约束时出错: {e}")
@@ -557,40 +563,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
try:
# 备份表名
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
# 1. 创建备份表
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
logger.info(f"已创建备份表 '{backup_table}'")
# 2. 删除原表
db.execute_sql(f"DROP TABLE {table_name}")
logger.info(f"已删除原表 '{table_name}'")
# 3. 重新创建表(使用当前模型定义)
db.create_tables([model])
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
# 4. 从备份表恢复数据
# 获取字段列表
fields = list(model._meta.fields.keys())
fields_str = ', '.join(fields)
fields_str = ", ".join(fields)
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
# 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [
constraint['field_name'] for constraint in constraints_to_fix
if constraint['action'] == 'disallow_null'
constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
]
if null_to_notnull_fields:
# 需要处理 NULL 值,为这些字段设置默认值
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL需要处理现有的NULL值")
# 构建更复杂的 SELECT 语句来处理 NULL 值
select_fields = []
for field_name in fields:
@@ -607,21 +612,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
default_value = f"'{datetime.datetime.now()}'"
else:
default_value = "''"
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
else:
select_fields.append(field_name)
select_str = ', '.join(select_fields)
select_str = ", ".join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
db.execute_sql(insert_sql)
logger.info(f"已从备份表恢复数据到 '{table_name}'")
# 5. 验证数据完整性
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
if original_count == new_count:
logger.info(f"数据完整性验证通过: {original_count} 行数据")
# 删除备份表
@@ -630,12 +635,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
else:
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count}")
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
# 记录修复的约束
for constraint in constraints_to_fix:
logger.info(f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
logger.info(
f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
)
except Exception as e:
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
# 尝试恢复
@@ -654,7 +661,7 @@ def check_field_constraints():
检查但不修复字段约束,返回不一致的字段信息。
用于在修复前预览需要修复的内容。
"""
models = [
ChatStreams,
LLMUsage,
@@ -665,13 +672,13 @@ def check_field_constraints():
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords,
MemoryChest,
MemoryConflict,
]
inconsistencies = {}
try:
with db:
for model in models:
@@ -681,49 +688,67 @@ def check_field_constraints():
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
current_schema = {
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
table_inconsistencies = []
# 检查每个模型字段的约束
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue
current_notnull = current_schema[field_name]['notnull']
current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null
if model_allows_null and current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_allows_null_but_db_not_null',
'model_constraint': 'NULL',
'db_constraint': 'NOT NULL',
'recommended_action': 'allow_null'
})
table_inconsistencies.append(
{
"field_name": field_name,
"issue": "model_allows_null_but_db_not_null",
"model_constraint": "NULL",
"db_constraint": "NOT NULL",
"recommended_action": "allow_null",
}
)
elif not model_allows_null and not current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_not_null_but_db_allows_null',
'model_constraint': 'NOT NULL',
'db_constraint': 'NULL',
'recommended_action': 'disallow_null'
})
table_inconsistencies.append(
{
"field_name": field_name,
"issue": "model_not_null_but_db_allows_null",
"model_constraint": "NOT NULL",
"db_constraint": "NULL",
"recommended_action": "disallow_null",
}
)
if table_inconsistencies:
inconsistencies[table_name] = table_inconsistencies
except Exception as e:
logger.exception(f"检查字段约束时出错: {e}")
return inconsistencies
def fix_image_id():
"""
修复表情包的 image_id 字段
"""
import uuid
try:
with db:
for img in Images.select():
if not img.image_id:
img.image_id = str(uuid.uuid4())
img.save()
logger.info(f"已为表情包 {img.id} 生成新的 image_id: {img.image_id}")
except Exception as e:
logger.exception(f"修复 image_id 时出错: {e}")
# 模块加载时调用初始化函数
initialize_database(sync_constraints=True)
fix_image_id()

View File

@@ -339,24 +339,18 @@ MODULE_COLORS = {
# 67 具体的颜色编号0-255这里是较暗的蓝色
"sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志
"send_api": "\033[38;5;24m", # 208号色橙色适合突出显示
# 生成
"replyer": "\033[38;5;208m", # 橙色
"llm_api": "\033[38;5;208m", # 橙色
# 消息处理
"chat": "\033[38;5;82m", # 亮蓝色
"chat_image": "\033[38;5;68m", # 浅蓝色
#emoji
# emoji
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
# 核心模块
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
"memory": "\033[38;5;34m", # 天蓝色
"config": "\033[93m", # 亮黄色
"common": "\033[95m", # 亮紫色
"tools": "\033[96m", # 亮青色
@@ -367,24 +361,17 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m",
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"hfc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"bc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"sub_heartflow": "\033[38;5;207m", # 粉紫色
"subheartflow_manager": "\033[38;5;201m", # 深粉色
"background_tasks": "\033[38;5;240m", # 灰色
"chat_message": "\033[38;5;45m", # 青色
"chat_stream": "\033[38;5;51m", # 亮青色
"message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块
"memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统
"plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色
@@ -412,7 +399,6 @@ MODULE_COLORS = {
# 工具和实用模块
"prompt_build": "\033[38;5;105m", # 紫色
"chat_utils": "\033[38;5;111m", # 蓝色
"maibot_statistic": "\033[38;5;129m", # 紫色
# 特殊功能插件
"mute_plugin": "\033[38;5;240m", # 灰色
@@ -420,7 +406,7 @@ MODULE_COLORS = {
"tts_action": "\033[38;5;58m", # 深黄色
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
# Action组件
"no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"reply_action": "\033[38;5;46m", # 亮绿色
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
@@ -433,9 +419,7 @@ MODULE_COLORS = {
"model_utils": "\033[38;5;164m", # 紫红色
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
"relationship_builder": "\033[38;5;93m", # 浅蓝色
# s4u
"context_web_api": "\033[38;5;240m", # 深灰色
"S4U_chat": "\033[92m", # 深灰色
"conflict_tracker": "\033[38;5;82m", # 柔和的粉色,不显眼但保持粉色系
}
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
@@ -447,10 +431,8 @@ MODULE_ALIASES = {
"llm_api": "生成API",
"emoji": "表情包",
"emoji_api": "表情包API",
"chat": "所见",
"chat_image": "识图",
"action_manager": "动作",
"memory_activator": "记忆",
"tool_use": "工具",
@@ -460,7 +442,6 @@ MODULE_ALIASES = {
"memory": "记忆",
"tool_executor": "工具",
"hfc": "聊天节奏",
"plugin_manager": "插件",
"relationship_builder": "关系",
"llm_models": "模型",

View File

@@ -81,7 +81,8 @@ def find_messages(
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not Messages.is_command)
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
query = query.where(~Messages.is_command)
if limit > 0:
if limit_mode == "earliest":

View File

@@ -102,9 +102,6 @@ class ModelTaskConfig(ConfigBase):
replyer: TaskConfig
"""normal_chat首要回复模型模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
@@ -117,9 +114,6 @@ class ModelTaskConfig(ConfigBase):
planner: TaskConfig
"""规划模型配置"""
planner_small: TaskConfig
"""副规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""

View File

@@ -18,8 +18,6 @@ from src.config.official_configs import (
ExpressionConfig,
ChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
@@ -32,8 +30,9 @@ from src.config.official_configs import (
RelationshipConfig,
ToolConfig,
VoiceConfig,
MoodConfig,
MemoryConfig,
DebugConfig,
CustomPromptConfig,
)
from .api_ada_configs import (
@@ -56,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.2-snapshot.3"
MMC_VERSION = "0.11.0-snapshot.3"
def get_key_comment(toml_table, key):
@@ -114,7 +113,7 @@ def set_value_by_path(d, path, value):
if k not in d or not isinstance(d[k], dict):
d[k] = {}
d = d[k]
# 使用 tomlkit.item 来保持 TOML 格式
try:
d[path[-1]] = tomlkit.item(value)
@@ -175,13 +174,8 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
@@ -253,7 +247,7 @@ def _update_config_generic(config_name: str, template_name: str):
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
config_updated = True
# 如果配置有更新,立即保存到文件
if config_updated:
with open(old_config_path, "w", encoding="utf-8") as f:
@@ -347,8 +341,6 @@ class Config(ConfigBase):
message_receive: MessageReceiveConfig
emoji: EmojiConfig
expression: ExpressionConfig
memory: MemoryConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_post_process: ResponsePostProcessConfig
@@ -358,8 +350,9 @@ class Config(ConfigBase):
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
memory: MemoryConfig
debug: DebugConfig
custom_prompt: CustomPromptConfig
mood: MoodConfig
voice: VoiceConfig

View File

@@ -2,6 +2,7 @@ import re
from dataclasses import dataclass, field
from typing import Literal, Optional
import time
from src.config.config_base import ConfigBase
@@ -38,15 +39,22 @@ class PersonalityConfig(ConfigBase):
personality: str
"""人格"""
emotion_style: str
"""情感特征"""
reply_style: str = ""
"""表达风格"""
interest: str = ""
"""兴趣"""
plan_style: str = ""
"""说话规则,行为风格"""
visual_style: str = ""
"""图片提示词"""
private_plan_style: str = ""
"""私聊说话规则,行为风格"""
@dataclass
class RelationshipConfig(ConfigBase):
"""关系配置类"""
@@ -61,56 +69,221 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
interest_rate_mode: Literal["fast", "accurate"] = "fast"
"""兴趣值计算模式fast为快速计算accurate为精确计算"""
mentioned_bot_reply: float = 1
"""提及 bot 必然回复1为100%回复0为不额外增幅"""
planner_size: float = 1.5
"""副规划器大小越小麦麦的动作执行能力越精细但是消耗更多token调大可以缓解429类错误"""
mentioned_bot_reply: bool = True
"""是否启用提及必回复"""
auto_chat_value: float = 1
"""自动聊天,越小,麦麦主动聊天的概率越低"""
at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅"""
talk_frequency: float = 0.5
"""回复频率阈值"""
# 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
planner_smooth: float = 3
"""规划器平滑增大数值会减小planner负荷略微降低反应速度推荐2-50为关闭必须大于等于0"""
talk_value: float = 1
"""思考频率"""
focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多"""
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
talk_value_rules: list[dict] = field(default_factory=lambda: [])
"""
统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例:
思考频率规则列表,支持按聊天流/按日内时段配置
规则格式:{ target="platform:id:type" "", time="HH:MM-HH:MM", value=0.5 }
示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
["", "09:00-22:59", 1.0], # 全局规则:白天正常
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
说明:
- 当第一个元素为空字符串""时,表示全局默认配置
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
- 优先级:特定聊天流配置 > 全局配置 > 默认值
注意:
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
- focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
时间区间支持跨夜,例如 "23:00-02:00"
"""
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
"""
自动聊天频率规则列表,支持按聊天流/按日内时段配置。
规则格式:{ target="platform:id:type""", time="HH:MM-HH:MM", value=0.5 }
示例:
[
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
["", "09:00-22:59", 1.0], # 全局规则:白天正常
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
时间区间支持跨夜,例如 "23:00-02:00"
"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None
def _now_minutes(self) -> int:
"""返回本地时间的分钟数(0-1439)。"""
lt = time.localtime()
return lt.tm_hour * 60 + lt.tm_min
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
try:
start_str, end_str = [s.strip() for s in range_str.split("-")]
sh, sm = [int(x) for x in start_str.split(":")]
eh, em = [int(x) for x in end_str.split(":")]
return sh * 60 + sm, eh * 60 + em
except Exception:
return None
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
"""
判断 now_min 是否在 [start_min, end_min] 区间内。
支持跨夜:如果 start > end则表示跨越午夜。
"""
if start_min <= end_min:
return start_min <= now_min <= end_min
# 跨夜:例如 23:00-02:00
return now_min >= start_min or now_min <= end_min
def get_talk_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 talk_value未匹配则回退到基础值。"""
if not self.talk_value_rules:
return self.talk_value
now_min = self._now_minutes()
# 1) 先尝试匹配指定 chat 的规则
if chat_id:
for rule in self.talk_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", "")
time_range = rule.get("time", "")
value = rule.get("value", None)
if not isinstance(time_range, str):
continue
# 跳过全局
if target == "":
continue
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
if config_chat_id is None or config_chat_id != chat_id:
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 2) 再匹配全局规则("")
for rule in self.talk_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", None)
time_range = rule.get("time", "")
value = rule.get("value", None)
if target != "" or not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 3) 未命中规则返回基础值
return self.talk_value
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 auto_chat_value未匹配则回退到基础值。"""
if not self.auto_chat_value_rules:
return self.auto_chat_value
now_min = self._now_minutes()
# 1) 先尝试匹配指定 chat 的规则
if chat_id:
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", "")
time_range = rule.get("time", "")
value = rule.get("value", None)
if not isinstance(time_range, str):
continue
# 跳过全局
if target == "":
continue
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
if config_chat_id is None or config_chat_id != chat_id:
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 2) 再匹配全局规则("")
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", None)
time_range = rule.get("time", "")
value = rule.get("value", None)
if target != "" or not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 3) 未命中规则返回基础值
return self.auto_chat_value
@dataclass
@@ -123,10 +296,23 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
max_memory_number: int = 100
"""记忆最大数量"""
memory_build_frequency: int = 1
"""记忆构建频率"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: Literal["llm", "context", "full-context"] = "context"
"""表达方式模式可选llm模式context上下文模式full-context 完整上下文嵌入模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表,支持按聊天流配置
@@ -287,6 +473,19 @@ class ToolConfig(ConfigBase):
"""是否在聊天中启用工具"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = True
"""是否启用情绪系统"""
mood_update_threshold: float = 1
"""情绪更新阈值,越高,更新越慢"""
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@@ -321,37 +520,6 @@ class EmojiConfig(ConfigBase):
"""表情包过滤要求"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
enable_memory: bool = True
"""是否启用记忆系统"""
forget_memory_interval: int = 1500
"""记忆遗忘间隔(秒)"""
memory_forget_time: int = 24
"""记忆遗忘时间(小时)"""
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = False
"""是否启用情绪系统"""
mood_update_threshold: float = 1.0
"""情绪更新阈值,越高,更新越慢"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
@@ -399,14 +567,6 @@ class KeywordReactionConfig(ConfigBase):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类"""
image_prompt: str = ""
"""图片提示词"""
@dataclass
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
@@ -475,9 +635,6 @@ class ExperimentalConfig(ConfigBase):
enable_friend_chat: bool = False
"""是否启用好友聊天"""
pfc_chatting: bool = False
"""是否启用PFC"""
@dataclass
class MaimMessageConfig(ConfigBase):

View File

@@ -65,39 +65,6 @@ class RespParseException(Exception):
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
class EmptyResponseException(Exception):
"""响应内容为空"""
@@ -107,3 +74,15 @@ class EmptyResponseException(Exception):
def __str__(self):
return self.message
class ModelAttemptFailed(Exception):
"""当在单个模型上的所有重试都失败后,由“执行者”函数抛出,以通知“调度器”切换模型。"""
def __init__(self, message: str, original_exception: Exception | None = None):
super().__init__(message)
self.message = message
self.original_exception = original_exception
def __str__(self):
return self.message

View File

@@ -72,8 +72,8 @@ class BaseClient(ABC):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
@@ -117,6 +117,7 @@ class BaseClient(ABC):
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
@@ -174,7 +175,7 @@ class ClientRegistry:
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):

View File

@@ -1,7 +1,7 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
from google import genai
from google.genai.types import (
@@ -17,6 +17,7 @@ from google.genai.types import (
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HttpOptions,
HarmCategory,
HarmBlockThreshold,
)
@@ -182,6 +183,14 @@ def _process_delta(
if delta.text:
fc_delta_buffer.write(delta.text)
# 处理 thoughtGemini 的特殊字段)
for c in getattr(delta, "candidates", []):
if c.content and getattr(c.content, "parts", None):
for p in c.content.parts:
if getattr(p, "thought", False) and getattr(p, "text", None):
# 把 thought 写入 buffer避免 resp.content 永远为空
fc_delta_buffer.write(p.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
@@ -203,6 +212,7 @@ def _process_delta(
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
@@ -227,6 +237,21 @@ def _build_stream_api_resp(
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
# 检查是否因为 max_tokens 截断
reason = None
if last_resp and getattr(last_resp, "candidates", None):
c0 = last_resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
if str(reason).endswith("MAX_TOKENS"):
if resp.content and resp.content.strip():
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
@@ -245,12 +270,14 @@ async def _default_stream_response_handler(
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in resp_stream:
last_resp = chunk # 保存最后一个响应
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
@@ -269,10 +296,12 @@ async def _default_stream_response_handler(
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
last_resp=last_resp,
), _usage_record
except Exception:
# 确保缓冲区被关闭
@@ -332,6 +361,35 @@ def _default_normal_response_parser(
api_response.raw_data = resp
# 检查是否因为 max_tokens 截断
try:
if resp.candidates:
c0 = resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
if reason and "MAX_TOKENS" in str(reason):
# 检查第二个及之后的 parts 是否有内容
has_real_output = False
if getattr(c0, "content", None) and getattr(c0.content, "parts", None):
for p in c0.content.parts[1:]: # 跳过第一个 thought
if getattr(p, "text", None) and p.text.strip():
has_real_output = True
break
if not has_real_output and getattr(resp, "text", None):
has_real_output = True
if has_real_output:
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
return api_response, _usage_record
except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
# 最终的、唯一的空响应检查
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException("响应中既无文本内容也无工具调用")
@@ -345,17 +403,45 @@ class GeminiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 增加传入参数处理
http_options_kwargs: Dict[str, Any] = {}
# 秒转换为毫秒传入
if api_provider.timeout is not None:
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
# 传入并处理地址和版本(必须为Gemini格式)
if api_provider.base_url:
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
if len(parts) == 2 and parts[1].startswith("v"):
http_options_kwargs["base_url"] = f"{parts[0]}/"
http_options_kwargs["api_version"] = parts[1]
else:
http_options_kwargs["base_url"] = api_provider.base_url
http_options_kwargs["api_version"] = None
self.client = genai.Client(
http_options=HttpOptions(**http_options_kwargs),
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
@staticmethod
def clamp_thinking_budget(tb: int, model_id: str) -> int:
def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
"""
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
"""
limits = None
# 参数传入处理
tb = THINKING_BUDGET_AUTO
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
)
# 优先尝试精确匹配
if model_id in THINKING_BUDGET_LIMITS:
limits = THINKING_BUDGET_LIMITS[model_id]
@@ -368,20 +454,29 @@ class GeminiClient(BaseClient):
limits = THINKING_BUDGET_LIMITS[key]
break
# 特殊值处理
# 预算值处理
if tb == THINKING_BUDGET_AUTO:
return THINKING_BUDGET_AUTO
if tb == THINKING_BUDGET_DISABLED:
if limits and limits.get("can_disable", False):
return THINKING_BUDGET_DISABLED
return limits["min"] if limits else THINKING_BUDGET_AUTO
if limits:
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
return limits["min"]
return THINKING_BUDGET_AUTO
# 已知模型裁剪到范围
# 已知模型范围裁剪 + 提示
if limits:
return max(limits["min"], min(tb, limits["max"]))
if tb < limits["min"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
return limits["min"]
if tb > limits["max"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
return limits["max"]
return tb
# 未知模型,返回动态模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容")
# 未知模型 → 默认自动模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
return THINKING_BUDGET_AUTO
async def get_response(
@@ -389,8 +484,8 @@ class GeminiClient(BaseClient):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.4,
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 0.4,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
@@ -429,16 +524,8 @@ class GeminiClient(BaseClient):
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
tb = THINKING_BUDGET_AUTO
# 空处理
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
@@ -497,15 +584,20 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientErrorServerErrorRespNotOkException
# 重封装 ClientErrorServerErrorRespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
# 工具调用相关错误
raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None
except EmptyResponseException as e:
# 保持原始异常,便于区分“空响应”和网络异常
raise e
except Exception as e:
# 其他未预料的错误,才归为网络连接类
raise NetworkConnectionError() from e
if usage_record:
@@ -561,41 +653,51 @@ class GeminiClient(BaseClient):
return response
def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = 2048,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: 音频文件的Base64编码字符串
:param max_tokens: 最大输出token数默认2048
:param extra_params: 额外参数(可选)
:return: 转录响应
"""
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 构造 prompt + 音频输入
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
contents = [
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
]
generation_config_dict = {
"max_output_tokens": 2048,
"max_output_tokens": max_tokens,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
thinking_budget=tb,
),
"safety_settings": gemini_safe_settings,
}
generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
try:
raw_response: GenerateContentResponse = self.client.models.generate_content(
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=[
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
],
contents=contents,
config=generate_content_config,
)
resp, usage_record = _default_normal_response_parser(raw_response)

View File

@@ -403,8 +403,8 @@ class OpenaiClient(BaseClient):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
@@ -488,6 +488,9 @@ class OpenaiClient(BaseClient):
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
# logger.
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
resp, usage_record = async_response_parser(req_task.result())
@@ -507,6 +510,8 @@ class OpenaiClient(BaseClient):
total_tokens=usage_record[2],
)
# logger.debug(f"OpenAI API响应: {resp}")
return resp
async def get_embedding(
@@ -531,7 +536,7 @@ class OpenaiClient(BaseClient):
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__:
if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e
except APIStatusError as e:
@@ -555,7 +560,7 @@ class OpenaiClient(BaseClient):
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens or 0,
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
completion_tokens=getattr(raw_response.usage, "completion_tokens", 0),
total_tokens=raw_response.usage.total_tokens or 0,
)

View File

@@ -1,3 +1,3 @@
from .tool_option import ToolCall
__all__ = ["ToolCall"]
__all__ = ["ToolCall"]

View File

@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str)
or instance["description"].strip() == ""
not isinstance(instance["description"], str) or instance["description"].strip() == ""
):
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(
f"{path}/{str(i)}", sub_schema[i], defs
)
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(
f"{path}/{key}", value, defs
)
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
return sub_schema
@@ -163,9 +158,7 @@ class RespFormat:
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(
_link_definitions(_remove_title(schema.model_json_schema()))
),
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
"strict": False,
}
if schema.__doc__:

View File

@@ -155,7 +155,13 @@ class LLMUsageRecorder:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
self,
model_info: ModelInfo,
model_usage: UsageRecord,
user_id: str,
request_type: str,
endpoint: str,
time_cost: float = 0.0,
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3),
time_cost=round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -4,7 +4,8 @@ import time
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback
from src.common.logger import get_logger
from src.config.config import model_config
@@ -16,28 +17,15 @@ from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
EmptyResponseException,
ModelAttemptFailed,
)
install(extra_lines=3)
logger = get_logger("model_utils")
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class RequestType(Enum):
"""请求类型枚举"""
@@ -76,32 +64,25 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 模型选择
start_time = time.time()
model_info, api_provider, client = self._select_model()
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
)
messages = [message_builder.build()]
def message_factory(client: BaseClient) -> List[Message]:
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
)
return [message_builder.build()]
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
response, model_info = await self._execute_request(
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
message_factory=message_factory,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
@@ -124,15 +105,8 @@ class LLMRequest:
Returns:
(Optional[str]): 生成的文本描述或None
"""
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
response, _ = await self._execute_request(
request_type=RequestType.AUDIO,
model_info=model_info,
audio_base64=voice_base64,
)
return response.content or None
@@ -151,43 +125,37 @@ class LLMRequest:
prompt (str): 提示词
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
tools (Optional[List[Dict[str, Any]]]): 工具列表
raise_when_empty (bool): 当响应为空时是否抛出异常
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
start_time = time.time()
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
def message_factory(client: BaseClient) -> List[Message]:
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
return [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
logger.debug(f"LLM选择耗时: {model_info.name} {time.time() - start_time}")
response = await self._execute_request(
api_provider=api_provider,
client=client,
response, model_info = await self._execute_request(
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
message_factory=message_factory,
temperature=temperature,
max_tokens=max_tokens,
tool_options=tool_built,
)
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
logger.debug(f"LLM生成内容: {response}")
content = response.content
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
@@ -197,31 +165,22 @@ class LLMRequest:
endpoint="/chat/completions",
time_cost=time.time() - start_time,
)
return content or "", (reasoning_content, model_info.name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
"""
获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
start_time = time.time()
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
response, model_info = await self._execute_request(
request_type=RequestType.EMBEDDING,
model_info=model_info,
embedding_input=embedding_input,
)
embedding = response.embedding
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
@@ -231,59 +190,61 @@ class LLMRequest:
endpoint="/embeddings",
time_cost=time.time() - start_time,
)
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding, model_info.name
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
"""
available_models = {
model: scores
for model, scores in self.model_usage.items()
if not exclude_models or model not in exclude_models
}
if not available_models:
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
least_used_model_name = min(
self.model_usage,
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
available_models,
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
return model_info, api_provider, client
async def _execute_request(
async def _attempt_request_on_model(
self,
model_info: ModelInfo,
api_provider: APIProvider,
client: BaseClient,
request_type: RequestType,
model_info: ModelInfo,
message_list: List[Message] | None = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str = "",
audio_base64: str = "",
message_list: List[Message],
tool_options: list[ToolOption] | None,
response_format: RespFormat | None,
stream_response_handler: Optional[Callable],
async_response_parser: Optional[Callable],
temperature: Optional[float],
max_tokens: Optional[int],
embedding_input: str | None,
audio_base64: str | None,
) -> APIResponse:
"""
实际执行请求的方法
包含了重试和异常处理逻辑
在单个模型上执行请求,包含针对临时错误的重试逻辑。
如果成功返回APIResponse。如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常。
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
assert message_list is not None, "message_list cannot be None for response requests"
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
@@ -296,201 +257,125 @@ class LLMRequest:
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input, "embedding_input cannot be empty for embedding requests"
assert embedding_input is not None, "嵌入输入不能为空"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
assert audio_base64 is not None, "音频Base64不能为空"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 处理异常
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_name=model_info.name,
remain_try=retry_remain,
retry_interval=api_provider.retry_interval,
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
retry_remain = 0 # 不再重试
elif wait_interval > 0:
logger.info(f"等待 {wait_interval} 秒后重试...")
await asyncio.sleep(wait_interval)
finally:
# 放在finally防止死循环
except (EmptyResponseException, NetworkConnectionError) as e:
retry_remain -= 1
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在用尽对临时错误的重试次数后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
def _default_exception_handler(
logger.warning(f"模型 '{model_info.name}' 遇到可重试错误: {str(e)}。剩余重试次数: {retry_remain}")
await asyncio.sleep(api_provider.retry_interval)
except RespNotOkException as e:
# 可重试的HTTP错误
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
continue
# 特殊处理413尝试压缩
if e.status_code == 413 and message_list and not compressed_messages:
logger.warning(f"模型 '{model_info.name}' 返回413请求体过大尝试压缩后重试...")
# 压缩消息本身不消耗重试次数
compressed_messages = compress_messages(message_list)
continue
# 不可重试的HTTP错误
logger.warning(f"模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
except Exception as e:
logger.error(traceback.format_exc())
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
raise ModelAttemptFailed(f"模型 '{model_info.name}' 未被尝试因为重试次数已配置为0或更少。")
async def _execute_request(
self,
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: Tuple[List[Message], bool] | None = None,
) -> Tuple[int, List[Message] | None]:
request_type: RequestType,
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str | None = None,
audio_base64: str | None = None,
) -> Tuple[APIResponse, ModelInfo]:
"""
默认异常处理函数
Args:
e (Exception): 异常对象
task_name (str): 任务名称
model_name (str): 模型名称
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
Returns:
(等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
调度器函数,负责模型选择、故障切换。
"""
failed_models_this_request: Set[str] = set()
max_attempts = len(self.model_for_task.model_list)
last_exception: Optional[Exception] = None
if isinstance(e, NetworkConnectionError): # 网络连接错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常超过最大重试次数请检查网络连接状态或URL是否正确",
)
elif isinstance(e, EmptyResponseException): # 空响应错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 收到空响应,将于{retry_interval}秒后重试。原因: {e}",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 收到空响应,超过最大重试次数,放弃请求",
)
elif isinstance(e, ReqAbortException):
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return self._handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
logger.debug(f"附加内容: {str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
return -1, None # 不再重试请求该模型
for _ in range(max_attempts):
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
def _check_retry(
self,
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> Tuple[int, List[Message] | None]:
"""辅助函数:检查是否可以重试
Args:
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
can_retry_msg (str): 可以重试时的提示信息
cannot_retry_msg (str): 不可以重试时的提示信息
can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
**kwargs: 其他参数
message_list = []
if message_factory:
message_list = message_factory(client)
Returns:
(Tuple[int, List[Message] | None]): (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
self,
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
Args:
e (RespNotOkException): 响应错误异常对象
task_name (str): 任务名称
model_name (str): 模型名称
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
Returns:
(等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [400, 401, 402, 403, 404]:
# 客户端错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return self._check_retry(
remain_try,
0,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
can_retry_callable=compress_messages,
messages=messages[0],
try:
response = await self._attempt_request_on_model(
model_info,
api_provider,
client,
request_type,
message_list=message_list,
tool_options=tool_options,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
temperature=temperature,
max_tokens=max_tokens,
embedding_input=embedding_input,
audio_base64=audio_base64,
)
# 没有消息可压缩
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
return -1, None
elif e.status_code == 429:
# 请求过于频繁
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
)
elif e.status_code >= 500:
# 服务器错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if response_usage := response.usage:
total_tokens += response_usage.total_tokens
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
return response, model_info
except ModelAttemptFailed as e:
last_exception = e.original_exception or e
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
failed_models_this_request.add(model_info.name)
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
raise last_exception from e
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
if last_exception:
raise last_exception
raise RuntimeError("请求失败,所有可用模型均已尝试失败。")
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
# sourcery skip: extract-method

View File

@@ -13,8 +13,8 @@ from src.common.logger import get_logger
from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager
from src.chat.knowledge import lpmm_start_up
from src.memory_system.memory_management_task import MemoryManagementTask
from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server
# 导入新的插件管理器
@@ -23,10 +23,6 @@ from src.plugin_system.core.plugin_manager import plugin_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
# 条件导入记忆系统
if global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager
# 插件系统现在使用统一的插件加载器
install(extra_lines=3)
@@ -36,11 +32,6 @@ logger = get_logger("main")
class MainSystem:
def __init__(self):
# 根据配置条件性地初始化记忆系统
self.hippocampus_manager = None
if global_config.memory.enable_memory:
self.hippocampus_manager = hippocampus_manager
# 使用消息API替代直接的FastAPI实例
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
@@ -92,29 +83,26 @@ class MainSystem:
logger.info("表情包管理器初始化成功")
# 启动情绪管理器
await mood_manager.start()
logger.info("情绪管理器初始化成功")
if global_config.mood.enable_mood:
await mood_manager.start()
logger.info("情绪管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())
logger.info("聊天管理器初始化成功")
# 根据配置条件性地初始化记忆系统
if global_config.memory.enable_memory:
if self.hippocampus_manager:
self.hippocampus_manager.initialize()
logger.info("记忆系统初始化成功")
else:
logger.info("记忆系统已禁用,跳过初始化")
# 添加记忆管理任务
await async_task_manager.add_task(MemoryManagementTask())
logger.info("记忆管理任务已启动")
# await asyncio.sleep(0.5) #防止logger输出飞了
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
await check_and_run_migrations()
# 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager
@@ -138,25 +126,15 @@ class MainSystem:
self.server.run(),
]
# 根据配置条件性地添加记忆系统相关任务
if global_config.memory.enable_memory and self.hippocampus_manager:
tasks.extend(
[
# 移除记忆构建的定期调用改为在heartFC_chat.py中调用
# self.build_memory_task(),
self.forget_memory_task(),
]
)
await asyncio.gather(*tasks)
async def forget_memory_task(self):
"""记忆遗忘任务"""
while True:
await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成")
# async def forget_memory_task(self):
# """记忆遗忘任务"""
# while True:
# await asyncio.sleep(global_config.memory.forget_memory_interval)
# logger.info("[记忆遗忘] 开始遗忘记忆...")
# await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
# logger.info("[记忆遗忘] 记忆遗忘完成")
async def main():

View File

@@ -1,36 +0,0 @@
[inner]
version = "1.0.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示

View File

@@ -1,68 +0,0 @@
[inner]
version = "1.2.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
enable_s4u = false
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_streaming_output = true # 是否启用流式输出false时全部生成后一次性发送
max_context_message_length = 20
max_core_message_length = 30
# 模型配置
[models]
# 主要对话模型配置
[models.chat]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 规划模型配置
[models.motion]
name = "qwen3-32b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 情感分析模型配置
[models.emotion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7

View File

@@ -1,167 +0,0 @@
from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger
logger = get_logger(__name__)
def init_prompt():
Prompt(
"""
你之前的内心想法是:{mind}
{memory_block}
{relation_info_block}
{chat_target}
{time_block}
{chat_info}
{identity}
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
---------------------
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
你刚刚选择回复的内容是:{reponse}
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出想法:""",
"after_response_think_prompt",
)
class MaiThinking:
def __init__(self, chat_id):
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform
if self.chat_stream.group_info:
self.is_group = True
else:
self.is_group = False
self.s4u_message_processor = S4UMessageProcessor()
self.mind = ""
self.memory_block = ""
self.relation_info_block = ""
self.time_block = ""
self.chat_target = ""
self.chat_target_2 = ""
self.chat_info = ""
self.mood_state = ""
self.identity = ""
self.sender = ""
self.target = ""
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
async def do_think_before_response(self):
pass
async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt",
mind=self.mind,
reponse=reponse,
memory_block=self.memory_block,
relation_info_block=self.relation_info_block,
time_block=self.time_block,
chat_target=self.chat_target,
chat_target_2=self.chat_target_2,
chat_info=self.chat_info,
mood_state=self.mood_state,
identity=self.identity,
sender=self.sender,
target=self.target,
)
result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self):
pass
async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}"
message_dict = {
"message_info": {
"message_id": msg_id,
"time": time.time(),
"user_info": {
"user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充
},
"platform": self.platform, # 平台
# 其他 message_info 字段按需补充
},
"message_segment": {
"type": "text", # 消息类型
"data": message_text, # 消息内容
# 其他 segment 字段按需补充
},
"raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_mentioned": False,
"is_command": False,
"is_internal": True,
"priority_mode": "interest",
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0,
}
if self.is_group:
message_dict["message_info"]["group_info"] = {
"platform": self.platform,
"group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name,
}
msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True
return msg_recv
class MaiThinkingManager:
def __init__(self):
self.mai_think_list = []
def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id:
return mai_think
mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think)
return mai_think
mai_thinking_manager = MaiThinkingManager()
init_prompt()

View File

@@ -1,342 +0,0 @@
import json
import time
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.s4u_config import s4u_config
logger = get_logger("action")
# 使用字典作为默认值但通过Prompt来注册以便外部重载
DEFAULT_HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
"看向右边": "(1,0,0)",
"随意朝向": "random",
"看向摄像机": "camera",
"注视对方": "(0,0,0)",
"看向正前方": "(0,0,0)",
}
DEFAULT_BODY_CODE = {
"双手背后向前弯腰": "010_0070",
"歪头双手合十": "010_0100",
"标准文静站立": "010_0101",
"双手交叠腹部站立": "010_0150",
"帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210",
"平静,双手后放": "平静,双手后放",
"思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般",
"可爱,双手前放": "可爱,双手前放",
}
async def get_head_code() -> dict:
"""获取头部动作代码字典"""
head_code_str = await global_prompt_manager.format_prompt("head_code_prompt")
if not head_code_str:
return DEFAULT_HEAD_CODE
try:
return json.loads(head_code_str)
except Exception as e:
logger.error(f"解析head_code_prompt失败使用默认值: {e}")
return DEFAULT_HEAD_CODE
async def get_body_code() -> dict:
"""获取身体动作代码字典"""
body_code_str = await global_prompt_manager.format_prompt("body_code_prompt")
if not body_code_str:
return DEFAULT_BODY_CODE
try:
return json.loads(body_code_str)
except Exception as e:
logger.error(f"解析body_code_prompt失败使用默认值: {e}")
return DEFAULT_BODY_CODE
def init_prompt():
# 注册头部动作代码
Prompt(
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
"head_code_prompt",
)
# 注册身体动作代码
Prompt(
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
"body_code_prompt",
)
# 注册原有提示模板
Prompt(
"""
{chat_talking_prompt}
以上是群里正在进行的聊天记录
{indentify_block}
你现在的动作状态是:
- 身体动作:{body_action}
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
身体动作可选:
{all_actions}
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"change_action_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是群里最近的聊天记录
{indentify_block}
你之前的动作状态是
- 身体动作:{body_action}
身体动作可选:
{all_actions}
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"regress_action_prompt",
)
class ChatAction:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.body_action: str = "一般"
self.head_action: str = "注视摄像机"
self.regression_count: int = 0
# 新增body_action冷却池key为动作名value为剩余冷却次数
self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion)
print(model_config.model_task_config.emotion)
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self):
"""发送动作更新到前端"""
body_code = (await get_body_code()).get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
content=body_code,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=15,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"change_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := json.loads(repair_json(response)):
# 记录原动作,切换后进入冷却
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新
await self.send_action_update()
self.last_change_time = message_time
except Exception as e:
logger.error(f"update_action_by_message error: {e}")
async def regress_action(self):
message_time = time.time()
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"regress_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := json.loads(repair_json(response)):
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action
# 发送动作更新
await self.send_action_update()
self.regression_count += 1
self.last_change_time = message_time
except Exception as e:
logger.error(f"regress_action error: {e}")
# 新增:冷却池维护方法
def _update_body_action_cooldown(self):
remove_keys = []
for k in self.body_action_cooldown:
self.body_action_cooldown[k] -= 1
if self.body_action_cooldown[k] <= 0:
remove_keys.append(k)
for k in remove_keys:
del self.body_action_cooldown[k]
class ActionRegressionTask(AsyncTask):
def __init__(self, action_manager: "ActionManager"):
super().__init__(task_name="ActionRegressionTask", run_interval=3)
self.action_manager = action_manager
async def run(self):
logger.debug("Running action regression task...")
now = time.time()
for action_state in self.action_manager.action_state_list:
if action_state.last_change_time == 0:
continue
if now - action_state.last_change_time > 10:
if action_state.regression_count >= 3:
continue
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1}")
await action_state.regress_action()
class ActionManager:
def __init__(self):
self.action_state_list: list[ChatAction] = []
"""当前动作状态"""
self.task_started: bool = False
async def start(self):
"""启动动作回归后台任务"""
if self.task_started:
return
logger.info("启动动作回归任务...")
task = ActionRegressionTask(self)
await async_task_manager.add_task(task)
self.task_started = True
logger.info("动作回归任务已启动")
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
for action_state in self.action_state_list:
if action_state.chat_id == chat_id:
return action_state
new_action_state = ChatAction(chat_id)
self.action_state_list.append(new_action_state)
return new_action_state
init_prompt()
action_manager = ActionManager()
"""全局动作管理器"""

View File

@@ -1,685 +0,0 @@
import asyncio
import json
from collections import deque
from datetime import datetime
from typing import Dict, List, Optional
from aiohttp import web, WSMsgType
import aiohttp_cors
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("context_web")
class ContextMessage:
"""上下文消息类"""
def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text
self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型
self.is_gift = getattr(message, 'is_gift', False)
self.is_superchat = getattr(message, 'is_superchat', False)
# 添加礼物和SC相关信息
if self.is_gift:
self.gift_name = getattr(message, 'gift_name', '')
self.gift_count = getattr(message, 'gift_count', '1')
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat:
self.superchat_price = getattr(message, 'superchat_price', '0')
self.superchat_message = getattr(message, 'superchat_message_text', '')
if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}"
else:
self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self):
return {
"user_name": self.user_name,
"user_id": self.user_id,
"content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name,
"is_gift": self.is_gift,
"is_superchat": self.is_superchat
}
class ContextWebManager:
"""上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages
self.port = port
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
self.websockets: List[web.WebSocketResponse] = []
self.app = None
self.runner = None
self.site = None
self._server_starting = False # 添加启动标志防止并发
async def start_server(self):
"""启动web服务器"""
if self.site is not None:
logger.debug("Web服务器已经启动跳过重复启动")
return
if self._server_starting:
logger.debug("Web服务器正在启动中等待启动完成...")
# 等待启动完成
while self._server_starting and self.site is None:
await asyncio.sleep(0.1)
return
self._server_starting = True
try:
self.app = web.Application()
# 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
allow_methods="*"
)
})
# 添加路由
self.app.router.add_get('/', self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler)
# 为所有路由添加CORS
for route in list(self.app.router.routes()):
cors.add(route)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port)
await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
raise
finally:
self._server_starting = False
async def stop_server(self):
"""停止web服务器"""
if self.site:
await self.site.stop()
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
self._server_starting = False
async def index_handler(self, request):
"""主页处理器"""
html_content = '''
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>聊天上下文</title>
<style>
html, body {
background: transparent !important;
background-color: transparent !important;
margin: 0;
padding: 20px;
font-family: 'Microsoft YaHei', Arial, sans-serif;
color: #ffffff;
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
}
.container {
max-width: 800px;
margin: 0 auto;
background: transparent !important;
}
.message {
background: rgba(0, 0, 0, 0.3);
margin: 10px 0;
padding: 15px;
border-radius: 10px;
border-left: 4px solid #00ff88;
backdrop-filter: blur(5px);
animation: slideIn 0.3s ease-out;
transform: translateY(0);
transition: transform 0.5s ease, opacity 0.5s ease;
}
.message:hover {
background: rgba(0, 0, 0, 0.5);
transform: translateX(5px);
transition: all 0.3s ease;
}
.message.gift {
border-left: 4px solid #ff8800;
background: rgba(255, 136, 0, 0.2);
}
.message.gift:hover {
background: rgba(255, 136, 0, 0.3);
}
.message.gift .username {
color: #ff8800;
}
.message.superchat {
border-left: 4px solid #ff6b6b;
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
background-size: 200% 200%;
animation: rainbow 3s ease infinite;
}
.message.superchat:hover {
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
background-size: 200% 200%;
}
.message.superchat .username {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
background-size: 300% 300%;
animation: rainbow-text 2s ease infinite;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
@keyframes rainbow {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
@keyframes rainbow-text {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.message-line {
line-height: 1.4;
word-wrap: break-word;
font-size: 24px;
}
.username {
color: #00ff88;
}
.content {
color: #ffffff;
}
.new-message {
animation: slideInNew 0.6s ease-out;
}
.debug-btn {
position: fixed;
bottom: 20px;
right: 20px;
background: rgba(0, 0, 0, 0.7);
color: #00ff88;
font-size: 12px;
padding: 8px 12px;
border-radius: 20px;
backdrop-filter: blur(10px);
z-index: 1000;
text-decoration: none;
border: 1px solid #00ff88;
}
.debug-btn:hover {
background: rgba(0, 255, 136, 0.2);
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInNew {
from {
opacity: 0;
transform: translateY(50px) scale(0.95);
}
to {
opacity: 1;
transform: translateY(0) scale(1);
}
}
.no-messages {
text-align: center;
color: #666;
font-style: italic;
margin-top: 50px;
}
</style>
</head>
<body>
<div class="container">
<a href="/debug" class="debug-btn">🔧 调试</a>
<div id="messages">
<div class="no-messages">暂无消息</div>
</div>
</div>
<script>
let ws;
let reconnectInterval;
let currentMessages = []; // 存储当前显示的消息
function connectWebSocket() {
console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
ws.onopen = function() {
console.log('WebSocket连接已建立');
if (reconnectInterval) {
clearInterval(reconnectInterval);
reconnectInterval = null;
}
};
ws.onmessage = function(event) {
console.log('收到WebSocket消息:', event.data);
try {
const data = JSON.parse(event.data);
updateMessages(data.contexts);
} catch (e) {
console.error('解析消息失败:', e, event.data);
}
};
ws.onclose = function(event) {
console.log('WebSocket连接关闭:', event.code, event.reason);
if (!reconnectInterval) {
reconnectInterval = setInterval(connectWebSocket, 3000);
}
};
ws.onerror = function(error) {
console.error('WebSocket错误:', error);
};
}
function updateMessages(contexts) {
const messagesDiv = document.getElementById('messages');
if (!contexts || contexts.length === 0) {
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
currentMessages = [];
return;
}
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
if (currentMessages.length === 0) {
console.log('首次加载消息,数量:', contexts.length);
messagesDiv.innerHTML = '';
contexts.forEach(function(msg) {
const messageDiv = createMessageElement(msg);
messagesDiv.appendChild(messageDiv);
});
currentMessages = [...contexts];
window.scrollTo(0, document.body.scrollHeight);
return;
}
// 检测新消息 - 使用更可靠的方法
const newMessages = findNewMessages(contexts, currentMessages);
if (newMessages.length > 0) {
console.log('添加新消息,数量:', newMessages.length);
// 先检查是否需要移除老消息保持DOM清洁
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
const currentMessageElements = messagesDiv.querySelectorAll('.message');
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
if (willExceedLimit) {
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
console.log('需要移除老消息数量:', removeCount);
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
const oldMessage = currentMessageElements[i];
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
oldMessage.style.opacity = '0';
oldMessage.style.transform = 'translateY(-20px)';
setTimeout(() => {
if (oldMessage.parentNode) {
oldMessage.parentNode.removeChild(oldMessage);
}
}, 300);
}
}
// 添加新消息
newMessages.forEach(function(msg) {
const messageDiv = createMessageElement(msg, true); // true表示是新消息
messagesDiv.appendChild(messageDiv);
// 移除动画类,避免重复动画
setTimeout(() => {
messageDiv.classList.remove('new-message');
}, 600);
});
// 更新当前消息列表
currentMessages = [...contexts];
// 平滑滚动到底部
setTimeout(() => {
window.scrollTo({
top: document.body.scrollHeight,
behavior: 'smooth'
});
}, 100);
}
}
function findNewMessages(contexts, currentMessages) {
// 如果当前消息为空,所有消息都是新的
if (currentMessages.length === 0) {
return contexts;
}
// 找到最后一条当前消息在新消息列表中的位置
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
let lastIndex = -1;
// 从后往前找,因为新消息通常在末尾
for (let i = contexts.length - 1; i >= 0; i--) {
const msg = contexts[i];
if (msg.user_id === lastCurrentMsg.user_id &&
msg.content === lastCurrentMsg.content &&
msg.timestamp === lastCurrentMsg.timestamp) {
lastIndex = i;
break;
}
}
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
if (lastIndex >= 0) {
return contexts.slice(lastIndex + 1);
} else {
console.log('未找到匹配的最后消息,可能需要完全刷新');
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
}
}
function createMessageElement(msg, isNew = false) {
const messageDiv = document.createElement('div');
let className = 'message';
// 根据消息类型添加对应的CSS类
if (msg.is_gift) {
className += ' gift';
} else if (msg.is_superchat) {
className += ' superchat';
}
if (isNew) {
className += ' new-message';
}
messageDiv.className = className;
messageDiv.innerHTML = `
<div class="message-line">
<span class="username">${escapeHtml(msg.user_name)}</span><span class="content">${escapeHtml(msg.content)}</span>
</div>
`;
return messageDiv;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 初始加载数据
fetch('/api/contexts')
.then(response => response.json())
.then(data => {
console.log('初始数据加载成功:', data);
updateMessages(data.contexts);
})
.catch(err => console.error('加载初始数据失败:', err));
// 连接WebSocket
connectWebSocket();
</script>
</body>
</html>
'''
return web.Response(text=html_content, content_type='text/html')
async def websocket_handler(self, request):
"""WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.websockets.append(ws)
logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}")
# 发送初始数据
await self.send_contexts_to_websocket(ws)
async for msg in ws:
if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
break
# 清理断开的连接
if ws in self.websockets:
self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}")
return ws
async def get_contexts_handler(self, request):
"""获取上下文API"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request):
"""调试信息处理器"""
debug_info = {
"server_status": "running",
"websocket_connections": len(self.websockets),
"total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
}
# 构建聊天详情HTML
chats_html = ""
for chat_id, contexts in self.contexts.items():
messages_html = ""
for msg in contexts:
timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f'''
<div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html}
</div>
'''
html_content = f'''
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>调试信息</title>
<style>
body {{ font-family: monospace; margin: 20px; }}
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
.message {{ margin: 5px 0; padding: 5px; background: white; }}
</style>
</head>
<body>
<h1>上下文网页管理器调试信息</h1>
<div class="section">
<h2>服务器状态</h2>
<p>状态: {debug_info["server_status"]}</p>
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
<p>聊天总数: {debug_info["total_chats"]}</p>
<p>消息总数: {debug_info["total_messages"]}</p>
</div>
<div class="section">
<h2>聊天详情</h2>
{chats_html}
</div>
<div class="section">
<h2>操作</h2>
<button onclick="location.reload()">刷新页面</button>
<button onclick="window.location.href='/'">返回主页</button>
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
</div>
<script>
console.log('调试信息:', {json.dumps(debug_info, ensure_ascii=False, indent=2)});
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
</script>
</body>
</html>
'''
return web.Response(text=html_content, content_type='text/html')
async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文"""
if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg)
# 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
# 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
# 广播更新给所有WebSocket连接
await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False))
async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新"""
if not self.websockets:
logger.debug("没有WebSocket连接跳过广播")
return
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False)
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本避免在遍历时修改
websockets_copy = self.websockets.copy()
removed_count = 0
for ws in websockets_copy:
if ws.closed:
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
else:
try:
await ws.send_str(message)
logger.debug("消息发送成功")
except Exception as e:
logger.error(f"发送WebSocket消息失败: {e}")
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
# 全局实例
_context_web_manager: Optional[ContextWebManager] = None
def get_context_web_manager() -> ContextWebManager:
"""获取上下文网页管理器实例"""
global _context_web_manager
if _context_web_manager is None:
_context_web_manager = ContextWebManager()
return _context_web_manager
async def init_context_web_manager():
"""初始化上下文网页管理器"""
manager = get_context_web_manager()
await manager.start_server()
return manager

View File

@@ -1,155 +0,0 @@
import asyncio
from typing import Dict, Tuple, Callable, Optional
from dataclasses import dataclass
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
logger = get_logger("gift_manager")
@dataclass
class PendingGift:
"""等待中的礼物消息"""
message: MessageRecvS4U
total_count: int
timer_task: asyncio.Task
callback: Callable[[MessageRecvS4U], None]
class GiftManager:
"""礼物管理器,提供防抖功能"""
def __init__(self):
"""初始化礼物管理器"""
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
"""处理礼物消息,返回是否应该立即处理
Args:
message: 礼物消息
callback: 防抖完成后的回调函数
Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理
"""
if not message.is_gift:
return True
# 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message)
return False
# 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback)
return False
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息"""
pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器
if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
# 累加礼物数量
try:
new_count = int(new_message.gift_count)
pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变
# 重新创建定时器
pending_gift.timer_task = asyncio.create_task(
self._gift_timeout(gift_key)
)
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift(
self,
gift_key: Tuple[str, str],
message: MessageRecvS4U,
callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None:
"""创建新的等待礼物"""
try:
initial_count = int(message.gift_count)
except ValueError:
initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象
pending_gift = PendingGift(
message=message,
total_count=initial_count,
timer_task=timer_task,
callback=callback
)
self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
"""礼物防抖超时处理"""
try:
# 等待防抖时间
await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物
if gift_key not in self.pending_gifts:
return
pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调
if pending_gift.callback:
try:
pending_gift.callback(message)
except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError:
# 定时器被取消,不需要处理
pass
except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量"""
return len(self.pending_gifts)
async def flush_all(self) -> None:
"""立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()):
pending_gift = self.pending_gifts.get(gift_key)
if pending_gift and not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
await self._gift_timeout(gift_key)
# 创建全局礼物管理器实例
gift_manager = GiftManager()

View File

@@ -1,14 +0,0 @@
class InternalManager:
def __init__(self):
self.now_internal_state = str()
def set_internal_state(self,internal_state:str):
self.now_internal_state = internal_state
def get_internal_state(self):
return self.now_internal_state
def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager()

View File

@@ -1,579 +0,0 @@
import asyncio
import traceback
import time
import random
from typing import Optional, Dict, Tuple, List # 导入类型提示
from maim_message import UserInfo, Seg
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from .s4u_stream_generator import S4UStreamGenerator
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
from src.config.config import global_config
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from .s4u_watching_manager import watching_manager
import json
from .s4u_mood_manager import mood_manager
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import get_person_id
from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat")
class MessageSenderContainer:
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
self.chat_stream = chat_stream
self.original_message = original_message
self.queue = asyncio.Queue()
self.storage = MessageStorage()
self._task: Optional[asyncio.Task] = None
self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
async def add_message(self, chunk: str):
"""向队列中添加一个消息块。"""
await self.queue.put(chunk)
async def close(self):
"""表示没有更多消息了,关闭队列。"""
await self.queue.put(None) # Sentinel
def pause(self):
"""暂停发送。"""
self._paused_event.clear()
def resume(self):
"""恢复发送。"""
self._paused_event.set()
def _calculate_typing_delay(self, text: str) -> float:
"""根据文本长度计算模拟打字延迟。"""
chars_per_second = s4u_config.chars_per_second
min_delay = s4u_config.min_typing_delay
max_delay = s4u_config.max_typing_delay
delay = len(text) / chars_per_second
return max(min_delay, min(delay, max_delay))
async def _send_worker(self):
"""从队列中取出消息并发送。"""
while True:
try:
# This structure ensures that task_done() is called for every item retrieved,
# even if the worker is cancelled while processing the item.
chunk = await self.queue.get()
except asyncio.CancelledError:
break
try:
if chunk is None:
break
# Check for pause signal *after* getting an item.
await self._paused_event.wait()
# 根据配置选择延迟模式
if s4u_config.enable_dynamic_typing_delay:
delay = self._calculate_typing_delay(chunk)
else:
delay = s4u_config.typing_delay
await asyncio.sleep(delay)
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await get_global_api().send_message(bot_message)
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
message_segment = Seg(type="text", data=chunk)
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
self.queue.task_done()
def start(self):
"""启动发送任务。"""
if self._task is None:
self._task = asyncio.create_task(self._send_worker())
async def join(self):
"""等待所有消息发送完毕。"""
if self._task:
await self._task
class S4UChatManager:
def __init__(self):
self.s4u_chats: Dict[str, "S4UChat"] = {}
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
if chat_stream.stream_id not in self.s4u_chats:
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
logger.info(f"Creating new S4UChat for stream: {stream_name}")
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id]
if not s4u_config.enable_s4u:
s4u_chat_manager = None
else:
s4u_chat_manager = S4UChatManager()
def get_s4u_chat_manager() -> S4UChatManager:
return s4u_chat_manager
class S4UChat:
def __init__(self, chat_stream: ChatStream):
"""初始化 S4UChat 实例。"""
self.chat_stream = chat_stream
self.stream_id = chat_stream.stream_id
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
# 两个消息队列
self._vip_queue = asyncio.PriorityQueue()
self._normal_queue = asyncio.PriorityQueue()
self._entry_counter = 0 # 保证FIFO的全局计数器
self._new_message_event = asyncio.Event() # 用于唤醒处理器
self._processing_task = asyncio.create_task(self._message_processor())
self._current_generation_task: Optional[asyncio.Task] = None
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None
self._is_replying = False
self.gpt = S4UStreamGenerator()
self.gpt.chat_stream = self.chat_stream
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.internal_message :List[MessageRecvS4U] = []
self.msg_id = ""
self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
def _get_priority_info(self, message: MessageRecv) -> dict:
"""安全地从消息中提取和解析 priority_info"""
priority_info_raw = message.priority_info
priority_info = {}
if isinstance(priority_info_raw, str):
try:
priority_info = json.loads(priority_info_raw)
except json.JSONDecodeError:
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
elif isinstance(priority_info_raw, dict):
priority_info = priority_info_raw
return priority_info
def _is_vip(self, priority_info: dict) -> bool:
"""检查消息是否来自VIP用户。"""
return priority_info.get("message_type") == "vip"
def _get_interest_score(self, user_id: str) -> float:
"""获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0)
def go_processing(self):
if self.voice_done == self.last_msg_id:
return True
return False
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
"""
为消息计算基础优先级分数。分数越高,优先级越高。
"""
score = 0.0
# 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0)
# 加上用户的固有兴趣分
score += self._get_interest_score(message.message_info.user_info.user_id)
return score
def decay_interest_score(self):
for person_id, score in self.interest_dict.items():
if score > 0:
self.interest_dict[person_id] = score * 0.95
else:
self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
person_id = get_person_id(platform, user_id)
# try:
# is_gift = message.is_gift
# is_superchat = message.is_superchat
# # print(is_gift)
# # print(is_superchat)
# if is_gift:
# await self.relationship_builder.build_relation(immediate_build=person_id)
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
# elif is_superchat:
# await self.relationship_builder.build_relation(immediate_build=person_id)
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# # 添加SuperChat到管理器
# super_chat_manager = get_super_chat_manager()
# await super_chat_manager.add_superchat(message)
# else:
# await self.relationship_builder.build_relation(20)
# except Exception:
# traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False
if (s4u_config.enable_message_interruption and
self._current_generation_task and not self._current_generation_task.done()):
if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied
# 规则VIP从不被打断
if current_queue == "vip":
pass # Do nothing
# 规则:普通消息可以被打断
elif current_queue == "normal":
# VIP消息可以打断普通消息
if is_vip:
should_interrupt = True
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
# 普通消息的内部打断逻辑
else:
new_sender_id = message.message_info.user_info.user_id
current_sender_id = current_msg.message_info.user_info.user_id
# 新消息优先级更高
if new_priority_score > current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
# 同用户,新消息的优先级不能更低
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
if should_interrupt:
if self.gpt.partial_response:
logger.warning(
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
)
self._current_generation_task.cancel()
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
item = (-new_priority_score, self._entry_counter, time.time(), message)
if is_vip and s4u_config.vip_queue_priority:
await self._vip_queue.put(item)
logger.info(f"[{self.stream_name}] VIP message added to queue.")
else:
await self._normal_queue.put(item)
self._entry_counter += 1
self._new_message_event.set() # 唤醒处理器
def _cleanup_old_normal_messages(self):
"""清理普通队列中不在最近N条消息范围内的消息"""
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return
# 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息
temp_messages = []
removed_count = 0
# 取出所有普通队列中的消息
while not self._normal_queue.empty():
try:
item = self._normal_queue.get_nowait()
neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}")
if entry_count >= cutoff_counter:
temp_messages.append(item)
else:
removed_count += 1
self._normal_queue.task_done() # 标记被移除的任务为完成
except asyncio.QueueEmpty:
break
# 将保留的消息重新放入队列
for item in temp_messages:
self._normal_queue.put_nowait(item)
if removed_count > 0:
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除")
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。"""
while True:
try:
# 等待有新消息的信号,避免空转
await self._new_message_event.wait()
self._new_message_event.clear()
# 清理普通队列中的过旧消息
self._cleanup_old_normal_messages()
# 优先处理VIP队列
if not self._vip_queue.empty():
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
priority = -neg_priority
queue_name = "vip"
# 其次处理普通队列
elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority
# 检查普通消息是否超时
if time.time() - timestamp > s4u_config.message_timeout_seconds:
logger.info(
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
)
self._normal_queue.task_done()
continue # 处理下一条
queue_name = "normal"
else:
if self.internal_message:
message = self.internal_message[-1]
self.internal_message = []
priority = 0
neg_priority = 0
entry_count = 0
queue_name = "internal"
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...")
else:
continue # 没有消息了,回去等事件
self._current_message_being_replied = (queue_name, priority, entry_count, message)
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
try:
await self._current_generation_task
except asyncio.CancelledError:
logger.info(
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
)
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
except Exception as e:
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
finally:
self._current_generation_task = None
self._current_message_being_replied = None
# 标记任务完成
if queue_name == "vip":
self._vip_queue.task_done()
elif queue_name == "internal":
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
pass
else:
self._normal_queue.task_done()
# 检查是否还有任务,有则立即再次触发事件
if not self._vip_queue.empty() or not self._normal_queue.empty():
self._new_message_event.set()
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
break
except Exception as e:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1)
def get_processing_message_id(self):
self.last_msg_id = self.msg_id
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数
self.get_processing_message_id()
# 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
if message.is_internal:
await chat_watching.on_internal_message_start()
else:
await chat_watching.on_reply_start()
sender_container = MessageSenderContainer(self.chat_stream, message)
sender_container.start()
async def generate_and_send_inner():
nonlocal total_chars_sent
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
if s4u_config.enable_streaming_output:
logger.info("[S4U] 开始流式输出")
# 流式输出,边生成边发送
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
sender_container.msg_id = self.msg_id
await sender_container.add_message(chunk)
total_chars_sent += len(chunk)
else:
logger.info("[S4U] 开始一次性输出")
# 一次性输出先收集所有chunk
all_chunks = []
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
all_chunks.append(chunk)
total_chars_sent += len(chunk)
# 一次性发送
sender_container.msg_id = self.msg_id
await sender_container.add_message("".join(all_chunks))
try:
try:
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
except asyncio.TimeoutError:
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
sender_container.msg_id = self.msg_id
await sender_container.add_message("麦麦不知道哦")
total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id)
# 等待所有文本消息发送完成
await sender_container.close()
await sender_container.join()
await chat_watching.on_thinking_finished()
start_time = time.time()
logged = False
while not self.go_processing():
if time.time() - start_time > 60:
logger.warning(f"[{self.stream_name}] 等待消息发送超时60秒强制跳出循环。")
break
if not logged:
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
logged = True
await asyncio.sleep(0.2)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
raise # 将取消异常向上传播
except Exception as e:
traceback.print_exc()
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
# 回复生成实时展示:清空内容(出错时)
finally:
self._is_replying = False
# 视线管理:回复结束时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_finished()
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
sender_container.resume()
if not sender_container._task.done():
await sender_container.close()
await sender_container.join()
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
async def shutdown(self):
"""平滑关闭处理任务。"""
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
# 取消正在运行的任务
if self._current_generation_task and not self._current_generation_task.done():
self._current_generation_task.cancel()
if self._processing_task and not self._processing_task.done():
self._processing_task.cancel()
# 等待任务响应取消
try:
await self._processing_task
except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@@ -1,456 +0,0 @@
import asyncio
import json
import time
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.s4u_config import s4u_config
"""
情绪管理系统使用说明:
1. 情绪数值系统:
- 情绪包含四个维度joy(喜), anger(怒), sorrow(哀), fear(惧)
- 每个维度的取值范围为1-10
- 当情绪发生变化时会自动发送到ws端处理
2. 情绪更新机制:
- 接收到新消息时会更新情绪状态
- 定期进行情绪回归(冷静下来)
- 每次情绪变化都会发送到ws端格式为
type: "emotion"
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
3. ws端处理
- 本地只负责情绪计算和发送情绪数值
- 表情渲染和动作由ws端根据情绪数值处理
"""
logger = get_logger("mood")
def init_prompt():
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是:{mood_state}
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
请只输出情绪状态,不要输出其他内容:
""",
"change_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是:{mood_state}
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
请只输出情绪状态,不要输出其他内容:
""",
"regress_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是:{mood_state}
具体来说从1-10分你的情绪状态是
喜(Joy): {joy}
怒(Anger): {anger}
哀(Sorrow): {sorrow}
惧(Fear): {fear}
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10。
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON。
""",
"change_mood_numerical_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是:{mood_state}
具体来说从1-10分你的情绪状态是
喜(Joy): {joy}
怒(Anger): {anger}
哀(Sorrow): {sorrow}
惧(Fear): {fear}
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10。
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON。
""",
"regress_mood_numerical_prompt",
)
class ChatMood:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.mood_state: str = "感觉很平静"
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
self.regression_count: int = 0
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
self.mood_model_numerical = LLMRequest(
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
)
self.last_change_time: float = 0
# 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values))
def _parse_numerical_mood(self, response: str) -> dict[str, int] | None:
try:
# The LLM might output markdown with json inside
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
data = json.loads(response)
# Validate
required_keys = {"joy", "anger", "sorrow", "fear"}
if not required_keys.issubset(data.keys()):
logger.warning(f"Numerical mood response missing keys: {response}")
return None
for key in required_keys:
value = data[key]
if not isinstance(value, int) or not (1 <= value <= 10):
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
return None
return {key: data[key] for key in required_keys}
except json.JSONDecodeError:
logger.warning(f"Failed to parse numerical mood JSON: {response}")
return None
except Exception as e:
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
return None
async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _update_text_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response
async def _update_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt, temperature=0.4
)
logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
self.last_change_time = message_time
async def regress_mood(self):
message_time = time.time()
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=5,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _regress_text_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response
async def _regress_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt,
temperature=0.4,
)
logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
self.regression_count += 1
async def send_emotion_update(self, mood_values: dict[str, int]):
"""发送情绪更新到ws端"""
emotion_data = {
"joy": mood_values.get("joy", 5),
"anger": mood_values.get("anger", 1),
"sorrow": mood_values.get("sorrow", 1),
"fear": mood_values.get("fear", 1),
}
await send_api.custom_to_stream(
message_type="emotion",
content=emotion_data,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
class MoodRegressionTask(AsyncTask):
def __init__(self, mood_manager: "MoodManager"):
super().__init__(task_name="MoodRegressionTask", run_interval=30)
self.mood_manager = mood_manager
self.run_count = 0
async def run(self):
self.run_count += 1
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
now = time.time()
regression_executed = 0
for mood in self.mood_manager.mood_list:
chat_info = f"chat {mood.chat_id}"
if mood.last_change_time == 0:
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
continue
time_since_last_change = now - mood.last_change_time
# 检查是否有极端情绪需要快速回归
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
has_extreme_emotion = len(high_emotions) > 0
# 回归条件1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
should_regress = False
regress_reason = ""
if time_since_last_change > 120:
should_regress = True
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
elif has_extreme_emotion and time_since_last_change > 30:
should_regress = True
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
if should_regress:
if mood.regression_count >= 3:
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
continue
logger.info(
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
)
await mood.regress_mood()
regression_executed += 1
else:
if has_extreme_emotion:
remaining_time = 5 - time_since_last_change
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
logger.debug(
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}"
)
else:
remaining_time = 120 - time_since_last_change
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}")
if regression_executed > 0:
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
else:
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
class MoodManager:
def __init__(self):
self.mood_list: list[ChatMood] = []
"""当前情绪状态"""
self.task_started: bool = False
async def start(self):
"""启动情绪回归后台任务"""
if self.task_started:
return
logger.info("启动情绪管理任务...")
# 启动情绪回归任务
regression_task = MoodRegressionTask(self)
await async_task_manager.add_task(regression_task)
self.task_started = True
logger.info("情绪管理任务已启动(情绪回归)")
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
for mood in self.mood_list:
if mood.chat_id == chat_id:
return mood
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
return new_mood
def reset_mood_by_chat_id(self, chat_id: str):
for mood in self.mood_list:
if mood.chat_id == chat_id:
mood.mood_state = "感觉很平静"
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
mood.regression_count = 0
# 发送重置后的情绪状态到ws端
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
return
# 如果没有找到现有的mood创建新的
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
# 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
if s4u_config.enable_s4u:
init_prompt()
mood_manager = MoodManager()
else:
mood_manager = None
"""全局情绪管理器"""

View File

@@ -1,264 +0,0 @@
import asyncio
import math
from typing import Tuple
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from maim_message.message_base import GroupInfo
from src.chat.message_receive.storage import MessageStorage
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
from src.mais4u.mais4u_chat.gift_manager import gift_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from .s4u_chat import get_s4u_chat_manager
# from ..message_receive.message_buffer import message_buffer
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool]: (兴趣度, 是否被提及)
"""
is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0
if global_config.memory.enable_memory:
with Timer("记忆激活"):
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)
logger.debug(f"记忆激活率: {interested_rate:.2f}")
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
interested_rate += base_interest
if is_mentioned:
interest_increase_on_mention = 1
interested_rate += interest_increase_on_mention
return interested_rate, is_mentioned
class S4UMessageProcessor:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):
"""初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage()
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
"""处理接收到的原始消息数据
主要流程:
1. 消息解析与初始化
2. 消息缓冲处理
3. 过滤检查
4. 兴趣度计算
5. 关系处理
Args:
message_data: 原始消息字符串
"""
# 1. 消息解析与初始化
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
message_info = message.message_info
chat = await get_chat_manager().get_or_create_stream(
platform=message_info.platform,
user_info=userinfo,
group_info=groupinfo,
)
if await self.handle_internal_message(message):
return
if await self.hadle_if_voice_done(message):
return
# 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message):
return
await self.check_if_fake_gift(message)
# 处理屏幕消息
if await self.handle_screen_message(message):
return
await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message)
await mood_manager.start()
# 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message))
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
asyncio.create_task(chat_action.update_action_by_message(message))
# 视线管理:收到消息时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
await chat_watching.on_message_received()
# 上下文网页管理启动独立task处理消息上下文
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
# 日志记录
if message.is_gift:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
async def handle_internal_message(self, message: MessageRecvS4U):
if message.is_internal:
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心")
chat = await get_chat_manager().get_or_create_stream(
platform = "amaidesu_default",
user_info = message.message_info.user_info,
group_info = group_info
)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message)
s4u_chat._new_message_event.set()
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}")
return True
return False
async def handle_screen_message(self, message: MessageRecvS4U):
if message.is_screen:
screen_manager.set_screen(message.screen_info)
return True
return False
async def hadle_if_voice_done(self, message: MessageRecvS4U):
if message.voice_done:
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done
return True
return False
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物"""
if message.is_gift:
return False
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True
return True
return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息
Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
"""
if message.is_gift:
# 定义防抖完成后的回调函数
def gift_callback(merged_message: MessageRecvS4U):
"""礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task
Args:
chat_id: 聊天ID
message: 消息对象
"""
try:
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
context_manager = get_context_web_manager()
# 只在服务器未启动时启动(避免重复启动)
if context_manager.site is None:
logger.info("🚀 首次启动上下文网页服务器...")
await context_manager.start_server()
# 添加消息到上下文并更新网页
await asyncio.sleep(1.5)
await context_manager.add_message(chat_id, message)
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
except Exception as e:
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)

View File

@@ -1,396 +0,0 @@
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker
from src.chat.memory_system.Hippocampus import hippocampus_manager
import random
from datetime import datetime
import asyncio
from src.mais4u.s4u_config import s4u_config
from src.chat.message_receive.message import MessageRecvS4U
from src.person_info.person_info import Person, get_person_id
from src.chat.message_receive.chat_stream import ChatStream
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from src.chat.express.expression_selector import expression_selector
from .s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.data_models.database_data_model import DatabaseMessages
from typing import List
logger = get_logger("prompt")
def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
你可以看见用户发送的弹幕礼物和superchat
{screen_info}
{internal_state}
{relation_info_block}
{memory_block}
{expression_habits_block}
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
{sc_info}
{background_dialogue_prompt}
--------------------------------
{time_block}
这是你和{sender_name}的对话,你们正在交流中:
{core_dialogue_prompt}
对方最新发送的内容:{message_txt}
{gift_info}
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}
你的发言:
""",
"s4u_prompt", # New template for private CHAT chat
)
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
你可以看见用户发送的弹幕礼物和superchat
你可以看见面前的屏幕,目前屏幕的内容是:
{screen_info}
{memory_block}
{expression_habits_block}
{sc_info}
{time_block}
{chat_info_danmu}
--------------------------------
以上是你和弹幕的对话与此同时你在与QQ群友聊天聊天记录如下
{chat_info_qq}
--------------------------------
你刚刚回复了QQ群你内心的想法是{mind}
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
{gift_info}
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。
你的发言:
""",
"s4u_prompt_internal", # New template for private CHAT chat
)
class PromptBuilder:
def __init__(self):
self.prompt_built = ""
self.activate_messages = ""
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
chat_stream.stream_id, chat_history, max_num=12, target_message=target
)
if selected_expressions:
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
# 动态构建expression habits块
expression_habits_block = ""
if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
return expression_habits_block
async def build_relation_info(self, chat_stream) -> str:
is_group_chat = bool(chat_stream.group_info)
who_chat_in_group = []
if is_group_chat:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.chat.max_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
)
relation_prompt = ""
if global_config.relationship.enable_relationship and who_chat_in_group:
# 将 (platform, user_id, nickname) 转换为 person_id
person_ids = []
for person in who_chat_in_group:
person_id = get_person_id(person[0], person[1])
person_ids.append(person_id)
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
)
return relation_prompt
async def build_memory_block(self, text: str) -> str:
# 待更新记忆系统
return ""
related_memory = await hippocampus_manager.get_memory_from_text(
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
related_memory_info = ""
if related_memory:
for memory in related_memory:
related_memory_info += memory[1]
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
return ""
def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300,
)
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list: List[DatabaseMessages] = []
background_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id)
for msg in message_list_before_now:
try:
msg_user_id = str(msg.user_info.user_id)
if msg_user_id == bot_id:
if msg.reply_to and talk_type == msg.reply_to:
core_dialogue_list.append(msg)
elif msg.reply_to and talk_type != msg.reply_to:
background_dialogue_list.append(msg)
# else:
# background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id:
core_dialogue_list.append(msg)
else:
background_dialogue_list.append(msg)
except Exception as e:
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
background_dialogue_prompt = ""
if background_dialogue_list:
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
background_dialogue_prompt_str = build_readable_messages(
context_msgs,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
core_msg_str = ""
if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.user_info.user_id
if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n"
else:
start_speaking_user_id = target_user_id
last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
all_msg_seg_list = []
for msg in core_dialogue_list[1:]:
speaker = msg.user_info.user_id
if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
else:
msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str)
if speaker == bot_id:
msg_seg_str = "你的发言:\n"
else:
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list:
core_msg_str += msg
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
)
all_dialogue_prompt_str = build_readable_messages(
all_dialogue_history,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
def build_gift_info(self, message: MessageRecvS4U):
if message.is_gift:
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
else:
if message.is_fake_gift:
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
return ""
def build_sc_info(self, message: MessageRecvS4U):
super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal(
self,
message: MessageRecvS4U,
message_txt: str,
) -> str:
chat_stream = message.chat_stream
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
person_name = person.person_name
if message.chat_stream.user_info.user_nickname:
if person_name:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
else:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
self.build_relation_info(chat_stream),
self.build_memory_block(message_txt),
self.build_expression_habits(chat_stream, message_txt, sender_name),
)
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
chat_stream, message
)
gift_info = self.build_gift_info(message)
sc_info = self.build_sc_info(message)
screen_info = screen_manager.get_screen_str()
internal_state = internal_manager.get_internal_state_str()
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
template_name = "s4u_prompt"
if not message.is_internal:
prompt = await global_prompt_manager.format_prompt(
template_name,
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
internal_state=internal_state,
gift_info=gift_info,
sc_info=sc_info,
sender_name=sender_name,
core_dialogue_prompt=core_dialogue_prompt,
background_dialogue_prompt=background_dialogue_prompt,
message_txt=message_txt,
mood_state=mood.mood_state,
)
else:
prompt = await global_prompt_manager.format_prompt(
"s4u_prompt_internal",
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
gift_info=gift_info,
sc_info=sc_info,
chat_info_danmu=all_dialogue_prompt,
chat_info_qq=message.chat_info,
mind=message.processed_plain_text,
mood_state=mood.mood_state,
)
# print(prompt)
return prompt
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights, strict=False))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt()
prompt_builder = PromptBuilder()

Some files were not shown because too many files have changed in this diff Show More