Merge branch 'MaiM-with-u:dev' into dev
This commit is contained in:
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1,3 +1,2 @@
|
||||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
MaiLauncher.bat text eol=crlf working-tree-encoding=GBK
|
||||
*.cmd text eol=crlf
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -20,6 +20,8 @@ MaiBot-Napcat-Adapter
|
||||
nonebot-maibot-adapter/
|
||||
MaiMBot-LPMM
|
||||
*.zip
|
||||
run_bot.bat
|
||||
run_na.bat
|
||||
run.bat
|
||||
log_debug/
|
||||
run_amds.bat
|
||||
@@ -41,16 +43,13 @@ config/bot_config.toml
|
||||
config/bot_config.toml.bak
|
||||
config/lpmm_config.toml
|
||||
config/lpmm_config.toml.bak
|
||||
src/mais4u/config/s4u_config.toml
|
||||
src/mais4u/config/old
|
||||
template/compare/bot_config_template.toml
|
||||
template/compare/model_config_template.toml
|
||||
(测试版)麦麦生成人格.bat
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
s4u.s4u
|
||||
s4u.s4u1
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@@ -321,9 +320,14 @@ run_pet.bat
|
||||
/plugins/*
|
||||
!/plugins
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/deep_think
|
||||
!/plugins/MaiFrequencyControl
|
||||
!/plugins/__init__.py
|
||||
|
||||
config.toml
|
||||
|
||||
interested_rates.txt
|
||||
MaiBot.code-workspace
|
||||
*.lock
|
||||
18
README.md
18
README.md
@@ -26,12 +26,10 @@
|
||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||
- 🔌 **强大插件系统**:全面重构的插件架构,更多API。
|
||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||
- 🧠 **持久记忆系统**:基于图的长期记忆存储。
|
||||
- 🔄 **动态人格系统**:自适应的性格特征和表达方式。
|
||||
- 🔌 **强大插件系统**:提供API和事件系统,可编写强大插件。
|
||||
|
||||
<div style="text-align: center">
|
||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||
@@ -46,7 +44,7 @@
|
||||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.10.1** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.10.3** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
@@ -64,7 +62,7 @@
|
||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||
> - 由于程序处于开发中,可能消耗较多 token。
|
||||
|
||||
## 麦麦MC项目(早期开发)
|
||||
## 麦麦MC项目MaiCraft(早期开发)
|
||||
[让麦麦玩MC](https://github.com/MaiM-with-u/Maicraft)
|
||||
|
||||
交流群:1058573197
|
||||
@@ -72,13 +70,13 @@
|
||||
## 💬 讨论
|
||||
|
||||
**技术交流群:**
|
||||
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||
[三群](https://qm.qq.com/q/wlH5eT8OmQ) |
|
||||
[四群](https://qm.qq.com/q/wGePTl1UyY)
|
||||
[麦麦脑电图](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||
[麦麦脑磁图](https://qm.qq.com/q/wlH5eT8OmQ) |
|
||||
[麦麦大脑磁共振](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||
[麦麦要当VTB](https://qm.qq.com/q/wGePTl1UyY)
|
||||
|
||||
**聊天吹水群:**
|
||||
- [五群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
- [麦麦之闲聊群](https://qm.qq.com/q/JxvHZnxyec)
|
||||
|
||||
**插件开发测试版群:**
|
||||
- [插件开发群](https://qm.qq.com/q/1036092828)
|
||||
|
||||
24
bot.py
24
bot.py
@@ -5,16 +5,29 @@ import sys
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
import shutil
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
from rich.traceback import install
|
||||
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
env_path = Path(__file__).parent / ".env"
|
||||
template_env_path = Path(__file__).parent / "template" / "template.env"
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(str(env_path), override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
try:
|
||||
if template_env_path.exists():
|
||||
shutil.copyfile(template_env_path, env_path)
|
||||
print("未找到.env,已从 template/template.env 自动创建")
|
||||
load_dotenv(str(env_path), override=True)
|
||||
else:
|
||||
print("未找到.env文件,也未找到模板 template/template.env")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
except Exception as e:
|
||||
print(f"自动创建 .env 失败: {e}")
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||
@@ -62,9 +75,10 @@ def easter_egg():
|
||||
async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
# 触发 ON_STOP 事件
|
||||
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
|
||||
|
||||
|
||||
@@ -1,5 +1,41 @@
|
||||
# Changelog
|
||||
|
||||
## [0.11.0] - 2025-9-22
|
||||
### 🌟 主要功能更改
|
||||
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
|
||||
- 麦麦好奇功能,麦麦会自主提出问题
|
||||
- 添加deepthink插件(默认关闭),让麦麦可以深度思考一些问题
|
||||
- 添加表情包管理插件
|
||||
|
||||
### 细节功能更改
|
||||
- 修复配置文件转义问题
|
||||
- 情绪系统现在可以由配置文件控制开关
|
||||
- 修复平行动作控制失效的问题
|
||||
- 添加planner防抖,防止短时间快速消耗token
|
||||
- 修复吞字问题
|
||||
- 更新依赖表
|
||||
- 修复负载均衡
|
||||
- 优化了对gemini和不同模型的支持
|
||||
|
||||
## [0.10.3] - 2025-9-22
|
||||
### 🌟 主要功能更改
|
||||
- planner支持多动作,移除Sub_planner
|
||||
- 移除激活度系统,现在回复完全由planner控制
|
||||
- 现可自定义planner行为,更优化的聊天频率控制
|
||||
- 支持发送转发和合并转发
|
||||
- 关系现在支持多人的信息
|
||||
- 更好的event系统,正式建立
|
||||
|
||||
### 细节功能更改
|
||||
- 支持所有表达方式互通
|
||||
- 现可使用付费嵌入模型
|
||||
- 添加多种发送类型
|
||||
- 优化识图token限制
|
||||
- 为空回复添加重试机制
|
||||
- 加入brainchat模式,为私聊支持做准备
|
||||
- 修复qq号格式
|
||||
|
||||
|
||||
|
||||
## [0.10.2] - 2025-8-31
|
||||
### 🌟 主要功能更改
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
## [1.0.3] - 2025-3-31
|
||||
### Added
|
||||
- 新增了心流相关配置项:
|
||||
- `heartflow` 配置项,用于控制心流功能
|
||||
|
||||
### Removed
|
||||
- 移除了 `response` 配置项中的 `model_r1_probability` 和 `model_v3_probability` 选项
|
||||
- 移除了次级推理模型相关配置
|
||||
|
||||
## [1.0.1] - 2025-3-30
|
||||
### Added
|
||||
- 增加了流式输出控制项 `stream`
|
||||
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
|
||||
|
||||
## [1.0.0] - 2025-3-30
|
||||
### Added
|
||||
- 修复了错误的版本命名
|
||||
- 杀掉了所有无关文件
|
||||
|
||||
## [0.0.11] - 2025-3-12
|
||||
### Added
|
||||
- 新增了 `schedule` 配置项,用于配置日程表生成功能
|
||||
- 新增了 `response_splitter` 配置项,用于控制回复分割
|
||||
- 新增了 `experimental` 配置项,用于实验性功能开关
|
||||
- 新增了 `llm_observation` 和 `llm_sub_heartflow` 模型配置
|
||||
- 新增了 `llm_heartflow` 模型配置
|
||||
- 在 `personality` 配置项中新增了 `prompt_schedule_gen` 参数
|
||||
|
||||
### Changed
|
||||
- 优化了模型配置的组织结构
|
||||
- 调整了部分配置项的默认值
|
||||
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
|
||||
- 在 `message` 配置项中:
|
||||
- 新增了 `model_max_output_length` 参数
|
||||
- 在 `willing` 配置项中新增了 `emoji_response_penalty` 参数
|
||||
- 将 `personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
|
||||
|
||||
### Removed
|
||||
- 移除了 `min_text_length` 配置项
|
||||
- 移除了 `cq_code` 配置项
|
||||
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
|
||||
|
||||
## [0.0.5] - 2025-3-11
|
||||
### Added
|
||||
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
|
||||
|
||||
## [0.0.4] - 2025-3-9
|
||||
### Added
|
||||
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。
|
||||
@@ -28,7 +28,7 @@ version = "1.1.1"
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek" # 服务商名称(自定义)
|
||||
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
|
||||
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
||||
api_key = "your-api-key-here" # API密钥
|
||||
client_type = "openai" # 客户端类型
|
||||
max_retry = 2 # 最大重试次数
|
||||
@@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
|
||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||
| `base_url` | ✅ | API服务的基础URL | - |
|
||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
||||
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
||||
### 2.3 支持的服务商示例
|
||||
|
||||
#### DeepSeek
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek"
|
||||
base_url = "https://api.deepseek.cn/v1"
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-deepseek-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
@@ -73,7 +73,7 @@ client_type = "openai"
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "Google"
|
||||
base_url = "https://api.google.com/v1"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "your-google-api-key"
|
||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||
```
|
||||
@@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
|
||||
[models.extra_params]
|
||||
thinking = {type = "disabled"} # 禁用思考
|
||||
```
|
||||
|
||||
而对于`gemini`需要单独进行配置
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "gemini-2.5-flash"
|
||||
name = "gemini-2.5-flash"
|
||||
api_provider = "Google"
|
||||
[models.extra_params]
|
||||
thinking_budget = 0 # 禁用思考
|
||||
# thinking_budget = -1 由模型自己决定
|
||||
```
|
||||
|
||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
||||
### 3.3 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
|
||||
57
flake.lock
generated
57
flake.lock
generated
@@ -1,57 +0,0 @@
|
||||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 0,
|
||||
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
|
||||
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
|
||||
"type": "path"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs",
|
||||
"utils": "utils"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
39
flake.nix
39
flake.nix
@@ -1,39 +0,0 @@
|
||||
{
|
||||
description = "MaiMBot Nix Dev Env";
|
||||
|
||||
inputs = {
|
||||
utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = {
|
||||
self,
|
||||
nixpkgs,
|
||||
utils,
|
||||
...
|
||||
}:
|
||||
utils.lib.eachDefaultSystem (system: let
|
||||
pkgs = import nixpkgs {inherit system;};
|
||||
pythonPackages = pkgs.python3Packages;
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
name = "python-venv";
|
||||
venvDir = "./.venv";
|
||||
buildInputs = with pythonPackages; [
|
||||
python
|
||||
venvShellHook
|
||||
scipy
|
||||
numpy
|
||||
];
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
pip install -r requirements.txt
|
||||
'';
|
||||
|
||||
postShellHook = ''
|
||||
# allow pip to install wheels
|
||||
unset SOURCE_DATE_EPOCH
|
||||
'';
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -828,7 +828,7 @@ class LogViewer:
|
||||
parts, tags = self.formatter.format_log_entry(log_entry)
|
||||
line_text = " ".join(parts)
|
||||
log_lines.append(line_text)
|
||||
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_lines))
|
||||
messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
|
||||
@@ -1188,15 +1188,16 @@ class LogViewer:
|
||||
line_count += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
# 如果发现了新模块,在主线程中更新模块集合
|
||||
if new_modules:
|
||||
|
||||
def update_modules():
|
||||
self.modules.update(new_modules)
|
||||
self.update_module_list()
|
||||
|
||||
|
||||
self.root.after(0, update_modules)
|
||||
|
||||
|
||||
return new_entries
|
||||
|
||||
def append_new_logs(self, new_entries):
|
||||
@@ -1424,4 +1425,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
34
plugins/deep_think/_manifest.json
Normal file
34
plugins/deep_think/_manifest.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Deep Think插件 (Deep Think Actions)",
|
||||
"version": "1.0.0",
|
||||
"description": "可以深度思考",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.11.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["deep", "think", "action", "built-in"],
|
||||
"categories": ["Deep Think"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "deep_think",
|
||||
"description": "发送深度思考"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
102
plugins/deep_think/plugin.py
Normal file
102
plugins/deep_think/plugin.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List, Tuple, Type, Any
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
logger = get_logger("relation_actions")
|
||||
|
||||
|
||||
|
||||
class DeepThinkTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "deep_think"
|
||||
description = "深度思考,对某个知识,概念或逻辑问题进行全面且深入的思考,当面临复杂环境或重要问题时,使用此获得更好的解决方案。"
|
||||
parameters = [
|
||||
("question", ToolParamType.STRING, "需要思考的问题,越具体越好(从上下文中总结)", True, None),
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
question: str = function_args.get("question") # type: ignore
|
||||
|
||||
print(f"question: {question}")
|
||||
|
||||
prompt = f"""
|
||||
请你思考以下问题,以简洁的一段话回答:
|
||||
{question}
|
||||
"""
|
||||
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
||||
|
||||
success, thinking_result, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="deep_think"
|
||||
)
|
||||
|
||||
logger.info(f"{question}: {thinking_result}")
|
||||
|
||||
thinking_result =f"思考结果:{thinking_result}\n**注意** 因为你进行了深度思考,最后的回复内容可以回复的长一些,更加详细一些,不用太简洁。\n"
|
||||
|
||||
return {"content": thinking_result}
|
||||
|
||||
|
||||
@register_plugin
|
||||
class DeepThinkPlugin(BasePlugin):
|
||||
"""关系动作插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- Reply: 回复动作
|
||||
- NoReply: 不回复动作
|
||||
- Emoji: 表情动作
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "deep_think" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((DeepThinkTool.get_tool_info(), DeepThinkTool))
|
||||
|
||||
return components
|
||||
53
plugins/emoji_manage_plugin/_manifest.json
Normal file
53
plugins/emoji_manage_plugin/_manifest.json
Normal file
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "BetterEmoji",
|
||||
"version": "1.0.0",
|
||||
"description": "更好的表情包管理插件",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/SengokuCola"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.4"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"keywords": ["emoji", "manage", "plugin"],
|
||||
"categories": ["Examples", "Tutorial"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "emoji_manage",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "hello_greeting",
|
||||
"description": "向用户发送问候消息"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "bye_greeting",
|
||||
"description": "向用户发送告别消息",
|
||||
"activation_modes": ["keyword"],
|
||||
"keywords": ["再见", "bye", "88", "拜拜"]
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "time",
|
||||
"description": "查询当前时间",
|
||||
"pattern": "/time"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"问候和告别功能",
|
||||
"时间查询命令",
|
||||
"配置文件示例",
|
||||
"新手教程代码"
|
||||
]
|
||||
}
|
||||
}
|
||||
399
plugins/emoji_manage_plugin/plugin.py
Normal file
399
plugins/emoji_manage_plugin/plugin.py
Normal file
@@ -0,0 +1,399 @@
|
||||
from typing import List, Tuple, Type
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseCommand,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
ReplyContentType,
|
||||
emoji_api,
|
||||
)
|
||||
from maim_message import Seg
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("emoji_manage_plugin")
|
||||
|
||||
|
||||
class AddEmojiCommand(BaseCommand):
|
||||
command_name = "add_emoji"
|
||||
command_description = "添加表情包"
|
||||
command_pattern = r".*/emoji add.*"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
# 查找消息中的表情包
|
||||
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
|
||||
|
||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
||||
|
||||
if not emoji_base64_list:
|
||||
return False, "未在消息中找到表情包或图片", False
|
||||
|
||||
# 注册找到的表情包
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
||||
try:
|
||||
# 使用emoji_api注册表情包(让API自动生成唯一文件名)
|
||||
result = await emoji_api.register_emoji(emoji_base64)
|
||||
|
||||
if result["success"]:
|
||||
success_count += 1
|
||||
description = result.get("description", "未知描述")
|
||||
emotions = result.get("emotions", [])
|
||||
replaced = result.get("replaced", False)
|
||||
|
||||
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
|
||||
if description:
|
||||
result_msg += f"\n描述: {description}"
|
||||
if emotions:
|
||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
|
||||
results.append(result_msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
error_msg = result.get("message", "注册失败")
|
||||
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
|
||||
|
||||
# 构建返回消息
|
||||
total_count = success_count + fail_count
|
||||
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
||||
|
||||
# 如果有结果详情,添加到返回消息中
|
||||
details_msg = ""
|
||||
if results:
|
||||
details_msg = "\n" + "\n".join(results)
|
||||
final_msg = summary_msg + details_msg
|
||||
else:
|
||||
final_msg = summary_msg
|
||||
|
||||
# 使用表达器重写回复
|
||||
try:
|
||||
from src.plugin_system.apis import generator_api
|
||||
|
||||
# 构建重写数据
|
||||
rewrite_data = {
|
||||
"raw_reply": summary_msg,
|
||||
"reason": f"注册了表情包:{details_msg}\n",
|
||||
}
|
||||
|
||||
# 调用表达器重写
|
||||
result_status, data = await generator_api.rewrite_reply(
|
||||
chat_stream=self.message.chat_stream,
|
||||
reply_data=rewrite_data,
|
||||
)
|
||||
|
||||
if result_status:
|
||||
# 发送重写后的回复
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data)
|
||||
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
else:
|
||||
# 如果重写失败,发送原始消息
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
# 如果表达器调用失败,发送原始消息
|
||||
logger.error(f"[add_emoji] 表达器重写失败: {e}")
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
||||
emoji_base64_list = []
|
||||
|
||||
# 处理单个Seg对象的情况
|
||||
if isinstance(message_segments, Seg):
|
||||
if message_segments.type == "emoji":
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
||||
return emoji_base64_list
|
||||
|
||||
# 处理Seg列表的情况
|
||||
for seg in message_segments:
|
||||
if seg.type == "emoji":
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
||||
return emoji_base64_list
|
||||
|
||||
|
||||
class ListEmojiCommand(BaseCommand):
|
||||
"""列表表情包Command - 响应/emoji list命令"""
|
||||
|
||||
command_name = "emoji_list"
|
||||
command_description = "列表表情包"
|
||||
|
||||
# === 命令设置(必须填写)===
|
||||
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
"""执行列表表情包"""
|
||||
from src.plugin_system.apis import emoji_api
|
||||
import datetime
|
||||
|
||||
# 解析命令参数
|
||||
import re
|
||||
|
||||
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
|
||||
max_count = 10 # 默认显示10个
|
||||
if match and match.group(1):
|
||||
max_count = min(int(match.group(1)), 50) # 最多显示50个
|
||||
|
||||
# 获取当前时间
|
||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime(time_format)
|
||||
|
||||
# 获取表情包信息
|
||||
emoji_count = emoji_api.get_count()
|
||||
emoji_info = emoji_api.get_info()
|
||||
|
||||
# 构建返回消息
|
||||
message_lines = [
|
||||
f"📊 表情包统计信息 ({time_str})",
|
||||
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
|
||||
f"• 可用: {emoji_info['available_emojis']}",
|
||||
]
|
||||
|
||||
if emoji_count == 0:
|
||||
message_lines.append("\n❌ 暂无表情包")
|
||||
final_message = "\n".join(message_lines)
|
||||
await self.send_text(final_message)
|
||||
return True, final_message, True
|
||||
|
||||
# 获取所有表情包
|
||||
all_emojis = await emoji_api.get_all()
|
||||
if not all_emojis:
|
||||
message_lines.append("\n❌ 无法获取表情包列表")
|
||||
final_message = "\n".join(message_lines)
|
||||
await self.send_text(final_message)
|
||||
return False, final_message, True
|
||||
|
||||
# 显示前N个表情包
|
||||
display_emojis = all_emojis[:max_count]
|
||||
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
|
||||
|
||||
for i, (_, description, emotion) in enumerate(display_emojis, 1):
|
||||
# 截断过长的描述
|
||||
short_desc = description[:50] + "..." if len(description) > 50 else description
|
||||
message_lines.append(f"{i}. {short_desc} [{emotion}]")
|
||||
|
||||
# 如果还有更多表情包,显示总数
|
||||
if len(all_emojis) > max_count:
|
||||
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
|
||||
|
||||
final_message = "\n".join(message_lines)
|
||||
|
||||
# 直接发送文本消息
|
||||
await self.send_text(final_message)
|
||||
|
||||
return True, final_message, True
|
||||
|
||||
|
||||
class DeleteEmojiCommand(BaseCommand):
|
||||
command_name = "delete_emoji"
|
||||
command_description = "删除表情包"
|
||||
command_pattern = r".*/emoji delete.*"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
# 查找消息中的表情包图片
|
||||
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
|
||||
|
||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
||||
|
||||
if not emoji_base64_list:
|
||||
return False, "未在消息中找到表情包或图片", False
|
||||
|
||||
# 删除找到的表情包
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
||||
try:
|
||||
# 计算图片的哈希值来查找对应的表情包
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(emoji_base64, str):
|
||||
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
else:
|
||||
emoji_base64_clean = str(emoji_base64)
|
||||
|
||||
# 计算哈希值
|
||||
image_bytes = base64.b64decode(emoji_base64_clean)
|
||||
emoji_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 使用emoji_api删除表情包
|
||||
result = await emoji_api.delete_emoji(emoji_hash)
|
||||
|
||||
if result["success"]:
|
||||
success_count += 1
|
||||
description = result.get("description", "未知描述")
|
||||
count_before = result.get("count_before", 0)
|
||||
count_after = result.get("count_after", 0)
|
||||
emotions = result.get("emotions", [])
|
||||
|
||||
result_msg = f"表情包 {i + 1} 删除成功"
|
||||
if description:
|
||||
result_msg += f"\n描述: {description}"
|
||||
if emotions:
|
||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
result_msg += f"\n表情包数量: {count_before} → {count_after}"
|
||||
|
||||
results.append(result_msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
error_msg = result.get("message", "删除失败")
|
||||
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
|
||||
|
||||
# 构建返回消息
|
||||
total_count = success_count + fail_count
|
||||
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
||||
|
||||
# 如果有结果详情,添加到返回消息中
|
||||
details_msg = ""
|
||||
if results:
|
||||
details_msg = "\n" + "\n".join(results)
|
||||
final_msg = summary_msg + details_msg
|
||||
else:
|
||||
final_msg = summary_msg
|
||||
|
||||
# 使用表达器重写回复
|
||||
try:
|
||||
from src.plugin_system.apis import generator_api
|
||||
|
||||
# 构建重写数据
|
||||
rewrite_data = {
|
||||
"raw_reply": summary_msg,
|
||||
"reason": f"删除了表情包:{details_msg}\n",
|
||||
}
|
||||
|
||||
# 调用表达器重写
|
||||
result_status, data = await generator_api.rewrite_reply(
|
||||
chat_stream=self.message.chat_stream,
|
||||
reply_data=rewrite_data,
|
||||
)
|
||||
|
||||
if result_status:
|
||||
# 发送重写后的回复
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data)
|
||||
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
else:
|
||||
# 如果重写失败,发送原始消息
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
# 如果表达器调用失败,发送原始消息
|
||||
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
||||
emoji_base64_list = []
|
||||
|
||||
# 处理单个Seg对象的情况
|
||||
if isinstance(message_segments, Seg):
|
||||
if message_segments.type == "emoji":
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
||||
return emoji_base64_list
|
||||
|
||||
# 处理Seg列表的情况
|
||||
for seg in message_segments:
|
||||
if seg.type == "emoji":
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
||||
return emoji_base64_list
|
||||
|
||||
|
||||
class RandomEmojis(BaseCommand):
|
||||
command_name = "random_emojis"
|
||||
command_description = "发送多张随机表情包"
|
||||
command_pattern = r"^/random_emojis$"
|
||||
|
||||
async def execute(self):
|
||||
emojis = await emoji_api.get_random(5)
|
||||
if not emojis:
|
||||
return False, "未找到表情包", False
|
||||
emoji_base64_list = []
|
||||
for emoji in emojis:
|
||||
emoji_base64_list.append(emoji[0])
|
||||
return await self.forward_images(emoji_base64_list)
|
||||
|
||||
async def forward_images(self, images: List[str]):
|
||||
"""
|
||||
把多张图片用合并转发的方式发给用户
|
||||
"""
|
||||
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
||||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@register_plugin
|
||||
class EmojiManagePlugin(BasePlugin):
|
||||
"""表情包管理插件 - 管理表情包"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "emoji_manage_plugin" # 内部标识符
|
||||
enable_plugin: bool = False
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
return [
|
||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
||||
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
|
||||
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
|
||||
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
from typing import List, Tuple, Type, Any
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
@@ -12,7 +13,10 @@ from src.plugin_system import (
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType,
|
||||
ReplyContentType,
|
||||
emoji_api,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
@@ -24,6 +28,7 @@ class CompareNumbersTool(BaseTool):
|
||||
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
|
||||
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
|
||||
]
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
@@ -136,12 +141,80 @@ class PrintMessage(BaseEventHandler):
|
||||
handler_name = "print_message_handler"
|
||||
handler_description = "打印接收到的消息"
|
||||
|
||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None]:
|
||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, str | None, None, None]:
|
||||
"""执行打印消息事件处理"""
|
||||
# 打印接收到的消息
|
||||
if self.get_config("print_message.enabled", False):
|
||||
print(f"接收到消息: {message.raw_message if message else '无效消息'}")
|
||||
return True, True, "消息已打印", None
|
||||
return True, True, "消息已打印", None, None
|
||||
|
||||
|
||||
class ForwardMessages(BaseEventHandler):
|
||||
"""
|
||||
把接收到的消息转发到指定聊天ID
|
||||
|
||||
此组件是HYBRID消息和FORWARD消息的使用示例。
|
||||
每收到10条消息,就会以1%的概率使用HYBRID消息转发,否则使用FORWARD消息转发。
|
||||
"""
|
||||
|
||||
event_type = EventType.ON_MESSAGE
|
||||
handler_name = "forward_messages_handler"
|
||||
handler_description = "把接收到的消息转发到指定聊天ID"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.counter = 0 # 用于计数转发的消息数量
|
||||
self.messages: List[str] = []
|
||||
|
||||
async def execute(self, message: MaiMessages | None) -> Tuple[bool, bool, None, None, None]:
|
||||
if not message:
|
||||
return True, True, None, None, None
|
||||
stream_id = message.stream_id or ""
|
||||
|
||||
if message.plain_text:
|
||||
self.messages.append(message.plain_text)
|
||||
self.counter += 1
|
||||
if self.counter % 10 == 0:
|
||||
if random.random() < 0.01:
|
||||
success = await self.send_hybrid(stream_id, [(ReplyContentType.TEXT, msg) for msg in self.messages])
|
||||
else:
|
||||
success = await self.send_forward(
|
||||
stream_id,
|
||||
[
|
||||
(
|
||||
str(global_config.bot.qq_account),
|
||||
str(global_config.bot.nickname),
|
||||
[(ReplyContentType.TEXT, msg)],
|
||||
)
|
||||
for msg in self.messages
|
||||
],
|
||||
)
|
||||
if not success:
|
||||
raise ValueError("转发消息失败")
|
||||
self.messages = []
|
||||
return True, True, None, None, None
|
||||
|
||||
|
||||
class RandomEmojis(BaseCommand):
|
||||
command_name = "random_emojis"
|
||||
command_description = "发送多张随机表情包"
|
||||
command_pattern = r"^/random_emojis$"
|
||||
|
||||
async def execute(self):
|
||||
emojis = await emoji_api.get_random(5)
|
||||
if not emojis:
|
||||
return False, "未找到表情包", False
|
||||
emoji_base64_list = []
|
||||
for emoji in emojis:
|
||||
emoji_base64_list.append(emoji[0])
|
||||
return await self.forward_images(emoji_base64_list)
|
||||
|
||||
async def forward_images(self, images: List[str]):
|
||||
"""
|
||||
把多张图片用合并转发的方式发给用户
|
||||
"""
|
||||
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
||||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
@@ -153,7 +226,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "hello_world_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
enable_plugin: bool = False
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
@@ -164,8 +237,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
},
|
||||
"greeting": {
|
||||
@@ -185,6 +257,8 @@ class HelloWorldPlugin(BasePlugin):
|
||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||
(TimeCommand.get_command_info(), TimeCommand),
|
||||
(PrintMessage.get_handler_info(), PrintMessage),
|
||||
(ForwardMessages.get_handler_info(), ForwardMessages),
|
||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,56 +1,37 @@
|
||||
[project]
|
||||
name = "MaiBot"
|
||||
version = "0.8.1"
|
||||
version = "0.11.0"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.14",
|
||||
"apscheduler>=3.11.0",
|
||||
"aiohttp-cors>=0.8.1",
|
||||
"colorama>=0.4.6",
|
||||
"cryptography>=45.0.5",
|
||||
"customtkinter>=5.2.2",
|
||||
"dotenv>=0.9.9",
|
||||
"faiss-cpu>=1.11.0",
|
||||
"fastapi>=0.116.0",
|
||||
"google-genai>=1.39.1",
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"jsonlines>=4.0.0",
|
||||
"maim-message>=0.3.8",
|
||||
"maim-message",
|
||||
"matplotlib>=3.10.3",
|
||||
"networkx>=3.4.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"psutil>=7.0.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pymongo>=4.13.2",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-igraph>=0.11.9",
|
||||
"quick-algo>=0.1.3",
|
||||
"reportportal-client>=5.6.5",
|
||||
"requests>=2.32.4",
|
||||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
"scikit-learn>=1.7.0",
|
||||
"scipy>=1.15.3",
|
||||
"seaborn>=0.13.2",
|
||||
"setuptools>=80.9.0",
|
||||
"strawberry-graphql[fastapi]>=0.275.5",
|
||||
"structlog>=25.4.0",
|
||||
"toml>=0.10.2",
|
||||
"tomli>=2.2.1",
|
||||
"tomli-w>=1.2.0",
|
||||
"tomlkit>=0.13.3",
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.txt -o requirements.lock
|
||||
aenum==3.1.16
|
||||
# via reportportal-client
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.14
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# reportportal-client
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.9.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
apscheduler==3.11.0
|
||||
# via -r requirements.txt
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonlines
|
||||
certifi==2025.7.9
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# reportportal-client
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
# via uvicorn
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# click
|
||||
# tqdm
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
cryptography==45.0.5
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
customtkinter==5.2.2
|
||||
# via -r requirements.txt
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
darkdetect==0.8.0
|
||||
# via customtkinter
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
dnspython==2.7.0
|
||||
# via pymongo
|
||||
dotenv==0.9.9
|
||||
# via -r requirements.txt
|
||||
faiss-cpu==1.11.0
|
||||
# via -r requirements.txt
|
||||
fastapi==0.116.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# strawberry-graphql
|
||||
fonttools==4.58.5
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
graphql-core==3.2.6
|
||||
# via strawberry-graphql
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via openai
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
igraph==0.11.9
|
||||
# via python-igraph
|
||||
jieba==0.42.1
|
||||
# via -r requirements.txt
|
||||
jiter==0.10.0
|
||||
# via openai
|
||||
joblib==1.5.1
|
||||
# via scikit-learn
|
||||
json-repair==0.47.6
|
||||
# via -r requirements.txt
|
||||
jsonlines==4.0.0
|
||||
# via -r requirements.txt
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
maim-message==0.3.8
|
||||
# via -r requirements.txt
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.10.3
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
networkx==3.5
|
||||
# via -r requirements.txt
|
||||
numpy==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# contourpy
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# pandas
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# seaborn
|
||||
openai==1.95.0
|
||||
# via -r requirements.txt
|
||||
packaging==25.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# customtkinter
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# strawberry-graphql
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
peewee==3.18.2
|
||||
# via -r requirements.txt
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==7.0.0
|
||||
# via -r requirements.txt
|
||||
pyarrow==20.0.0
|
||||
# via -r requirements.txt
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# fastapi
|
||||
# maim-message
|
||||
# openai
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.19.2
|
||||
# via rich
|
||||
pymongo==4.13.2
|
||||
# via -r requirements.txt
|
||||
pyparsing==3.2.3
|
||||
# via matplotlib
|
||||
pypinyin==0.54.0
|
||||
# via -r requirements.txt
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
# pandas
|
||||
# strawberry-graphql
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# dotenv
|
||||
python-igraph==0.11.9
|
||||
# via -r requirements.txt
|
||||
python-multipart==0.0.20
|
||||
# via strawberry-graphql
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
quick-algo==0.1.3
|
||||
# via -r requirements.txt
|
||||
reportportal-client==5.6.5
|
||||
# via -r requirements.txt
|
||||
requests==2.32.4
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# reportportal-client
|
||||
rich==14.0.0
|
||||
# via -r requirements.txt
|
||||
ruff==0.12.2
|
||||
# via -r requirements.txt
|
||||
scikit-learn==1.7.0
|
||||
# via -r requirements.txt
|
||||
scipy==1.16.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# scikit-learn
|
||||
seaborn==0.13.2
|
||||
# via -r requirements.txt
|
||||
setuptools==80.9.0
|
||||
# via -r requirements.txt
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# openai
|
||||
starlette==0.46.2
|
||||
# via fastapi
|
||||
strawberry-graphql==0.275.5
|
||||
# via -r requirements.txt
|
||||
structlog==25.4.0
|
||||
# via -r requirements.txt
|
||||
texttable==1.7.0
|
||||
# via igraph
|
||||
threadpoolctl==3.6.0
|
||||
# via scikit-learn
|
||||
toml==0.10.2
|
||||
# via -r requirements.txt
|
||||
tomli==2.2.1
|
||||
# via -r requirements.txt
|
||||
tomli-w==1.2.0
|
||||
# via -r requirements.txt
|
||||
tomlkit==0.13.3
|
||||
# via -r requirements.txt
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# openai
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# fastapi
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# strawberry-graphql
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# pandas
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via apscheduler
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# requests
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
@@ -1,49 +1,28 @@
|
||||
APScheduler
|
||||
Pillow
|
||||
aiohttp
|
||||
aiohttp-cors
|
||||
colorama
|
||||
customtkinter
|
||||
dotenv
|
||||
faiss-cpu
|
||||
fastapi
|
||||
jieba
|
||||
jsonlines
|
||||
maim_message
|
||||
quick_algo
|
||||
matplotlib
|
||||
networkx
|
||||
numpy
|
||||
openai
|
||||
pandas
|
||||
peewee
|
||||
pyarrow
|
||||
pydantic
|
||||
pypinyin
|
||||
python-dateutil
|
||||
python-dotenv
|
||||
python-igraph
|
||||
pymongo
|
||||
requests
|
||||
ruff
|
||||
scipy
|
||||
setuptools
|
||||
toml
|
||||
tomli
|
||||
tomli_w
|
||||
tomlkit
|
||||
tqdm
|
||||
urllib3
|
||||
uvicorn
|
||||
websockets
|
||||
strawberry-graphql[fastapi]
|
||||
packaging
|
||||
rich
|
||||
psutil
|
||||
cryptography
|
||||
json-repair
|
||||
reportportal-client
|
||||
scikit-learn
|
||||
seaborn
|
||||
structlog
|
||||
google.genai
|
||||
aiohttp>=3.12.14
|
||||
aiohttp-cors>=0.8.1
|
||||
colorama>=0.4.6
|
||||
faiss-cpu>=1.11.0
|
||||
fastapi>=0.116.0
|
||||
google-genai>=1.39.1
|
||||
jieba>=0.42.1
|
||||
json-repair>=0.47.6
|
||||
maim-message
|
||||
matplotlib>=3.10.3
|
||||
numpy>=2.2.6
|
||||
openai>=1.95.0
|
||||
pandas>=2.3.1
|
||||
peewee>=3.18.2
|
||||
pillow>=11.3.0
|
||||
pyarrow>=20.0.0
|
||||
pydantic>=2.11.7
|
||||
pypinyin>=0.54.0
|
||||
python-dotenv>=1.1.1
|
||||
quick-algo>=0.1.3
|
||||
rich>=14.0.0
|
||||
ruff>=0.12.2
|
||||
setuptools>=80.9.0
|
||||
structlog>=25.4.0
|
||||
toml>=0.10.2
|
||||
tomlkit>=0.13.3
|
||||
urllib3>=2.5.0
|
||||
uvicorn>=0.35.0
|
||||
389
scripts/build_io_pairs.py
Normal file
389
scripts/build_io_pairs.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
def clean_output_text(text: str) -> str:
|
||||
"""
|
||||
清理输出文本,移除表情包和回复内容
|
||||
- 移除 [表情包:...] 格式的内容
|
||||
- 移除 [回复...] 格式的内容
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 移除表情包内容:[表情包:...]
|
||||
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
|
||||
|
||||
# 移除回复内容:[回复...],说:... 的完整模式
|
||||
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
|
||||
|
||||
# 清理多余的空格和换行
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err: Optional[Exception] = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def fetch_messages_between(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||||
if platform:
|
||||
filter_query["chat_info_platform"] = platform
|
||||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||||
|
||||
|
||||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
|
||||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||||
"""
|
||||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||||
"""
|
||||
if not bucket:
|
||||
raise ValueError("bucket 为空,无法合并")
|
||||
|
||||
latest = bucket[-1]
|
||||
merged_texts: List[str] = []
|
||||
for m in bucket:
|
||||
text = m.processed_plain_text or ""
|
||||
if text:
|
||||
merged_texts.append(text)
|
||||
|
||||
merged = DatabaseMessages(
|
||||
# 其他信息采用最新消息
|
||||
message_id=latest.message_id,
|
||||
time=latest.time,
|
||||
chat_id=latest.chat_id,
|
||||
reply_to=latest.reply_to,
|
||||
interest_value=latest.interest_value,
|
||||
key_words=latest.key_words,
|
||||
key_words_lite=latest.key_words_lite,
|
||||
is_mentioned=latest.is_mentioned,
|
||||
is_at=latest.is_at,
|
||||
reply_probability_boost=latest.reply_probability_boost,
|
||||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||||
display_message=latest.display_message,
|
||||
priority_mode=latest.priority_mode,
|
||||
priority_info=latest.priority_info,
|
||||
additional_config=latest.additional_config,
|
||||
is_emoji=latest.is_emoji,
|
||||
is_picid=latest.is_picid,
|
||||
is_command=latest.is_command,
|
||||
is_notify=latest.is_notify,
|
||||
selected_expressions=latest.selected_expressions,
|
||||
user_id=latest.user_info.user_id,
|
||||
user_nickname=latest.user_info.user_nickname,
|
||||
user_cardname=latest.user_info.user_cardname,
|
||||
user_platform=latest.user_info.platform,
|
||||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||||
chat_info_stream_id=latest.chat_info.stream_id,
|
||||
chat_info_platform=latest.chat_info.platform,
|
||||
chat_info_create_time=latest.chat_info.create_time,
|
||||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: List[DatabaseMessages] = []
|
||||
bucket: List[DatabaseMessages] = []
|
||||
|
||||
def flush_bucket() -> None:
|
||||
nonlocal bucket
|
||||
if bucket:
|
||||
merged.append(_merge_bucket_to_message(bucket))
|
||||
bucket = []
|
||||
|
||||
for msg in messages:
|
||||
if not bucket:
|
||||
bucket = [msg]
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
else:
|
||||
flush_bucket()
|
||||
bucket = [msg]
|
||||
|
||||
flush_bucket()
|
||||
return merged
|
||||
|
||||
|
||||
def build_pairs_for_chat(
|
||||
original_messages: List[DatabaseMessages],
|
||||
merged_messages: List[DatabaseMessages],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
target_user_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
对每条合并后的消息作为 output,从其前面取 20-30 条(可配置)的原始消息作为 input。
|
||||
input 使用原始未合并的消息构建上下文。
|
||||
output 使用合并后消息的 processed_plain_text。
|
||||
如果指定了 target_user_id,则只处理该用户的消息作为 output。
|
||||
"""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n_merged = len(merged_messages)
|
||||
n_original = len(original_messages)
|
||||
|
||||
if n_merged == 0 or n_original == 0:
|
||||
return pairs
|
||||
|
||||
# 为每个合并后的消息找到对应的原始消息位置
|
||||
merged_to_original_map = {}
|
||||
original_idx = 0
|
||||
|
||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||
# 找到这个合并消息对应的第一个原始消息
|
||||
while (original_idx < n_original and
|
||||
original_messages[original_idx].time < merged_msg.time):
|
||||
original_idx += 1
|
||||
|
||||
# 如果找到了时间匹配的原始消息,建立映射
|
||||
if (original_idx < n_original and
|
||||
original_messages[original_idx].time == merged_msg.time):
|
||||
merged_to_original_map[merged_idx] = original_idx
|
||||
|
||||
for merged_idx in range(n_merged):
|
||||
merged_msg = merged_messages[merged_idx]
|
||||
|
||||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||||
continue
|
||||
|
||||
# 找到对应的原始消息位置
|
||||
if merged_idx not in merged_to_original_map:
|
||||
continue
|
||||
|
||||
original_idx = merged_to_original_map[merged_idx]
|
||||
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, original_idx - window)
|
||||
context_msgs = original_messages[start:original_idx]
|
||||
|
||||
# 使用原始未合并消息构建 input
|
||||
input_str = build_readable_messages(
|
||||
messages=context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
)
|
||||
|
||||
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
|
||||
output_text = merged_msg.processed_plain_text or ""
|
||||
output_text = clean_output_text(output_text)
|
||||
output_id = merged_msg.message_id or ""
|
||||
pairs.append((input_str, output_text, output_id))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def build_pairs(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str],
|
||||
user_id: Optional[str],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 获取所有消息(不按user_id过滤),这样input上下文可以包含所有用户的消息
|
||||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
# 对消息进行合并,用于output
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
|
||||
all_pairs.extend(pairs)
|
||||
|
||||
return all_pairs
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
# 若未提供参数,则进入交互模式
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
if len(argv) == 0:
|
||||
return run_interactive()
|
||||
|
||||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表,支持按用户ID筛选消息")
|
||||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||||
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
|
||||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
start_ts = parse_datetime_to_timestamp(args.start)
|
||||
end_ts = parse_datetime_to_timestamp(args.end)
|
||||
if end_ts <= start_ts:
|
||||
raise ValueError("结束时间必须大于起始时间")
|
||||
|
||||
if args.max_ctx < args.min_ctx:
|
||||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||||
|
||||
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
|
||||
|
||||
if args.output:
|
||||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||||
else:
|
||||
# 打印到 stdout
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||||
if value == "" and default is not None:
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def run_interactive() -> int:
|
||||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||||
start_str = _prompt_with_default("请输入起始时间", None)
|
||||
end_str = _prompt_with_default("请输入结束时间", None)
|
||||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||||
user_id = _prompt_with_default("用户ID(可留空表示不限)", "")
|
||||
try:
|
||||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||||
except Exception:
|
||||
print("上下文条数输入有误,使用默认 20/30")
|
||||
min_ctx, max_ctx = 20, 30
|
||||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||||
|
||||
if not start_str or not end_str:
|
||||
print("必须提供起始与结束时间。")
|
||||
return 2
|
||||
|
||||
try:
|
||||
start_ts = parse_datetime_to_timestamp(start_str)
|
||||
end_ts = parse_datetime_to_timestamp(end_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"时间解析失败:{e}")
|
||||
return 2
|
||||
|
||||
if end_ts <= start_ts:
|
||||
print("结束时间必须大于起始时间。")
|
||||
return 2
|
||||
|
||||
if max_ctx < min_ctx:
|
||||
print("最多条数不能小于最少条数。")
|
||||
return 2
|
||||
|
||||
platform_val = platform if platform != "" else None
|
||||
user_id_val = user_id if user_id != "" else None
|
||||
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||||
else:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
print(f"总计 {len(pairs)} 条。")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
334
scripts/expression_scatter_analysis.py
Normal file
334
scripts/expression_scatter_analysis.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||
"""获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表"""
|
||||
expressions = Expression.select()
|
||||
data = []
|
||||
|
||||
for expr in expressions:
|
||||
# 如果create_date为空,跳过该记录
|
||||
if expr.create_date is None:
|
||||
continue
|
||||
|
||||
data.append((
|
||||
expr.create_date,
|
||||
expr.count,
|
||||
expr.chat_id,
|
||||
expr.type
|
||||
))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 分离数据
|
||||
create_dates = [item[0] for item in data]
|
||||
counts = [item[1] for item in data]
|
||||
chat_ids = [item[2] for item in data]
|
||||
expression_types = [item[3] for item in data]
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
time_span = max(dates) - min(dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 创建散点图
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(scatter)
|
||||
cbar.set_label('数据点顺序', fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 数据统计 ===")
|
||||
print(f"总数据点数量: {len(data)}")
|
||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||
print(f"平均使用次数: {np.mean(counts):.2f}")
|
||||
print(f"中位数使用次数: {np.median(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按聊天分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按chat_id分组
|
||||
chat_groups = {}
|
||||
for item in data:
|
||||
chat_id = item[2]
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
|
||||
# 为每个聊天分配不同颜色
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
|
||||
|
||||
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
|
||||
create_dates = [item[0] for item in chat_data]
|
||||
counts = [item[1] for item in chat_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
chat_name = get_chat_name(chat_id)
|
||||
# 截断过长的聊天名称
|
||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 分组统计 ===")
|
||||
print(f"总聊天数量: {len(chat_groups)}")
|
||||
for chat_id, chat_data in chat_groups.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
counts = [item[1] for item in chat_data]
|
||||
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n分组散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按表达式类型分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按type分组
|
||||
type_groups = {}
|
||||
for item in data:
|
||||
expr_type = item[3]
|
||||
if expr_type not in type_groups:
|
||||
type_groups[expr_type] = []
|
||||
type_groups[expr_type].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 为每个类型分配不同颜色
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
|
||||
|
||||
for i, (expr_type, type_data) in enumerate(type_groups.items()):
|
||||
create_dates = [item[0] for item in type_data]
|
||||
counts = [item[1] for item in type_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 类型统计 ===")
|
||||
for expr_type, type_data in type_groups.items():
|
||||
counts = [item[1] for item in type_data]
|
||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n类型散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始分析表达式数据...")
|
||||
|
||||
# 获取数据
|
||||
data = get_expression_data()
|
||||
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据(create_date不为空的数据)")
|
||||
return
|
||||
|
||||
print(f"找到 {len(data)} 条有效数据")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 生成时间戳用于文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 1. 创建基础散点图
|
||||
print("\n1. 创建基础散点图...")
|
||||
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
|
||||
|
||||
# 2. 创建按聊天分组的散点图
|
||||
print("\n2. 创建按聊天分组的散点图...")
|
||||
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
|
||||
|
||||
# 3. 创建按类型分组的散点图
|
||||
print("\n3. 创建按类型分组的散点图...")
|
||||
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,12 +5,11 @@ from typing import Dict, List
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
@@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
@@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
'0-1天': 0,
|
||||
'1-3天': 0,
|
||||
'3-7天': 0,
|
||||
'7-14天': 0,
|
||||
'14-30天': 0,
|
||||
'30-60天': 0,
|
||||
'60-90天': 0,
|
||||
'90+天': 0
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24*3600)
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution['0-1天'] += 1
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution['1-3天'] += 1
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution['3-7天'] += 1
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution['7-14天'] += 1
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution['14-30天'] += 1
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution['30-60天'] += 1
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution['60-90天'] += 1
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution['90+天'] += 1
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {
|
||||
'0-1': 0,
|
||||
'1-2': 0,
|
||||
'2-3': 0,
|
||||
'3-4': 0,
|
||||
'4-5': 0,
|
||||
'5-10': 0,
|
||||
'10+': 0
|
||||
}
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution['0-1'] += 1
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution['1-2'] += 1
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution['2-3'] += 1
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution['3-4'] += 1
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution['4-5'] += 1
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution['5-10'] += 1
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution['10+'] += 1
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return (Expression.select()
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.order_by(Expression.count.desc())
|
||||
.limit(top_n))
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count/total*100:.2f}%)")
|
||||
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count/total*100:.2f}%)")
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
@@ -163,32 +151,32 @@ def interactive_menu() -> None:
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
@@ -200,9 +188,9 @@ def interactive_menu() -> None:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
interactive_menu()
|
||||
|
||||
@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
if not os.path.exists(OPENIE_DIR):
|
||||
@@ -253,7 +254,7 @@ def main():
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
|
||||
@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
@@ -48,6 +50,7 @@ def ensure_dirs():
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract,
|
||||
request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_rdf_build,
|
||||
request_type="lpmm.rdf_build"
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
|
||||
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of interest_value"""
|
||||
distribution = {
|
||||
'0.000-0.010': 0,
|
||||
'0.010-0.050': 0,
|
||||
'0.050-0.100': 0,
|
||||
'0.100-0.500': 0,
|
||||
'0.500-1.000': 0,
|
||||
'1.000-2.000': 0,
|
||||
'2.000-5.000': 0,
|
||||
'5.000-10.000': 0,
|
||||
'10.000+': 0
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.interest_value is None or msg.interest_value == 0.0:
|
||||
continue
|
||||
|
||||
value = float(msg.interest_value)
|
||||
if value < 0.010:
|
||||
distribution['0.000-0.010'] += 1
|
||||
elif value < 0.050:
|
||||
distribution['0.010-0.050'] += 1
|
||||
elif value < 0.100:
|
||||
distribution['0.050-0.100'] += 1
|
||||
elif value < 0.500:
|
||||
distribution['0.100-0.500'] += 1
|
||||
elif value < 1.000:
|
||||
distribution['0.500-1.000'] += 1
|
||||
elif value < 2.000:
|
||||
distribution['1.000-2.000'] += 1
|
||||
elif value < 5.000:
|
||||
distribution['2.000-5.000'] += 1
|
||||
elif value < 10.000:
|
||||
distribution['5.000-10.000'] += 1
|
||||
else:
|
||||
distribution['10.000+'] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for interest_value"""
|
||||
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
|
||||
|
||||
if not values:
|
||||
return {
|
||||
'count': 0,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
}
|
||||
|
||||
values.sort()
|
||||
count = len(values)
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'avg': sum(values) / count,
|
||||
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
).count()
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
"""Analyze interest values with optional filters"""
|
||||
|
||||
# 构建查询条件
|
||||
query = Messages.select().where(
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
)
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_interest_value_distribution(messages)
|
||||
stats = get_interest_value_stats(messages)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Interest Value 分析结果 ===")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
||||
print(f"最小值: {stats['min']:.3f}")
|
||||
print(f"最大值: {stats['max']:.3f}")
|
||||
print(f"平均值: {stats['avg']:.3f}")
|
||||
print(f"中位数: {stats['median']:.3f}")
|
||||
|
||||
print("\nInterest Value 分布:")
|
||||
total = stats['count']
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name}: {count} ({percentage:.2f}%)")
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for interest value analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("Interest Value 分析工具")
|
||||
print("="*50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到有interest_value数据的聊天")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条有效消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_interest_values(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
@@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
return sha256_list, raw_data
|
||||
|
||||
@@ -1,394 +0,0 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
|
||||
|
||||
def contains_emoji_or_image_tags(text: str) -> bool:
|
||||
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否包含 [表情包] 或 [图片] 标记
|
||||
emoji_pattern = r'\[表情包[^\]]*\]'
|
||||
image_pattern = r'\[图片[^\]]*\]'
|
||||
|
||||
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
|
||||
|
||||
|
||||
def clean_reply_text(text: str) -> str:
|
||||
"""Remove reply references like [回复 xxxx...] from text"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 匹配 [回复 xxxx...] 格式的内容
|
||||
# 使用非贪婪匹配,匹配到第一个 ] 就停止
|
||||
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
|
||||
|
||||
# 去除多余的空白字符
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of processed_plain_text length"""
|
||||
distribution = {
|
||||
'0': 0, # 空文本
|
||||
'1-5': 0, # 极短文本
|
||||
'6-10': 0, # 很短文本
|
||||
'11-20': 0, # 短文本
|
||||
'21-30': 0, # 较短文本
|
||||
'31-50': 0, # 中短文本
|
||||
'51-70': 0, # 中等文本
|
||||
'71-100': 0, # 较长文本
|
||||
'101-150': 0, # 长文本
|
||||
'151-200': 0, # 很长文本
|
||||
'201-300': 0, # 超长文本
|
||||
'301-500': 0, # 极长文本
|
||||
'501-1000': 0, # 巨长文本
|
||||
'1000+': 0 # 超巨长文本
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
continue
|
||||
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
|
||||
if length == 0:
|
||||
distribution['0'] += 1
|
||||
elif length <= 5:
|
||||
distribution['1-5'] += 1
|
||||
elif length <= 10:
|
||||
distribution['6-10'] += 1
|
||||
elif length <= 20:
|
||||
distribution['11-20'] += 1
|
||||
elif length <= 30:
|
||||
distribution['21-30'] += 1
|
||||
elif length <= 50:
|
||||
distribution['31-50'] += 1
|
||||
elif length <= 70:
|
||||
distribution['51-70'] += 1
|
||||
elif length <= 100:
|
||||
distribution['71-100'] += 1
|
||||
elif length <= 150:
|
||||
distribution['101-150'] += 1
|
||||
elif length <= 200:
|
||||
distribution['151-200'] += 1
|
||||
elif length <= 300:
|
||||
distribution['201-300'] += 1
|
||||
elif length <= 500:
|
||||
distribution['301-500'] += 1
|
||||
elif length <= 1000:
|
||||
distribution['501-1000'] += 1
|
||||
else:
|
||||
distribution['1000+'] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for processed_plain_text length"""
|
||||
lengths = []
|
||||
null_count = 0
|
||||
excluded_count = 0 # 被排除的消息数量
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
null_count += 1
|
||||
elif contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
# 排除包含表情包或图片标记的消息
|
||||
excluded_count += 1
|
||||
else:
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
lengths.append(len(cleaned_text))
|
||||
|
||||
if not lengths:
|
||||
return {
|
||||
'count': 0,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
}
|
||||
|
||||
lengths.sort()
|
||||
count = len(lengths)
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': min(lengths),
|
||||
'max': max(lengths),
|
||||
'avg': sum(lengths) / count,
|
||||
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id,排除特殊类型消息
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
).count()
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
|
||||
"""Get top N longest messages"""
|
||||
message_lengths = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is not None:
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
chat_name = get_chat_name(msg.chat_id)
|
||||
time_str = format_timestamp(msg.time)
|
||||
# 截取前100个字符作为预览
|
||||
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
|
||||
message_lengths.append((chat_name, length, time_str, preview))
|
||||
|
||||
# 按长度排序,取前N个
|
||||
message_lengths.sort(key=lambda x: x[1], reverse=True)
|
||||
return message_lengths[:top_n]
|
||||
|
||||
|
||||
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
"""Analyze processed_plain_text lengths with optional filters"""
|
||||
|
||||
# 构建查询条件,排除特殊类型的消息
|
||||
query = Messages.select().where(
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
)
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_text_length_distribution(messages)
|
||||
stats = get_text_length_stats(messages)
|
||||
top_longest = get_top_longest_messages(messages, 10)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Processed Plain Text 长度分析结果 ===")
|
||||
print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"总消息数量: {len(messages)}")
|
||||
print(f"有文本消息数量: {stats['count']}")
|
||||
print(f"空文本消息数量: {stats['null_count']}")
|
||||
print(f"被排除的消息数量: {stats['excluded_count']}")
|
||||
if stats['count'] > 0:
|
||||
print(f"最短长度: {stats['min']} 字符")
|
||||
print(f"最长长度: {stats['max']} 字符")
|
||||
print(f"平均长度: {stats['avg']:.2f} 字符")
|
||||
print(f"中位数长度: {stats['median']:.2f} 字符")
|
||||
|
||||
print("\n文本长度分布:")
|
||||
total = stats['count']
|
||||
if total > 0:
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
|
||||
|
||||
# 显示最长的消息
|
||||
if top_longest:
|
||||
print(f"\n最长的 {len(top_longest)} 条消息:")
|
||||
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
|
||||
print(f"{i}. [{chat_name}] {time_str}")
|
||||
print(f" 长度: {length} 字符")
|
||||
print(f" 预览: {preview}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for text length analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("Processed Plain Text 长度分析工具")
|
||||
print("="*50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到聊天数据")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_text_lengths(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
570
src/chat/brain_chat/brain_chat.py
Normal file
570
src/chat/brain_chat/brain_chat.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
|
||||
logger = get_logger("bc") # Logger Name Changed
|
||||
|
||||
|
||||
class BrainChatting:
|
||||
"""
|
||||
管理一个连续的私聊Brain Chat循环
|
||||
用于在特定聊天流中生成回复。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
BrainChatting 初始化函数
|
||||
|
||||
参数:
|
||||
chat_id: 聊天流唯一标识符(如stream_id)
|
||||
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
|
||||
performance_version: 性能记录版本号,用于区分不同启动版本
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = BrainPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
# 循环控制内部状态
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} BrainChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 标记为活动状态,防止重复启动
|
||||
self.running = True
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
logger.info(f"{self.log_prefix} BrainChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} BrainChatting 启动失败: {e}")
|
||||
raise
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} BrainChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 结束了聊天")
|
||||
|
||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: "ReplySetModel",
|
||||
action_message: "DatabaseMessages",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.chat_info.platform
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["reply_text"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["reply_text"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional["DatabaseMessages"] = None,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
action_reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.run()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
return success, action_text, command
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: "ReplySetModel",
|
||||
message_data: "DatabaseMessages",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for reply_content in reply_set.reply_data:
|
||||
if reply_content.content_type != ReplyContentType.TEXT:
|
||||
continue
|
||||
data: str = reply_content.content # type: ignore
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
action_planner_info.reasoning or "",
|
||||
action_planner_info.action_data or {},
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
)
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
538
src/chat/brain_chat/brain_planner.py
Normal file
538
src/chat/brain_chat/brain_planner.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{time_block}
|
||||
{name_block}
|
||||
你的兴趣是:{interest}
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
**动作记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
**可用的action**
|
||||
reply
|
||||
动作描述:
|
||||
进行回复,你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"想要回复的消息id",
|
||||
"reason":"回复的原因"
|
||||
}}
|
||||
|
||||
no_reply
|
||||
动作描述:
|
||||
等待,保持沉默,等待对方发言
|
||||
{{
|
||||
"action": "no_reply",
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
请选择合适的action,并说明触发action的消息id和选择该action的原因。消息id格式:m+数字
|
||||
先输出你的选择思考理由,再输出你选择的action,理由是一段平文本,不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
{plan_style}
|
||||
{moderation_prompt}
|
||||
|
||||
请选择所有符合使用要求的action,动作用json格式输出,如果输出多个json,每个json都要单独用```json包裹,你可以重复使用同一个动作或不同动作:
|
||||
**示例**
|
||||
// 理由文本
|
||||
```json
|
||||
{{
|
||||
"action":"动作名",
|
||||
"target_message_id":"触发动作的消息id",
|
||||
//对应参数
|
||||
}}
|
||||
```
|
||||
```json
|
||||
{{
|
||||
"action":"动作名",
|
||||
"target_message_id":"触发动作的消息id",
|
||||
//对应参数
|
||||
}}
|
||||
```
|
||||
|
||||
""",
|
||||
"brain_planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{action_name}
|
||||
动作描述:{action_description}
|
||||
使用条件:
|
||||
{action_require}
|
||||
{{
|
||||
"action": "{action_name}",{action_parameters},
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"触发action的原因"
|
||||
}}
|
||||
""",
|
||||
"brain_action_prompt",
|
||||
)
|
||||
|
||||
|
||||
class BrainPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional["DatabaseMessages"]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item[0] == message_id:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
|
||||
# 选择最新消息作为target_message
|
||||
target_message = message_id_list[-1][1]
|
||||
else:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "no_reply"
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
actions = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
# 获取最近执行过的动作
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 600,
|
||||
timestamp_end=time.time(),
|
||||
limit=6,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
if actions_before_now_block:
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
else:
|
||||
actions_before_now_block = ""
|
||||
|
||||
if chat_target_info:
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = (
|
||||
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
|
||||
)
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
||||
# 其他信息
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也可以叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("brain_planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
interest=interest,
|
||||
plan_style=global_config.personality.private_plan_style,
|
||||
)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
|
||||
"""
|
||||
获取 Planner 需要的必要信息
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
return is_group_chat, chat_target_info, current_available_actions
|
||||
|
||||
def _filter_actions_by_activation_type(
|
||||
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
|
||||
) -> Dict[str, ActionInfo]:
|
||||
"""根据激活类型过滤动作"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in available_actions.items():
|
||||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if action_info.activation_keywords:
|
||||
for keyword in action_info.activation_keywords:
|
||||
if keyword in chat_content_block:
|
||||
filtered_actions[action_name] = action_info
|
||||
break
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
|
||||
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
# sourcery skip: use-join
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数文本
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in action_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = ""
|
||||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("brain_action_prompt")
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=action_name,
|
||||
action_description=action_info.description,
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
)
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
||||
return action_options_block
|
||||
|
||||
async def _execute_main_planner(
|
||||
self,
|
||||
prompt: str,
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
try:
|
||||
if json_objects := self._extract_json_from_markdown(llm_content):
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> List[dict]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象"""
|
||||
json_objects = []
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
|
||||
"""根据哈希值获取已注册表情包的情感标签列表
|
||||
|
||||
@@ -731,7 +731,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(',')
|
||||
return emoji_record.emotion.split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
|
||||
@@ -3,21 +3,25 @@ import random
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
build_anonymous_messages,
|
||||
build_bare_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
DECAY_DAYS = 15 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
@@ -46,48 +50,63 @@ def init_prompt() -> None:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不想讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
当"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
|
||||
当"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
现在请你概括
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
match_expression_context_prompt = """
|
||||
**聊天内容**
|
||||
{chat_str}
|
||||
|
||||
**从聊天内容总结的表达方式pairs**
|
||||
{expression_pairs}
|
||||
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。
|
||||
如果找不到原句,就不输出该句的匹配结果。
|
||||
以json格式输出:
|
||||
格式如下:
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
...
|
||||
|
||||
现在请你输出匹配结果:
|
||||
"""
|
||||
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.embedding_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习参数
|
||||
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||
|
||||
def can_learn_for_chat(self) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许学习表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许学习
|
||||
"""
|
||||
try:
|
||||
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
return enable_learning
|
||||
except Exception as e:
|
||||
logger.error(f"检查学习权限失败: {e}")
|
||||
return False
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 150 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
@@ -99,27 +118,13 @@ class ExpressionLearner:
|
||||
Returns:
|
||||
bool: 是否应该触发学习
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 获取该聊天流的学习强度
|
||||
try:
|
||||
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
||||
return False
|
||||
|
||||
# 检查是否允许学习
|
||||
if not enable_learning:
|
||||
if not self.enable_learning:
|
||||
return False
|
||||
|
||||
# 根据学习强度计算最短学习时间间隔
|
||||
min_interval = self.min_learning_interval / learning_intensity
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_learning_time
|
||||
if time_diff < min_interval:
|
||||
time_diff = time.time() - self.last_learning_time
|
||||
if time_diff < self.min_learning_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
@@ -165,6 +170,7 @@ class ExpressionLearner:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
@@ -229,39 +235,43 @@ class ExpressionLearner:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
# 检查是否允许在此聊天流中学习(在函数最前面检查)
|
||||
if not self.can_learn_for_chat():
|
||||
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
|
||||
return []
|
||||
|
||||
res = await self.learn_expression(num)
|
||||
|
||||
if res is None:
|
||||
return []
|
||||
learnt_expressions, chat_id = res
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
else:
|
||||
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
|
||||
learnt_expressions_str = ""
|
||||
for _chat_id, situation, style in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {group_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
if not learnt_expressions:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
learnt_expressions = res
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
_chat_id,
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_context_words,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
for (
|
||||
chat_id,
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
context_words,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style})
|
||||
chat_dict[chat_id].append(
|
||||
{
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"context_words": context_words,
|
||||
}
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
@@ -281,6 +291,8 @@ class ExpressionLearner:
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.context_words = new_expr["context_words"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
@@ -293,6 +305,8 @@ class ExpressionLearner:
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
context_words=new_expr["context_words"],
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
@@ -306,7 +320,105 @@ class ExpressionLearner:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
async def match_expression_context(
|
||||
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 为expression_pairs逐个条目赋予编号,并构建成字符串
|
||||
numbered_pairs = []
|
||||
for i, (situation, style) in enumerate(expression_pairs, 1):
|
||||
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
|
||||
|
||||
expression_pairs_str = "\n".join(numbered_pairs)
|
||||
|
||||
prompt = "match_expression_context_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
expression_pairs=expression_pairs_str,
|
||||
chat_str=random_msg_match_str,
|
||||
)
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
print(f"match_expression_context_prompt: {prompt}")
|
||||
print(f"random_msg_match_str: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith("[") and response.endswith("]"):
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 尝试直接解析多个JSON对象
|
||||
try:
|
||||
# 如果是多个JSON对象用逗号分隔,包装成数组
|
||||
if response.startswith("{") and not response.startswith("["):
|
||||
response = "[" + response + "]"
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 使用repair_json处理响应
|
||||
repaired_content = repair_json(response)
|
||||
|
||||
# 确保repaired_content是列表格式
|
||||
if isinstance(repaired_content, str):
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [parsed_data]
|
||||
elif isinstance(parsed_data, list):
|
||||
match_responses = parsed_data
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
match_responses = []
|
||||
elif isinstance(repaired_content, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [repaired_content]
|
||||
elif isinstance(repaired_content, list):
|
||||
match_responses = repaired_content
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
# 如果还是失败,尝试repair_json
|
||||
repaired_content = repair_json(response)
|
||||
if isinstance(repaired_content, str):
|
||||
parsed_data = json.loads(repaired_content)
|
||||
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
|
||||
else:
|
||||
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 获取表达方式序号
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response["context"]
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -317,7 +429,7 @@ class ExpressionLearner:
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
# 获取上次学习之后的消息
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
@@ -328,17 +440,18 @@ class ExpressionLearner:
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
chat_id: str = random_msg[0].chat_id
|
||||
_chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
logger.debug(f"学习{type_str}的prompt: {prompt}")
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
@@ -346,13 +459,48 @@ class ExpressionLearner:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
return None
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
return expressions, chat_id
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
|
||||
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
|
||||
matched_expressions
|
||||
)
|
||||
|
||||
split_matched_expressions_w_emb = []
|
||||
|
||||
for situation, style, context, context_words in split_matched_expressions:
|
||||
split_matched_expressions_w_emb.append(
|
||||
(self.chat_id, situation, style, context, context_words)
|
||||
)
|
||||
|
||||
return split_matched_expressions_w_emb
|
||||
|
||||
def split_expression_context(
|
||||
self, matched_expressions: List[Tuple[str, str, str]]
|
||||
) -> List[Tuple[str, str, str, List[str]]]:
|
||||
"""
|
||||
对matched_expressions中的context部分进行jieba分词
|
||||
|
||||
Args:
|
||||
matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context)
|
||||
|
||||
Returns:
|
||||
添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words)
|
||||
"""
|
||||
result = []
|
||||
for situation, style, context in matched_expressions:
|
||||
# 使用jieba进行分词
|
||||
context_words = list(jieba.cut(context))
|
||||
result.append((situation, style, context, context_words))
|
||||
|
||||
return result
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
@@ -379,7 +527,7 @@ class ExpressionLearner:
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((chat_id, situation, style))
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
|
||||
@@ -391,8 +539,6 @@ class ExpressionLearnerManager:
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
self._auto_migrate_json_to_db()
|
||||
self._migrate_old_data_create_date()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
@@ -417,189 +563,5 @@ class ExpressionLearnerManager:
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
然后检查done.done2,如果没有就删除所有grammar表达并创建该标记文件。
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
done_flag = os.path.join(base_dir, "done.done")
|
||||
done_flag2 = os.path.join(base_dir, "done.done2")
|
||||
|
||||
# 确保基础目录存在
|
||||
try:
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {base_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式目录失败: {e}")
|
||||
return
|
||||
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
else:
|
||||
logger.info("开始迁移表达方式JSON到数据库...")
|
||||
migrated_count = 0
|
||||
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
logger.debug(f"目录不存在,跳过: {type_dir}")
|
||||
continue
|
||||
|
||||
try:
|
||||
chat_ids = os.listdir(type_dir)
|
||||
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录失败 {type_dir}: {e}")
|
||||
continue
|
||||
|
||||
for chat_id in chat_ids:
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||
continue
|
||||
|
||||
for expr in expressions:
|
||||
if not isinstance(expr, dict):
|
||||
continue
|
||||
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
|
||||
if not situation or not style_val:
|
||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 标记迁移完成
|
||||
try:
|
||||
# 确保done.done文件的父目录存在
|
||||
done_parent_dir = os.path.dirname(done_flag)
|
||||
if not os.path.exists(done_parent_dir):
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
except OSError as e:
|
||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
# 检查并处理grammar表达删除
|
||||
if not os.path.exists(done_flag2):
|
||||
logger.info("开始删除所有grammar类型的表达...")
|
||||
try:
|
||||
deleted_count = self.delete_all_grammar_expressions()
|
||||
logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达")
|
||||
|
||||
# 创建done.done2标记文件
|
||||
with open(done_flag2, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("已创建done.done2标记文件,grammar表达删除标记完成")
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
|
||||
else:
|
||||
logger.info("grammar表达已删除,跳过重复删除")
|
||||
|
||||
def _migrate_old_data_create_date(self):
|
||||
"""
|
||||
为没有create_date的老数据设置创建日期
|
||||
使用last_active_time作为create_date的默认值
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
def delete_all_grammar_expressions(self) -> int:
|
||||
"""
|
||||
检查expression库中所有type为"grammar"的表达并全部删除
|
||||
|
||||
Returns:
|
||||
int: 删除的grammar表达数量
|
||||
"""
|
||||
try:
|
||||
# 查询所有type为"grammar"的表达
|
||||
grammar_expressions = Expression.select().where(Expression.type == "grammar")
|
||||
grammar_count = grammar_expressions.count()
|
||||
|
||||
if grammar_count == 0:
|
||||
logger.info("expression库中没有找到grammar类型的表达")
|
||||
return 0
|
||||
|
||||
logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...")
|
||||
|
||||
# 删除所有grammar类型的表达
|
||||
deleted_count = 0
|
||||
for expr in grammar_expressions:
|
||||
try:
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达过程中发生错误: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
|
||||
@@ -77,10 +77,10 @@ class ExpressionSelector:
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
@@ -114,6 +114,20 @@ class ExpressionSelector:
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
@@ -123,9 +137,7 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -200,15 +212,15 @@ class ExpressionSelector:
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||
|
||||
style_exprs = self.get_random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
@@ -248,7 +260,6 @@ class ExpressionSelector:
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
@@ -295,7 +306,6 @@ class ExpressionSelector:
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
from typing import Optional
|
||||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||
|
||||
|
||||
def get_config_base_focus_value(chat_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 focus_value
|
||||
"""
|
||||
if not global_config.chat.focus_value_adjust:
|
||||
return global_config.chat.focus_value
|
||||
|
||||
if chat_id:
|
||||
stream_focus_value = get_stream_specific_focus_value(chat_id)
|
||||
if stream_focus_value is not None:
|
||||
return stream_focus_value
|
||||
|
||||
global_focus_value = get_global_focus_value()
|
||||
if global_focus_value is not None:
|
||||
return global_focus_value
|
||||
|
||||
return global_config.chat.focus_value
|
||||
|
||||
|
||||
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
||||
"""
|
||||
获取特定聊天流在当前时间的专注度
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
# 查找匹配的聊天流配置
|
||||
for config_item in global_config.chat.focus_value_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||
|
||||
# 解析配置字符串并生成对应的 chat_id
|
||||
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
|
||||
if config_chat_id is None:
|
||||
continue
|
||||
|
||||
# 比较生成的 chat_id
|
||||
if config_chat_id != chat_id:
|
||||
continue
|
||||
|
||||
# 使用通用的时间专注度解析方法
|
||||
return get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_time_based_focus_value(time_focus_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的专注度
|
||||
|
||||
Args:
|
||||
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
current_hour, current_minute = map(int, current_time.split(":"))
|
||||
current_minutes = current_hour * 60 + current_minute
|
||||
|
||||
# 解析时间专注度配置
|
||||
time_focus_pairs = []
|
||||
for time_focus_str in time_focus_list:
|
||||
try:
|
||||
time_str, focus_str = time_focus_str.split(",")
|
||||
hour, minute = map(int, time_str.split(":"))
|
||||
focus_value = float(focus_str)
|
||||
minutes = hour * 60 + minute
|
||||
time_focus_pairs.append((minutes, focus_value))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if not time_focus_pairs:
|
||||
return None
|
||||
|
||||
# 按时间排序
|
||||
time_focus_pairs.sort(key=lambda x: x[0])
|
||||
|
||||
# 查找当前时间对应的专注度
|
||||
current_focus_value = None
|
||||
for minutes, focus_value in time_focus_pairs:
|
||||
if current_minutes >= minutes:
|
||||
current_focus_value = focus_value
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
|
||||
if current_focus_value is None and time_focus_pairs:
|
||||
current_focus_value = time_focus_pairs[-1][1]
|
||||
|
||||
return current_focus_value
|
||||
|
||||
|
||||
def get_global_focus_value() -> Optional[float]:
|
||||
"""
|
||||
获取全局默认专注度配置
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in global_config.chat.focus_value_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
return get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,501 +1,46 @@
|
||||
import time
|
||||
from typing import Optional, Dict, List
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
|
||||
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value
|
||||
|
||||
logger = get_logger("frequency_control")
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class FrequencyControl:
|
||||
"""
|
||||
频率控制类,可以根据最近时间段的发言数量和发言人数动态调整频率
|
||||
|
||||
特点:
|
||||
- 发言频率调整:基于最近10分钟的数据,评估单位为"消息数/10分钟"
|
||||
- 专注度调整:基于最近10分钟的数据,评估单位为"消息数/10分钟"
|
||||
- 历史基准值:基于最近一周的数据,按小时统计,每小时都有独立的基准值(需要至少50条历史消息)
|
||||
- 统一标准:两个调整都使用10分钟窗口,确保逻辑一致性和响应速度
|
||||
- 双向调整:根据活跃度高低,既能提高也能降低频率和专注度
|
||||
- 数据充足性检查:当历史数据不足50条时,不更新基准值;当基准值为默认值时,不进行动态调整
|
||||
- 基准值更新:直接使用新计算的周均值,无平滑更新
|
||||
"""
|
||||
|
||||
"""简化的频率控制类,仅管理不同chat_id的频率值"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {chat_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
# 发言频率调整值
|
||||
self.talk_frequency_adjust: float = 1.0
|
||||
self.talk_frequency_external_adjust: float = 1.0
|
||||
# 专注度调整值
|
||||
self.focus_value_adjust: float = 1.0
|
||||
self.focus_value_external_adjust: float = 1.0
|
||||
|
||||
# 动态调整相关参数
|
||||
self.last_update_time = time.time()
|
||||
self.update_interval = 60 # 每60秒更新一次
|
||||
|
||||
# 历史数据缓存
|
||||
self._message_count_cache = 0
|
||||
self._user_count_cache = 0
|
||||
self._last_cache_time = 0
|
||||
self._cache_duration = 30 # 缓存30秒
|
||||
|
||||
# 调整参数
|
||||
self.min_adjust = 0.3 # 最小调整值
|
||||
self.max_adjust = 2.0 # 最大调整值
|
||||
|
||||
# 动态基准值(将根据历史数据计算)
|
||||
self.base_message_count = 5 # 默认基准消息数量,将被动态更新
|
||||
self.base_user_count = 3 # 默认基准用户数量,将被动态更新
|
||||
|
||||
# 平滑因子
|
||||
self.smoothing_factor = 0.3
|
||||
|
||||
# 历史数据相关参数
|
||||
self._last_historical_update = 0
|
||||
self._historical_update_interval = 600 # 每十分钟更新一次历史基准值
|
||||
self._historical_days = 7 # 使用最近7天的数据计算基准值
|
||||
|
||||
# 按小时统计的历史基准值
|
||||
self._hourly_baseline = {
|
||||
'messages': {}, # {0-23: 平均消息数}
|
||||
'users': {} # {0-23: 平均用户数}
|
||||
}
|
||||
|
||||
# 初始化24小时的默认基准值
|
||||
for hour in range(24):
|
||||
self._hourly_baseline['messages'][hour] = 0.0
|
||||
self._hourly_baseline['users'][hour] = 0.0
|
||||
|
||||
def _update_historical_baseline(self):
|
||||
"""
|
||||
更新基于历史数据的基准值
|
||||
使用最近一周的数据,按小时统计平均消息数量和用户数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查是否需要更新历史基准值
|
||||
if current_time - self._last_historical_update < self._historical_update_interval:
|
||||
return
|
||||
|
||||
try:
|
||||
# 计算一周前的时间戳
|
||||
week_ago = current_time - (self._historical_days * 24 * 3600)
|
||||
|
||||
# 获取最近一周的消息数据
|
||||
historical_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
start_time=week_ago,
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True
|
||||
)
|
||||
|
||||
if historical_messages and len(historical_messages) >= 50:
|
||||
# 按小时统计消息数和用户数
|
||||
hourly_stats = {hour: {'messages': [], 'users': set()} for hour in range(24)}
|
||||
|
||||
for msg in historical_messages:
|
||||
# 获取消息的小时(UTC时间)
|
||||
msg_time = time.localtime(msg.time)
|
||||
msg_hour = msg_time.tm_hour
|
||||
|
||||
# 统计消息数
|
||||
hourly_stats[msg_hour]['messages'].append(msg)
|
||||
|
||||
# 统计用户数
|
||||
if msg.user_info and msg.user_info.user_id:
|
||||
hourly_stats[msg_hour]['users'].add(msg.user_info.user_id)
|
||||
|
||||
# 计算每个小时的平均值(基于一周的数据)
|
||||
for hour in range(24):
|
||||
# 计算该小时的平均消息数(一周内该小时的总消息数 / 7天)
|
||||
total_messages = len(hourly_stats[hour]['messages'])
|
||||
total_users = len(hourly_stats[hour]['users'])
|
||||
|
||||
# 只计算有消息的时段,没有消息的时段设为0
|
||||
if total_messages > 0:
|
||||
avg_messages = total_messages / self._historical_days
|
||||
avg_users = total_users / self._historical_days
|
||||
self._hourly_baseline['messages'][hour] = avg_messages
|
||||
self._hourly_baseline['users'][hour] = avg_users
|
||||
else:
|
||||
# 没有消息的时段设为0,表示该时段不活跃
|
||||
self._hourly_baseline['messages'][hour] = 0.0
|
||||
self._hourly_baseline['users'][hour] = 0.0
|
||||
|
||||
# 更新整体基准值(用于兼容性)- 基于原始数据计算,不受max(1.0)限制影响
|
||||
overall_avg_messages = sum(len(hourly_stats[hour]['messages']) for hour in range(24)) / (24 * self._historical_days)
|
||||
overall_avg_users = sum(len(hourly_stats[hour]['users']) for hour in range(24)) / (24 * self._historical_days)
|
||||
|
||||
self.base_message_count = overall_avg_messages
|
||||
self.base_user_count = overall_avg_users
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 历史基准值更新完成: "
|
||||
f"整体平均消息数={overall_avg_messages:.2f}, 整体平均用户数={overall_avg_users:.2f}"
|
||||
)
|
||||
|
||||
# 记录几个关键时段的基准值
|
||||
key_hours = [8, 12, 18, 22] # 早、中、晚、夜
|
||||
for hour in key_hours:
|
||||
# 计算该小时平均每10分钟的消息数和用户数
|
||||
hourly_10min_messages = self._hourly_baseline['messages'][hour] / 6 # 1小时 = 6个10分钟
|
||||
hourly_10min_users = self._hourly_baseline['users'][hour] / 6
|
||||
logger.info(
|
||||
f"{self.log_prefix} {hour}时基准值: "
|
||||
f"消息数={self._hourly_baseline['messages'][hour]:.2f}/小时 "
|
||||
f"({hourly_10min_messages:.2f}/10分钟), "
|
||||
f"用户数={self._hourly_baseline['users'][hour]:.2f}/小时 "
|
||||
f"({hourly_10min_users:.2f}/10分钟)"
|
||||
)
|
||||
|
||||
elif historical_messages and len(historical_messages) < 50:
|
||||
# 历史数据不足50条,不更新基准值
|
||||
logger.info(f"{self.log_prefix} 历史数据不足50条({len(historical_messages)}条),不更新基准值")
|
||||
else:
|
||||
# 如果没有历史数据,不更新基准值
|
||||
logger.info(f"{self.log_prefix} 无历史数据,不更新基准值")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 更新历史基准值时出错: {e}")
|
||||
# 出错时保持原有基准值不变
|
||||
|
||||
self._last_historical_update = current_time
|
||||
|
||||
def _get_current_hour_baseline(self) -> tuple[float, float]:
|
||||
"""
|
||||
获取当前小时的基准值
|
||||
|
||||
Returns:
|
||||
tuple: (基准消息数, 基准用户数)
|
||||
"""
|
||||
current_hour = time.localtime().tm_hour
|
||||
return (
|
||||
self._hourly_baseline['messages'][current_hour],
|
||||
self._hourly_baseline['users'][current_hour]
|
||||
)
|
||||
|
||||
def get_dynamic_talk_frequency_adjust(self) -> float:
|
||||
"""
|
||||
获取纯动态调整值(不包含配置文件基础值)
|
||||
|
||||
Returns:
|
||||
float: 动态调整值
|
||||
"""
|
||||
self._update_talk_frequency_adjust()
|
||||
def get_talk_frequency_adjust(self) -> float:
|
||||
"""获取发言频率调整值"""
|
||||
return self.talk_frequency_adjust
|
||||
|
||||
def get_dynamic_focus_value_adjust(self) -> float:
|
||||
"""
|
||||
获取纯动态调整值(不包含配置文件基础值)
|
||||
|
||||
Returns:
|
||||
float: 动态调整值
|
||||
"""
|
||||
self._update_focus_value_adjust()
|
||||
return self.focus_value_adjust
|
||||
|
||||
def _update_talk_frequency_adjust(self):
|
||||
"""
|
||||
更新发言频率调整值
|
||||
适合人少话多的时候:人少但消息多,提高回复频率
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查是否需要更新
|
||||
if current_time - self.last_update_time < self.update_interval:
|
||||
return
|
||||
|
||||
# 先更新历史基准值
|
||||
self._update_historical_baseline()
|
||||
|
||||
try:
|
||||
# 获取最近10分钟的数据(发言频率更敏感)
|
||||
recent_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
start_time=current_time - 600, # 10分钟前
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True
|
||||
)
|
||||
|
||||
# 计算消息数量和用户数量
|
||||
message_count = len(recent_messages)
|
||||
user_ids = set()
|
||||
for msg in recent_messages:
|
||||
if msg.user_info and msg.user_info.user_id:
|
||||
user_ids.add(msg.user_info.user_id)
|
||||
user_count = len(user_ids)
|
||||
|
||||
# 获取当前小时的基准值
|
||||
current_hour_base_messages, current_hour_base_users = self._get_current_hour_baseline()
|
||||
|
||||
# 计算当前小时平均每10分钟的基准值
|
||||
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
|
||||
current_hour_10min_users = current_hour_base_users / 6
|
||||
|
||||
# 发言频率调整逻辑:根据活跃度双向调整
|
||||
# 检查是否有足够的数据进行分析
|
||||
if user_count > 0 and message_count >= 2: # 至少需要2条消息才能进行有意义的分析
|
||||
# 检查历史基准值是否有效(该时段有活跃度)
|
||||
if current_hour_base_messages > 0.0 and current_hour_base_users > 0.0:
|
||||
# 计算人均消息数(10分钟窗口)
|
||||
messages_per_user = message_count / user_count
|
||||
# 使用当前小时每10分钟的基准人均消息数
|
||||
base_messages_per_user = current_hour_10min_messages / current_hour_10min_users if current_hour_10min_users > 0 else 1.0
|
||||
|
||||
# 双向调整逻辑
|
||||
if messages_per_user > base_messages_per_user * 1.2:
|
||||
# 活跃度很高:提高回复频率
|
||||
target_talk_adjust = min(self.max_adjust, messages_per_user / base_messages_per_user)
|
||||
elif messages_per_user < base_messages_per_user * 0.8:
|
||||
# 活跃度很低:降低回复频率
|
||||
target_talk_adjust = max(self.min_adjust, messages_per_user / base_messages_per_user)
|
||||
else:
|
||||
# 活跃度正常:保持正常
|
||||
target_talk_adjust = 1.0
|
||||
else:
|
||||
# 历史基准值不足,不调整
|
||||
target_talk_adjust = 1.0
|
||||
else:
|
||||
# 数据不足:不调整
|
||||
target_talk_adjust = 1.0
|
||||
|
||||
# 限制调整范围
|
||||
target_talk_adjust = max(self.min_adjust, min(self.max_adjust, target_talk_adjust))
|
||||
|
||||
# 记录调整前的值
|
||||
old_adjust = self.talk_frequency_adjust
|
||||
|
||||
# 平滑调整
|
||||
self.talk_frequency_adjust = (
|
||||
self.talk_frequency_adjust * (1 - self.smoothing_factor) +
|
||||
target_talk_adjust * self.smoothing_factor
|
||||
)
|
||||
|
||||
# 判断调整方向
|
||||
if target_talk_adjust > 1.0:
|
||||
adjust_direction = "提高"
|
||||
elif target_talk_adjust < 1.0:
|
||||
adjust_direction = "降低"
|
||||
else:
|
||||
if current_hour_base_messages <= 0.0 or current_hour_base_users <= 0.0:
|
||||
adjust_direction = "不调整(该时段无活跃度)"
|
||||
else:
|
||||
adjust_direction = "保持"
|
||||
|
||||
# 计算实际变化方向
|
||||
actual_change = ""
|
||||
if self.talk_frequency_adjust > old_adjust:
|
||||
actual_change = f"(实际提高: {old_adjust:.2f}→{self.talk_frequency_adjust:.2f})"
|
||||
elif self.talk_frequency_adjust < old_adjust:
|
||||
actual_change = f"(实际降低: {old_adjust:.2f}→{self.talk_frequency_adjust:.2f})"
|
||||
else:
|
||||
actual_change = f"(无变化: {self.talk_frequency_adjust:.2f})"
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 发言频率调整: "
|
||||
f"当前: {message_count}消息/{user_count}用户, 人均: {message_count/user_count if user_count > 0 else 0:.2f}消息|"
|
||||
f"基准: {current_hour_10min_messages:.2f}消息/{current_hour_10min_users:.2f}用户,人均:{current_hour_10min_messages/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}消息|"
|
||||
f"目标调整: {adjust_direction}到{target_talk_adjust:.2f}, 实际结果: {self.talk_frequency_adjust:.2f} {actual_change}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 更新发言频率调整值时出错: {e}")
|
||||
|
||||
def _update_focus_value_adjust(self):
|
||||
"""
|
||||
更新专注度调整值
|
||||
适合人多话多的时候:人多且消息多,提高专注度(LLM消耗更多,但回复更精准)
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查是否需要更新
|
||||
if current_time - self.last_update_time < self.update_interval:
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取最近10分钟的数据(与发言频率保持一致)
|
||||
recent_messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
start_time=current_time - 600, # 10分钟前
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True
|
||||
)
|
||||
|
||||
# 计算消息数量和用户数量
|
||||
message_count = len(recent_messages)
|
||||
user_ids = set()
|
||||
for msg in recent_messages:
|
||||
if msg.user_info and msg.user_info.user_id:
|
||||
user_ids.add(msg.user_info.user_id)
|
||||
user_count = len(user_ids)
|
||||
|
||||
# 获取当前小时的基准值
|
||||
current_hour_base_messages, current_hour_base_users = self._get_current_hour_baseline()
|
||||
|
||||
# 计算当前小时平均每10分钟的基准值
|
||||
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
|
||||
current_hour_10min_users = current_hour_base_users / 6
|
||||
|
||||
# 专注度调整逻辑:根据活跃度双向调整
|
||||
# 检查是否有足够的数据进行分析
|
||||
if user_count > 0 and current_hour_10min_users > 0 and message_count >= 2:
|
||||
# 检查历史基准值是否有效(该时段有活跃度)
|
||||
if current_hour_base_messages > 0.0 and current_hour_base_users > 0.0:
|
||||
# 计算用户活跃度比率(基于10分钟数据)
|
||||
user_ratio = user_count / current_hour_10min_users
|
||||
# 计算消息活跃度比率(基于10分钟数据)
|
||||
message_ratio = message_count / current_hour_10min_messages if current_hour_10min_messages > 0 else 1.0
|
||||
|
||||
# 双向调整逻辑
|
||||
if user_ratio > 1.3 and message_ratio > 1.3:
|
||||
# 活跃度很高:提高专注度,消耗更多LLM资源但回复更精准
|
||||
target_focus_adjust = min(self.max_adjust, (user_ratio + message_ratio) / 2)
|
||||
elif user_ratio > 1.1 and message_ratio > 1.1:
|
||||
# 活跃度较高:适度提高专注度
|
||||
target_focus_adjust = min(self.max_adjust, 1.0 + (user_ratio + message_ratio - 2.0) * 0.2)
|
||||
elif user_ratio < 0.7 or message_ratio < 0.7:
|
||||
# 活跃度很低:降低专注度,节省LLM资源
|
||||
target_focus_adjust = max(self.min_adjust, min(user_ratio, message_ratio))
|
||||
else:
|
||||
# 正常情况:保持默认专注度
|
||||
target_focus_adjust = 1.0
|
||||
else:
|
||||
# 历史基准值不足,不调整
|
||||
target_focus_adjust = 1.0
|
||||
else:
|
||||
# 数据不足:不调整
|
||||
target_focus_adjust = 1.0
|
||||
|
||||
# 限制调整范围
|
||||
target_focus_adjust = max(self.min_adjust, min(self.max_adjust, target_focus_adjust))
|
||||
|
||||
# 记录调整前的值
|
||||
old_focus_adjust = self.focus_value_adjust
|
||||
|
||||
# 平滑调整
|
||||
self.focus_value_adjust = (
|
||||
self.focus_value_adjust * (1 - self.smoothing_factor) +
|
||||
target_focus_adjust * self.smoothing_factor
|
||||
)
|
||||
|
||||
# 计算当前小时平均每10分钟的基准值
|
||||
current_hour_10min_messages = current_hour_base_messages / 6 # 1小时 = 6个10分钟
|
||||
current_hour_10min_users = current_hour_base_users / 6
|
||||
|
||||
# 判断调整方向
|
||||
if target_focus_adjust > 1.0:
|
||||
adjust_direction = "提高"
|
||||
elif target_focus_adjust < 1.0:
|
||||
adjust_direction = "降低"
|
||||
else:
|
||||
if current_hour_base_messages <= 0.0 or current_hour_base_users <= 0.0:
|
||||
adjust_direction = "不调整(该时段无活跃度)"
|
||||
else:
|
||||
adjust_direction = "保持"
|
||||
|
||||
# 计算实际变化方向
|
||||
actual_change = ""
|
||||
if self.focus_value_adjust > old_focus_adjust:
|
||||
actual_change = f"(实际提高: {old_focus_adjust:.2f}→{self.focus_value_adjust:.2f})"
|
||||
elif self.focus_value_adjust < old_focus_adjust:
|
||||
actual_change = f"(实际降低: {old_focus_adjust:.2f}→{self.focus_value_adjust:.2f})"
|
||||
else:
|
||||
actual_change = f"(无变化: {self.focus_value_adjust:.2f})"
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 专注度调整(10分钟): "
|
||||
f"当前: {message_count}消息/{user_count}用户,人均:{message_count/user_count if user_count > 0 else 0:.2f}消息|"
|
||||
f"基准: {current_hour_10min_messages:.2f}消息/{current_hour_10min_users:.2f}用户,人均:{current_hour_10min_messages/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}消息|"
|
||||
f"比率: 用户{user_count/current_hour_10min_users if current_hour_10min_users > 0 else 0:.2f}x, 消息{message_count/current_hour_10min_messages if current_hour_10min_messages > 0 else 0:.2f}x, "
|
||||
f"目标调整: {adjust_direction}到{target_focus_adjust:.2f}, 实际结果: {self.focus_value_adjust:.2f} {actual_change}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 更新专注度调整值时出错: {e}")
|
||||
|
||||
def get_final_talk_frequency(self) -> float:
|
||||
return get_config_base_talk_frequency(self.chat_stream.stream_id) * self.get_dynamic_talk_frequency_adjust() * self.talk_frequency_external_adjust
|
||||
|
||||
def get_final_focus_value(self) -> float:
|
||||
return get_config_base_focus_value(self.chat_stream.stream_id) * self.get_dynamic_focus_value_adjust() * self.focus_value_external_adjust
|
||||
|
||||
|
||||
def set_adjustment_parameters(
|
||||
self,
|
||||
min_adjust: Optional[float] = None,
|
||||
max_adjust: Optional[float] = None,
|
||||
base_message_count: Optional[int] = None,
|
||||
base_user_count: Optional[int] = None,
|
||||
smoothing_factor: Optional[float] = None,
|
||||
update_interval: Optional[int] = None,
|
||||
historical_update_interval: Optional[int] = None,
|
||||
historical_days: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
设置调整参数
|
||||
|
||||
Args:
|
||||
min_adjust: 最小调整值
|
||||
max_adjust: 最大调整值
|
||||
base_message_count: 基准消息数量
|
||||
base_user_count: 基准用户数量
|
||||
smoothing_factor: 平滑因子
|
||||
update_interval: 更新间隔(秒)
|
||||
"""
|
||||
if min_adjust is not None:
|
||||
self.min_adjust = max(0.1, min_adjust)
|
||||
if max_adjust is not None:
|
||||
self.max_adjust = max(1.0, max_adjust)
|
||||
if base_message_count is not None:
|
||||
self.base_message_count = max(1, base_message_count)
|
||||
if base_user_count is not None:
|
||||
self.base_user_count = max(1, base_user_count)
|
||||
if smoothing_factor is not None:
|
||||
self.smoothing_factor = max(0.0, min(1.0, smoothing_factor))
|
||||
if update_interval is not None:
|
||||
self.update_interval = max(10, update_interval)
|
||||
if historical_update_interval is not None:
|
||||
self._historical_update_interval = max(300, historical_update_interval) # 最少5分钟
|
||||
if historical_days is not None:
|
||||
self._historical_days = max(1, min(30, historical_days)) # 1-30天之间
|
||||
def set_talk_frequency_adjust(self, value: float) -> None:
|
||||
"""设置发言频率调整值"""
|
||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||
|
||||
|
||||
class FrequencyControlManager:
|
||||
"""
|
||||
频率控制管理器,管理多个聊天流的频率控制实例
|
||||
"""
|
||||
|
||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
|
||||
|
||||
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
|
||||
"""
|
||||
获取或创建指定聊天流的频率控制实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
FrequencyControl: 频率控制实例
|
||||
"""
|
||||
"""获取或创建指定聊天流的频率控制实例"""
|
||||
if chat_id not in self.frequency_control_dict:
|
||||
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
|
||||
return self.frequency_control_dict[chat_id]
|
||||
|
||||
def remove_frequency_control(self, chat_id: str) -> bool:
|
||||
"""移除指定聊天流的频率控制实例"""
|
||||
if chat_id in self.frequency_control_dict:
|
||||
del self.frequency_control_dict[chat_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all_chat_ids(self) -> list[str]:
|
||||
"""获取所有有频率控制的聊天ID"""
|
||||
return list(self.frequency_control_dict.keys())
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
from typing import Optional
|
||||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||
|
||||
|
||||
def get_config_base_talk_frequency(chat_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为 "platform:chat_id:type"
|
||||
|
||||
Returns:
|
||||
float: 对应的频率值
|
||||
"""
|
||||
if not global_config.chat.talk_frequency_adjust:
|
||||
return global_config.chat.talk_frequency
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_id:
|
||||
stream_frequency = get_stream_specific_frequency(chat_id)
|
||||
if stream_frequency is not None:
|
||||
return stream_frequency
|
||||
|
||||
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||
global_frequency = get_global_frequency()
|
||||
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
||||
|
||||
|
||||
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的频率
|
||||
|
||||
Args:
|
||||
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
current_hour, current_minute = map(int, current_time.split(":"))
|
||||
current_minutes = current_hour * 60 + current_minute
|
||||
|
||||
# 解析时间频率配置
|
||||
time_freq_pairs = []
|
||||
for time_freq_str in time_freq_list:
|
||||
try:
|
||||
time_str, freq_str = time_freq_str.split(",")
|
||||
hour, minute = map(int, time_str.split(":"))
|
||||
frequency = float(freq_str)
|
||||
minutes = hour * 60 + minute
|
||||
time_freq_pairs.append((minutes, frequency))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if not time_freq_pairs:
|
||||
return None
|
||||
|
||||
# 按时间排序
|
||||
time_freq_pairs.sort(key=lambda x: x[0])
|
||||
|
||||
# 查找当前时间对应的频率
|
||||
current_frequency = None
|
||||
for minutes, frequency in time_freq_pairs:
|
||||
if current_minutes >= minutes:
|
||||
current_frequency = frequency
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
|
||||
if current_frequency is None and time_freq_pairs:
|
||||
current_frequency = time_freq_pairs[-1][1]
|
||||
|
||||
return current_frequency
|
||||
|
||||
|
||||
def get_stream_specific_frequency(chat_stream_id: str):
|
||||
"""
|
||||
获取特定聊天流在当前时间的频率
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
# 查找匹配的聊天流配置
|
||||
for config_item in global_config.chat.talk_frequency_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||
|
||||
# 解析配置字符串并生成对应的 chat_id
|
||||
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
|
||||
if config_chat_id is None:
|
||||
continue
|
||||
|
||||
# 比较生成的 chat_id
|
||||
if config_chat_id != chat_stream_id:
|
||||
continue
|
||||
|
||||
# 使用通用的时间频率解析方法
|
||||
return get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_global_frequency() -> Optional[float]:
|
||||
"""
|
||||
获取全局默认频率配置
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in global_config.chat.talk_frequency_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
return get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
|
||||
|
||||
def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
||||
Args:
|
||||
stream_config_str: 格式为 "platform:id:type" 的字符串
|
||||
|
||||
Returns:
|
||||
str: 生成的 chat_id,如果解析失败则返回 None
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
@@ -1,15 +1,15 @@
|
||||
import asyncio
|
||||
from multiprocessing import context
|
||||
import time
|
||||
import traceback
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from collections import deque
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
@@ -18,14 +18,15 @@ from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.memory_system.question_maker import QuestionMaker
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -33,6 +34,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
@@ -84,8 +86,6 @@ class HeartFChatting:
|
||||
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
self.frequency_control = frequency_control_manager.get_or_create_frequency_control(self.stream_id)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
@@ -99,8 +99,15 @@ class HeartFChatting:
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 10
|
||||
self.last_read_time = time.time() - 2
|
||||
self.no_reply_until_call = False
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
|
||||
self.questioned = False
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
@@ -153,63 +160,117 @@ class HeartFChatting:
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
if elapsed < 0.1:
|
||||
# 不显示小于0.1秒的计时器
|
||||
continue
|
||||
formatted_time = f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
action_type = "未知动作"
|
||||
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||
if isinstance(loop_plan_info, dict):
|
||||
action_result = loop_plan_info.get("action_result", {})
|
||||
if isinstance(action_result, dict):
|
||||
# 旧格式:action_result是字典
|
||||
action_type = action_result.get("action_type", "未知动作")
|
||||
elif isinstance(action_result, list) and action_result:
|
||||
# 新格式:action_result是actions列表
|
||||
# TODO: 把这里写明白
|
||||
action_type = action_result[0].action_type or "未知动作"
|
||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||
# 直接是actions列表的情况
|
||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
|
||||
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def caculate_interest_value(self, recent_messages_list: List["DatabaseMessages"]) -> float:
|
||||
total_interest = 0.0
|
||||
for msg in recent_messages_list:
|
||||
interest_value = msg.interest_value
|
||||
if interest_value is not None and msg.processed_plain_text:
|
||||
total_interest += float(interest_value)
|
||||
return total_interest / len(recent_messages_list)
|
||||
|
||||
async def _loopbody(self):
|
||||
async def _loopbody(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=10,
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
if recent_messages_list:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(interest_value=await self.caculate_interest_value(recent_messages_list),recent_messages_list=recent_messages_list)
|
||||
|
||||
question_probability = 0
|
||||
if time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.01
|
||||
elif time.time() - self.last_active_time > 1200:
|
||||
question_probability = 0.005
|
||||
elif time.time() - self.last_active_time > 600:
|
||||
question_probability = 0.001
|
||||
else:
|
||||
question_probability = 0.0003
|
||||
|
||||
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
|
||||
|
||||
# print(f"{self.log_prefix} questioned: {self.questioned},len: {len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id))}")
|
||||
if question_probability > 0 and not self.questioned and len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id)) == 0: #长久没有回复,可以试试主动发言,提问概率随着时间增加
|
||||
# logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,概率: {question_probability}")
|
||||
if random.random() < question_probability: # 30%概率主动发言
|
||||
try:
|
||||
self.questioned = True
|
||||
self.last_active_time = time.time()
|
||||
# print(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
question_maker = QuestionMaker(self.stream_id)
|
||||
question, context,conflict_context = await question_maker.make_question()
|
||||
if question:
|
||||
logger.info(f"{self.log_prefix} 问题: {question}")
|
||||
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
|
||||
await self._lift_question_reply(question,context,cycle_timers,thinking_id)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 无问题")
|
||||
# self.end_cycle(cycle_timers, thinking_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 主动提问失败: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
if len(recent_messages_list) >= 1:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
# !处理no_reply_until_call逻辑
|
||||
if self.no_reply_until_call:
|
||||
for message in recent_messages_list:
|
||||
if (
|
||||
message.is_mentioned
|
||||
or message.is_at
|
||||
or len(recent_messages_list) >= 8
|
||||
or time.time() - self.last_read_time > 600
|
||||
):
|
||||
self.no_reply_until_call = False
|
||||
self.last_read_time = time.time()
|
||||
break
|
||||
# 没有提到,继续保持沉默
|
||||
if self.no_reply_until_call:
|
||||
# logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
|
||||
self.last_read_time = time.time()
|
||||
|
||||
# !此处使at或者提及必定回复
|
||||
mentioned_message = None
|
||||
for message in recent_messages_list:
|
||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
||||
elif (
|
||||
random.random()
|
||||
< global_config.chat.get_talk_value(self.stream_id)
|
||||
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
||||
):
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
else:
|
||||
# 没有提到,继续保持沉默,等待5秒防止频繁触发
|
||||
await asyncio.sleep(10)
|
||||
return True
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set,
|
||||
response_set: "ReplySetModel",
|
||||
action_message: "DatabaseMessages",
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
@@ -257,191 +318,166 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0,recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
force_reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||
def calculate_normal_mode_probability(interest_val: float) -> float:
|
||||
# 使用sigmoid函数,调整参数使概率分布更合理
|
||||
# 当interest_value = 0时,概率约为0.1
|
||||
# 当interest_value = 1时,概率约为0.5
|
||||
# 当interest_value = 2时,概率约为0.8
|
||||
# 当interest_value = 3时,概率约为0.95
|
||||
k = 2.0 # 控制曲线陡峭程度
|
||||
x0 = 1.0 # 控制曲线中心点
|
||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
start_time = time.time()
|
||||
|
||||
normal_mode_probability = (
|
||||
calculate_normal_mode_probability(interest_value)
|
||||
* 2
|
||||
* self.frequency_control.get_final_talk_frequency()
|
||||
)
|
||||
|
||||
#对呼唤名字进行增幅
|
||||
for msg in recent_messages_list:
|
||||
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
|
||||
normal_mode_probability += msg.reply_probability_boost
|
||||
if global_config.chat.mentioned_bot_reply and msg.is_mentioned:
|
||||
normal_mode_probability += global_config.chat.mentioned_bot_reply
|
||||
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
|
||||
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
|
||||
|
||||
|
||||
# 根据概率决定使用直接回复
|
||||
interest_triggerd = False
|
||||
focus_triggerd = False
|
||||
|
||||
if random.random() < normal_mode_probability:
|
||||
interest_triggerd = True
|
||||
logger.info(
|
||||
f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复"
|
||||
)
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
await send_typing()
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
await global_memory_chest.build_running_content(chat_id=self.stream_id)
|
||||
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
#如果兴趣度不足以激活
|
||||
if not interest_triggerd:
|
||||
#看看专注值够不够
|
||||
if random.random() < self.frequency_control.get_final_focus_value():
|
||||
#专注值足够,仍然进入正式思考
|
||||
focus_triggerd = True #都没触发,路边
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
|
||||
# 任意一种触发都行
|
||||
if interest_triggerd or focus_triggerd:
|
||||
# 进入正式思考模式
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# 第一步:动作检查
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 执行planner
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
interest=global_config.personality.interest,
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
# current_available_actions=planner_info[2],
|
||||
chat_content_block=chat_content_block,
|
||||
# actions_before_now_block=actions_before_now_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
if not await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
):
|
||||
return False
|
||||
with Timer("规划器", cycle_timers):
|
||||
# 根据不同触发,进入不同plan
|
||||
if focus_triggerd:
|
||||
mode = ChatMode.FOCUS
|
||||
else:
|
||||
mode = ChatMode.NORMAL
|
||||
|
||||
action_to_use_info, _ = await self.action_planner.plan(
|
||||
mode=mode,
|
||||
loop_start_time=self.last_read_time,
|
||||
has_reply = False
|
||||
for action in action_to_use_info:
|
||||
if action.action_type == "reply":
|
||||
has_reply = True
|
||||
break
|
||||
|
||||
if not has_reply and force_reply_message:
|
||||
action_to_use_info.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="有人提到了你,进行回复",
|
||||
action_data={},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
action_command = ""
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
_cur_action = action_to_use_info[i]
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["reply_text"]
|
||||
action_command = result.get("command", "")
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["reply_text"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
excute_result_str = ""
|
||||
for result in results:
|
||||
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["result"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=excute_result_str)
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
)
|
||||
reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
"""S4U内容,暂时保留"""
|
||||
if s4u_config.enable_s4u:
|
||||
await stop_typing()
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
"""S4U内容,暂时保留"""
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
end_time = time.time()
|
||||
if end_time - start_time < global_config.chat.planner_smooth:
|
||||
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
@@ -466,7 +502,7 @@ class HeartFChatting:
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
@@ -477,11 +513,11 @@ class HeartFChatting:
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
action_reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
|
||||
action_message: 消息数据
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
@@ -491,37 +527,105 @@ class HeartFChatting:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_reasoning=action_reasoning,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
return False, ""
|
||||
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果
|
||||
result = await action_handler.handle_action()
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
return success, action_text, command
|
||||
|
||||
return success, action_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
return False, ""
|
||||
|
||||
async def _lift_question_reply(self, question: str, context: str, cycle_timers: Dict[str, float], thinking_id: str):
|
||||
reason = f"在聊天中:\n{context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
new_msg = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=1,
|
||||
)
|
||||
|
||||
reply_action_info = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning= "",
|
||||
action_data={},
|
||||
action_message=new_msg[0],
|
||||
available_actions=None,
|
||||
loop_start_time=time.time(),
|
||||
action_reasoning=reason)
|
||||
self.action_planner.add_plan_log(reasoning=f"你对问题\"{question}\"感到好奇,想要和群友讨论", actions=[reply_action_info])
|
||||
|
||||
success, llm_response = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_data={
|
||||
"raw_reply": f"我对这个问题感到好奇:{question}",
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.info("主动提问发言失败")
|
||||
self.action_planner.add_plan_excute_log(result="主动回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "主动回复生成失败", "loop_info": None}
|
||||
|
||||
if success:
|
||||
for reply_seg in llm_response.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await send_api.text_to_stream(
|
||||
text=send_data,
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": llm_response.reply_set.reply_data[0].content},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": [reply_action_info],
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": llm_response.reply_set.reply_data[0].content,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
self.last_active_time = time.time()
|
||||
self.action_planner.add_plan_excute_log(result=f"你提问:{question}")
|
||||
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你提问:{question}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set,
|
||||
reply_set: "ReplySetModel",
|
||||
message_data: "DatabaseMessages",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> str:
|
||||
@@ -529,15 +633,17 @@ class HeartFChatting:
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
need_reply = new_message_count >= random.randint(2, 3)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
for reply_content in reply_set.reply_data:
|
||||
if reply_content.content_type != ReplyContentType.TEXT:
|
||||
continue
|
||||
data: str = reply_content.content # type: ignore
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
@@ -571,86 +677,127 @@ class HeartFChatting:
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
if action_planner_info.action_type == "no_action":
|
||||
# 直接处理no_action逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
# 直接当场执行no_reply逻辑
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_action信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_action",
|
||||
)
|
||||
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_planner_info.action_type != "reply":
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
action_planner_info.reasoning or "",
|
||||
action_planner_info.action_data or {},
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
try:
|
||||
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "no_reply_until_call":
|
||||
# 直接当场执行no_reply_until_call逻辑
|
||||
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
|
||||
self.no_reply_until_call = True
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply_until_call",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
# 刷新主动发言状态
|
||||
|
||||
reason = action_planner_info.reasoning or "选择回复"
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
reply_reason=reason,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
else:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, result = await self._handle_action(
|
||||
action = action_planner_info.action_type,
|
||||
action_reasoning = action_planner_info.action_reasoning or "",
|
||||
action_data = action_planner_info.action_data or {},
|
||||
cycle_timers = cycle_timers,
|
||||
thinking_id = thinking_id,
|
||||
action_message= action_planner_info.action_message,
|
||||
)
|
||||
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
if chat_id in self.heartflow_chat_list:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
new_chat = HeartFChatting(chat_id = chat_id)
|
||||
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
|
||||
if chat_stream.group_info:
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
else:
|
||||
new_chat = BrainChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
@@ -27,4 +38,5 @@ class Heartflow:
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import asyncio
|
||||
import re
|
||||
import math
|
||||
import traceback
|
||||
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.chat_message_builder import replace_user_references
|
||||
from src.common.logger import get_logger
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
@@ -23,6 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
@@ -34,58 +29,17 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""
|
||||
if message.is_picid or message.is_emoji:
|
||||
return 0.0, []
|
||||
|
||||
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth= 4,
|
||||
fast_retrieval=global_config.chat.interest_rate_mode == "fast",
|
||||
)
|
||||
message.key_words = keywords
|
||||
message.key_words_lite = keywords_lite
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# interested_rate = 0.0
|
||||
keywords = []
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
|
||||
message.interest_value = base_interest
|
||||
message.interest_value = 1
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
return base_interest, keywords
|
||||
|
||||
return 1, keywords
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
@@ -114,17 +68,11 @@ class HeartFCMessageReceiver:
|
||||
chat = message.chat_stream
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
interested_rate, keywords = await _calculate_interest(message)
|
||||
|
||||
_, keywords = await _calculate_interest(message)
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
@@ -132,7 +80,7 @@ class HeartFCMessageReceiver:
|
||||
# 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
picid_list = re.findall(picid_pattern, message.processed_plain_text)
|
||||
|
||||
|
||||
# 创建替换后的文本
|
||||
processed_text = message.processed_plain_text
|
||||
if picid_list:
|
||||
@@ -145,18 +93,22 @@ class HeartFCMessageReceiver:
|
||||
# 如果没有找到图片描述,则移除[picid:xxxx]标记
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
|
||||
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references(
|
||||
processed_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True,
|
||||
)
|
||||
# if not processed_plain_text:
|
||||
# print(message)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
|
||||
|
||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_id=message.message_info.user_info.user_id, # type: ignore
|
||||
nickname=userinfo.user_nickname, # type: ignore
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理失败: {e}")
|
||||
|
||||
@@ -124,6 +124,7 @@ async def send_typing():
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
@@ -135,4 +136,4 @@ async def stop_typing():
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
||||
@@ -25,7 +25,6 @@ from rich.progress import (
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -33,11 +32,11 @@ install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
@@ -94,7 +93,13 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
dir_path: str,
|
||||
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
@@ -104,12 +109,16 @@ class EmbeddingStore:
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||
|
||||
|
||||
# 如果配置值被调整,记录日志
|
||||
if self.max_workers != max_workers:
|
||||
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||
logger.warning(
|
||||
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
|
||||
)
|
||||
if self.chunk_size != chunk_size:
|
||||
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||
logger.warning(
|
||||
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
|
||||
)
|
||||
|
||||
self.store = {}
|
||||
|
||||
@@ -121,23 +130,23 @@ class EmbeddingStore:
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
@@ -148,43 +157,45 @@ class EmbeddingStore:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
|
||||
# 分块
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i:i + chunk_size]
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
@@ -194,25 +205,25 @@ class EmbeddingStore:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
@@ -221,14 +232,14 @@ class EmbeddingStore:
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
return chunk_results
|
||||
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
@@ -242,7 +253,7 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
|
||||
# 按原始顺序返回结果
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
@@ -251,7 +262,7 @@ class EmbeddingStore:
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
|
||||
return ordered_results
|
||||
|
||||
def get_test_file_path(self):
|
||||
@@ -260,14 +271,14 @@ class EmbeddingStore:
|
||||
def save_embedding_test_vectors(self):
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 构建测试向量字典
|
||||
test_vectors = {}
|
||||
for idx, (s, embedding) in enumerate(embedding_results):
|
||||
@@ -277,10 +288,10 @@ class EmbeddingStore:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info("测试字符串嵌入向量保存完成")
|
||||
|
||||
def load_embedding_test_vectors(self):
|
||||
@@ -298,35 +309,35 @@ class EmbeddingStore:
|
||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
# 检查本地向量完整性
|
||||
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||
if local_vectors.get(str(idx)) is None:
|
||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 检查一致性
|
||||
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||
local_emb = local_vectors.get(str(idx))
|
||||
if not new_emb:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
return False
|
||||
|
||||
|
||||
sim = cosine_similarity(local_emb, new_emb)
|
||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||
return False
|
||||
|
||||
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
@@ -334,22 +345,22 @@ class EmbeddingStore:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
|
||||
|
||||
total = len(strs)
|
||||
|
||||
|
||||
# 过滤已存在的字符串
|
||||
new_strs = []
|
||||
for s in strs:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash not in self.store:
|
||||
new_strs.append(s)
|
||||
|
||||
|
||||
if not new_strs:
|
||||
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
@@ -363,31 +374,39 @@ class EmbeddingStore:
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||
|
||||
|
||||
# 首先更新已存在项的进度
|
||||
already_processed = total - len(new_strs)
|
||||
if already_processed > 0:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
|
||||
if new_strs:
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
self.max_workers,
|
||||
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
|
||||
)
|
||||
|
||||
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||
|
||||
|
||||
# 定义进度更新回调函数
|
||||
def update_progress(count):
|
||||
progress.update(task, advance=count)
|
||||
|
||||
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
@@ -520,7 +539,7 @@ class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
|
||||
Args:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
|
||||
@@ -426,9 +426,7 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score)
|
||||
for node_key, score in ppr_res.items()
|
||||
if node_key.startswith("paragraph")
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||
from .lpmmconfig import global_config
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from .lpmmconfig import global_config # noqa
|
||||
from .embedding_store import EmbeddingManager # noqa
|
||||
from .llm_client import LLMClient # noqa
|
||||
from .utils.dyn_topk import dyn_select_top_k # noqa
|
||||
|
||||
|
||||
class MemoryActiveManager:
|
||||
|
||||
@@ -8,7 +8,7 @@ def dyn_select_top_k(
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
return []
|
||||
|
||||
|
||||
# 按照分数排序(降序)
|
||||
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,241 +0,0 @@
|
||||
import json
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
Args:
|
||||
json_str: JSON格式的字符串
|
||||
|
||||
Returns:
|
||||
List[str]: 关键词列表
|
||||
"""
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json = repair_json(json_str)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
return result.get("keywords", [])
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# --- Group Chat Prompt ---
|
||||
memory_activator_prompt = """
|
||||
你需要根据以下信息来挑选合适的记忆编号
|
||||
以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
|
||||
|
||||
聊天记录:
|
||||
{obs_info_text}
|
||||
你想要回复的消息:
|
||||
{target_message}
|
||||
|
||||
记忆:
|
||||
{memory_info}
|
||||
|
||||
请输出一个json格式,包含以下字段:
|
||||
{{
|
||||
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
|
||||
}}
|
||||
不要输出其他多余内容,只输出json格式就好
|
||||
"""
|
||||
|
||||
Prompt(memory_activator_prompt, "memory_activator_prompt")
|
||||
|
||||
|
||||
class MemoryActivator:
|
||||
def __init__(self):
|
||||
self.key_words_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
# 用于记忆选择的 LLM 模型
|
||||
self.memory_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.selection",
|
||||
)
|
||||
|
||||
async def activate_memory_with_chat_history(
|
||||
self, target_message, chat_history: List[DatabaseMessages]
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
keywords_list = set()
|
||||
|
||||
for msg in chat_history:
|
||||
keywords = parse_keywords_string(msg.key_words)
|
||||
if keywords:
|
||||
if len(keywords_list) < 30:
|
||||
# 最多容纳30个关键词
|
||||
keywords_list.update(keywords)
|
||||
logger.debug(f"提取关键词: {keywords_list}")
|
||||
else:
|
||||
break
|
||||
|
||||
if not keywords_list:
|
||||
logger.debug("没有提取到关键词,返回空记忆列表")
|
||||
return []
|
||||
|
||||
# 从海马体获取相关记忆
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
return []
|
||||
|
||||
used_ids = set()
|
||||
candidate_memories = []
|
||||
|
||||
# 为每个记忆分配随机ID并过滤相关记忆
|
||||
for memory in related_memory:
|
||||
keyword, content = memory
|
||||
found = any(kw in content for kw in keywords_list)
|
||||
if found:
|
||||
# 随机分配一个不重复的2位数id
|
||||
while True:
|
||||
random_id = "{:02d}".format(random.randint(0, 99))
|
||||
if random_id not in used_ids:
|
||||
used_ids.add(random_id)
|
||||
break
|
||||
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
|
||||
|
||||
if not candidate_memories:
|
||||
logger.info("没有找到相关的候选记忆")
|
||||
return []
|
||||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
|
||||
|
||||
async def _select_memories_with_llm(
|
||||
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
使用 LLM 选择合适的记忆
|
||||
|
||||
Args:
|
||||
target_message: 目标消息
|
||||
chat_history_prompt: 聊天历史
|
||||
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
|
||||
"""
|
||||
try:
|
||||
# 构建聊天历史字符串
|
||||
obs_info_text = build_readable_messages(
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 构建记忆信息字符串
|
||||
memory_lines = []
|
||||
for memory in candidate_memories:
|
||||
memory_id = memory["memory_id"]
|
||||
keyword = memory["keyword"]
|
||||
content = memory["content"]
|
||||
|
||||
# 将 content 列表转换为字符串
|
||||
if isinstance(content, list):
|
||||
content_str = " | ".join(str(item) for item in content)
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
|
||||
|
||||
memory_info = "\n".join(memory_lines)
|
||||
|
||||
# 获取并格式化 prompt
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||
formatted_prompt = prompt_template.format(
|
||||
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||
formatted_prompt, temperature=0.3, max_tokens=150
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.info(f"LLM 记忆选择响应: {response}")
|
||||
else:
|
||||
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.debug(f"LLM 记忆选择响应: {response}")
|
||||
|
||||
# 解析响应获取选择的记忆编号
|
||||
try:
|
||||
fixed_json = repair_json(response)
|
||||
|
||||
# 解析为 Python 对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
|
||||
# 提取 memory_ids 字段并解析逗号分隔的编号
|
||||
if memory_ids_str := result.get("memory_ids", ""):
|
||||
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||
# 过滤掉空字符串和无效编号
|
||||
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||
selected_memory_ids = valid_memory_ids
|
||||
else:
|
||||
selected_memory_ids = []
|
||||
except Exception as e:
|
||||
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
|
||||
selected_memory_ids = []
|
||||
|
||||
# 根据编号筛选记忆
|
||||
selected_memories = []
|
||||
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||
|
||||
selected_memories = [
|
||||
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
|
||||
]
|
||||
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
|
||||
# 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -3,19 +3,18 @@ import os
|
||||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
from maim_message import UserInfo, Seg, GroupInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
# 定义日志配置
|
||||
@@ -27,7 +26,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
@@ -40,14 +39,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
@@ -58,9 +57,13 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
# 检查text是否为None或空字符串
|
||||
if text is None or not text:
|
||||
return False
|
||||
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
@@ -74,8 +77,6 @@ class ChatBot:
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
@@ -149,33 +150,29 @@ class ChatBot:
|
||||
if message.message_info.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.info("notice消息")
|
||||
# print(message)
|
||||
print(message)
|
||||
|
||||
return True
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
用于专门处理回送消息ID的函数
|
||||
"""
|
||||
message_data: Dict[str, Any] = raw_data.get("content", {})
|
||||
if not message_data:
|
||||
return
|
||||
message_type = message_data.get("type")
|
||||
if message_type != "echo":
|
||||
return
|
||||
mmc_message_id = message_data.get("echo")
|
||||
actual_message_id = message_data.get("actual_id")
|
||||
if MessageStorage.update_message(mmc_message_id, actual_message_id):
|
||||
logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}")
|
||||
else:
|
||||
logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}")
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -194,11 +191,6 @@ class ChatBot:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
platform = message_data["message_info"].get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
@@ -211,18 +203,35 @@ class ChatBot:
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_MESSAGE_PRE_PROCESS, message
|
||||
)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
if await self.handle_notice_message(message):
|
||||
# return
|
||||
pass
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
sent_message = message.message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
await MessageStorage.update_message(message)
|
||||
return
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text,
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
) or _check_ban_regex(
|
||||
message.raw_message, # type: ignore
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
):
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
@@ -234,21 +243,10 @@ class ChatBot:
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
|
||||
@@ -258,8 +256,11 @@ class ChatBot:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_plain_text:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Any, List
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from .chat_stream import ChatStream
|
||||
@@ -79,6 +80,14 @@ class Message(MessageBase):
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
elif segment.type == "forward":
|
||||
segments_text = []
|
||||
for node_dict in segment.data:
|
||||
message = MessageBase.from_dict(node_dict) # type: ignore
|
||||
processed_text = await self._process_message_segments(message.message_segment)
|
||||
if processed_text:
|
||||
segments_text.append(f"{global_config.bot.nickname}: {processed_text}")
|
||||
return "[合并消息]: " + "\n-- ".join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment) # type: ignore
|
||||
@@ -129,6 +138,7 @@ class MessageRecv(Message):
|
||||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
# print(f"self.message_segment: {self.message_segment}")
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
@@ -199,129 +209,6 @@ class MessageRecv(Message):
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.is_internal = False
|
||||
self.voice_done = None
|
||||
|
||||
self.chat_info = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
|
||||
@@ -18,7 +18,7 @@ class MessageStorage:
|
||||
if isinstance(keywords, list):
|
||||
return json.dumps(keywords, ensure_ascii=False)
|
||||
return "[]"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
@@ -33,7 +33,6 @@ class MessageStorage:
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 莫越权 救世啊
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# print(message)
|
||||
@@ -85,7 +84,7 @@ class MessageStorage:
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
selected_expressions = ""
|
||||
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
@@ -143,31 +142,26 @@ class MessageStorage:
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
message: MessageRecv,
|
||||
) -> None: # 用于实时更新数据库的自身发送消息ID,目前能处理text,reply,image和emoji
|
||||
"""更新最新一条匹配消息的message_id"""
|
||||
def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool:
|
||||
"""实时更新数据库的自身发送消息ID"""
|
||||
try:
|
||||
if message.message_segment.type == "notify":
|
||||
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||
else:
|
||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||
return
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return
|
||||
return False
|
||||
if matched_message := (
|
||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||
):
|
||||
# 更新找到的消息记录
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
return True
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
@@ -15,7 +16,7 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
|
||||
@@ -32,7 +33,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
class HeartFCSender:
|
||||
class UniversalMessageSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -66,8 +67,36 @@ class HeartFCSender:
|
||||
message.build_reply()
|
||||
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||
return False
|
||||
if modified_message:
|
||||
if modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
if modified_message._modify_flags.modify_plain_text:
|
||||
logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。")
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
await message.process()
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_SEND, message=message, stream_id=chat_id
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||
return False
|
||||
if modified_message:
|
||||
if modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
if modified_message._modify_flags.modify_plain_text:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
if typing:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
@@ -76,10 +105,22 @@ class HeartFCSender:
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
sent_msg = await send_message(message, show_log=show_log)
|
||||
sent_msg = await _send_message(message, show_log=show_log)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_SEND, message=message, stream_id=chat_id
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
|
||||
return True
|
||||
if modified_message:
|
||||
if modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
if modified_message._modify_flags.modify_plain_text:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
if storage_message:
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class ActionManager:
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
action_reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
@@ -46,7 +46,7 @@ class ActionManager:
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
reasoning: 执行理由
|
||||
action_reasoning: 执行理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流
|
||||
@@ -77,7 +77,7 @@ class ActionManager:
|
||||
# 创建动作实例
|
||||
instance = component_class(
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
@@ -124,4 +124,4 @@ class ActionManager:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
@@ -103,25 +103,23 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
|
||||
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
# if chat_content is not None:
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
@@ -131,9 +129,7 @@ class ActionModifier:
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,8 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
@@ -15,124 +16,36 @@ from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
init_rewrite_prompt()
|
||||
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{expression_habits_block}
|
||||
{relation_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你现在的心情是:{mood_state}
|
||||
你正在{chat_target_2},{reply_target_block}
|
||||
你想要对上述的发言进行回复,回复的具体内容(原句)是:{raw_reply}
|
||||
原因是:{reason}
|
||||
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息。
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
{reply_style}
|
||||
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,emoji,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
|
||||
# s4u 风格的 prompt 模板
|
||||
Prompt(
|
||||
"""{identity}
|
||||
你正在群聊中聊天,你想要回复 {sender_name} 的发言。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。
|
||||
|
||||
{time_block}
|
||||
{background_dialogue_prompt}
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{expression_habits_block}{tool_info_block}
|
||||
{knowledge_prompt}{memory_block}{relation_info_block}
|
||||
{extra_info_block}
|
||||
|
||||
{reply_target_block}
|
||||
你的心情:{mood_state}
|
||||
{reply_style}
|
||||
注意不要复读你说过的话
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,emoji,at或 @等 )。只输出一条回复就好
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{identity}
|
||||
{time_block}
|
||||
你现在正在一个QQ群里聊天,以下是正在进行的聊天内容:
|
||||
{background_dialogue_prompt}
|
||||
|
||||
{expression_habits_block}{tool_info_block}
|
||||
{knowledge_prompt}{memory_block}{relation_info_block}
|
||||
{extra_info_block}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。
|
||||
注意保持上下文的连贯性。
|
||||
你现在的心情是:{mood_state}
|
||||
{reply_style}
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,emoji,at或 @等 )。只输出一条回复就好
|
||||
现在,你说:
|
||||
""",
|
||||
"replyer_self_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的知识获取指令
|
||||
|
||||
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
|
||||
""",
|
||||
name="lpmm_get_knowledge_prompt",
|
||||
)
|
||||
|
||||
|
||||
class DefaultReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -142,8 +55,8 @@ class DefaultReplyer:
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
@@ -159,6 +72,7 @@ class DefaultReplyer:
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -192,6 +106,7 @@ class DefaultReplyer:
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
reply_time_point=reply_time_point,
|
||||
)
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
@@ -202,10 +117,14 @@ class DefaultReplyer:
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
|
||||
if not from_plugin:
|
||||
if not await events_manager.handle_mai_events(
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
||||
):
|
||||
)
|
||||
if not continue_flag:
|
||||
raise UserWarning("插件于请求前中断了内容生成")
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
llm_response.prompt = modified_message.llm_prompt
|
||||
prompt = str(modified_message.llm_prompt)
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
@@ -219,10 +138,19 @@ class DefaultReplyer:
|
||||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
llm_response.tool_calls = tool_call
|
||||
if not from_plugin and not await events_manager.handle_mai_events(
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||
):
|
||||
)
|
||||
if not from_plugin and not continue_flag:
|
||||
raise UserWarning("插件于请求后取消了内容生成")
|
||||
if modified_message:
|
||||
if modified_message._modify_flags.modify_llm_prompt:
|
||||
logger.warning("警告:插件在内容生成后才修改了prompt,此修改不会生效")
|
||||
llm_response.prompt = modified_message.llm_prompt # 虽然我不知道为什么在这里需要改prompt
|
||||
if modified_message._modify_flags.modify_llm_response_content:
|
||||
llm_response.content = modified_message.llm_response_content
|
||||
if modified_message._modify_flags.modify_llm_response_reasoning:
|
||||
llm_response.reasoning = modified_message.llm_response_reasoning
|
||||
except UserWarning as e:
|
||||
raise e
|
||||
except Exception as llm_e:
|
||||
@@ -293,24 +221,6 @@ class DefaultReplyer:
|
||||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_relation_info(self, sender: str, target: str):
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
if sender == global_config.bot.nickname:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person = Person(person_name=sender)
|
||||
if not is_person_known(person_name=sender):
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return person.build_relationship()
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
@@ -348,46 +258,42 @@ class DefaultReplyer:
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = (
|
||||
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
)
|
||||
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_memory_block(self) -> str:
|
||||
"""构建记忆块
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
|
||||
Returns:
|
||||
str: 记忆信息字符串
|
||||
"""
|
||||
# if not global_config.memory.enable_memory:
|
||||
# return ""
|
||||
|
||||
if not global_config.memory.enable_memory:
|
||||
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
|
||||
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
instant_memory = None
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history=chat_history
|
||||
)
|
||||
running_memories = None
|
||||
|
||||
if not running_memories:
|
||||
async def build_question_block(self) -> str:
|
||||
"""构建问题块"""
|
||||
# if not global_config.question.enable_question:
|
||||
# return ""
|
||||
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
|
||||
questions_str = ""
|
||||
for question in questions:
|
||||
questions_str += f"- {question.question}\n"
|
||||
if questions_str:
|
||||
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
keywords, content = running_memory
|
||||
memory_str += f"- {keywords}:{content}\n"
|
||||
|
||||
if instant_memory:
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
|
||||
return memory_str
|
||||
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
@@ -417,7 +323,7 @@ class DefaultReplyer:
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
tool_info_str += f"- 【{tool_name}】: {content}\n"
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
@@ -453,6 +359,64 @@ class DefaultReplyer:
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
def _replace_picids_with_descriptions(self, text: str) -> str:
|
||||
"""将文本中的[picid:xxx]替换为具体的图片描述
|
||||
|
||||
Args:
|
||||
text: 包含picid标记的文本
|
||||
|
||||
Returns:
|
||||
替换后的文本
|
||||
"""
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
pic_id = match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, text)
|
||||
|
||||
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
|
||||
"""分析target内容类型(基于原始picid格式)
|
||||
|
||||
Args:
|
||||
target: 目标消息内容(包含[picid:xxx]格式)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
return False, False, "", ""
|
||||
|
||||
# 检查是否只包含picid标记
|
||||
picid_pattern = r"\[picid:[^\]]+\]"
|
||||
picid_matches = re.findall(picid_pattern, target)
|
||||
|
||||
# 移除所有picid标记后检查是否还有文字内容
|
||||
text_without_picids = re.sub(picid_pattern, "", target).strip()
|
||||
|
||||
has_only_pics = len(picid_matches) > 0 and not text_without_picids
|
||||
has_text = bool(text_without_picids)
|
||||
|
||||
# 提取图片部分(转换为[图片:描述]格式)
|
||||
pic_part = ""
|
||||
if picid_matches:
|
||||
pic_descriptions = []
|
||||
for picid_match in picid_matches:
|
||||
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分(从第7个字符开始)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
logger.info(f"图片ID: {pic_id}, 描述: {description}")
|
||||
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
|
||||
if description == "[图片]":
|
||||
pic_descriptions.append(description)
|
||||
else:
|
||||
pic_descriptions.append(f"[图片:{description}]")
|
||||
pic_part = "".join(pic_descriptions)
|
||||
|
||||
return has_only_pics, has_text, pic_part, text_without_picids
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
@@ -511,11 +475,10 @@ class DefaultReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
def build_s4u_chat_history_prompts(
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
@@ -539,18 +502,6 @@ class DefaultReplyer:
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
@@ -583,60 +534,30 @@ class DefaultReplyer:
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
if core_dialogue_prompt:
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
else:
|
||||
all_dialogue_prompt = f"{all_dialogue_prompt_str}"
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
def build_mai_think_context(
|
||||
self,
|
||||
chat_id: str,
|
||||
memory_block: str,
|
||||
relation_info: str,
|
||||
time_block: str,
|
||||
chat_target_1: str,
|
||||
chat_target_2: str,
|
||||
mood_prompt: str,
|
||||
identity_block: str,
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_info: str,
|
||||
) -> Any:
|
||||
"""构建 mai_think 上下文信息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_block: 记忆块内容
|
||||
relation_info: 关系信息
|
||||
time_block: 时间块内容
|
||||
chat_target_1: 聊天目标1
|
||||
chat_target_2: 聊天目标2
|
||||
mood_prompt: 情绪提示
|
||||
identity_block: 身份块内容
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_info: 聊天信息
|
||||
|
||||
Returns:
|
||||
Any: mai_think 实例
|
||||
"""
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = chat_info
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
"""构建动作提示"""
|
||||
|
||||
action_descriptions = ""
|
||||
skip_names = ["emoji","build_memory","build_relation","reply"]
|
||||
skip_names = ["emoji", "build_memory", "build_relation", "reply"]
|
||||
if available_actions:
|
||||
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
|
||||
for action_name, action_info in available_actions.items():
|
||||
@@ -673,19 +594,18 @@ class DefaultReplyer:
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = (
|
||||
f"{global_config.personality.personality};"
|
||||
)
|
||||
prompt_personality = f"{global_config.personality.personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_message: DatabaseMessages,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[str, List[int]]:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -720,26 +640,46 @@ class DefaultReplyer:
|
||||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
|
||||
mood_prompt: str = ""
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
timestamp=reply_time_point,
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
timestamp=reply_time_point,
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
|
||||
person_list_short: List[Person] = []
|
||||
for msg in message_list_before_short:
|
||||
if (
|
||||
global_config.bot.qq_account == msg.user_info.user_id
|
||||
and global_config.bot.platform == msg.user_info.platform
|
||||
):
|
||||
continue
|
||||
if (
|
||||
reply_message
|
||||
and reply_message.user_info.user_id == msg.user_info.user_id
|
||||
and reply_message.user_info.platform == msg.user_info.platform
|
||||
):
|
||||
continue
|
||||
person = Person(platform=msg.user_info.platform, user_id=msg.user_info.user_id)
|
||||
if person.is_known:
|
||||
person_list_short.append(person)
|
||||
|
||||
# for person in person_list_short:
|
||||
# print(person.person_name)
|
||||
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
@@ -753,25 +693,29 @@ class DefaultReplyer:
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
||||
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(self.build_memory_block(), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(self.build_question_block(), "question_block"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
"memory_block": "回忆",
|
||||
# "memory_block": "回忆",
|
||||
"memory_block": "记忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"question_block": "问题",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
@@ -794,13 +738,16 @@ class DefaultReplyer:
|
||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
expression_habits_block: str
|
||||
selected_expressions: List[int]
|
||||
relation_info: str = results_dict["relation_info"]
|
||||
# relation_info: str = results_dict["relation_info"]
|
||||
# memory_block: str = results_dict["memory_block"]
|
||||
memory_block: str = results_dict["memory_block"]
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
question_block: str = results_dict["question_block"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
@@ -811,69 +758,50 @@ class DefaultReplyer:
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if sender:
|
||||
if is_group_chat:
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。原因是{reply_reason}"
|
||||
)
|
||||
else: # private chat
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。原因是{reply_reason}"
|
||||
)
|
||||
# 使用预先分析的内容类型结果
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
|
||||
elif has_text:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意"
|
||||
else:
|
||||
# 其他情况(空内容等)
|
||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
|
||||
message_list_before_now_long, user_id, sender
|
||||
)
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_self_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
target=target,
|
||||
reason=reply_reason,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
), selected_expressions
|
||||
else:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
sender_name=sender,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
), selected_expressions
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
memory_block=memory_block,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
# memory_block=memory_block,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
sender_name=sender,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
question_block=question_block,
|
||||
), selected_expressions
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
@@ -887,14 +815,12 @@ class DefaultReplyer:
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 添加情绪状态获取
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
@@ -910,9 +836,8 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
# 并行执行2个构建任务
|
||||
(expression_habits_block, _), relation_info, personality_prompt = await asyncio.gather(
|
||||
(expression_habits_block, _), personality_prompt = await asyncio.gather(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||
self.build_relation_info(sender, target),
|
||||
self.build_personality_prompt(),
|
||||
)
|
||||
|
||||
@@ -925,18 +850,39 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
@@ -963,7 +909,7 @@ class DefaultReplyer:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info,
|
||||
# relation_info_block=relation_info,
|
||||
chat_target=chat_target_1,
|
||||
time_block=time_block,
|
||||
chat_info=chat_talking_prompt_half,
|
||||
@@ -972,7 +918,6 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
mood_state=mood_prompt, # 添加情绪状态参数
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
@@ -1015,20 +960,18 @@ class DefaultReplyer:
|
||||
async def llm_generate_content(self, prompt: str):
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 直接使用已初始化的模型实例
|
||||
logger.info(f"使用模型集生成回复: {', '.join(map(str, self.express_model.model_for_task.model_list))}")
|
||||
|
||||
logger.info(f"\n{prompt}\n")
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
@@ -1115,6 +1058,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
1012
src/chat/replyer/private_generator.py
Normal file
1012
src/chat/replyer/private_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
20
src/chat/replyer/prompt/lpmm_prompt.py
Normal file
20
src/chat/replyer/prompt/lpmm_prompt.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
def init_lpmm_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的知识获取指令
|
||||
|
||||
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
|
||||
""",
|
||||
name="lpmm_get_knowledge_prompt",
|
||||
)
|
||||
71
src/chat/replyer/prompt/replyer_prompt.py
Normal file
71
src/chat/replyer/prompt/replyer_prompt.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
|
||||
def init_replyer_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里正在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("正在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}{question_block}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
{time_block}
|
||||
{background_dialogue_prompt}
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{identity}
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{identity}
|
||||
你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
|
||||
{identity}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
31
src/chat/replyer/prompt/rewrite_prompt.py
Normal file
31
src/chat/replyer/prompt/rewrite_prompt.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
def init_rewrite_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里正在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("正在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{expression_habits_block}
|
||||
{chat_target}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你正在{chat_target_2},{reply_target_block}
|
||||
现在请你对这句内容进行改写,请你参考上述内容进行改写,原句是:{raw_reply}:
|
||||
原因是:{reason}
|
||||
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息。
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
{reply_style}
|
||||
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括冒号和引号,表情包,emoji,at或 @等 ),只输出一条回复就好。
|
||||
改写后的回复:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
@@ -2,21 +2,22 @@ from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
class ReplyerManager:
|
||||
def __init__(self):
|
||||
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||
self._repliers: Dict[str, DefaultReplyer | PrivateReplyer] = {}
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
"""
|
||||
获取或创建回复器实例。
|
||||
|
||||
@@ -46,10 +47,17 @@ class ReplyerManager:
|
||||
return None
|
||||
|
||||
# model_configs 只在此时(初始化时)生效
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
if target_stream.group_info:
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
else:
|
||||
replyer = PrivateReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
|
||||
self._repliers[stream_id] = replyer
|
||||
return replyer
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
@@ -124,6 +124,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
# print(f"get_raw_msg_by_timestamp_with_chat: {chat_id}, {timestamp_start}, {timestamp_end}, {limit}, {limit_mode}, {filter_bot}, {filter_command}")
|
||||
return find_messages(
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
@@ -215,6 +216,7 @@ def get_actions_by_timestamp_with_chat(
|
||||
chat_id=action.chat_id,
|
||||
chat_info_stream_id=action.chat_info_stream_id,
|
||||
chat_info_platform=action.chat_info_platform,
|
||||
action_reasoning=action.action_reasoning,
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
@@ -417,12 +419,6 @@ def _build_readable_messages_internal(
|
||||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
|
||||
# 向下兼容
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
if show_pic:
|
||||
content = process_pic_ids(content)
|
||||
@@ -564,14 +560,12 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||
output_lines = []
|
||||
current_time = time.time()
|
||||
|
||||
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
|
||||
# sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True)
|
||||
|
||||
for action in actions:
|
||||
action_time = action.time or current_time
|
||||
action_name = action.action_name or "未知动作"
|
||||
# action_reason = action.get(action_data")
|
||||
if action_name in ["no_action", "no_action"]:
|
||||
if action_name in ["no_reply", "no_reply"]:
|
||||
continue
|
||||
|
||||
action_prompt_display = action.action_prompt_display or "无具体内容"
|
||||
@@ -593,6 +587,7 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||
|
||||
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
||||
output_lines.append(line)
|
||||
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
@@ -628,6 +623,7 @@ def build_readable_messages_with_id(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
remove_emoji_stickers: bool = False,
|
||||
) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
@@ -644,6 +640,7 @@ def build_readable_messages_with_id(
|
||||
show_pic=show_pic,
|
||||
read_mark=read_mark,
|
||||
message_id_list=message_id_list,
|
||||
remove_emoji_stickers=remove_emoji_stickers,
|
||||
)
|
||||
|
||||
return formatted_string, message_id_list
|
||||
@@ -658,6 +655,7 @@ def build_readable_messages(
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
remove_emoji_stickers: bool = False,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
@@ -672,13 +670,40 @@ def build_readable_messages(
|
||||
read_mark: 已读标记时间戳
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
remove_emoji_stickers: 是否移除表情包并过滤空消息
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
|
||||
# 如果启用移除表情包,先过滤消息
|
||||
if remove_emoji_stickers:
|
||||
filtered_messages = []
|
||||
for msg in messages:
|
||||
# 获取消息内容
|
||||
content = msg.processed_plain_text
|
||||
# 移除表情包
|
||||
emoji_pattern = r"\[表情包:[^\]]+\]"
|
||||
content = re.sub(emoji_pattern, "", content)
|
||||
|
||||
# 如果移除表情包后内容不为空,则保留消息
|
||||
if content.strip():
|
||||
filtered_messages.append(msg)
|
||||
|
||||
messages = filtered_messages
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = []
|
||||
for msg in messages:
|
||||
if remove_emoji_stickers:
|
||||
# 创建 MessageAndActionModel 但移除表情包
|
||||
model = MessageAndActionModel.from_DatabaseMessages(msg)
|
||||
# 移除表情包
|
||||
if model.processed_plain_text:
|
||||
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
|
||||
copy_messages.append(model)
|
||||
else:
|
||||
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
@@ -862,17 +887,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
user_id = msg.user_info.user_id
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
|
||||
# if not all([platform, user_id, timestamp is not None]):
|
||||
# continue
|
||||
|
||||
anon_name = get_anon_name(platform, user_id)
|
||||
# print(f"anon_name:{anon_name}")
|
||||
|
||||
@@ -909,6 +926,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
@@ -937,3 +955,45 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
|
||||
|
||||
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
|
||||
"""
|
||||
构建简化版消息字符串,只包含processed_plain_text内容,不考虑用户名和时间戳
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
只包含消息内容的字符串
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
output_lines = []
|
||||
|
||||
for msg in messages:
|
||||
# 获取纯文本内容
|
||||
content = msg.processed_plain_text or ""
|
||||
|
||||
# 处理图片ID
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
|
||||
def replace_pic_id(match):
|
||||
return "[图片]"
|
||||
|
||||
content = re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 处理用户引用格式,移除回复和@标记
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(reply_pattern, "回复[某人]", content)
|
||||
|
||||
at_pattern = r"@<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(at_pattern, "@[某人]", content)
|
||||
|
||||
# 清理并添加到输出
|
||||
content = content.strip()
|
||||
if content:
|
||||
output_lines.append(content)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
@@ -151,7 +151,7 @@ class Prompt(str):
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
|
||||
"""处理模板中的转义花括号,替换为临时标记""" # type: ignore
|
||||
# 如果传入的是列表,将其转换为字符串
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
|
||||
@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||
|
||||
|
||||
for item_name in stats[period_key][category]:
|
||||
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||
if time_costs:
|
||||
# 计算平均耗时
|
||||
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||
|
||||
|
||||
# 计算标准差
|
||||
if len(time_costs) > 1:
|
||||
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||
std_time_cost = variance ** 0.5
|
||||
std_time_cost = variance**0.5
|
||||
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||
else:
|
||||
stats[period_key][std_key][item_name] = 0.0
|
||||
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[MSG_CNT_BY_CHAT]
|
||||
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 生成HTML
|
||||
return f"""
|
||||
|
||||
@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
|
||||
# 这部分怎么处理啊啊啊啊
|
||||
#我觉得可以给消息加一个 reply_probability_boost字段
|
||||
# 我觉得可以给消息加一个 reply_probability_boost字段
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
@@ -339,7 +339,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
else:
|
||||
split_sentences = [cleaned_text]
|
||||
|
||||
sentences = []
|
||||
sentences: List[str] = []
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo.enable and enable_chinese_typo:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
@@ -383,10 +383,6 @@ def calculate_typing_time(
|
||||
- 在所有输入结束后,额外加上回车时间0.3秒
|
||||
- 如果is_emoji为True,将使用固定1秒的输入时间
|
||||
"""
|
||||
# # 将0-1的唤醒度映射到-1到1
|
||||
# mood_arousal = mood_manager.current_mood.arousal
|
||||
# # 映射到0.5到2倍的速度系数
|
||||
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
|
||||
# chinese_time *= 1 / typing_speed_multiplier
|
||||
# english_time *= 1 / typing_speed_multiplier
|
||||
# 计算中文字符数
|
||||
@@ -826,20 +822,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
return [keywords_str] if keywords_str else []
|
||||
|
||||
|
||||
|
||||
|
||||
def cut_key_words(concept_name: str) -> list[str]:
|
||||
"""对概念名称进行jieba分词,并过滤掉关键词列表中的关键词"""
|
||||
concept_name_tokens = list(jieba.cut(concept_name))
|
||||
|
||||
# 定义常见连词、停用词与标点
|
||||
conjunctions = {
|
||||
"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"
|
||||
}
|
||||
conjunctions = {"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"}
|
||||
stop_words = {
|
||||
"的", "了", "呢", "吗", "吧", "啊", "哦", "恩", "嗯", "呀", "嘛", "哇",
|
||||
"在", "是", "很", "也", "又", "就", "都", "还", "更", "最", "被", "把",
|
||||
"给", "对", "和", "与", "及", "跟", "并", "而且", "或者", "或", "以及"
|
||||
"的",
|
||||
"了",
|
||||
"呢",
|
||||
"吗",
|
||||
"吧",
|
||||
"啊",
|
||||
"哦",
|
||||
"恩",
|
||||
"嗯",
|
||||
"呀",
|
||||
"嘛",
|
||||
"哇",
|
||||
"在",
|
||||
"是",
|
||||
"很",
|
||||
"也",
|
||||
"又",
|
||||
"就",
|
||||
"都",
|
||||
"还",
|
||||
"更",
|
||||
"最",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"对",
|
||||
"和",
|
||||
"与",
|
||||
"及",
|
||||
"跟",
|
||||
"并",
|
||||
"而且",
|
||||
"或者",
|
||||
"或",
|
||||
"以及",
|
||||
}
|
||||
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
|
||||
|
||||
@@ -864,11 +888,16 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
left = merged_tokens[-1]
|
||||
right = cleaned_tokens[i + 1]
|
||||
# 左右都需要是有效词
|
||||
if left and right \
|
||||
and left not in conjunctions and right not in conjunctions \
|
||||
and left not in stop_words and right not in stop_words \
|
||||
and not all(ch in chinese_punctuations for ch in left) \
|
||||
and not all(ch in chinese_punctuations for ch in right):
|
||||
if (
|
||||
left
|
||||
and right
|
||||
and left not in conjunctions
|
||||
and right not in conjunctions
|
||||
and left not in stop_words
|
||||
and right not in stop_words
|
||||
and not all(ch in chinese_punctuations for ch in left)
|
||||
and not all(ch in chinese_punctuations for ch in right)
|
||||
):
|
||||
# 合并为一个新词,并替换掉左侧与跳过右侧
|
||||
combined = f"{left}{tok}{right}"
|
||||
merged_tokens[-1] = combined
|
||||
@@ -889,7 +918,7 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
if tok in stop_words:
|
||||
continue
|
||||
# if tok in ban_words:
|
||||
# continue
|
||||
# continue
|
||||
if all(ch in chinese_punctuations for ch in tok):
|
||||
continue
|
||||
if tok.strip() == "":
|
||||
@@ -899,4 +928,4 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
result_tokens.append(tok)
|
||||
|
||||
filtered_concept_name_tokens = result_tokens
|
||||
return filtered_concept_name_tokens
|
||||
return filtered_concept_name_tokens
|
||||
|
||||
@@ -91,9 +91,10 @@ class ImageManager:
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
@@ -120,6 +121,7 @@ class ImageManager:
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||
if tags:
|
||||
@@ -144,14 +146,14 @@ class ImageManager:
|
||||
return "[表情包(GIF处理失败)]"
|
||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
|
||||
vlm_prompt, image_base64_processed, "jpg", temperature=0.4
|
||||
)
|
||||
else:
|
||||
vlm_prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
vlm_prompt, image_base64, image_format, temperature=0.4
|
||||
)
|
||||
|
||||
if detailed_description is None:
|
||||
@@ -172,9 +174,7 @@ class ImageManager:
|
||||
|
||||
# 使用较低温度确保输出稳定
|
||||
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=50
|
||||
)
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt, temperature=0.3)
|
||||
|
||||
if not emotion_result:
|
||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||
@@ -220,11 +220,13 @@ class ImageManager:
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
Images.create(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
@@ -268,7 +270,7 @@ class ImageManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
prompt = global_config.personality.visual_style
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
@@ -564,7 +566,7 @@ class ImageManager:
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
prompt = global_config.personality.visual_style
|
||||
|
||||
# 获取VLM描述
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
@@ -621,3 +623,41 @@ def image_path_to_base64(image_path: str) -> str:
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
|
||||
|
||||
def base64_to_image(image_base64: str, output_path: str) -> bool:
|
||||
"""将base64编码的图片保存为文件
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否成功保存
|
||||
|
||||
Raises:
|
||||
ValueError: 当base64编码无效时
|
||||
IOError: 当保存文件失败时
|
||||
"""
|
||||
try:
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
|
||||
# 解码base64
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存base64图片失败: {e}")
|
||||
return False
|
||||
|
||||
@@ -6,7 +6,8 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
|
||||
def transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例
|
||||
|
||||
@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -219,6 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
chat_id: str,
|
||||
chat_info_stream_id: str,
|
||||
chat_info_platform: str,
|
||||
action_reasoning:str
|
||||
):
|
||||
self.action_id = action_id
|
||||
self.time = time
|
||||
@@ -232,4 +234,5 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.action_reasoning = action_reasoning
|
||||
@@ -23,3 +23,5 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
action_data: Optional[Dict] = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
loop_start_time: Optional[float] = None
|
||||
action_reasoning: Optional[str] = None
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -13,4 +16,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional["ReplySetModel"] = None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING, List, Tuple, Union, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -34,3 +35,172 @@ class MessageAndActionModel(BaseDataModel):
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
)
|
||||
|
||||
|
||||
class ReplyContentType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
EMOJI = "emoji"
|
||||
COMMAND = "command"
|
||||
VOICE = "voice"
|
||||
FORWARD = "forward"
|
||||
HYBRID = "hybrid" # 混合类型,包含多种内容
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardNode(BaseDataModel):
|
||||
user_id: Optional[str] = None
|
||||
user_nickname: Optional[str] = None
|
||||
content: Union[List["ReplyContent"], str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
|
||||
return cls(user_id="", user_nickname="", content=message_id)
|
||||
|
||||
@classmethod
|
||||
def construct_as_created_node(
|
||||
cls, user_id: str, user_nickname: str, content: List["ReplyContent"]
|
||||
) -> "ForwardNode":
|
||||
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplyContent(BaseDataModel):
|
||||
content_type: ReplyContentType | str
|
||||
content: Union[str, Dict, List[ForwardNode], List["ReplyContent"]] # 支持嵌套的 ReplyContent
|
||||
|
||||
@classmethod
|
||||
def construct_as_text(cls, text: str):
|
||||
return cls(content_type=ReplyContentType.TEXT, content=text)
|
||||
|
||||
@classmethod
|
||||
def construct_as_image(cls, image_base64: str):
|
||||
return cls(content_type=ReplyContentType.IMAGE, content=image_base64)
|
||||
|
||||
@classmethod
|
||||
def construct_as_voice(cls, voice_base64: str):
|
||||
return cls(content_type=ReplyContentType.VOICE, content=voice_base64)
|
||||
|
||||
@classmethod
|
||||
def construct_as_emoji(cls, emoji_str: str):
|
||||
return cls(content_type=ReplyContentType.EMOJI, content=emoji_str)
|
||||
|
||||
@classmethod
|
||||
def construct_as_command(cls, command_arg: Dict):
|
||||
return cls(content_type=ReplyContentType.COMMAND, content=command_arg)
|
||||
|
||||
@classmethod
|
||||
def construct_as_hybrid(cls, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
|
||||
hybrid_content_list: List[ReplyContent] = []
|
||||
for content_type, content in hybrid_content:
|
||||
assert content_type not in [
|
||||
ReplyContentType.HYBRID,
|
||||
ReplyContentType.FORWARD,
|
||||
ReplyContentType.VOICE,
|
||||
ReplyContentType.COMMAND,
|
||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
||||
assert isinstance(content, str), "混合内容的每个项必须是字符串"
|
||||
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
|
||||
return cls(content_type=ReplyContentType.HYBRID, content=hybrid_content_list)
|
||||
|
||||
@classmethod
|
||||
def construct_as_forward(cls, forward_nodes: List[ForwardNode]):
|
||||
return cls(content_type=ReplyContentType.FORWARD, content=forward_nodes)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.content_type, ReplyContentType):
|
||||
if self.content_type not in [ReplyContentType.HYBRID, ReplyContentType.FORWARD] and isinstance(
|
||||
self.content, List
|
||||
):
|
||||
raise ValueError(
|
||||
f"非混合类型/转发类型的内容不能是列表,content_type: {self.content_type}, content: {self.content}"
|
||||
)
|
||||
elif self.content_type in [ReplyContentType.HYBRID, ReplyContentType.FORWARD]:
|
||||
if not isinstance(self.content, List):
|
||||
raise ValueError(
|
||||
f"混合类型/转发类型的内容必须是列表,content_type: {self.content_type}, content: {self.content}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplySetModel(BaseDataModel):
|
||||
"""
|
||||
回复集数据模型,用于多种回复类型的返回
|
||||
"""
|
||||
|
||||
reply_data: List[ReplyContent] = field(default_factory=list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.reply_data)
|
||||
|
||||
def add_text_content(self, text: str):
|
||||
"""
|
||||
添加文本内容
|
||||
Args:
|
||||
text: 文本内容
|
||||
"""
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
|
||||
|
||||
def add_image_content(self, image_base64: str):
|
||||
"""
|
||||
添加图片内容,base64编码的图片数据
|
||||
Args:
|
||||
image_base64: base64编码的图片数据
|
||||
"""
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.IMAGE, content=image_base64))
|
||||
|
||||
def add_voice_content(self, voice_base64: str):
|
||||
"""
|
||||
添加语音内容,base64编码的音频数据
|
||||
Args:
|
||||
voice_base64: base64编码的音频数据
|
||||
"""
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
|
||||
|
||||
def add_hybrid_content_by_raw(self, hybrid_content: List[Tuple[ReplyContentType | str, str]]):
|
||||
"""
|
||||
添加混合型内容,可以包含text, image, emoji的任意组合
|
||||
Args:
|
||||
hybrid_content: 元组 (类型, 消息内容) 构成的列表,如[(ReplyContentType.TEXT, "Hello"), (ReplyContentType.IMAGE, "<base64")]
|
||||
"""
|
||||
hybrid_content_list: List[ReplyContent] = []
|
||||
for content_type, content in hybrid_content:
|
||||
assert content_type not in [
|
||||
ReplyContentType.HYBRID,
|
||||
ReplyContentType.FORWARD,
|
||||
ReplyContentType.VOICE,
|
||||
ReplyContentType.COMMAND,
|
||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
||||
assert isinstance(content, str), "混合内容的每个项必须是字符串"
|
||||
hybrid_content_list.append(ReplyContent(content_type=content_type, content=content))
|
||||
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content_list))
|
||||
|
||||
def add_hybrid_content(self, hybrid_content: List[ReplyContent]):
|
||||
"""
|
||||
添加混合型内容,使用已经构造好的 ReplyContent 列表
|
||||
Args:
|
||||
hybrid_content: ReplyContent 构成的列表,如[ReplyContent(ReplyContentType.TEXT, "Hello"), ReplyContent(ReplyContentType.IMAGE, "<base64")]
|
||||
"""
|
||||
for content in hybrid_content:
|
||||
assert content.content_type not in [
|
||||
ReplyContentType.HYBRID,
|
||||
ReplyContentType.FORWARD,
|
||||
ReplyContentType.VOICE,
|
||||
ReplyContentType.COMMAND,
|
||||
], "混合内容的每个项不能是混合、转发、语音或命令类型"
|
||||
assert isinstance(content.content, str), "混合内容的每个项必须是字符串"
|
||||
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_content))
|
||||
|
||||
def add_custom_content(self, content_type: str, content: Any):
|
||||
"""
|
||||
添加自定义类型的内容"""
|
||||
self.reply_data.append(ReplyContent(content_type=content_type, content=content))
|
||||
|
||||
def add_forward_content(self, forward_content: List[ForwardNode]):
|
||||
"""添加转发内容,可以是字符串或ReplyContent,嵌套的转发内容需要自己构造放入"""
|
||||
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_content))
|
||||
|
||||
36
src/common/data_models/message_data_model_ref.md
Normal file
36
src/common/data_models/message_data_model_ref.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# 对于`message_data_model.py`中`class ReplyContent`的规划解读
|
||||
|
||||
分类讨论如下:
|
||||
- `ReplyContent.TEXT`: 单独的文本,`_level = 0`,`content`为`str`类型。
|
||||
- `ReplyContent.IMAGE`: 单独的图片,`_level = 0`,`content`为`str`类型(图片base64)。
|
||||
- `ReplyContent.EMOJI`: 单独的表情包,`_level = 0`,`content`为`str`类型(图片base64)。
|
||||
- `ReplyContent.VOICE`: 单独的语音,`_level = 0`,`content`为`str`类型(语音base64)。
|
||||
- `ReplyContent.HYBRID`: 混合内容,`_level = 0`
|
||||
- 其应该是一个列表,列表内应该只接受`str`类型的内容(图片和文本混合体)
|
||||
- `ReplyContent.FORWARD`: 转发消息,`_level = n`
|
||||
- 其应该是一个列表,列表接受`str`类型(图片/文本),`ReplyContent`类型(嵌套转发,嵌套有最高层数限制)
|
||||
- `ReplyContent.COMMAND`: 指令消息,`_level = 0`
|
||||
- 其应该是一个列表,列表内应该只接受`Dict`类型的内容
|
||||
|
||||
未来规划:
|
||||
- `ReplyContent.AT`: 单独的艾特,`_level = 0`,`content`为`str`类型(用户ID)。
|
||||
|
||||
内容构造方式:
|
||||
- 对于`TEXT`, `IMAGE`, `EMOJI`, `VOICE`,直接传入对应类型的内容,且`content`应该为`str`。
|
||||
- 对于`COMMAND`,传入一个字典,字典内的内容类型应符合上述规定。
|
||||
- 对于`HYBRID`, `FORWARD`,传入一个列表,列表内的内容类型应符合上述规定。
|
||||
|
||||
因此,我们的类型注解应该是:
|
||||
```python
|
||||
from typing import Union, List, Dict
|
||||
|
||||
ReplyContentType = Union[
|
||||
str, # TEXT, IMAGE, EMOJI, VOICE
|
||||
List[Union[str, 'ReplyContent']], # HYBRID, FORWARD
|
||||
Dict # COMMAND
|
||||
]
|
||||
```
|
||||
|
||||
现在`_level`被移除了,在解析的时候显式地检查内容的类型和结构即可。
|
||||
|
||||
`send_api`的custom_reply_set_to_stream仅在特定的类型下提供reply)message
|
||||
57
src/common/data_models/reply_set_doc.md
Normal file
57
src/common/data_models/reply_set_doc.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# 有关转发消息和其他消息的构建类型说明
|
||||
```mermaid
|
||||
graph LR;
|
||||
direction TB;
|
||||
A[ReplySet] --- B[ReplyContent];
|
||||
A --- C["ReplyContent"];
|
||||
A --- K["ReplyContent"];
|
||||
A --- L["ReplyContent"];
|
||||
A --- N["ReplyContent"];
|
||||
A --- D[...];
|
||||
B --- E["Text (in str)"];
|
||||
B --- F["Image (in base64)"];
|
||||
C --- G["Voice (in base64)"];
|
||||
B --- I["Emoji (in base64)"];
|
||||
subgraph "可行内容(以下的任意组合)";
|
||||
subgraph "转发消息(Forward)"
|
||||
M["List[ForwardNode]"]
|
||||
end
|
||||
subgraph "混合消息(Hybrid)"
|
||||
J["List[ReplyContent] (要求只能包含普通消息)"]
|
||||
end
|
||||
subgraph "命令消息(Command)"
|
||||
H["Command (in Dict)"]
|
||||
end
|
||||
subgraph "语音消息"
|
||||
G
|
||||
end
|
||||
subgraph "普通消息"
|
||||
E
|
||||
F
|
||||
I
|
||||
end
|
||||
end
|
||||
N --- H
|
||||
K --- J
|
||||
L --- M
|
||||
subgraph ForwardNodes
|
||||
O["ForwardNode"]
|
||||
P["ForwardNode"]
|
||||
Q["ForwardNode"]
|
||||
end
|
||||
M --- O
|
||||
M --- P
|
||||
M --- Q
|
||||
subgraph "内容 (message_id引用法)"
|
||||
P --- U["content: str, 引用已有消息的有效ID"];
|
||||
end
|
||||
subgraph "内容 (生成法)"
|
||||
O --- R["user_id: str"];
|
||||
O --- S["user_nickname: str"];
|
||||
O --- T["content: List[ReplyContent], 为这个转发节点的消息内容"];
|
||||
end
|
||||
```
|
||||
|
||||
另外,自定义消息类型我们在这里不做讨论。
|
||||
|
||||
以上列出了所有可能的ReplySet构建方式,下面我们来解释一下各个类型的含义。
|
||||
@@ -1,64 +1,9 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri:
|
||||
# 支持标准mongodb://和mongodb+srv://连接字符串
|
||||
if uri.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return MongoClient(uri)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
|
||||
"For MongoDB Atlas, use 'mongodb+srv://' format. "
|
||||
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
|
||||
)
|
||||
|
||||
if username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
|
||||
|
||||
# 否则使用无认证连接
|
||||
return MongoClient(host, port)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
@@ -135,7 +135,7 @@ class Messages(BaseModel):
|
||||
interest_value = DoubleField(null=True)
|
||||
key_words = TextField(null=True)
|
||||
key_words_lite = TextField(null=True)
|
||||
|
||||
|
||||
is_mentioned = BooleanField(null=True)
|
||||
is_at = BooleanField(null=True)
|
||||
reply_probability_boost = DoubleField(null=True)
|
||||
@@ -169,7 +169,7 @@ class Messages(BaseModel):
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
is_notify = BooleanField(default=False)
|
||||
|
||||
|
||||
selected_expressions = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
@@ -185,6 +185,8 @@ class ActionRecords(BaseModel):
|
||||
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
|
||||
time = DoubleField() # 消息时间戳
|
||||
|
||||
action_reasoning = TextField(null=True)
|
||||
|
||||
action_name = TextField()
|
||||
action_data = TextField()
|
||||
action_done = BooleanField(default=False)
|
||||
@@ -267,12 +269,6 @@ class PersonInfo(BaseModel):
|
||||
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
||||
know_since = FloatField(null=True) # 首次印象总结时间
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
|
||||
attitude_to_me = TextField(null=True) # 对bot的态度
|
||||
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
|
||||
|
||||
|
||||
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
@@ -299,6 +295,7 @@ class GroupInfo(BaseModel):
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "group_info"
|
||||
|
||||
|
||||
class Expression(BaseModel):
|
||||
"""
|
||||
用于存储表达风格的模型。
|
||||
@@ -307,6 +304,11 @@ class Expression(BaseModel):
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
context_words = TextField(null=True)
|
||||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
@@ -315,36 +317,35 @@ class Expression(BaseModel):
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
class GraphNodes(BaseModel):
|
||||
class MemoryChest(BaseModel):
|
||||
"""
|
||||
用于存储记忆图节点的模型
|
||||
用于存储记忆仓库的模型
|
||||
"""
|
||||
|
||||
concept = TextField(unique=True, index=True) # 节点概念
|
||||
memory_items = TextField() # JSON格式存储的记忆列表
|
||||
weight = FloatField(default=0.0) # 节点权重
|
||||
hash = TextField() # 节点哈希值
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
title = TextField() # 标题
|
||||
content = TextField() # 内容
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
locked = BooleanField(default=False) # 是否锁定
|
||||
|
||||
class Meta:
|
||||
table_name = "graph_nodes"
|
||||
table_name = "memory_chest"
|
||||
|
||||
|
||||
class GraphEdges(BaseModel):
|
||||
class MemoryConflict(BaseModel):
|
||||
"""
|
||||
用于存储记忆图边的模型
|
||||
用于存储记忆整合过程中冲突内容的模型
|
||||
"""
|
||||
|
||||
source = TextField(index=True) # 源节点
|
||||
target = TextField(index=True) # 目标节点
|
||||
strength = IntegerField() # 连接强度
|
||||
hash = TextField() # 边哈希值
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
conflict_content = TextField() # 冲突内容
|
||||
answer = TextField(null=True) # 回答内容
|
||||
create_time = FloatField() # 创建时间
|
||||
update_time = FloatField() # 更新时间
|
||||
context = TextField(null=True) # 上下文
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
raise_time = FloatField(null=True) # 触发次数
|
||||
|
||||
class Meta:
|
||||
table_name = "graph_edges"
|
||||
table_name = "memory_conflicts"
|
||||
|
||||
|
||||
|
||||
def create_tables():
|
||||
@@ -363,9 +364,9 @@ def create_tables():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes, # 添加图节点表
|
||||
GraphEdges, # 添加图边表
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict, # 添加记忆冲突表
|
||||
]
|
||||
)
|
||||
|
||||
@@ -374,7 +375,7 @@ def initialize_database(sync_constraints=False):
|
||||
"""
|
||||
检查所有定义的表是否存在,如果不存在则创建它们。
|
||||
检查所有表的所有字段是否存在,如果缺失则自动添加。
|
||||
|
||||
|
||||
Args:
|
||||
sync_constraints (bool): 是否同步字段约束。默认为 False。
|
||||
如果为 True,会检查并修复字段的 NULL 约束不一致问题。
|
||||
@@ -390,9 +391,9 @@ def initialize_database(sync_constraints=False):
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
@@ -456,13 +457,13 @@ def initialize_database(sync_constraints=False):
|
||||
logger.info(f"字段 '{field_name}' 删除成功")
|
||||
except Exception as e:
|
||||
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
||||
|
||||
|
||||
# 如果启用了约束同步,执行约束检查和修复
|
||||
if sync_constraints:
|
||||
logger.debug("开始同步数据库字段约束...")
|
||||
sync_field_constraints()
|
||||
logger.debug("数据库字段约束同步完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
||||
# 如果检查失败(例如数据库不可用),则退出
|
||||
@@ -476,7 +477,7 @@ def sync_field_constraints():
|
||||
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
|
||||
如果发现不一致,会自动修复字段约束。
|
||||
"""
|
||||
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
@@ -487,9 +488,9 @@ def sync_field_constraints():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
@@ -501,50 +502,55 @@ def sync_field_constraints():
|
||||
continue
|
||||
|
||||
logger.debug(f"检查表 '{table_name}' 的字段约束...")
|
||||
|
||||
|
||||
# 获取当前表结构信息
|
||||
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||
for row in cursor.fetchall()}
|
||||
|
||||
current_schema = {
|
||||
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
# 检查每个模型字段的约束
|
||||
constraints_to_fix = []
|
||||
for field_name, field_obj in model._meta.fields.items():
|
||||
if field_name not in current_schema:
|
||||
continue # 字段不存在,跳过
|
||||
|
||||
current_notnull = current_schema[field_name]['notnull']
|
||||
|
||||
current_notnull = current_schema[field_name]["notnull"]
|
||||
model_allows_null = field_obj.null
|
||||
|
||||
|
||||
# 如果模型允许 null 但数据库字段不允许 null,需要修复
|
||||
if model_allows_null and current_notnull:
|
||||
constraints_to_fix.append({
|
||||
'field_name': field_name,
|
||||
'field_obj': field_obj,
|
||||
'action': 'allow_null',
|
||||
'current_constraint': 'NOT NULL',
|
||||
'target_constraint': 'NULL'
|
||||
})
|
||||
constraints_to_fix.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"field_obj": field_obj,
|
||||
"action": "allow_null",
|
||||
"current_constraint": "NOT NULL",
|
||||
"target_constraint": "NULL",
|
||||
}
|
||||
)
|
||||
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL,但数据库为NOT NULL")
|
||||
|
||||
|
||||
# 如果模型不允许 null 但数据库字段允许 null,也需要修复(但要小心)
|
||||
elif not model_allows_null and not current_notnull:
|
||||
constraints_to_fix.append({
|
||||
'field_name': field_name,
|
||||
'field_obj': field_obj,
|
||||
'action': 'disallow_null',
|
||||
'current_constraint': 'NULL',
|
||||
'target_constraint': 'NOT NULL'
|
||||
})
|
||||
constraints_to_fix.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"field_obj": field_obj,
|
||||
"action": "disallow_null",
|
||||
"current_constraint": "NULL",
|
||||
"target_constraint": "NOT NULL",
|
||||
}
|
||||
)
|
||||
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL,但数据库允许NULL")
|
||||
|
||||
|
||||
# 修复约束不一致的字段
|
||||
if constraints_to_fix:
|
||||
logger.info(f"表 '{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
|
||||
_fix_table_constraints(table_name, model, constraints_to_fix)
|
||||
else:
|
||||
logger.debug(f"表 '{table_name}' 的字段约束已同步")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"同步字段约束时出错: {e}")
|
||||
|
||||
@@ -557,40 +563,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
try:
|
||||
# 备份表名
|
||||
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
|
||||
|
||||
|
||||
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
|
||||
|
||||
|
||||
# 1. 创建备份表
|
||||
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
|
||||
logger.info(f"已创建备份表 '{backup_table}'")
|
||||
|
||||
|
||||
# 2. 删除原表
|
||||
db.execute_sql(f"DROP TABLE {table_name}")
|
||||
logger.info(f"已删除原表 '{table_name}'")
|
||||
|
||||
|
||||
# 3. 重新创建表(使用当前模型定义)
|
||||
db.create_tables([model])
|
||||
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
|
||||
|
||||
|
||||
# 4. 从备份表恢复数据
|
||||
# 获取字段列表
|
||||
fields = list(model._meta.fields.keys())
|
||||
fields_str = ', '.join(fields)
|
||||
|
||||
fields_str = ", ".join(fields)
|
||||
|
||||
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
|
||||
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
|
||||
|
||||
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||
null_to_notnull_fields = [
|
||||
constraint['field_name'] for constraint in constraints_to_fix
|
||||
if constraint['action'] == 'disallow_null'
|
||||
constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
|
||||
]
|
||||
|
||||
|
||||
if null_to_notnull_fields:
|
||||
# 需要处理 NULL 值,为这些字段设置默认值
|
||||
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL,需要处理现有的NULL值")
|
||||
|
||||
|
||||
# 构建更复杂的 SELECT 语句来处理 NULL 值
|
||||
select_fields = []
|
||||
for field_name in fields:
|
||||
@@ -607,21 +612,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
default_value = f"'{datetime.datetime.now()}'"
|
||||
else:
|
||||
default_value = "''"
|
||||
|
||||
|
||||
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
|
||||
else:
|
||||
select_fields.append(field_name)
|
||||
|
||||
select_str = ', '.join(select_fields)
|
||||
|
||||
select_str = ", ".join(select_fields)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
|
||||
|
||||
|
||||
db.execute_sql(insert_sql)
|
||||
logger.info(f"已从备份表恢复数据到 '{table_name}'")
|
||||
|
||||
|
||||
# 5. 验证数据完整性
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
|
||||
|
||||
|
||||
if original_count == new_count:
|
||||
logger.info(f"数据完整性验证通过: {original_count} 行数据")
|
||||
# 删除备份表
|
||||
@@ -630,12 +635,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
else:
|
||||
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count} 行")
|
||||
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
|
||||
|
||||
|
||||
# 记录修复的约束
|
||||
for constraint in constraints_to_fix:
|
||||
logger.info(f"已修复字段 '{constraint['field_name']}': "
|
||||
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
|
||||
|
||||
logger.info(
|
||||
f"已修复字段 '{constraint['field_name']}': "
|
||||
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
|
||||
# 尝试恢复
|
||||
@@ -654,7 +661,7 @@ def check_field_constraints():
|
||||
检查但不修复字段约束,返回不一致的字段信息。
|
||||
用于在修复前预览需要修复的内容。
|
||||
"""
|
||||
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
@@ -665,13 +672,13 @@ def check_field_constraints():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
|
||||
inconsistencies = {}
|
||||
|
||||
|
||||
try:
|
||||
with db:
|
||||
for model in models:
|
||||
@@ -681,49 +688,67 @@ def check_field_constraints():
|
||||
|
||||
# 获取当前表结构信息
|
||||
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||
for row in cursor.fetchall()}
|
||||
|
||||
current_schema = {
|
||||
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
table_inconsistencies = []
|
||||
|
||||
|
||||
# 检查每个模型字段的约束
|
||||
for field_name, field_obj in model._meta.fields.items():
|
||||
if field_name not in current_schema:
|
||||
continue
|
||||
|
||||
current_notnull = current_schema[field_name]['notnull']
|
||||
|
||||
current_notnull = current_schema[field_name]["notnull"]
|
||||
model_allows_null = field_obj.null
|
||||
|
||||
|
||||
if model_allows_null and current_notnull:
|
||||
table_inconsistencies.append({
|
||||
'field_name': field_name,
|
||||
'issue': 'model_allows_null_but_db_not_null',
|
||||
'model_constraint': 'NULL',
|
||||
'db_constraint': 'NOT NULL',
|
||||
'recommended_action': 'allow_null'
|
||||
})
|
||||
table_inconsistencies.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"issue": "model_allows_null_but_db_not_null",
|
||||
"model_constraint": "NULL",
|
||||
"db_constraint": "NOT NULL",
|
||||
"recommended_action": "allow_null",
|
||||
}
|
||||
)
|
||||
elif not model_allows_null and not current_notnull:
|
||||
table_inconsistencies.append({
|
||||
'field_name': field_name,
|
||||
'issue': 'model_not_null_but_db_allows_null',
|
||||
'model_constraint': 'NOT NULL',
|
||||
'db_constraint': 'NULL',
|
||||
'recommended_action': 'disallow_null'
|
||||
})
|
||||
|
||||
table_inconsistencies.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"issue": "model_not_null_but_db_allows_null",
|
||||
"model_constraint": "NOT NULL",
|
||||
"db_constraint": "NULL",
|
||||
"recommended_action": "disallow_null",
|
||||
}
|
||||
)
|
||||
|
||||
if table_inconsistencies:
|
||||
inconsistencies[table_name] = table_inconsistencies
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"检查字段约束时出错: {e}")
|
||||
|
||||
|
||||
return inconsistencies
|
||||
|
||||
|
||||
def fix_image_id():
|
||||
"""
|
||||
修复表情包的 image_id 字段
|
||||
"""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
with db:
|
||||
for img in Images.select():
|
||||
if not img.image_id:
|
||||
img.image_id = str(uuid.uuid4())
|
||||
img.save()
|
||||
logger.info(f"已为表情包 {img.id} 生成新的 image_id: {img.image_id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"修复 image_id 时出错: {e}")
|
||||
|
||||
|
||||
# 模块加载时调用初始化函数
|
||||
initialize_database(sync_constraints=True)
|
||||
|
||||
|
||||
|
||||
|
||||
fix_image_id()
|
||||
|
||||
@@ -339,24 +339,18 @@ MODULE_COLORS = {
|
||||
# 67 :具体的颜色编号(0-255),这里是较暗的蓝色
|
||||
"sender": "\033[38;5;24m", # 67号色,较暗的蓝色,适合不显眼的日志
|
||||
"send_api": "\033[38;5;24m", # 208号色,橙色,适合突出显示
|
||||
|
||||
# 生成
|
||||
"replyer": "\033[38;5;208m", # 橙色
|
||||
"llm_api": "\033[38;5;208m", # 橙色
|
||||
|
||||
# 消息处理
|
||||
"chat": "\033[38;5;82m", # 亮蓝色
|
||||
"chat_image": "\033[38;5;68m", # 浅蓝色
|
||||
|
||||
#emoji
|
||||
# emoji
|
||||
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
|
||||
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
|
||||
|
||||
# 核心模块
|
||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||
|
||||
"memory": "\033[38;5;34m", # 天蓝色
|
||||
|
||||
"config": "\033[93m", # 亮黄色
|
||||
"common": "\033[95m", # 亮紫色
|
||||
"tools": "\033[96m", # 亮青色
|
||||
@@ -367,24 +361,17 @@ MODULE_COLORS = {
|
||||
"llm_models": "\033[36m", # 青色
|
||||
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||
"planner": "\033[36m",
|
||||
|
||||
|
||||
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"hfc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"bc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
||||
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
||||
"background_tasks": "\033[38;5;240m", # 灰色
|
||||
"chat_message": "\033[38;5;45m", # 青色
|
||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# 专注聊天模块
|
||||
|
||||
"memory_activator": "\033[38;5;117m", # 天蓝色
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
@@ -412,7 +399,6 @@ MODULE_COLORS = {
|
||||
# 工具和实用模块
|
||||
"prompt_build": "\033[38;5;105m", # 紫色
|
||||
"chat_utils": "\033[38;5;111m", # 蓝色
|
||||
|
||||
"maibot_statistic": "\033[38;5;129m", # 紫色
|
||||
# 特殊功能插件
|
||||
"mute_plugin": "\033[38;5;240m", # 灰色
|
||||
@@ -420,7 +406,7 @@ MODULE_COLORS = {
|
||||
"tts_action": "\033[38;5;58m", # 深黄色
|
||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||
# Action组件
|
||||
"no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
@@ -433,9 +419,7 @@ MODULE_COLORS = {
|
||||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
# s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
"conflict_tracker": "\033[38;5;82m", # 柔和的粉色,不显眼但保持粉色系
|
||||
}
|
||||
|
||||
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
|
||||
@@ -447,10 +431,8 @@ MODULE_ALIASES = {
|
||||
"llm_api": "生成API",
|
||||
"emoji": "表情包",
|
||||
"emoji_api": "表情包API",
|
||||
|
||||
"chat": "所见",
|
||||
"chat_image": "识图",
|
||||
|
||||
"action_manager": "动作",
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
@@ -460,7 +442,6 @@ MODULE_ALIASES = {
|
||||
"memory": "记忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
|
||||
"plugin_manager": "插件",
|
||||
"relationship_builder": "关系",
|
||||
"llm_models": "模型",
|
||||
|
||||
@@ -81,7 +81,8 @@ def find_messages(
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
|
||||
query = query.where(~Messages.is_command)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
|
||||
@@ -102,9 +102,6 @@ class ModelTaskConfig(ConfigBase):
|
||||
replyer: TaskConfig
|
||||
"""normal_chat首要回复模型模型配置"""
|
||||
|
||||
emotion: TaskConfig
|
||||
"""情绪模型配置"""
|
||||
|
||||
vlm: TaskConfig
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
@@ -117,9 +114,6 @@ class ModelTaskConfig(ConfigBase):
|
||||
planner: TaskConfig
|
||||
"""规划模型配置"""
|
||||
|
||||
planner_small: TaskConfig
|
||||
"""副规划模型配置"""
|
||||
|
||||
embedding: TaskConfig
|
||||
"""嵌入模型配置"""
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ from src.config.official_configs import (
|
||||
ExpressionConfig,
|
||||
ChatConfig,
|
||||
EmojiConfig,
|
||||
MemoryConfig,
|
||||
MoodConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
@@ -32,8 +30,9 @@ from src.config.official_configs import (
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
MoodConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -56,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.10.2-snapshot.3"
|
||||
MMC_VERSION = "0.11.0-snapshot.3"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
@@ -114,7 +113,7 @@ def set_value_by_path(d, path, value):
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
|
||||
|
||||
# 使用 tomlkit.item 来保持 TOML 格式
|
||||
try:
|
||||
d[path[-1]] = tomlkit.item(value)
|
||||
@@ -175,13 +174,8 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
|
||||
_update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
@@ -253,7 +247,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
config_updated = True
|
||||
|
||||
|
||||
# 如果配置有更新,立即保存到文件
|
||||
if config_updated:
|
||||
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||
@@ -347,8 +341,6 @@ class Config(ConfigBase):
|
||||
message_receive: MessageReceiveConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
memory: MemoryConfig
|
||||
mood: MoodConfig
|
||||
keyword_reaction: KeywordReactionConfig
|
||||
chinese_typo: ChineseTypoConfig
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
@@ -358,8 +350,9 @@ class Config(ConfigBase):
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
memory: MemoryConfig
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
mood: MoodConfig
|
||||
voice: VoiceConfig
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import re
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
import time
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
@@ -38,15 +39,22 @@ class PersonalityConfig(ConfigBase):
|
||||
personality: str
|
||||
"""人格"""
|
||||
|
||||
emotion_style: str
|
||||
"""情感特征"""
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
|
||||
|
||||
interest: str = ""
|
||||
"""兴趣"""
|
||||
|
||||
plan_style: str = ""
|
||||
"""说话规则,行为风格"""
|
||||
|
||||
visual_style: str = ""
|
||||
"""图片提示词"""
|
||||
|
||||
private_plan_style: str = ""
|
||||
"""私聊说话规则,行为风格"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationshipConfig(ConfigBase):
|
||||
"""关系配置类"""
|
||||
@@ -61,56 +69,221 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
|
||||
interest_rate_mode: Literal["fast", "accurate"] = "fast"
|
||||
"""兴趣值计算模式,fast为快速计算,accurate为精确计算"""
|
||||
|
||||
mentioned_bot_reply: float = 1
|
||||
"""提及 bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
planner_size: float = 1.5
|
||||
"""副规划器大小,越小,麦麦的动作执行能力越精细,但是消耗更多token,调大可以缓解429类错误"""
|
||||
|
||||
mentioned_bot_reply: bool = True
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
auto_chat_value: float = 1
|
||||
"""自动聊天,越小,麦麦主动聊天的概率越低"""
|
||||
|
||||
at_bot_inevitable_reply: float = 1
|
||||
"""@bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
talk_frequency: float = 0.5
|
||||
"""回复频率阈值"""
|
||||
|
||||
# 合并后的时段频率配置
|
||||
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
planner_smooth: float = 3
|
||||
"""规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-5,0为关闭,必须大于等于0"""
|
||||
|
||||
talk_value: float = 1
|
||||
"""思考频率"""
|
||||
|
||||
focus_value: float = 0.5
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
|
||||
talk_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
统一的活跃度和专注度配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
|
||||
全局配置示例:
|
||||
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||
|
||||
特定聊天流配置示例:
|
||||
思考频率规则列表,支持按聊天流/按日内时段配置。
|
||||
规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
|
||||
["", "09:00-22:59", 1.0], # 全局规则:白天正常
|
||||
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
说明:
|
||||
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
|
||||
- 优先级:特定聊天流配置 > 全局配置 > 默认值
|
||||
|
||||
注意:
|
||||
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
|
||||
- focus_value_adjust 控制专注思考能力,数值越低越容易专注,消耗token也越多
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
|
||||
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
自动聊天频率规则列表,支持按聊天流/按日内时段配置。
|
||||
规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
|
||||
["", "09:00-22:59", 1.0], # 全局规则:白天正常
|
||||
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
is_group = stream_type == "group"
|
||||
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def _now_minutes(self) -> int:
|
||||
"""返回本地时间的分钟数(0-1439)。"""
|
||||
lt = time.localtime()
|
||||
return lt.tm_hour * 60 + lt.tm_min
|
||||
|
||||
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
|
||||
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||
try:
|
||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||
sh, sm = [int(x) for x in start_str.split(":")]
|
||||
eh, em = [int(x) for x in end_str.split(":")]
|
||||
return sh * 60 + sm, eh * 60 + em
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
|
||||
"""
|
||||
判断 now_min 是否在 [start_min, end_min] 区间内。
|
||||
支持跨夜:如果 start > end,则表示跨越午夜。
|
||||
"""
|
||||
if start_min <= end_min:
|
||||
return start_min <= now_min <= end_min
|
||||
# 跨夜:例如 23:00-02:00
|
||||
return now_min >= start_min or now_min <= end_min
|
||||
|
||||
def get_talk_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
|
||||
if not self.talk_value_rules:
|
||||
return self.talk_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
# 1) 先尝试匹配指定 chat 的规则
|
||||
if chat_id:
|
||||
for rule in self.talk_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", "")
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
# 跳过全局
|
||||
if target == "":
|
||||
continue
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
|
||||
if config_chat_id is None or config_chat_id != chat_id:
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 再匹配全局规则("")
|
||||
for rule in self.talk_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", None)
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if target != "" or not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.talk_value
|
||||
|
||||
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 auto_chat_value,未匹配则回退到基础值。"""
|
||||
if not self.auto_chat_value_rules:
|
||||
return self.auto_chat_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
# 1) 先尝试匹配指定 chat 的规则
|
||||
if chat_id:
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", "")
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
# 跳过全局
|
||||
if target == "":
|
||||
continue
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
|
||||
if config_chat_id is None or config_chat_id != chat_id:
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 再匹配全局规则("")
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", None)
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if target != "" or not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.auto_chat_value
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -123,10 +296,23 @@ class MessageReceiveConfig(ConfigBase):
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
max_memory_number: int = 100
|
||||
"""记忆最大数量"""
|
||||
|
||||
memory_build_frequency: int = 1
|
||||
"""记忆构建频率"""
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: Literal["llm", "context", "full-context"] = "context"
|
||||
"""表达方式模式,可选:llm模式,context上下文模式,full-context 完整上下文嵌入模式"""
|
||||
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
@@ -287,6 +473,19 @@ class ToolConfig(ConfigBase):
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = True
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||
"""情感特征,影响情绪的变化情况"""
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -321,37 +520,6 @@ class EmojiConfig(ConfigBase):
|
||||
"""表情包过滤要求"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
enable_memory: bool = True
|
||||
"""是否启用记忆系统"""
|
||||
|
||||
forget_memory_interval: int = 1500
|
||||
"""记忆遗忘间隔(秒)"""
|
||||
|
||||
memory_forget_time: int = 24
|
||||
"""记忆遗忘时间(小时)"""
|
||||
|
||||
memory_forget_percentage: float = 0.01
|
||||
"""记忆遗忘比例"""
|
||||
|
||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||
"""不允许记忆的词列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = False
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1.0
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordRuleConfig(ConfigBase):
|
||||
"""关键词规则配置类"""
|
||||
@@ -399,14 +567,6 @@ class KeywordReactionConfig(ConfigBase):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomPromptConfig(ConfigBase):
|
||||
"""自定义提示词配置类"""
|
||||
|
||||
image_prompt: str = ""
|
||||
"""图片提示词"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
@@ -475,9 +635,6 @@ class ExperimentalConfig(ConfigBase):
|
||||
enable_friend_chat: bool = False
|
||||
"""是否启用好友聊天"""
|
||||
|
||||
pfc_chatting: bool = False
|
||||
"""是否启用PFC"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
|
||||
@@ -65,39 +65,6 @@ class RespParseException(Exception):
|
||||
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||
|
||||
|
||||
class PayLoadTooLargeError(Exception):
|
||||
"""自定义异常类,用于处理请求体过大错误"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return "请求体过大,请尝试压缩图片或减少输入内容。"
|
||||
|
||||
|
||||
class RequestAbortException(Exception):
|
||||
"""自定义异常类,用于处理请求中断异常"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
"""自定义异常类,用于处理访问拒绝的异常"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class EmptyResponseException(Exception):
|
||||
"""响应内容为空"""
|
||||
|
||||
@@ -107,3 +74,15 @@ class EmptyResponseException(Exception):
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class ModelAttemptFailed(Exception):
|
||||
"""当在单个模型上的所有重试都失败后,由“执行者”函数抛出,以通知“调度器”切换模型。"""
|
||||
|
||||
def __init__(self, message: str, original_exception: Exception | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.original_exception = original_exception
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
@@ -72,8 +72,8 @@ class BaseClient(ABC):
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
@@ -117,6 +117,7 @@ class BaseClient(ABC):
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
@@ -174,7 +175,7 @@ class ClientRegistry:
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
@@ -17,6 +17,7 @@ from google.genai.types import (
|
||||
EmbedContentResponse,
|
||||
EmbedContentConfig,
|
||||
SafetySetting,
|
||||
HttpOptions,
|
||||
HarmCategory,
|
||||
HarmBlockThreshold,
|
||||
)
|
||||
@@ -182,6 +183,14 @@ def _process_delta(
|
||||
if delta.text:
|
||||
fc_delta_buffer.write(delta.text)
|
||||
|
||||
# 处理 thought(Gemini 的特殊字段)
|
||||
for c in getattr(delta, "candidates", []):
|
||||
if c.content and getattr(c.content, "parts", None):
|
||||
for p in c.content.parts:
|
||||
if getattr(p, "thought", False) and getattr(p, "text", None):
|
||||
# 把 thought 写入 buffer,避免 resp.content 永远为空
|
||||
fc_delta_buffer.write(p.text)
|
||||
|
||||
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
||||
for call in delta.function_calls:
|
||||
try:
|
||||
@@ -203,6 +212,7 @@ def _process_delta(
|
||||
def _build_stream_api_resp(
|
||||
_fc_delta_buffer: io.StringIO,
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]],
|
||||
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
|
||||
) -> APIResponse:
|
||||
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
||||
resp = APIResponse()
|
||||
@@ -227,6 +237,21 @@ def _build_stream_api_resp(
|
||||
|
||||
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||
|
||||
# 检查是否因为 max_tokens 截断
|
||||
reason = None
|
||||
if last_resp and getattr(last_resp, "candidates", None):
|
||||
c0 = last_resp.candidates[0]
|
||||
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||
|
||||
if str(reason).endswith("MAX_TOKENS"):
|
||||
if resp.content and resp.content.strip():
|
||||
logger.warning(
|
||||
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
|
||||
|
||||
if not resp.content and not resp.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
@@ -245,12 +270,14 @@ async def _default_stream_response_handler(
|
||||
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||
_usage_record = None # 使用情况记录
|
||||
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
|
||||
|
||||
def _insure_buffer_closed():
|
||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||
_fc_delta_buffer.close()
|
||||
|
||||
async for chunk in resp_stream:
|
||||
last_resp = chunk # 保存最后一个响应
|
||||
# 检查是否有中断量
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
# 如果中断量被设置,则抛出ReqAbortException
|
||||
@@ -269,10 +296,12 @@ async def _default_stream_response_handler(
|
||||
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||
chunk.usage_metadata.total_token_count or 0,
|
||||
)
|
||||
|
||||
try:
|
||||
return _build_stream_api_resp(
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
last_resp=last_resp,
|
||||
), _usage_record
|
||||
except Exception:
|
||||
# 确保缓冲区被关闭
|
||||
@@ -332,6 +361,35 @@ def _default_normal_response_parser(
|
||||
|
||||
api_response.raw_data = resp
|
||||
|
||||
# 检查是否因为 max_tokens 截断
|
||||
try:
|
||||
if resp.candidates:
|
||||
c0 = resp.candidates[0]
|
||||
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||
if reason and "MAX_TOKENS" in str(reason):
|
||||
# 检查第二个及之后的 parts 是否有内容
|
||||
has_real_output = False
|
||||
if getattr(c0, "content", None) and getattr(c0.content, "parts", None):
|
||||
for p in c0.content.parts[1:]: # 跳过第一个 thought
|
||||
if getattr(p, "text", None) and p.text.strip():
|
||||
has_real_output = True
|
||||
break
|
||||
|
||||
if not has_real_output and getattr(resp, "text", None):
|
||||
has_real_output = True
|
||||
|
||||
if has_real_output:
|
||||
logger.warning(
|
||||
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
|
||||
|
||||
return api_response, _usage_record
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||
|
||||
# 最终的、唯一的空响应检查
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||||
@@ -345,17 +403,45 @@ class GeminiClient(BaseClient):
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
|
||||
# 增加传入参数处理
|
||||
http_options_kwargs: Dict[str, Any] = {}
|
||||
|
||||
# 秒转换为毫秒传入
|
||||
if api_provider.timeout is not None:
|
||||
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
|
||||
|
||||
# 传入并处理地址和版本(必须为Gemini格式)
|
||||
if api_provider.base_url:
|
||||
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("v"):
|
||||
http_options_kwargs["base_url"] = f"{parts[0]}/"
|
||||
http_options_kwargs["api_version"] = parts[1]
|
||||
else:
|
||||
http_options_kwargs["base_url"] = api_provider.base_url
|
||||
http_options_kwargs["api_version"] = None
|
||||
self.client = genai.Client(
|
||||
http_options=HttpOptions(**http_options_kwargs),
|
||||
api_key=api_provider.api_key,
|
||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||
|
||||
@staticmethod
|
||||
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
||||
def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
|
||||
"""
|
||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||
"""
|
||||
limits = None
|
||||
|
||||
# 参数传入处理
|
||||
tb = THINKING_BUDGET_AUTO
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
||||
)
|
||||
|
||||
# 优先尝试精确匹配
|
||||
if model_id in THINKING_BUDGET_LIMITS:
|
||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||
@@ -368,20 +454,29 @@ class GeminiClient(BaseClient):
|
||||
limits = THINKING_BUDGET_LIMITS[key]
|
||||
break
|
||||
|
||||
# 特殊值处理
|
||||
# 预算值处理
|
||||
if tb == THINKING_BUDGET_AUTO:
|
||||
return THINKING_BUDGET_AUTO
|
||||
if tb == THINKING_BUDGET_DISABLED:
|
||||
if limits and limits.get("can_disable", False):
|
||||
return THINKING_BUDGET_DISABLED
|
||||
return limits["min"] if limits else THINKING_BUDGET_AUTO
|
||||
if limits:
|
||||
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
|
||||
return limits["min"]
|
||||
return THINKING_BUDGET_AUTO
|
||||
|
||||
# 已知模型裁剪到范围
|
||||
# 已知模型范围裁剪 + 提示
|
||||
if limits:
|
||||
return max(limits["min"], min(tb, limits["max"]))
|
||||
if tb < limits["min"]:
|
||||
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
|
||||
return limits["min"]
|
||||
if tb > limits["max"]:
|
||||
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
|
||||
return limits["max"]
|
||||
return tb
|
||||
|
||||
# 未知模型,返回动态模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
||||
# 未知模型 → 默认自动模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
|
||||
return THINKING_BUDGET_AUTO
|
||||
|
||||
async def get_response(
|
||||
@@ -389,8 +484,8 @@ class GeminiClient(BaseClient):
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: Optional[int] = 1024,
|
||||
temperature: Optional[float] = 0.4,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
@@ -429,16 +524,8 @@ class GeminiClient(BaseClient):
|
||||
messages = _convert_messages(message_list)
|
||||
# 将tool_options转换为Gemini API所需的格式
|
||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||
|
||||
tb = THINKING_BUDGET_AUTO
|
||||
# 空处理
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
|
||||
# 裁剪到模型支持的范围
|
||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||
# 解析并裁剪 thinking_budget
|
||||
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||
|
||||
# 将response_format转换为Gemini API所需的格式
|
||||
generation_config_dict = {
|
||||
@@ -497,15 +584,20 @@ class GeminiClient(BaseClient):
|
||||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
# 重封装 ClientError 和 ServerError 为 RespNotOkException
|
||||
raise RespNotOkException(e.code, e.message) from None
|
||||
except (
|
||||
UnknownFunctionCallArgumentError,
|
||||
UnsupportedFunctionError,
|
||||
FunctionInvocationError,
|
||||
) as e:
|
||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
||||
# 工具调用相关错误
|
||||
raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None
|
||||
except EmptyResponseException as e:
|
||||
# 保持原始异常,便于区分“空响应”和网络异常
|
||||
raise e
|
||||
except Exception as e:
|
||||
# 其他未预料的错误,才归为网络连接类
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
if usage_record:
|
||||
@@ -561,41 +653,51 @@ class GeminiClient(BaseClient):
|
||||
|
||||
return response
|
||||
|
||||
def get_audio_transcriptions(
|
||||
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = 2048,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录
|
||||
:param model_info: 模型信息
|
||||
:param audio_base64: 音频文件的Base64编码字符串
|
||||
:param max_tokens: 最大输出token数(默认2048)
|
||||
:param extra_params: 额外参数(可选)
|
||||
:return: 转录响应
|
||||
"""
|
||||
# 解析并裁剪 thinking_budget
|
||||
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||
|
||||
# 构造 prompt + 音频输入
|
||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||
contents = [
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
generation_config_dict = {
|
||||
"max_output_tokens": 2048,
|
||||
"max_output_tokens": max_tokens,
|
||||
"response_modalities": ["TEXT"],
|
||||
"thinking_config": ThinkingConfig(
|
||||
include_thoughts=True,
|
||||
thinking_budget=(
|
||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||
),
|
||||
thinking_budget=tb,
|
||||
),
|
||||
"safety_settings": gemini_safe_settings,
|
||||
}
|
||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||
|
||||
try:
|
||||
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
||||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||
model=model_info.model_identifier,
|
||||
contents=[
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||
],
|
||||
)
|
||||
],
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
)
|
||||
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||
|
||||
@@ -403,8 +403,8 @@ class OpenaiClient(BaseClient):
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = 1024,
|
||||
temperature: Optional[float] = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
@@ -488,6 +488,9 @@ class OpenaiClient(BaseClient):
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||
|
||||
# logger.
|
||||
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
|
||||
|
||||
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
|
||||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
@@ -507,6 +510,8 @@ class OpenaiClient(BaseClient):
|
||||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
# logger.debug(f"OpenAI API响应: {resp}")
|
||||
|
||||
return resp
|
||||
|
||||
async def get_embedding(
|
||||
@@ -531,7 +536,7 @@ class OpenaiClient(BaseClient):
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, '__cause__') and e.__cause__:
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
@@ -555,7 +560,7 @@ class OpenaiClient(BaseClient):
|
||||
model_name=model_info.name,
|
||||
provider_name=model_info.api_provider,
|
||||
prompt_tokens=raw_response.usage.prompt_tokens or 0,
|
||||
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
|
||||
completion_tokens=getattr(raw_response.usage, "completion_tokens", 0),
|
||||
total_tokens=raw_response.usage.total_tokens or 0,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .tool_option import ToolCall
|
||||
|
||||
__all__ = ["ToolCall"]
|
||||
__all__ = ["ToolCall"]
|
||||
|
||||
@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
|
||||
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
return "schema的'name'字段必须是非空字符串"
|
||||
if "description" in instance and (
|
||||
not isinstance(instance["description"], str)
|
||||
or instance["description"].strip() == ""
|
||||
not isinstance(instance["description"], str) or instance["description"].strip() == ""
|
||||
):
|
||||
return "schema的'description'字段只能填入非空字符串"
|
||||
if "schema" not in instance:
|
||||
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(
|
||||
f"{path}/{str(i)}", sub_schema[i], defs
|
||||
)
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
# 如果当前值是字典或列表,则递归调用
|
||||
sub_schema[key] = link_definitions_recursive(
|
||||
f"{path}/{key}", value, defs
|
||||
)
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
|
||||
|
||||
return sub_schema
|
||||
|
||||
@@ -163,9 +158,7 @@ class RespFormat:
|
||||
def _generate_schema_from_model(schema):
|
||||
json_schema = {
|
||||
"name": schema.__name__,
|
||||
"schema": _remove_defs(
|
||||
_link_definitions(_remove_title(schema.model_json_schema()))
|
||||
),
|
||||
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
|
||||
"strict": False,
|
||||
}
|
||||
if schema.__doc__:
|
||||
|
||||
@@ -155,7 +155,13 @@ class LLMUsageRecorder:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
model_usage: UsageRecord,
|
||||
user_id: str,
|
||||
request_type: str,
|
||||
endpoint: str,
|
||||
time_cost: float = 0.0,
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
time_cost = round(time_cost or 0.0, 3),
|
||||
time_cost=round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
@@ -4,7 +4,8 @@ import time
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
||||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
@@ -16,28 +17,15 @@ from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import (
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
EmptyResponseException,
|
||||
ModelAttemptFailed,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
# 常见Error Code Mapping
|
||||
error_code_mapping = {
|
||||
400: "参数不正确",
|
||||
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
|
||||
402: "账号余额不足",
|
||||
403: "需要实名,或余额不足",
|
||||
404: "Not Found",
|
||||
429: "请求过于频繁,请稍后再试",
|
||||
500: "服务器内部故障",
|
||||
503: "服务器负载过高",
|
||||
}
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
"""请求类型枚举"""
|
||||
@@ -76,32 +64,25 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 模型选择
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
message_builder.add_image_content(
|
||||
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
|
||||
)
|
||||
messages = [message_builder.build()]
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
message_builder.add_image_content(
|
||||
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
|
||||
)
|
||||
return [message_builder.build()]
|
||||
|
||||
# 请求并处理返回值
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
message_factory=message_factory,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
content = response.content or ""
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
@@ -124,15 +105,8 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Optional[str]): 生成的文本描述或None
|
||||
"""
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求并处理返回值
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
response, _ = await self._execute_request(
|
||||
request_type=RequestType.AUDIO,
|
||||
model_info=model_info,
|
||||
audio_base64=voice_base64,
|
||||
)
|
||||
return response.content or None
|
||||
@@ -151,43 +125,37 @@ class LLMRequest:
|
||||
prompt (str): 提示词
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
||||
raise_when_empty (bool): 当响应为空时是否抛出异常
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
start_time = time.time()
|
||||
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
return [message_builder.build()]
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求并处理返回值
|
||||
logger.debug(f"LLM选择耗时: {model_info.name} {time.time() - start_time}")
|
||||
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
message_factory=message_factory,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
||||
logger.debug(f"LLM生成内容: {response}")
|
||||
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
@@ -197,31 +165,22 @@ class LLMRequest:
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
"""
|
||||
获取嵌入向量
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
# 无需构建消息体,直接使用输入文本
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求并处理返回值
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.EMBEDDING,
|
||||
model_info=model_info,
|
||||
embedding_input=embedding_input,
|
||||
)
|
||||
|
||||
embedding = response.embedding
|
||||
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
@@ -231,59 +190,61 @@ class LLMRequest:
|
||||
endpoint="/embeddings",
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
|
||||
if not embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return embedding, model_info.name
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
根据总tokens和惩罚值选择的模型
|
||||
"""
|
||||
available_models = {
|
||||
model: scores
|
||||
for model, scores in self.model_usage.items()
|
||||
if not exclude_models or model not in exclude_models
|
||||
}
|
||||
if not available_models:
|
||||
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
|
||||
|
||||
least_used_model_name = min(
|
||||
self.model_usage,
|
||||
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
|
||||
available_models,
|
||||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
|
||||
return model_info, api_provider, client
|
||||
|
||||
async def _execute_request(
|
||||
async def _attempt_request_on_model(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
api_provider: APIProvider,
|
||||
client: BaseClient,
|
||||
request_type: RequestType,
|
||||
model_info: ModelInfo,
|
||||
message_list: List[Message] | None = None,
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[Callable] = None,
|
||||
async_response_parser: Optional[Callable] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
embedding_input: str = "",
|
||||
audio_base64: str = "",
|
||||
message_list: List[Message],
|
||||
tool_options: list[ToolOption] | None,
|
||||
response_format: RespFormat | None,
|
||||
stream_response_handler: Optional[Callable],
|
||||
async_response_parser: Optional[Callable],
|
||||
temperature: Optional[float],
|
||||
max_tokens: Optional[int],
|
||||
embedding_input: str | None,
|
||||
audio_base64: str | None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
实际执行请求的方法
|
||||
|
||||
包含了重试和异常处理逻辑
|
||||
在单个模型上执行请求,包含针对临时错误的重试逻辑。
|
||||
如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。
|
||||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
if request_type == RequestType.RESPONSE:
|
||||
assert message_list is not None, "message_list cannot be None for response requests"
|
||||
return await client.get_response(
|
||||
model_info=model_info,
|
||||
message_list=(compressed_messages or message_list),
|
||||
@@ -296,201 +257,125 @@ class LLMRequest:
|
||||
extra_params=model_info.extra_params,
|
||||
)
|
||||
elif request_type == RequestType.EMBEDDING:
|
||||
assert embedding_input, "embedding_input cannot be empty for embedding requests"
|
||||
assert embedding_input is not None, "嵌入输入不能为空"
|
||||
return await client.get_embedding(
|
||||
model_info=model_info,
|
||||
embedding_input=embedding_input,
|
||||
extra_params=model_info.extra_params,
|
||||
)
|
||||
elif request_type == RequestType.AUDIO:
|
||||
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
|
||||
assert audio_base64 is not None, "音频Base64不能为空"
|
||||
return await client.get_audio_transcriptions(
|
||||
model_info=model_info,
|
||||
audio_base64=audio_base64,
|
||||
extra_params=model_info.extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
# 处理异常
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
|
||||
|
||||
wait_interval, compressed_messages = self._default_exception_handler(
|
||||
e,
|
||||
self.task_name,
|
||||
model_name=model_info.name,
|
||||
remain_try=retry_remain,
|
||||
retry_interval=api_provider.retry_interval,
|
||||
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||
)
|
||||
|
||||
if wait_interval == -1:
|
||||
retry_remain = 0 # 不再重试
|
||||
elif wait_interval > 0:
|
||||
logger.info(f"等待 {wait_interval} 秒后重试...")
|
||||
await asyncio.sleep(wait_interval)
|
||||
finally:
|
||||
# 放在finally防止死循环
|
||||
except (EmptyResponseException, NetworkConnectionError) as e:
|
||||
retry_remain -= 1
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在用尽对临时错误的重试次数后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
def _default_exception_handler(
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到可重试错误: {str(e)}。剩余重试次数: {retry_remain}")
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except RespNotOkException as e:
|
||||
# 可重试的HTTP错误
|
||||
if e.status_code == 429 or e.status_code >= 500:
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
continue
|
||||
|
||||
# 特殊处理413,尝试压缩
|
||||
if e.status_code == 413 and message_list and not compressed_messages:
|
||||
logger.warning(f"模型 '{model_info.name}' 返回413请求体过大,尝试压缩后重试...")
|
||||
# 压缩消息本身不消耗重试次数
|
||||
compressed_messages = compress_messages(message_list)
|
||||
continue
|
||||
|
||||
# 不可重试的HTTP错误
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。")
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
e: Exception,
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
remain_try: int,
|
||||
retry_interval: int = 10,
|
||||
messages: Tuple[List[Message], bool] | None = None,
|
||||
) -> Tuple[int, List[Message] | None]:
|
||||
request_type: RequestType,
|
||||
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[Callable] = None,
|
||||
async_response_parser: Optional[Callable] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
embedding_input: str | None = None,
|
||||
audio_base64: str | None = None,
|
||||
) -> Tuple[APIResponse, ModelInfo]:
|
||||
"""
|
||||
默认异常处理函数
|
||||
Args:
|
||||
e (Exception): 异常对象
|
||||
task_name (str): 任务名称
|
||||
model_name (str): 模型名称
|
||||
remain_try (int): 剩余尝试次数
|
||||
retry_interval (int): 重试间隔
|
||||
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
|
||||
Returns:
|
||||
(等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
|
||||
调度器函数,负责模型选择、故障切换。
|
||||
"""
|
||||
failed_models_this_request: Set[str] = set()
|
||||
max_attempts = len(self.model_for_task.model_list)
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
if isinstance(e, NetworkConnectionError): # 网络连接错误
|
||||
return self._check_retry(
|
||||
remain_try,
|
||||
retry_interval,
|
||||
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
|
||||
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确",
|
||||
)
|
||||
elif isinstance(e, EmptyResponseException): # 空响应错误
|
||||
return self._check_retry(
|
||||
remain_try,
|
||||
retry_interval,
|
||||
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 收到空响应,将于{retry_interval}秒后重试。原因: {e}",
|
||||
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 收到空响应,超过最大重试次数,放弃请求",
|
||||
)
|
||||
elif isinstance(e, ReqAbortException):
|
||||
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
|
||||
return -1, None # 不再重试请求该模型
|
||||
elif isinstance(e, RespNotOkException):
|
||||
return self._handle_resp_not_ok(
|
||||
e,
|
||||
task_name,
|
||||
model_name,
|
||||
remain_try,
|
||||
retry_interval,
|
||||
messages,
|
||||
)
|
||||
elif isinstance(e, RespParseException):
|
||||
# 响应解析错误
|
||||
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
|
||||
logger.debug(f"附加内容: {str(e.ext_info)}")
|
||||
return -1, None # 不再重试请求该模型
|
||||
else:
|
||||
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
|
||||
return -1, None # 不再重试请求该模型
|
||||
for _ in range(max_attempts):
|
||||
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
||||
|
||||
def _check_retry(
|
||||
self,
|
||||
remain_try: int,
|
||||
retry_interval: int,
|
||||
can_retry_msg: str,
|
||||
cannot_retry_msg: str,
|
||||
can_retry_callable: Callable | None = None,
|
||||
**kwargs,
|
||||
) -> Tuple[int, List[Message] | None]:
|
||||
"""辅助函数:检查是否可以重试
|
||||
Args:
|
||||
remain_try (int): 剩余尝试次数
|
||||
retry_interval (int): 重试间隔
|
||||
can_retry_msg (str): 可以重试时的提示信息
|
||||
cannot_retry_msg (str): 不可以重试时的提示信息
|
||||
can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
|
||||
**kwargs: 其他参数
|
||||
message_list = []
|
||||
if message_factory:
|
||||
message_list = message_factory(client)
|
||||
|
||||
Returns:
|
||||
(Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
|
||||
"""
|
||||
if remain_try > 0:
|
||||
# 还有重试机会
|
||||
logger.warning(f"{can_retry_msg}")
|
||||
if can_retry_callable is not None:
|
||||
return retry_interval, can_retry_callable(**kwargs)
|
||||
else:
|
||||
return retry_interval, None
|
||||
else:
|
||||
# 达到最大重试次数
|
||||
logger.warning(f"{cannot_retry_msg}")
|
||||
return -1, None # 不再重试请求该模型
|
||||
|
||||
def _handle_resp_not_ok(
|
||||
self,
|
||||
e: RespNotOkException,
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
remain_try: int,
|
||||
retry_interval: int = 10,
|
||||
messages: tuple[list[Message], bool] | None = None,
|
||||
):
|
||||
"""
|
||||
处理响应错误异常
|
||||
Args:
|
||||
e (RespNotOkException): 响应错误异常对象
|
||||
task_name (str): 任务名称
|
||||
model_name (str): 模型名称
|
||||
remain_try (int): 剩余尝试次数
|
||||
retry_interval (int): 重试间隔
|
||||
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
|
||||
Returns:
|
||||
(等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
|
||||
"""
|
||||
# 响应错误
|
||||
if e.status_code in [400, 401, 402, 403, 404]:
|
||||
# 客户端错误
|
||||
logger.warning(
|
||||
f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
|
||||
)
|
||||
return -1, None # 不再重试请求该模型
|
||||
elif e.status_code == 413:
|
||||
if messages and not messages[1]:
|
||||
# 消息列表不为空且未压缩,尝试压缩消息
|
||||
return self._check_retry(
|
||||
remain_try,
|
||||
0,
|
||||
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
|
||||
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
|
||||
can_retry_callable=compress_messages,
|
||||
messages=messages[0],
|
||||
try:
|
||||
response = await self._attempt_request_on_model(
|
||||
model_info,
|
||||
api_provider,
|
||||
client,
|
||||
request_type,
|
||||
message_list=message_list,
|
||||
tool_options=tool_options,
|
||||
response_format=response_format,
|
||||
stream_response_handler=stream_response_handler,
|
||||
async_response_parser=async_response_parser,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
embedding_input=embedding_input,
|
||||
audio_base64=audio_base64,
|
||||
)
|
||||
# 没有消息可压缩
|
||||
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
|
||||
return -1, None
|
||||
elif e.status_code == 429:
|
||||
# 请求过于频繁
|
||||
return self._check_retry(
|
||||
remain_try,
|
||||
retry_interval,
|
||||
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
|
||||
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
# 服务器错误
|
||||
return self._check_retry(
|
||||
remain_try,
|
||||
retry_interval,
|
||||
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
|
||||
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
|
||||
)
|
||||
else:
|
||||
# 未知错误
|
||||
logger.warning(
|
||||
f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
|
||||
)
|
||||
return -1, None
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
if response_usage := response.usage:
|
||||
total_tokens += response_usage.total_tokens
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||
return response, model_info
|
||||
|
||||
except ModelAttemptFailed as e:
|
||||
last_exception = e.original_exception or e
|
||||
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
|
||||
failed_models_this_request.add(model_info.name)
|
||||
|
||||
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
|
||||
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
|
||||
raise last_exception from e
|
||||
|
||||
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise RuntimeError("请求失败,所有可用模型均已尝试失败。")
|
||||
|
||||
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
# sourcery skip: extract-method
|
||||
|
||||
54
src/main.py
54
src/main.py
@@ -13,8 +13,8 @@ from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server, Server
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.memory_system.memory_management_task import MemoryManagementTask
|
||||
from rich.traceback import install
|
||||
from src.migrate_helper.migrate import check_and_run_migrations
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
@@ -23,10 +23,6 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
|
||||
# 条件导入记忆系统
|
||||
if global_config.memory.enable_memory:
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -36,11 +32,6 @@ logger = get_logger("main")
|
||||
|
||||
class MainSystem:
|
||||
def __init__(self):
|
||||
# 根据配置条件性地初始化记忆系统
|
||||
self.hippocampus_manager = None
|
||||
if global_config.memory.enable_memory:
|
||||
self.hippocampus_manager = hippocampus_manager
|
||||
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
@@ -92,29 +83,26 @@ class MainSystem:
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
# 启动情绪管理器
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
if global_config.mood.enable_mood:
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
|
||||
# 初始化聊天管理器
|
||||
await get_chat_manager()._initialize()
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# 根据配置条件性地初始化记忆系统
|
||||
if global_config.memory.enable_memory:
|
||||
if self.hippocampus_manager:
|
||||
self.hippocampus_manager.initialize()
|
||||
logger.info("记忆系统初始化成功")
|
||||
else:
|
||||
logger.info("记忆系统已禁用,跳过初始化")
|
||||
|
||||
# 添加记忆管理任务
|
||||
await async_task_manager.add_task(MemoryManagementTask())
|
||||
logger.info("记忆管理任务已启动")
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
||||
self.app.register_message_handler(chat_bot.message_process)
|
||||
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
||||
|
||||
await check_and_run_migrations()
|
||||
|
||||
# 触发 ON_START 事件
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
@@ -138,25 +126,15 @@ class MainSystem:
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
# 根据配置条件性地添加记忆系统相关任务
|
||||
if global_config.memory.enable_memory and self.hippocampus_manager:
|
||||
tasks.extend(
|
||||
[
|
||||
# 移除记忆构建的定期调用,改为在heartFC_chat.py中调用
|
||||
# self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
]
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def forget_memory_task(self):
|
||||
"""记忆遗忘任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
# async def forget_memory_task(self):
|
||||
# """记忆遗忘任务"""
|
||||
# while True:
|
||||
# await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
# logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
# await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||
# logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
[inner]
|
||||
version = "1.0.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
[inner]
|
||||
version = "1.2.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
enable_s4u = false
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
|
||||
enable_streaming_output = true # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-32b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
@@ -1,167 +0,0 @@
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import time
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你之前的内心想法是:{mind}
|
||||
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
|
||||
---------------------
|
||||
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
|
||||
你刚刚选择回复的内容是:{reponse}
|
||||
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
|
||||
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出想法:""",
|
||||
"after_response_think_prompt",
|
||||
)
|
||||
|
||||
|
||||
class MaiThinking:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.platform = self.chat_stream.platform
|
||||
|
||||
if self.chat_stream.group_info:
|
||||
self.is_group = True
|
||||
else:
|
||||
self.is_group = False
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
self.mind = ""
|
||||
|
||||
self.memory_block = ""
|
||||
self.relation_info_block = ""
|
||||
self.time_block = ""
|
||||
self.chat_target = ""
|
||||
self.chat_target_2 = ""
|
||||
self.chat_info = ""
|
||||
self.mood_state = ""
|
||||
self.identity = ""
|
||||
self.sender = ""
|
||||
self.target = ""
|
||||
|
||||
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
|
||||
|
||||
async def do_think_before_response(self):
|
||||
pass
|
||||
|
||||
async def do_think_after_response(self, reponse: str):
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"after_response_think_prompt",
|
||||
mind=self.mind,
|
||||
reponse=reponse,
|
||||
memory_block=self.memory_block,
|
||||
relation_info_block=self.relation_info_block,
|
||||
time_block=self.time_block,
|
||||
chat_target=self.chat_target,
|
||||
chat_target_2=self.chat_target_2,
|
||||
chat_info=self.chat_info,
|
||||
mood_state=self.mood_state,
|
||||
identity=self.identity,
|
||||
sender=self.sender,
|
||||
target=self.target,
|
||||
)
|
||||
|
||||
result, _ = await self.thinking_model.generate_response_async(prompt)
|
||||
self.mind = result
|
||||
|
||||
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||
|
||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||
await self.s4u_message_processor.process_message(msg_recv)
|
||||
internal_manager.set_internal_state(self.mind)
|
||||
|
||||
async def do_think_when_receive_message(self):
|
||||
pass
|
||||
|
||||
async def build_internal_message_recv(self, message_text: str):
|
||||
msg_id = f"internal_{time.time()}"
|
||||
|
||||
message_dict = {
|
||||
"message_info": {
|
||||
"message_id": msg_id,
|
||||
"time": time.time(),
|
||||
"user_info": {
|
||||
"user_id": "internal", # 内部用户ID
|
||||
"user_nickname": "内心", # 内部昵称
|
||||
"platform": self.platform, # 平台标记为 internal
|
||||
# 其他 user_info 字段按需补充
|
||||
},
|
||||
"platform": self.platform, # 平台
|
||||
# 其他 message_info 字段按需补充
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "text", # 消息类型
|
||||
"data": message_text, # 消息内容
|
||||
# 其他 segment 字段按需补充
|
||||
},
|
||||
"raw_message": message_text, # 原始消息内容
|
||||
"processed_plain_text": message_text, # 处理后的纯文本
|
||||
# 下面这些字段可选,根据 MessageRecv 需要
|
||||
"is_emoji": False,
|
||||
"has_emoji": False,
|
||||
"is_picid": False,
|
||||
"has_picid": False,
|
||||
"is_voice": False,
|
||||
"is_mentioned": False,
|
||||
"is_command": False,
|
||||
"is_internal": True,
|
||||
"priority_mode": "interest",
|
||||
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
||||
"interest_value": 1.0,
|
||||
}
|
||||
|
||||
if self.is_group:
|
||||
message_dict["message_info"]["group_info"] = {
|
||||
"platform": self.platform,
|
||||
"group_id": self.chat_stream.group_info.group_id,
|
||||
"group_name": self.chat_stream.group_info.group_name,
|
||||
}
|
||||
|
||||
msg_recv = MessageRecvS4U(message_dict)
|
||||
msg_recv.chat_info = self.chat_info
|
||||
msg_recv.chat_stream = self.chat_stream
|
||||
msg_recv.is_internal = True
|
||||
|
||||
return msg_recv
|
||||
|
||||
|
||||
class MaiThinkingManager:
|
||||
def __init__(self):
|
||||
self.mai_think_list = []
|
||||
|
||||
def get_mai_think(self, chat_id):
|
||||
for mai_think in self.mai_think_list:
|
||||
if mai_think.chat_id == chat_id:
|
||||
return mai_think
|
||||
mai_think = MaiThinking(chat_id)
|
||||
self.mai_think_list.append(mai_think)
|
||||
return mai_think
|
||||
|
||||
|
||||
mai_thinking_manager = MaiThinkingManager()
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -1,342 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("action")
|
||||
|
||||
# 使用字典作为默认值,但通过Prompt来注册以便外部重载
|
||||
DEFAULT_HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
DEFAULT_BODY_CODE = {
|
||||
"双手背后向前弯腰": "010_0070",
|
||||
"歪头双手合十": "010_0100",
|
||||
"标准文静站立": "010_0101",
|
||||
"双手交叠腹部站立": "010_0150",
|
||||
"帅气的姿势": "010_0190",
|
||||
"另一个帅气的姿势": "010_0191",
|
||||
"手掌朝前可爱": "010_0210",
|
||||
"平静,双手后放": "平静,双手后放",
|
||||
"思考": "思考",
|
||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||
"一般": "一般",
|
||||
"可爱,双手前放": "可爱,双手前放",
|
||||
}
|
||||
|
||||
|
||||
async def get_head_code() -> dict:
|
||||
"""获取头部动作代码字典"""
|
||||
head_code_str = await global_prompt_manager.format_prompt("head_code_prompt")
|
||||
if not head_code_str:
|
||||
return DEFAULT_HEAD_CODE
|
||||
try:
|
||||
return json.loads(head_code_str)
|
||||
except Exception as e:
|
||||
logger.error(f"解析head_code_prompt失败,使用默认值: {e}")
|
||||
return DEFAULT_HEAD_CODE
|
||||
|
||||
|
||||
async def get_body_code() -> dict:
|
||||
"""获取身体动作代码字典"""
|
||||
body_code_str = await global_prompt_manager.format_prompt("body_code_prompt")
|
||||
if not body_code_str:
|
||||
return DEFAULT_BODY_CODE
|
||||
try:
|
||||
return json.loads(body_code_str)
|
||||
except Exception as e:
|
||||
logger.error(f"解析body_code_prompt失败,使用默认值: {e}")
|
||||
return DEFAULT_BODY_CODE
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# 注册头部动作代码
|
||||
Prompt(
|
||||
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
|
||||
"head_code_prompt",
|
||||
)
|
||||
|
||||
# 注册身体动作代码
|
||||
Prompt(
|
||||
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
|
||||
"body_code_prompt",
|
||||
)
|
||||
|
||||
# 注册原有提示模板
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里正在进行的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你现在的动作状态是:
|
||||
- 身体动作:{body_action}
|
||||
|
||||
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"change_action_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里最近的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你之前的动作状态是
|
||||
- 身体动作:{body_action}
|
||||
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"regress_action_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatAction:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.body_action: str = "一般"
|
||||
self.head_action: str = "注视摄像机"
|
||||
|
||||
self.regression_count: int = 0
|
||||
# 新增:body_action冷却池,key为动作名,value为剩余冷却次数
|
||||
self.body_action_cooldown: dict[str, int] = {}
|
||||
|
||||
print(s4u_config.models.motion)
|
||||
print(model_config.model_task_config.emotion)
|
||||
|
||||
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def send_action_update(self):
|
||||
"""发送动作更新到前端"""
|
||||
|
||||
body_code = (await get_body_code()).get(self.body_action, "")
|
||||
await send_api.custom_to_stream(
|
||||
message_type="body_action",
|
||||
content=body_code,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
async def update_action_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
# 记录原动作,切换后进入冷却
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 3
|
||||
self.body_action = new_body_action
|
||||
self.head_action = action_data.get("head_action", self.head_action)
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"update_action_by_message error: {e}")
|
||||
|
||||
async def regress_action(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 6
|
||||
self.body_action = new_body_action
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.regression_count += 1
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"regress_action error: {e}")
|
||||
|
||||
# 新增:冷却池维护方法
|
||||
def _update_body_action_cooldown(self):
|
||||
remove_keys = []
|
||||
for k in self.body_action_cooldown:
|
||||
self.body_action_cooldown[k] -= 1
|
||||
if self.body_action_cooldown[k] <= 0:
|
||||
remove_keys.append(k)
|
||||
for k in remove_keys:
|
||||
del self.body_action_cooldown[k]
|
||||
|
||||
|
||||
class ActionRegressionTask(AsyncTask):
|
||||
def __init__(self, action_manager: "ActionManager"):
|
||||
super().__init__(task_name="ActionRegressionTask", run_interval=3)
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Running action regression task...")
|
||||
now = time.time()
|
||||
for action_state in self.action_manager.action_state_list:
|
||||
if action_state.last_change_time == 0:
|
||||
continue
|
||||
|
||||
if now - action_state.last_change_time > 10:
|
||||
if action_state.regression_count >= 3:
|
||||
continue
|
||||
|
||||
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1} 次")
|
||||
await action_state.regress_action()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
def __init__(self):
|
||||
self.action_state_list: list[ChatAction] = []
|
||||
"""当前动作状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动动作回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动动作回归任务...")
|
||||
task = ActionRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
logger.info("动作回归任务已启动")
|
||||
|
||||
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
|
||||
for action_state in self.action_state_list:
|
||||
if action_state.chat_id == chat_id:
|
||||
return action_state
|
||||
|
||||
new_action_state = ChatAction(chat_id)
|
||||
self.action_state_list.append(new_action_state)
|
||||
return new_action_state
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
action_manager = ActionManager()
|
||||
"""全局动作管理器"""
|
||||
@@ -1,685 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import web, WSMsgType
|
||||
import aiohttp_cors
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("context_web")
|
||||
|
||||
|
||||
class ContextMessage:
|
||||
"""上下文消息类"""
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
self.user_name = message.message_info.user_info.user_nickname
|
||||
self.user_id = message.message_info.user_info.user_id
|
||||
self.content = message.processed_plain_text
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, 'is_gift', False)
|
||||
self.is_superchat = getattr(message, 'is_superchat', False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, 'gift_name', '')
|
||||
self.gift_count = getattr(message, 'gift_count', '1')
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, 'superchat_price', '0')
|
||||
self.superchat_message = getattr(message, 'superchat_message_text', '')
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat
|
||||
}
|
||||
|
||||
|
||||
class ContextWebManager:
|
||||
"""上下文网页管理器"""
|
||||
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: List[web.WebSocketResponse] = []
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False # 添加启动标志防止并发
|
||||
|
||||
async def start_server(self):
|
||||
"""启动web服务器"""
|
||||
if self.site is not None:
|
||||
logger.debug("Web服务器已经启动,跳过重复启动")
|
||||
return
|
||||
|
||||
if self._server_starting:
|
||||
logger.debug("Web服务器正在启动中,等待启动完成...")
|
||||
# 等待启动完成
|
||||
while self._server_starting and self.site is None:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
self._server_starting = True
|
||||
|
||||
try:
|
||||
self.app = web.Application()
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(self.app, defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
allow_methods="*"
|
||||
)
|
||||
})
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get('/', self.index_handler)
|
||||
self.app.router.add_get('/ws', self.websocket_handler)
|
||||
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
|
||||
self.app.router.add_get('/debug', self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, 'localhost', self.port)
|
||||
await self.site.start()
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||||
# 清理部分启动的资源
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
raise
|
||||
finally:
|
||||
self._server_starting = False
|
||||
|
||||
async def stop_server(self):
|
||||
"""停止web服务器"""
|
||||
if self.site:
|
||||
await self.site.stop()
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = '''
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>聊天上下文</title>
|
||||
<style>
|
||||
html, body {
|
||||
background: transparent !important;
|
||||
background-color: transparent !important;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
font-family: 'Microsoft YaHei', Arial, sans-serif;
|
||||
color: #ffffff;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
|
||||
}
|
||||
.container {
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
background: transparent !important;
|
||||
}
|
||||
.message {
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
margin: 10px 0;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #00ff88;
|
||||
backdrop-filter: blur(5px);
|
||||
animation: slideIn 0.3s ease-out;
|
||||
transform: translateY(0);
|
||||
transition: transform 0.5s ease, opacity 0.5s ease;
|
||||
}
|
||||
.message:hover {
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
transform: translateX(5px);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.message.gift {
|
||||
border-left: 4px solid #ff8800;
|
||||
background: rgba(255, 136, 0, 0.2);
|
||||
}
|
||||
.message.gift:hover {
|
||||
background: rgba(255, 136, 0, 0.3);
|
||||
}
|
||||
.message.gift .username {
|
||||
color: #ff8800;
|
||||
}
|
||||
.message.superchat {
|
||||
border-left: 4px solid #ff6b6b;
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
|
||||
background-size: 200% 200%;
|
||||
animation: rainbow 3s ease infinite;
|
||||
}
|
||||
.message.superchat:hover {
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
|
||||
background-size: 200% 200%;
|
||||
}
|
||||
.message.superchat .username {
|
||||
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
|
||||
background-size: 300% 300%;
|
||||
animation: rainbow-text 2s ease infinite;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
@keyframes rainbow {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
@keyframes rainbow-text {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
.message-line {
|
||||
line-height: 1.4;
|
||||
word-wrap: break-word;
|
||||
font-size: 24px;
|
||||
}
|
||||
.username {
|
||||
color: #00ff88;
|
||||
}
|
||||
.content {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.new-message {
|
||||
animation: slideInNew 0.6s ease-out;
|
||||
}
|
||||
|
||||
.debug-btn {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
background: rgba(0, 0, 0, 0.7);
|
||||
color: #00ff88;
|
||||
font-size: 12px;
|
||||
padding: 8px 12px;
|
||||
border-radius: 20px;
|
||||
backdrop-filter: blur(10px);
|
||||
z-index: 1000;
|
||||
text-decoration: none;
|
||||
border: 1px solid #00ff88;
|
||||
}
|
||||
.debug-btn:hover {
|
||||
background: rgba(0, 255, 136, 0.2);
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
@keyframes slideInNew {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(50px) scale(0.95);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
.no-messages {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
margin-top: 50px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<a href="/debug" class="debug-btn">🔧 调试</a>
|
||||
<div id="messages">
|
||||
<div class="no-messages">暂无消息</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
let reconnectInterval;
|
||||
let currentMessages = []; // 存储当前显示的消息
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
if (reconnectInterval) {
|
||||
clearInterval(reconnectInterval);
|
||||
reconnectInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
console.log('收到WebSocket消息:', event.data);
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
updateMessages(data.contexts);
|
||||
} catch (e) {
|
||||
console.error('解析消息失败:', e, event.data);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = function(event) {
|
||||
console.log('WebSocket连接关闭:', event.code, event.reason);
|
||||
|
||||
if (!reconnectInterval) {
|
||||
reconnectInterval = setInterval(connectWebSocket, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
console.error('WebSocket错误:', error);
|
||||
};
|
||||
}
|
||||
|
||||
function updateMessages(contexts) {
|
||||
const messagesDiv = document.getElementById('messages');
|
||||
|
||||
if (!contexts || contexts.length === 0) {
|
||||
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
|
||||
currentMessages = [];
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
|
||||
if (currentMessages.length === 0) {
|
||||
console.log('首次加载消息,数量:', contexts.length);
|
||||
messagesDiv.innerHTML = '';
|
||||
|
||||
contexts.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg);
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
});
|
||||
|
||||
currentMessages = [...contexts];
|
||||
window.scrollTo(0, document.body.scrollHeight);
|
||||
return;
|
||||
}
|
||||
|
||||
// 检测新消息 - 使用更可靠的方法
|
||||
const newMessages = findNewMessages(contexts, currentMessages);
|
||||
|
||||
if (newMessages.length > 0) {
|
||||
console.log('添加新消息,数量:', newMessages.length);
|
||||
|
||||
// 先检查是否需要移除老消息(保持DOM清洁)
|
||||
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
|
||||
const currentMessageElements = messagesDiv.querySelectorAll('.message');
|
||||
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
|
||||
|
||||
if (willExceedLimit) {
|
||||
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
|
||||
console.log('需要移除老消息数量:', removeCount);
|
||||
|
||||
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
|
||||
const oldMessage = currentMessageElements[i];
|
||||
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
|
||||
oldMessage.style.opacity = '0';
|
||||
oldMessage.style.transform = 'translateY(-20px)';
|
||||
|
||||
setTimeout(() => {
|
||||
if (oldMessage.parentNode) {
|
||||
oldMessage.parentNode.removeChild(oldMessage);
|
||||
}
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新消息
|
||||
newMessages.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg, true); // true表示是新消息
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
|
||||
// 移除动画类,避免重复动画
|
||||
setTimeout(() => {
|
||||
messageDiv.classList.remove('new-message');
|
||||
}, 600);
|
||||
});
|
||||
|
||||
// 更新当前消息列表
|
||||
currentMessages = [...contexts];
|
||||
|
||||
// 平滑滚动到底部
|
||||
setTimeout(() => {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function findNewMessages(contexts, currentMessages) {
|
||||
// 如果当前消息为空,所有消息都是新的
|
||||
if (currentMessages.length === 0) {
|
||||
return contexts;
|
||||
}
|
||||
|
||||
// 找到最后一条当前消息在新消息列表中的位置
|
||||
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
|
||||
let lastIndex = -1;
|
||||
|
||||
// 从后往前找,因为新消息通常在末尾
|
||||
for (let i = contexts.length - 1; i >= 0; i--) {
|
||||
const msg = contexts[i];
|
||||
if (msg.user_id === lastCurrentMsg.user_id &&
|
||||
msg.content === lastCurrentMsg.content &&
|
||||
msg.timestamp === lastCurrentMsg.timestamp) {
|
||||
lastIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
|
||||
if (lastIndex >= 0) {
|
||||
return contexts.slice(lastIndex + 1);
|
||||
} else {
|
||||
console.log('未找到匹配的最后消息,可能需要完全刷新');
|
||||
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
function createMessageElement(msg, isNew = false) {
|
||||
const messageDiv = document.createElement('div');
|
||||
let className = 'message';
|
||||
|
||||
// 根据消息类型添加对应的CSS类
|
||||
if (msg.is_gift) {
|
||||
className += ' gift';
|
||||
} else if (msg.is_superchat) {
|
||||
className += ' superchat';
|
||||
}
|
||||
|
||||
if (isNew) {
|
||||
className += ' new-message';
|
||||
}
|
||||
|
||||
messageDiv.className = className;
|
||||
messageDiv.innerHTML = `
|
||||
<div class="message-line">
|
||||
<span class="username">${escapeHtml(msg.user_name)}:</span><span class="content">${escapeHtml(msg.content)}</span>
|
||||
</div>
|
||||
`;
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 初始加载数据
|
||||
fetch('/api/contexts')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('初始数据加载成功:', data);
|
||||
updateMessages(data.contexts);
|
||||
})
|
||||
.catch(err => console.error('加载初始数据失败:', err));
|
||||
|
||||
// 连接WebSocket
|
||||
connectWebSocket();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.websockets.append(ws)
|
||||
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
|
||||
|
||||
# 发送初始数据
|
||||
await self.send_contexts_to_websocket(ws)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f'WebSocket错误: {ws.exception()}')
|
||||
break
|
||||
|
||||
# 清理断开的连接
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def get_contexts_handler(self, request):
|
||||
"""获取上下文API"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
|
||||
async def debug_handler(self, request):
|
||||
"""调试信息处理器"""
|
||||
debug_info = {
|
||||
"server_status": "running",
|
||||
"websocket_connections": len(self.websockets),
|
||||
"total_chats": len(self.contexts),
|
||||
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
|
||||
}
|
||||
|
||||
# 构建聊天详情HTML
|
||||
chats_html = ""
|
||||
for chat_id, contexts in self.contexts.items():
|
||||
messages_html = ""
|
||||
for msg in contexts:
|
||||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f'''
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
'''
|
||||
|
||||
html_content = f'''
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>调试信息</title>
|
||||
<style>
|
||||
body {{ font-family: monospace; margin: 20px; }}
|
||||
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
|
||||
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
|
||||
.message {{ margin: 5px 0; padding: 5px; background: white; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>上下文网页管理器调试信息</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>服务器状态</h2>
|
||||
<p>状态: {debug_info["server_status"]}</p>
|
||||
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
|
||||
<p>聊天总数: {debug_info["total_chats"]}</p>
|
||||
<p>消息总数: {debug_info["total_messages"]}</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>聊天详情</h2>
|
||||
{chats_html}
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>操作</h2>
|
||||
<button onclick="location.reload()">刷新页面</button>
|
||||
<button onclick="window.location.href='/'">返回主页</button>
|
||||
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
console.log('调试信息:', {json.dumps(debug_info, ensure_ascii=False, indent=2)});
|
||||
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
if chat_id not in self.contexts:
|
||||
self.contexts[chat_id] = deque(maxlen=self.max_messages)
|
||||
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
|
||||
|
||||
context_msg = ContextMessage(message)
|
||||
self.contexts[chat_id].append(context_msg)
|
||||
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info("📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
|
||||
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
|
||||
"""向单个WebSocket发送上下文数据"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
async def broadcast_contexts(self):
|
||||
"""向所有WebSocket连接广播上下文更新"""
|
||||
if not self.websockets:
|
||||
logger.debug("没有WebSocket连接,跳过广播")
|
||||
return
|
||||
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
|
||||
|
||||
# 创建WebSocket列表的副本,避免在遍历时修改
|
||||
websockets_copy = self.websockets.copy()
|
||||
removed_count = 0
|
||||
|
||||
for ws in websockets_copy:
|
||||
if ws.closed:
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
else:
|
||||
try:
|
||||
await ws.send_str(message)
|
||||
logger.debug("消息发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送WebSocket消息失败: {e}")
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
|
||||
|
||||
|
||||
# 全局实例
|
||||
_context_web_manager: Optional[ContextWebManager] = None
|
||||
|
||||
|
||||
def get_context_web_manager() -> ContextWebManager:
|
||||
"""获取上下文网页管理器实例"""
|
||||
global _context_web_manager
|
||||
if _context_web_manager is None:
|
||||
_context_web_manager = ContextWebManager()
|
||||
return _context_web_manager
|
||||
|
||||
|
||||
async def init_context_web_manager():
|
||||
"""初始化上下文网页管理器"""
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Dict, Tuple, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("gift_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
callback: Callable[[MessageRecvS4U], None]
|
||||
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(
|
||||
self._gift_timeout(gift_key)
|
||||
)
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self,
|
||||
gift_key: Tuple[str, str],
|
||||
message: MessageRecvS4U,
|
||||
callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
initial_count = int(message.gift_count)
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(
|
||||
message=message,
|
||||
total_count=initial_count,
|
||||
timer_task=timer_task,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
pending_gift = self.pending_gifts.get(gift_key)
|
||||
if pending_gift and not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
await self._gift_timeout(gift_key)
|
||||
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = str()
|
||||
|
||||
def set_internal_state(self,internal_state:str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
def get_internal_state(self):
|
||||
return self.now_internal_state
|
||||
|
||||
def get_internal_state_str(self):
|
||||
return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}"
|
||||
|
||||
internal_manager = InternalManager()
|
||||
@@ -1,579 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, Dict, Tuple, List # 导入类型提示
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
||||
from src.config.config import global_config
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from .s4u_watching_manager import watching_manager
|
||||
import json
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import get_person_id
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
|
||||
class MessageSenderContainer:
|
||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||
self.chat_stream = chat_stream
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
|
||||
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
|
||||
async def close(self):
|
||||
"""表示没有更多消息了,关闭队列。"""
|
||||
await self.queue.put(None) # Sentinel
|
||||
|
||||
def pause(self):
|
||||
"""暂停发送。"""
|
||||
self._paused_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""恢复发送。"""
|
||||
self._paused_event.set()
|
||||
|
||||
def _calculate_typing_delay(self, text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = s4u_config.chars_per_second
|
||||
min_delay = s4u_config.min_typing_delay
|
||||
max_delay = s4u_config.max_typing_delay
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
|
||||
async def _send_worker(self):
|
||||
"""从队列中取出消息并发送。"""
|
||||
while True:
|
||||
try:
|
||||
# This structure ensures that task_done() is called for every item retrieved,
|
||||
# even if the worker is cancelled while processing the item.
|
||||
chunk = await self.queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
# 根据配置选择延迟模式
|
||||
if s4u_config.enable_dynamic_typing_delay:
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
else:
|
||||
delay = s4u_config.typing_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
|
||||
await bot_message.process()
|
||||
|
||||
await get_global_api().send_message(bot_message)
|
||||
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
|
||||
|
||||
message_segment = Seg(type="text", data=chunk)
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
await bot_message.process()
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
self.queue.task_done()
|
||||
|
||||
def start(self):
|
||||
"""启动发送任务。"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._send_worker())
|
||||
|
||||
async def join(self):
|
||||
"""等待所有消息发送完毕。"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: Dict[str, "S4UChat"] = {}
|
||||
|
||||
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
|
||||
logger.info(f"Creating new S4UChat for stream: {stream_name}")
|
||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
|
||||
if not s4u_config.enable_s4u:
|
||||
s4u_chat_manager = None
|
||||
else:
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
return s4u_chat_manager
|
||||
|
||||
|
||||
class S4UChat:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
|
||||
# 两个消息队列
|
||||
self._vip_queue = asyncio.PriorityQueue()
|
||||
self._normal_queue = asyncio.PriorityQueue()
|
||||
|
||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: Optional[asyncio.Task] = None
|
||||
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
|
||||
self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None
|
||||
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message :List[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
|
||||
def _get_priority_info(self, message: MessageRecv) -> dict:
|
||||
"""安全地从消息中提取和解析 priority_info"""
|
||||
priority_info_raw = message.priority_info
|
||||
priority_info = {}
|
||||
if isinstance(priority_info_raw, str):
|
||||
try:
|
||||
priority_info = json.loads(priority_info_raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
|
||||
elif isinstance(priority_info_raw, dict):
|
||||
priority_info = priority_info_raw
|
||||
return priority_info
|
||||
|
||||
def _is_vip(self, priority_info: dict) -> bool:
|
||||
"""检查消息是否来自VIP用户。"""
|
||||
return priority_info.get("message_type") == "vip"
|
||||
|
||||
def _get_interest_score(self, user_id: str) -> float:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
def go_processing(self):
|
||||
if self.voice_done == self.last_msg_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
|
||||
"""
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
|
||||
# 加上用户的固有兴趣分
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
def decay_interest_score(self):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
self.interest_dict[person_id] = score * 0.95
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
|
||||
|
||||
self.decay_interest_score()
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
person_id = get_person_id(platform, user_id)
|
||||
|
||||
# try:
|
||||
# is_gift = message.is_gift
|
||||
# is_superchat = message.is_superchat
|
||||
# # print(is_gift)
|
||||
# # print(is_superchat)
|
||||
# if is_gift:
|
||||
# await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
# current_score = self.interest_dict.get(person_id, 1.0)
|
||||
# self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
|
||||
# elif is_superchat:
|
||||
# await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
# current_score = self.interest_dict.get(person_id, 1.0)
|
||||
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
# # 添加SuperChat到管理器
|
||||
# super_chat_manager = get_super_chat_manager()
|
||||
# await super_chat_manager.add_superchat(message)
|
||||
# else:
|
||||
# await self.relationship_builder.build_relation(20)
|
||||
# except Exception:
|
||||
# traceback.print_exc()
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if (s4u_config.enable_message_interruption and
|
||||
self._current_generation_task and not self._current_generation_task.done()):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
# 规则:VIP从不被打断
|
||||
if current_queue == "vip":
|
||||
pass # Do nothing
|
||||
|
||||
# 规则:普通消息可以被打断
|
||||
elif current_queue == "normal":
|
||||
# VIP消息可以打断普通消息
|
||||
if is_vip:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||
# 普通消息的内部打断逻辑
|
||||
else:
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
current_sender_id = current_msg.message_info.user_info.user_id
|
||||
# 新消息优先级更高
|
||||
if new_priority_score > current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||
# 同用户,新消息的优先级不能更低
|
||||
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||
|
||||
if should_interrupt:
|
||||
if self.gpt.partial_response:
|
||||
logger.warning(
|
||||
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
|
||||
)
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
|
||||
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
|
||||
item = (-new_priority_score, self._entry_counter, time.time(), message)
|
||||
|
||||
if is_vip and s4u_config.vip_queue_priority:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
await self._normal_queue.put(item)
|
||||
|
||||
self._entry_counter += 1
|
||||
self._new_message_event.set() # 唤醒处理器
|
||||
|
||||
def _cleanup_old_normal_messages(self):
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
removed_count = 0
|
||||
|
||||
# 取出所有普通队列中的消息
|
||||
while not self._normal_queue.empty():
|
||||
try:
|
||||
item = self._normal_queue.get_nowait()
|
||||
neg_priority, entry_count, timestamp, message = item
|
||||
|
||||
# 如果消息在最近N条消息范围内,保留它
|
||||
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}")
|
||||
|
||||
if entry_count >= cutoff_counter:
|
||||
temp_messages.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
self._normal_queue.task_done() # 标记被移除的任务为完成
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# 将保留的消息重新放入队列
|
||||
for item in temp_messages:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除")
|
||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 清理普通队列中的过旧消息
|
||||
self._cleanup_old_normal_messages()
|
||||
|
||||
# 优先处理VIP队列
|
||||
if not self._vip_queue.empty():
|
||||
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > s4u_config.message_timeout_seconds:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
self._normal_queue.task_done()
|
||||
continue # 处理下一条
|
||||
queue_name = "normal"
|
||||
else:
|
||||
if self.internal_message:
|
||||
message = self.internal_message[-1]
|
||||
self.internal_message = []
|
||||
|
||||
priority = 0
|
||||
neg_priority = 0
|
||||
entry_count = 0
|
||||
queue_name = "internal"
|
||||
|
||||
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...")
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
|
||||
)
|
||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
self._current_message_being_replied = None
|
||||
# 标记任务完成
|
||||
if queue_name == "vip":
|
||||
self._vip_queue.task_done()
|
||||
elif queue_name == "internal":
|
||||
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
|
||||
pass
|
||||
else:
|
||||
self._normal_queue.task_done()
|
||||
|
||||
# 检查是否还有任务,有则立即再次触发事件
|
||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||
self._new_message_event.set()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def get_processing_message_id(self):
|
||||
self.last_msg_id = self.msg_id
|
||||
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
|
||||
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
self.get_processing_message_id()
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
|
||||
if message.is_internal:
|
||||
await chat_watching.on_internal_message_start()
|
||||
else:
|
||||
await chat_watching.on_reply_start()
|
||||
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
|
||||
async def generate_and_send_inner():
|
||||
nonlocal total_chars_sent
|
||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||
|
||||
if s4u_config.enable_streaming_output:
|
||||
logger.info("[S4U] 开始流式输出")
|
||||
# 流式输出,边生成边发送
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
else:
|
||||
logger.info("[S4U] 开始一次性输出")
|
||||
# 一次性输出,先收集所有chunk
|
||||
all_chunks = []
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
all_chunks.append(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
# 一次性发送
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("".join(all_chunks))
|
||||
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("麦麦不知道哦")
|
||||
total_chars_sent = len("麦麦不知道哦")
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
|
||||
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id)
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
await chat_watching.on_thinking_finished()
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
logged = False
|
||||
while not self.go_processing():
|
||||
if time.time() - start_time > 60:
|
||||
logger.warning(f"[{self.stream_name}] 等待消息发送超时(60秒),强制跳出循环。")
|
||||
break
|
||||
if not logged:
|
||||
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
|
||||
logged = True
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_finished()
|
||||
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container._task.done():
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||
|
||||
async def shutdown(self):
|
||||
"""平滑关闭处理任务。"""
|
||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||
|
||||
# 取消正在运行的任务
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
if self._processing_task and not self._processing_task.done():
|
||||
self._processing_task.cancel()
|
||||
|
||||
# 等待任务响应取消
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
|
||||
1. 情绪数值系统:
|
||||
- 情绪包含四个维度:joy(喜), anger(怒), sorrow(哀), fear(惧)
|
||||
- 每个维度的取值范围为1-10
|
||||
- 当情绪发生变化时,会自动发送到ws端处理
|
||||
|
||||
2. 情绪更新机制:
|
||||
- 接收到新消息时会更新情绪状态
|
||||
- 定期进行情绪回归(冷静下来)
|
||||
- 每次情绪变化都会发送到ws端,格式为:
|
||||
type: "emotion"
|
||||
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
3. ws端处理:
|
||||
- 本地只负责情绪计算和发送情绪数值
|
||||
- 表情渲染和动作由ws端根据情绪数值处理
|
||||
"""
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"change_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"regress_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"change_mood_numerical_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"regress_mood_numerical_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatMood:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.mood_state: str = "感觉很平静"
|
||||
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
|
||||
self.mood_model_numerical = LLMRequest(
|
||||
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
|
||||
)
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||
|
||||
def _parse_numerical_mood(self, response: str) -> dict[str, int] | None:
|
||||
try:
|
||||
# The LLM might output markdown with json inside
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(response)
|
||||
|
||||
# Validate
|
||||
required_keys = {"joy", "anger", "sorrow", "fear"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
logger.warning(f"Numerical mood response missing keys: {response}")
|
||||
return None
|
||||
|
||||
for key in required_keys:
|
||||
value = data[key]
|
||||
if not isinstance(value, int) or not (1 <= value <= 10):
|
||||
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
|
||||
return None
|
||||
|
||||
return {key: data[key] for key in required_keys}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse numerical mood JSON: {response}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
|
||||
return None
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _update_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text mood response: {response}")
|
||||
logger.debug(f"text mood reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _update_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt, temperature=0.4
|
||||
)
|
||||
logger.info(f"numerical mood response: {response}")
|
||||
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.last_change_time = message_time
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=5,
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _regress_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text regress response: {response}")
|
||||
logger.debug(f"text regress reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _regress_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.4,
|
||||
)
|
||||
logger.info(f"numerical regress response: {response}")
|
||||
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.regression_count += 1
|
||||
|
||||
async def send_emotion_update(self, mood_values: dict[str, int]):
|
||||
"""发送情绪更新到ws端"""
|
||||
emotion_data = {
|
||||
"joy": mood_values.get("joy", 5),
|
||||
"anger": mood_values.get("anger", 1),
|
||||
"sorrow": mood_values.get("sorrow", 1),
|
||||
"fear": mood_values.get("fear", 1),
|
||||
}
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="emotion",
|
||||
content=emotion_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
||||
|
||||
|
||||
class MoodRegressionTask(AsyncTask):
|
||||
def __init__(self, mood_manager: "MoodManager"):
|
||||
super().__init__(task_name="MoodRegressionTask", run_interval=30)
|
||||
self.mood_manager = mood_manager
|
||||
self.run_count = 0
|
||||
|
||||
async def run(self):
|
||||
self.run_count += 1
|
||||
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
||||
|
||||
now = time.time()
|
||||
regression_executed = 0
|
||||
|
||||
for mood in self.mood_manager.mood_list:
|
||||
chat_info = f"chat {mood.chat_id}"
|
||||
|
||||
if mood.last_change_time == 0:
|
||||
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
||||
continue
|
||||
|
||||
time_since_last_change = now - mood.last_change_time
|
||||
|
||||
# 检查是否有极端情绪需要快速回归
|
||||
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
||||
has_extreme_emotion = len(high_emotions) > 0
|
||||
|
||||
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
||||
should_regress = False
|
||||
regress_reason = ""
|
||||
|
||||
if time_since_last_change > 120:
|
||||
should_regress = True
|
||||
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
||||
elif has_extreme_emotion and time_since_last_change > 30:
|
||||
should_regress = True
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
||||
|
||||
if should_regress:
|
||||
if mood.regression_count >= 3:
|
||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||
)
|
||||
await mood.regress_mood()
|
||||
regression_executed += 1
|
||||
else:
|
||||
if has_extreme_emotion:
|
||||
remaining_time = 5 - time_since_last_change
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
logger.debug(
|
||||
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||
)
|
||||
else:
|
||||
remaining_time = 120 - time_since_last_change
|
||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||
|
||||
if regression_executed > 0:
|
||||
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
||||
else:
|
||||
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
|
||||
|
||||
|
||||
class MoodManager:
|
||||
def __init__(self):
|
||||
self.mood_list: list[ChatMood] = []
|
||||
"""当前情绪状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动情绪回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动情绪管理任务...")
|
||||
|
||||
# 启动情绪回归任务
|
||||
regression_task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(regression_task)
|
||||
|
||||
self.task_started = True
|
||||
logger.info("情绪管理任务已启动(情绪回归)")
|
||||
|
||||
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
return mood
|
||||
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
return new_mood
|
||||
|
||||
def reset_mood_by_chat_id(self, chat_id: str):
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
mood.mood_state = "感觉很平静"
|
||||
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
mood.regression_count = 0
|
||||
# 发送重置后的情绪状态到ws端
|
||||
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
||||
return
|
||||
|
||||
# 如果没有找到现有的mood,创建新的
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
else:
|
||||
mood_manager = None
|
||||
|
||||
"""全局情绪管理器"""
|
||||
@@ -1,264 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
|
||||
class S4UMessageProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
message_info = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message_info.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
if await self.handle_internal_message(message):
|
||||
return
|
||||
|
||||
if await self.hadle_if_voice_done(message):
|
||||
return
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
# 处理屏幕消息
|
||||
if await self.handle_screen_message(message):
|
||||
return
|
||||
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_action.update_action_by_message(message))
|
||||
# 视线管理:收到消息时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
|
||||
await chat_watching.on_message_received()
|
||||
|
||||
# 上下文网页管理:启动独立task处理消息上下文
|
||||
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
|
||||
|
||||
# 日志记录
|
||||
if message.is_gift:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
async def handle_internal_message(self, message: MessageRecvS4U):
|
||||
if message.is_internal:
|
||||
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = message.message_info.user_info,
|
||||
group_info = group_info
|
||||
)
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
message.message_info.group_info = s4u_chat.chat_stream.group_info
|
||||
message.message_info.platform = s4u_chat.chat_stream.platform
|
||||
|
||||
|
||||
s4u_chat.internal_message.append(message)
|
||||
s4u_chat._new_message_event.set()
|
||||
|
||||
|
||||
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}")
|
||||
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def handle_screen_message(self, message: MessageRecvS4U):
|
||||
if message.is_screen:
|
||||
screen_manager.set_screen(message.screen_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def hadle_if_voice_done(self, message: MessageRecvS4U):
|
||||
if message.voice_done:
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
|
||||
s4u_chat.voice_done = message.voice_done
|
||||
return True
|
||||
return False
|
||||
|
||||
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.is_fake_gift = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
if message.is_gift:
|
||||
# 定义防抖完成后的回调函数
|
||||
def gift_callback(merged_message: MessageRecvS4U):
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
context_manager = get_context_web_manager()
|
||||
|
||||
# 只在服务器未启动时启动(避免重复启动)
|
||||
if context_manager.site is None:
|
||||
logger.info("🚀 首次启动上下文网页服务器...")
|
||||
await context_manager.start_server()
|
||||
|
||||
# 添加消息到上下文并更新网页
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
await context_manager.add_message(chat_id, message)
|
||||
|
||||
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)
|
||||
@@ -1,396 +0,0 @@
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
import random
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from typing import List
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
|
||||
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
{screen_info}
|
||||
{internal_state}
|
||||
|
||||
{relation_info_block}
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{sc_info}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
对方最新发送的内容:{message_txt}
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
你可以看见面前的屏幕,目前屏幕的内容是:
|
||||
{screen_info}
|
||||
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
{sc_info}
|
||||
|
||||
{time_block}
|
||||
{chat_info_danmu}
|
||||
--------------------------------
|
||||
以上是你和弹幕的对话,与此同时,你在与QQ群友聊天,聊天记录如下:
|
||||
{chat_info_qq}
|
||||
--------------------------------
|
||||
你刚刚回复了QQ群,你内心的想法是:{mind}
|
||||
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt_internal", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
||||
style_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
|
||||
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
async def build_relation_info(self, chat_stream) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
who_chat_in_group = []
|
||||
if is_group_chat:
|
||||
who_chat_in_group = get_recent_group_speaker(
|
||||
chat_stream.stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
elif chat_stream.user_info:
|
||||
who_chat_in_group.append(
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.relationship.enable_relationship and who_chat_in_group:
|
||||
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||
person_ids = []
|
||||
for person in who_chat_in_group:
|
||||
person_id = get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
|
||||
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
|
||||
if relation_info := "".join(relation_info_list):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=relation_info
|
||||
)
|
||||
return relation_prompt
|
||||
|
||||
async def build_memory_block(self, text: str) -> str:
|
||||
# 待更新记忆系统
|
||||
return ""
|
||||
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||
limit=300,
|
||||
)
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||
|
||||
core_dialogue_list: List[DatabaseMessages] = []
|
||||
background_dialogue_list: List[DatabaseMessages] = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
if msg_user_id == bot_id:
|
||||
if msg.reply_to and talk_type == msg.reply_to:
|
||||
core_dialogue_list.append(msg)
|
||||
elif msg.reply_to and talk_type != msg.reply_to:
|
||||
background_dialogue_list.append(msg)
|
||||
# else:
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
elif msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg)
|
||||
else:
|
||||
background_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.user_info.user_id
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
start_speaking_user_id = target_user_id
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.user_info.user_id
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
if speaker == bot_id:
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
last_speaking_user_id = speaker
|
||||
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
for msg in all_msg_seg_list:
|
||||
core_msg_str += msg
|
||||
|
||||
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
all_dialogue_history,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
|
||||
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||
|
||||
def build_gift_info(self, message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
else:
|
||||
if message.is_fake_gift:
|
||||
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
|
||||
|
||||
return ""
|
||||
|
||||
def build_sc_info(self, message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message: MessageRecvS4U,
|
||||
message_txt: str,
|
||||
) -> str:
|
||||
chat_stream = message.chat_stream
|
||||
|
||||
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
if person_name:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||
self.build_relation_info(chat_stream),
|
||||
self.build_memory_block(message_txt),
|
||||
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
|
||||
chat_stream, message
|
||||
)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
screen_info = screen_manager.get_screen_str()
|
||||
|
||||
internal_state = internal_manager.get_internal_state_str()
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
|
||||
if not message.is_internal:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
internal_state=internal_state,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
sender_name=sender_name,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
else:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"s4u_prompt_internal",
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
chat_info_danmu=all_dialogue_prompt,
|
||||
chat_info_qq=message.chat_info,
|
||||
mind=message.processed_plain_text,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights, strict=False))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
prompt_builder = PromptBuilder()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user