Merge pull request #1187 from MaiM-with-u/v0.10.0

V0.10.0 更新
This commit is contained in:
SengokuCola
2025-08-18 14:59:25 +08:00
committed by GitHub
162 changed files with 11382 additions and 12414 deletions

7
.gitignore vendored
View File

@@ -7,6 +7,7 @@ logs/
out/ out/
tool_call_benchmark.py tool_call_benchmark.py
run_maibot_core.bat run_maibot_core.bat
run_voice.bat
run_napcat_adapter.bat run_napcat_adapter.bat
run_ad.bat run_ad.bat
s4u.s4u s4u.s4u
@@ -40,7 +41,10 @@ config/bot_config.toml
config/bot_config.toml.bak config/bot_config.toml.bak
config/lpmm_config.toml config/lpmm_config.toml
config/lpmm_config.toml.bak config/lpmm_config.toml.bak
src/mais4u/config/s4u_config.toml
src/mais4u/config/old
template/compare/bot_config_template.toml template/compare/bot_config_template.toml
template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat (测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat (临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py src/plugins/utils/statistic.py
@@ -321,4 +325,5 @@ run_pet.bat
config.toml config.toml
interested_rates.txt interested_rates.txt
MaiBot.code-workspace

View File

@@ -25,8 +25,8 @@
**🍔MaiCore 是一个基于大语言模型的可交互智能体** **🍔MaiCore 是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理 - 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制 - 🔌 **强大插件系统**:全面重构的插件架构,更多API
- 🤔 **实时思维系统**:模拟人类思考过程。 - 🤔 **实时思维系统**:模拟人类思考过程。
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式 - 🧠 **表达学习功能**:学习群友的说话风格和表达方式
- 💝 **情感表达系统**:情绪系统和表情包系统。 - 💝 **情感表达系统**:情绪系统和表情包系统。
@@ -46,7 +46,7 @@
## 🔥 更新和安装 ## 🔥 更新和安装
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md)) **最新版本: v0.10.0** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器 可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
@@ -56,7 +56,6 @@
- `classical`: 旧版本(停止维护) - `classical`: 旧版本(停止维护)
### 最新版本部署教程 ### 最新版本部署教程
- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html)
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
> [!WARNING] > [!WARNING]
@@ -64,7 +63,6 @@
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。 > - 项目处于活跃开发阶段,功能和 API 可能随时调整。
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。 > - 文档未完善,有问题可以提交 Issue 或者 Discussion。
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。 > - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于持续迭代,可能存在一些已知或未知的 bug。
> - 由于程序处于开发中,可能消耗较多 token。 > - 由于程序处于开发中,可能消耗较多 token。
## 💬 讨论 ## 💬 讨论

41
bot.py
View File

@@ -20,11 +20,13 @@ from rich.traceback import install
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging from src.common.logger import initialize_logging, get_logger, shutdown_logging
from src.main import MainSystem
from src.manager.async_task_manager import async_task_manager
initialize_logging() initialize_logging()
from src.main import MainSystem #noqa
from src.manager.async_task_manager import async_task_manager #noqa
logger = get_logger("main") logger = get_logger("main")
@@ -74,36 +76,6 @@ def easter_egg():
print(rainbow_text) print(rainbow_text)
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
async def graceful_shutdown(): async def graceful_shutdown():
try: try:
@@ -229,9 +201,6 @@ def raw_main():
easter_egg() easter_egg()
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 返回MainSystem实例 # 返回MainSystem实例
return MainSystem() return MainSystem()

View File

@@ -1,5 +1,76 @@
# Changelog # Changelog
## [0.10.0] - 2025-7-1
### 🌟 主要功能更改
- 优化的回复生成,现在的回复对上下文把控更加精准
- 新的回复逻辑控制现在合并了normal和focus模式更加统一
- 优化表达方式系统,现在学习和使用更加精准
- 新的关系系统,现在的关系构建更精准也更克制
- 工具系统重构,现在合并到了插件系统中
- 彻底重构了整个LLM Request了现在支持模型轮询和更多灵活的参数
- 同时重构了整个模型配置系统升级需要重新配置llm配置文件
- 随着LLM Request的重构插件系统彻底重构完成。插件系统进入稳定状态仅增加新的API
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
#### 🔧 工具系统重构
- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力
- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式
- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性
#### 🚀 LLM系统全面重构
- **LLM Request重构**: 彻底重构了整个LLM Request系统现在支持模型轮询和更多灵活的参数
- **模型配置升级**: 同时重构了整个模型配置系统升级需要重新配置llm配置文件
- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑
- **异常处理增强**: 增强LLMRequest类的异常处理添加统一的模型异常处理方法
#### 🔌 插件系统稳定化
- **插件系统重构完成**: 随着LLM Request的重构插件系统彻底重构完成进入稳定状态
- **API扩展**: 仅增加新的API保持向后兼容性
- **插件管理优化**: 让插件管理配置真正有用,提升管理体验
#### 💾 记忆系统优化
- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建
- **精确提取**: 记忆提取更精确,提升记忆质量
#### 🎭 表达方式系统
- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪
- **学习优化**: 优化表达方式提取,修复表达学习出错问题
- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性
#### 🔄 聊天系统统一
- **normal和focus合并**: 彻底合并normal和focus完全基于planner决定target message
- **no_reply内置**: 将no_reply功能移动到主循环中简化系统架构
- **回复优化**: 优化reply填补缺失值让麦麦可以回复自己的消息
- **频率控制API**: 加入聊天频率控制相关API提供更精细的控制
#### 日志系统改进
- **日志颜色优化**: 修改了log的颜色更加护眼
- **日志清理优化**: 修复了日志清理先等24h的问题提升系统性能
- **计时定位**: 通过计时定位LLM异常延时提升问题排查效率
### 🐛 问题修复
#### 代码质量提升
- **lint问题修复**: 修复了lint爆炸的问题代码更加规范了
- **导入优化**: 修复导入爆炸和文档错误,优化代码结构
#### 系统稳定性
- **循环导入**: 修复了import时循环导入的问题
- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力
- **空响应处理**: 空响应就raise避免系统异常
#### 功能修复
- **API问题**: 修复api问题提升系统可用性
- **notice问题**: 为组件方法提供新参数暂时解决notice问题
- **关系构建**: 修复不认识的用户构建关系问题
- **流式解析**: 修复流式解析越界问题避免空choices的SSE帧错误
#### 配置和兼容性
- **默认值**: 添加默认值,提升配置灵活性
- **类型问题**: 修复类型问题,提升代码健壮性
- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性
## [0.9.1] - 2025-7-26 ## [0.9.1] - 2025-7-26
### 主要修复和优化 ### 主要修复和优化
@@ -81,7 +152,7 @@ MaiBot 0.9.0 重磅升级!本版本带来两大核心突破:**全面重构
#### 问题修复与优化 #### 问题修复与优化
- 修复normal planner没有超时退出问题添加回复超时检查 - 修复normal planner没有超时退出问题添加回复超时检查
- 重构no_reply逻辑,不再使用小模型,采用激活度决定 - 重构no_action逻辑,不再使用小模型,采用激活度决定
- 修复图片与文字混合兴趣值为0的情况 - 修复图片与文字混合兴趣值为0的情况
- 适配无兴趣度消息处理 - 适配无兴趣度消息处理
- 优化Docker镜像构建流程合并AMD64和ARM64构建步骤 - 优化Docker镜像构建流程合并AMD64和ARM64构建步骤
@@ -149,7 +220,7 @@ MMC启动速度加快
- 移除冗余处理器 - 移除冗余处理器
- 精简处理器上下文,减少不必要的处理 - 精简处理器上下文,减少不必要的处理
- 后置工具处理器大大减少token消耗 - 后置工具处理器大大减少token消耗
- **统计系统**: 提供focus统计功能可查看详细的no_reply统计信息 - **统计系统**: 提供focus统计功能可查看详细的no_action统计信息
### ⏰ 聊天频率精细控制 ### ⏰ 聊天频率精细控制

View File

@@ -25,6 +25,7 @@
- 这意味着你终于可以动态控制是否继续后续消息的处理了。 - 这意味着你终于可以动态控制是否继续后续消息的处理了。
8. 移除了dependency_manager但是依然保留了`python_dependencies`属性,等待后续重构。 8. 移除了dependency_manager但是依然保留了`python_dependencies`属性,等待后续重构。
- 一并移除了文档有关manager的内容。 - 一并移除了文档有关manager的内容。
9. 增加了工具的有关api
# 插件系统修改 # 插件系统修改
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
@@ -57,30 +58,12 @@
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。 15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
- 通过`disable_specific_chat_action``enable_specific_chat_action``disable_specific_chat_command``enable_specific_chat_command``disable_specific_chat_event_handler``enable_specific_chat_event_handler`来操作 - 通过`disable_specific_chat_action``enable_specific_chat_action``disable_specific_chat_command``enable_specific_chat_command``disable_specific_chat_event_handler``enable_specific_chat_event_handler`来操作
- 同样不保存到配置文件~ - 同样不保存到配置文件~
16. 把`BaseTool`一并合并进入了插件系统
# 官方插件修改 # 官方插件修改
1. `HelloWorld`插件现在有一个样例的`EventHandler` 1. `HelloWorld`插件现在有一个样例的`EventHandler`
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。 2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用)
3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`
### TODO
把这个看起来就很别扭的config获取方式改一下
# 吐槽
```python
plugin_path = Path(plugin_file)
if plugin_path.parent.name != "plugins":
# 插件包格式parent_dir.plugin
module_name = f"plugins.{plugin_path.parent.name}.plugin"
else:
# 单文件格式plugins.filename
module_name = f"plugins.{plugin_path.stem}"
```
```python
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
```
这两个区别很大的。
### 执笔BGM ### 执笔BGM
塞壬唱片! 塞壬唱片!

BIN
docs/image-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
docs/image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

View File

@@ -0,0 +1,325 @@
# 模型配置指南
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
## 配置文件结构
配置文件主要包含以下几个部分:
- 版本信息
- API服务提供商配置
- 模型配置
- 模型任务配置
## 1. 版本信息
```toml
[inner]
version = "1.1.1"
```
用于标识配置文件的版本,遵循语义化版本规则。
## 2. API服务提供商配置
### 2.1 基本配置
使用 `[[api_providers]]` 数组配置多个API服务提供商
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
timeout = 30 # 超时时间(秒)
retry_interval = 10 # 重试间隔(秒)
```
### 2.2 配置参数说明
| 参数 | 必填 | 说明 | 默认值 |
|------|------|------|--------|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式现在支持不良好 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.cn/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
#### SiliconFlow
```toml
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
```
#### Google Gemini
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
## 3. 模型配置
### 3.1 基本模型配置
使用 `[[models]]` 数组配置多个模型:
```toml
[[models]]
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
name = "deepseek-v3" # 自定义模型名称
api_provider = "DeepSeek" # 引用的API服务商名称
price_in = 2.0 # 输入价格(元/M token
price_out = 8.0 # 输出价格(元/M token
```
### 3.2 高级模型配置
#### 强制流式输出
对于不支持非流式输出的模型:
```toml
[[models]]
model_identifier = "some-model"
name = "custom-name"
api_provider = "Provider"
force_stream_mode = true # 启用强制流式输出
```
#### 额外参数配置`extra_params`
```toml
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
[models.extra_params]
enable_thinking = false # 禁用思考
```
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置**配置时应参考相应的API文档**。
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
![SiliconFlow文档截图](image-1.png)
以豆包文档为另一个例子
![豆包文档截图](image.png)
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
```toml
[[models]]
# 你的模型
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |
|------|------|------|
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
| `api_provider` | ✅ | 对应的API服务商名称 |
| `price_in` | ❌ | 输入价格(元/M token用于成本统计 |
| `price_out` | ❌ | 输出价格(元/M token用于成本统计 |
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
| `extra_params` | ❌ | 额外的模型参数配置 |
## 4. 模型任务配置
### utils - 工具模型
用于表情包模块、取名模块、关系模块等核心功能:
```toml
[model_task_config.utils]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### utils_small - 小型工具模型
用于高频率调用的场景,建议使用速度快的小模型:
```toml
[model_task_config.utils_small]
model_list = ["qwen3-8b"]
temperature = 0.7
max_tokens = 800
```
### replyer - 主要回复模型
首要回复模型,也用于表达器和表达方式学习:
```toml
[model_task_config.replyer]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### planner - 决策模型
负责决定MaiBot该做什么
```toml
[model_task_config.planner]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### emotion - 情绪模型
负责MaiBot的情绪变化
```toml
[model_task_config.emotion]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### memory - 记忆模型
```toml
[model_task_config.memory]
model_list = ["qwen3-30b"]
temperature = 0.7
max_tokens = 800
```
### vlm - 视觉语言模型
用于图像识别:
```toml
[model_task_config.vlm]
model_list = ["qwen2.5-vl-72b"]
max_tokens = 800
```
### voice - 语音识别模型
```toml
[model_task_config.voice]
model_list = ["sensevoice-small"]
```
### embedding - 嵌入模型
```toml
[model_task_config.embedding]
model_list = ["bge-m3"]
```
### tool_use - 工具调用模型
需要使用支持工具调用的模型:
```toml
[model_task_config.tool_use]
model_list = ["qwen3-14b"]
temperature = 0.7
max_tokens = 800
```
### lpmm_entity_extract - 实体提取模型
```toml
[model_task_config.lpmm_entity_extract]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_rdf_build - RDF构建模型
```toml
[model_task_config.lpmm_rdf_build]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_qa - 问答模型
```toml
[model_task_config.lpmm_qa]
model_list = ["deepseek-r1-distill-qwen-32b"]
temperature = 0.7
max_tokens = 800
```
## 5. 配置建议
### 5.1 Temperature 参数选择
| 任务类型 | 推荐温度 | 说明 |
|----------|----------|------|
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
### 5.2 模型选择建议
| 任务类型 | 推荐模型类型 | 示例 |
|----------|--------------|------|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
| 高频率任务 | 小模型 | Qwen3-8B |
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
### 5.3 成本优化
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
3. **选择免费模型**对于测试环境优先使用price为0的模型
## 6. 配置验证
### 6.1 必要检查项
1. ✅ API密钥是否正确配置
2. ✅ 模型标识符是否与API服务商提供的一致
3. ✅ 任务配置中引用的模型名称是否在models中定义
4. ✅ 多模态任务是否配置了对应的专用模型
### 6.2 测试配置
建议在正式使用前:
1. 使用少量测试数据验证配置
2. 检查API调用是否正常
3. 确认成本统计功能正常工作
## 7. 故障排除
### 7.1 常见问题
**问题1**: API调用失败
- 检查API密钥是否正确
- 确认base_url是否可访问
- 检查模型标识符是否正确
**问题2**: 模型未找到
- 确认模型名称在任务配置和模型定义中一致
- 检查api_provider名称是否匹配
**问题3**: 响应异常
- 检查温度参数是否合理0-1之间
- 确认max_tokens设置是否合适
- 验证模型是否支持所需功能
### 7.2 日志查看
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
## 8. 更新和维护
1. **定期更新**: 关注API服务商的模型更新及时调整配置
2. **性能监控**: 监控模型调用的成本和性能
3. **备份配置**: 在修改前备份当前配置文件

View File

@@ -22,7 +22,6 @@ class ExampleAction(BaseAction):
action_name = "example_action" # 动作的唯一标识符 action_name = "example_action" # 动作的唯一标识符
action_description = "这是一个示例动作" # 动作描述 action_description = "这是一个示例动作" # 动作描述
activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例 activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例
mode_enable = ChatMode.ALL # 一般取ALL表示在所有聊天模式下都可用
associated_types = ["text", "emoji", ...] # 关联类型 associated_types = ["text", "emoji", ...] # 关联类型
parallel_action = False # 是否允许与其他Action并行执行 parallel_action = False # 是否允许与其他Action并行执行
action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...} action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...}

View File

@@ -0,0 +1,194 @@
# 组件管理API
组件管理API模块提供了对插件组件的查询和管理功能使得插件能够获取和使用组件相关的信息。
## 导入方式
```python
from src.plugin_system.apis import component_manage_api
# 或者
from src.plugin_system import component_manage_api
```
## 功能概述
组件管理API主要提供以下功能
- **插件信息查询** - 获取所有插件或指定插件的信息。
- **组件查询** - 按名称或类型查询组件信息。
- **组件管理** - 启用或禁用组件,支持全局和局部操作。
## 主要功能
### 1. 获取所有插件信息
```python
def get_all_plugin_info() -> Dict[str, PluginInfo]:
```
获取所有插件的信息。
**Returns:**
- `Dict[str, PluginInfo]` - 包含所有插件信息的字典,键为插件名称,值为 `PluginInfo` 对象。
### 2. 获取指定插件信息
```python
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
```
获取指定插件的信息。
**Args:**
- `plugin_name` (str): 插件名称。
**Returns:**
- `Optional[PluginInfo]`: 插件信息对象,如果插件不存在则返回 `None`
### 3. 获取指定组件信息
```python
def get_component_info(component_name: str, component_type: ComponentType) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定组件的信息。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 组件信息对象,如果组件不存在则返回 `None`
### 4. 获取指定类型的所有组件信息
```python
def get_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定类型的所有组件信息。
**Args:**
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
### 5. 获取指定类型的所有启用的组件信息
```python
def get_enabled_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定类型的所有启用的组件信息。
**Args:**
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
### 6. 获取指定 Action 的注册信息
```python
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
```
获取指定 Action 的注册信息。
**Args:**
- `action_name` (str): Action 名称。
**Returns:**
- `Optional[ActionInfo]` - Action 信息对象,如果 Action 不存在则返回 `None`
### 7. 获取指定 Command 的注册信息
```python
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
```
获取指定 Command 的注册信息。
**Args:**
- `command_name` (str): Command 名称。
**Returns:**
- `Optional[CommandInfo]` - Command 信息对象,如果 Command 不存在则返回 `None`
### 8. 获取指定 Tool 的注册信息
```python
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
```
获取指定 Tool 的注册信息。
**Args:**
- `tool_name` (str): Tool 名称。
**Returns:**
- `Optional[ToolInfo]` - Tool 信息对象,如果 Tool 不存在则返回 `None`
### 9. 获取指定 EventHandler 的注册信息
```python
def get_registered_event_handler_info(event_handler_name: str) -> Optional[EventHandlerInfo]:
```
获取指定 EventHandler 的注册信息。
**Args:**
- `event_handler_name` (str): EventHandler 名称。
**Returns:**
- `Optional[EventHandlerInfo]` - EventHandler 信息对象,如果 EventHandler 不存在则返回 `None`
### 10. 全局启用指定组件
```python
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
```
全局启用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `bool` - 启用成功返回 `True`,否则返回 `False`
### 11. 全局禁用指定组件
```python
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
```
全局禁用指定组件。
**此函数是异步的,确保在异步环境中调用。**
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `bool` - 禁用成功返回 `True`,否则返回 `False`
### 12. 局部启用指定组件
```python
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
```
局部启用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
- `stream_id` (str): 消息流 ID。
**Returns:**
- `bool` - 启用成功返回 `True`,否则返回 `False`
### 13. 局部禁用指定组件
```python
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
```
局部禁用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
- `stream_id` (str): 消息流 ID。
**Returns:**
- `bool` - 禁用成功返回 `True`,否则返回 `False`
### 14. 获取指定消息流中禁用的组件列表
```python
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
```
获取指定消息流中禁用的组件列表。
**Args:**
- `stream_id` (str): 消息流 ID。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `list[str]` - 禁用的组件名称列表。

View File

@@ -1,6 +1,6 @@
# 配置API # 配置API
配置API模块提供了配置读取和用户信息获取等功能,让插件能够安全地访问全局配置和用户信息 配置API模块提供了配置读取功能让插件能够安全地访问全局配置和插件配置
## 导入方式 ## 导入方式

View File

@@ -6,72 +6,51 @@
```python ```python
from src.plugin_system.apis import database_api from src.plugin_system.apis import database_api
# 或者
from src.plugin_system import database_api
``` ```
## 主要功能 ## 主要功能
### 1. 通用数据库查询 ### 1. 通用数据库操作
#### `db_query(model_class, query_type="get", filters=None, data=None, limit=None, order_by=None, single_result=False)`
执行数据库查询操作的通用接口
**参数:**
- `model_class`Peewee模型类如ActionRecords、Messages等
- `query_type`:查询类型,可选值: "get", "create", "update", "delete", "count"
- `filters`:过滤条件字典,键为字段名,值为要匹配的值
- `data`:用于创建或更新的数据字典
- `limit`:限制结果数量
- `order_by`:排序字段列表,使用字段名,前缀'-'表示降序
- `single_result`:是否只返回单个结果
**返回:**
根据查询类型返回不同的结果:
- "get":返回查询结果列表或单个结果
- "create":返回创建的记录
- "update":返回受影响的行数
- "delete":返回受影响的行数
- "count":返回记录数量
### 2. 便捷查询函数
#### `db_save(model_class, data, key_field=None, key_value=None)`
保存数据到数据库(创建或更新)
**参数:**
- `model_class`Peewee模型类
- `data`:要保存的数据字典
- `key_field`:用于查找现有记录的字段名
- `key_value`:用于查找现有记录的字段值
**返回:**
- `Dict[str, Any]`保存后的记录数据失败时返回None
#### `db_get(model_class, filters=None, order_by=None, limit=None)`
简化的查询函数
**参数:**
- `model_class`Peewee模型类
- `filters`:过滤条件字典
- `order_by`:排序字段
- `limit`:限制结果数量
**返回:**
- `Union[List[Dict], Dict, None]`:查询结果
### 3. 专用函数
#### `store_action_info(...)`
存储动作信息的专用函数
## 使用示例
### 1. 基本查询操作
```python ```python
from src.plugin_system.apis import database_api async def db_query(
from src.common.database.database_model import Messages, ActionRecords model_class: Type[Model],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
```
执行数据库查询操作的通用接口。
# 查询最近10条消息 **Args:**
- `model_class`: Peewee模型类。
- Peewee模型类可以在`src.common.database.database_model`模块中找到,如`ActionRecords``Messages`等。
- `data`: 用于创建或更新的数据
- `query_type`: 查询类型
- 可选值: `get`, `create`, `update`, `delete`, `count`
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
- `limit`: 限制结果数量。
- `order_by`: 排序字段列表,使用字段名,前缀'-'表示降序。
- 排序字段,前缀`-`表示降序,例如`-time`表示按时间字段(即`time`字段)降序
- `single_result`: 是否只返回单个结果。
**Returns:**
- 根据查询类型返回不同的结果:
- `get`: 返回查询结果列表或单个结果。(如果 `single_result=True`
- `create`: 返回创建的记录。
- `update`: 返回受影响的行数。
- `delete`: 返回受影响的行数。
- `count`: 返回记录数量。
#### 示例
1. 查询最近10条消息
```python
messages = await database_api.db_query( messages = await database_api.db_query(
Messages, Messages,
query_type="get", query_type="get",
@@ -79,180 +58,159 @@ messages = await database_api.db_query(
limit=10, limit=10,
order_by=["-time"] order_by=["-time"]
) )
# 查询单条记录
message = await database_api.db_query(
Messages,
query_type="get",
filters={"message_id": "msg_123"},
single_result=True
)
``` ```
2. 创建一条记录
### 2. 创建记录
```python ```python
# 创建新的动作记录
new_record = await database_api.db_query( new_record = await database_api.db_query(
ActionRecords, ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create", query_type="create",
data={
"action_id": "action_123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
}
) )
print(f"创建了记录: {new_record['id']}")
``` ```
3. 更新记录
### 3. 更新记录
```python ```python
# 更新动作状态
updated_count = await database_api.db_query( updated_count = await database_api.db_query(
ActionRecords, ActionRecords,
data={"action_done": True},
query_type="update", query_type="update",
filters={"action_id": "action_123"}, filters={"action_id": "123"},
data={"action_done": True, "completion_time": time.time()}
) )
print(f"更新了 {updated_count} 条记录")
``` ```
4. 删除记录
### 4. 删除记录
```python ```python
# 删除过期记录
deleted_count = await database_api.db_query( deleted_count = await database_api.db_query(
ActionRecords, ActionRecords,
query_type="delete", query_type="delete",
filters={"time__lt": time.time() - 86400} # 删除24小时前的记录 filters={"action_id": "123"}
) )
print(f"删除了 {deleted_count} 条过期记录")
``` ```
5. 计数
### 5. 统计查询
```python ```python
# 统计消息数量 count = await database_api.db_query(
message_count = await database_api.db_query(
Messages, Messages,
query_type="count", query_type="count",
filters={"chat_id": chat_stream.stream_id} filters={"chat_id": chat_stream.stream_id}
) )
print(f"该聊天有 {message_count} 条消息")
``` ```
### 6. 使用便捷函 ### 2. 数据库保存
```python
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
```
保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
**Args:**
- `model_class`: Peewee模型类。
- `data`: 要保存的数据字典。
- `key_field`: 用于查找现有记录的字段名,例如"action_id"。
- `key_value`: 用于查找现有记录的字段值。
**Returns:**
- `Optional[Dict[str, Any]]`: 保存后的记录数据失败时返回None。
#### 示例
创建或更新一条记录
```python ```python
# 使用db_save进行创建或更新
record = await database_api.db_save( record = await database_api.db_save(
ActionRecords, ActionRecords,
{ {
"action_id": "action_123", "action_id": "123",
"time": time.time(), "time": time.time(),
"action_name": "TestAction", "action_name": "TestAction",
"action_done": True "action_done": True
}, },
key_field="action_id", key_field="action_id",
key_value="action_123" key_value="123"
) )
```
# 使用db_get进行简单查询 ### 3. 数据库获取
recent_messages = await database_api.db_get( ```python
async def db_get(
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
```
从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
**Args:**
- `model_class`: Peewee模型类。
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
- `limit`: 限制结果数量。
- `order_by`: 排序字段,使用字段名,前缀'-'表示降序。
- `single_result`: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
**Returns:**
- `Union[List[Dict], Dict, None]`: 查询结果列表或单个结果(如果`single_result=True`失败时返回None。
#### 示例
1. 获取单个记录
```python
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
```
2. 获取最近10条记录
```python
records = await database_api.db_get(
Messages, Messages,
filters={"chat_id": chat_stream.stream_id}, filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time", order_by="-time",
limit=5
) )
``` ```
## 高级用法 ### 4. 动作信息存储
### 复杂查询示例
```python ```python
# 查询特定用户在特定时间段的消息 async def store_action_info(
user_messages = await database_api.db_query( chat_stream=None,
Messages, action_build_into_prompt: bool = False,
query_type="get", action_prompt_display: str = "",
filters={ action_done: bool = True,
"user_id": "123456", thinking_id: str = "",
"time__gte": start_time, # 大于等于开始时间 action_data: Optional[dict] = None,
"time__lt": end_time # 小于结束时间 action_name: str = "",
}, ) -> Optional[Dict[str, Any]]:
order_by=["-time"], ```
limit=50 存储动作信息到数据库,是一种针对 Action 的 `db_save()` 的封装函数。
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
**Args:**
- `chat_stream`: 聊天流对象包含聊天ID等信息。
- `action_build_into_prompt`: 是否将动作信息构建到提示中。
- `action_prompt_display`: 动作提示的显示文本。
- `action_done`: 动作是否完成。
- `thinking_id`: 思考过程的ID。
- `action_data`: 动作的数据字典。
- `action_name`: 动作的名称。
**Returns:**
- `Optional[Dict[str, Any]]`: 存储后的记录数据失败时返回None。
#### 示例
```python
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
) )
```
# 批量处理
for message in user_messages:
print(f"消息内容: {message['plain_text']}")
print(f"发送时间: {message['time']}")
```
### 插件中的数据持久化
```python
from src.plugin_system.base import BasePlugin
from src.plugin_system.apis import database_api
class DataPlugin(BasePlugin):
async def handle_action(self, action_data, chat_stream):
# 保存插件数据
plugin_data = {
"plugin_name": self.plugin_name,
"chat_id": chat_stream.stream_id,
"data": json.dumps(action_data),
"created_time": time.time()
}
# 使用自定义表模型(需要先定义)
record = await database_api.db_save(
PluginData, # 假设的插件数据模型
plugin_data,
key_field="plugin_name",
key_value=self.plugin_name
)
return {"success": True, "record_id": record["id"]}
```
## 数据模型
### 常用模型类
系统提供了以下常用的数据模型:
- `Messages`:消息记录
- `ActionRecords`:动作记录
- `UserInfo`:用户信息
- `GroupInfo`:群组信息
### 字段说明
#### Messages模型主要字段
- `message_id`消息ID
- `chat_id`聊天ID
- `user_id`用户ID
- `plain_text`:纯文本内容
- `time`:时间戳
#### ActionRecords模型主要字段
- `action_id`动作ID
- `action_name`:动作名称
- `action_done`:是否完成
- `time`:创建时间
## 注意事项
1. **异步操作**所有数据库API都是异步的必须使用`await`
2. **错误处理**函数内置错误处理失败时返回None或空列表
3. **数据类型**:返回的都是字典格式的数据,不是模型对象
4. **性能考虑**:使用`limit`参数避免查询大量数据
5. **过滤条件**支持简单的等值过滤复杂查询需要使用原生Peewee语法
6. **事务**如需事务支持建议直接使用Peewee的事务功能

View File

@@ -6,11 +6,13 @@
```python ```python
from src.plugin_system.apis import emoji_api from src.plugin_system.apis import emoji_api
# 或者
from src.plugin_system import emoji_api
``` ```
## 🆕 **二步走识别优化** ## 二步走识别优化
新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案: 从新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
### **收到表情包时的识别流程** ### **收到表情包时的识别流程**
1. **第一步**VLM视觉分析 - 生成详细描述 1. **第一步**VLM视觉分析 - 生成详细描述
@@ -30,217 +32,84 @@ from src.plugin_system.apis import emoji_api
## 主要功能 ## 主要功能
### 1. 表情包获取 ### 1. 表情包获取
```python
#### `get_by_description(description: str) -> Optional[Tuple[str, str, str]]` async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
```
根据场景描述选择表情包 根据场景描述选择表情包
**参数** **Args**
- `description`场景描述文本,例如"开心的大笑"、"轻微的讽刺"、"表示无奈和沮丧"等 - `description`表情包的描述文本,例如"开心"、"难过"、"愤怒"等
**返回** **Returns**
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 匹配的场景) 或 None - `Optional[Tuple[str, str, str]]`一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到匹配的表情包则返回None
**示例:** #### 示例
```python ```python
emoji_result = await emoji_api.get_by_description("开心的大笑") emoji_result = await emoji_api.get_by_description("大笑")
if emoji_result: if emoji_result:
emoji_base64, description, matched_scene = emoji_result emoji_base64, description, matched_scene = emoji_result
print(f"获取到表情包: {description}, 场景: {matched_scene}") print(f"获取到表情包: {description}, 场景: {matched_scene}")
# 可以将emoji_base64用于发送表情包 # 可以将emoji_base64用于发送表情包
``` ```
#### `get_random() -> Optional[Tuple[str, str, str]]` ### 2. 随机获取表情包
随机获取表情包
**返回:**
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 随机场景) 或 None
**示例:**
```python ```python
random_emoji = await emoji_api.get_random() async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
if random_emoji:
emoji_base64, description, scene = random_emoji
print(f"随机表情包: {description}")
``` ```
随机获取指定数量的表情包
#### `get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]` **Args**
根据场景关键词获取表情包 - `count`:要获取表情包数量默认为1
**参数** **Returns**
- `emotion`:场景关键词,如"大笑"、"讽刺"、"无奈"等 - `List[Tuple[str, str, str]]`:一个包含多个表情包的列表,每个元素是一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到或出错则返回空列表
**返回:** ### 3. 根据情感获取表情包
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 匹配的场景) 或 None
**示例:**
```python ```python
emoji_result = await emoji_api.get_by_emotion("讽刺") async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
if emoji_result:
emoji_base64, description, scene = emoji_result
# 发送讽刺表情包
``` ```
根据情感标签获取表情包
### 2. 表情包信息查询 **Args**
- `emotion`:情感标签,例如"开心"、"悲伤"、"愤怒"等
#### `get_count() -> int` **Returns**
获取表情包数量 - `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签)如果未找到则返回None
**返回:** ### 4. 获取表情包数量
- `int`:当前可用的表情包数量 ```python
def get_count() -> int:
```
获取当前可用表情包的数量
#### `get_info() -> dict` ### 5. 获取表情包系统信息
获取表情包系统信息 ```python
def get_info() -> Dict[str, Any]:
```
获取表情包系统的基本信息
**返回** **Returns**
- `dict`:包含表情包数量、最大数量等信息 - `Dict[str, Any]`:包含表情包数量、描述等信息的字典,包含以下键:
- `current_count`:当前表情包数量
- `max_count`:最大表情包数量
- `available_emojis`:当前可用的表情包数量
**返回字典包含:** ### 6. 获取所有可用的情感标签
- `current_count`:当前表情包数量 ```python
- `max_count`:最大表情包数量 def get_emotions() -> List[str]:
- `available_emojis`:可用表情包数量 ```
获取所有可用的情感标签 **(已经去重)**
#### `get_emotions() -> list` ### 7. 获取所有表情包描述
获取所有可用的场景关键词 ```python
def get_descriptions() -> List[str]:
**返回:** ```
- `list`:所有表情包的场景关键词列表(去重)
#### `get_descriptions() -> list`
获取所有表情包的描述列表 获取所有表情包的描述列表
**返回:**
- `list`:所有表情包的描述文本列表
## 使用示例
### 1. 智能表情包选择
```python
from src.plugin_system.apis import emoji_api
async def send_emotion_response(message_text: str, chat_stream):
"""根据消息内容智能选择表情包回复"""
# 分析消息场景
if "哈哈" in message_text or "好笑" in message_text:
emoji_result = await emoji_api.get_by_description("开心的大笑")
elif "无语" in message_text or "算了" in message_text:
emoji_result = await emoji_api.get_by_description("表示无奈和沮丧")
elif "呵呵" in message_text or "是吗" in message_text:
emoji_result = await emoji_api.get_by_description("轻微的讽刺")
elif "生气" in message_text or "愤怒" in message_text:
emoji_result = await emoji_api.get_by_description("愤怒和不满")
else:
# 随机选择一个表情包
emoji_result = await emoji_api.get_random()
if emoji_result:
emoji_base64, description, scene = emoji_result
# 使用send_api发送表情包
from src.plugin_system.apis import send_api
success = await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
return success
return False
```
### 2. 表情包管理功能
```python
async def show_emoji_stats():
"""显示表情包统计信息"""
# 获取基本信息
count = emoji_api.get_count()
info = emoji_api.get_info()
scenes = emoji_api.get_emotions() # 实际返回的是场景关键词
stats = f"""
📊 表情包统计信息:
- 总数量: {count}
- 可用数量: {info['available_emojis']}
- 最大容量: {info['max_count']}
- 支持场景: {len(scenes)}
🎭 支持的场景关键词: {', '.join(scenes[:10])}{'...' if len(scenes) > 10 else ''}
"""
return stats
```
### 3. 表情包测试功能
```python
async def test_emoji_system():
"""测试表情包系统的各种功能"""
print("=== 表情包系统测试 ===")
# 测试场景描述查找
test_descriptions = ["开心的大笑", "轻微的讽刺", "表示无奈和沮丧", "愤怒和不满"]
for desc in test_descriptions:
result = await emoji_api.get_by_description(desc)
if result:
_, description, scene = result
print(f"✅ 场景'{desc}' -> {description} ({scene})")
else:
print(f"❌ 场景'{desc}' -> 未找到")
# 测试关键词查找
scenes = emoji_api.get_emotions()
if scenes:
test_scene = scenes[0]
result = await emoji_api.get_by_emotion(test_scene)
if result:
print(f"✅ 关键词'{test_scene}' -> 找到匹配表情包")
# 测试随机获取
random_result = await emoji_api.get_random()
if random_result:
print("✅ 随机获取 -> 成功")
print(f"📊 系统信息: {emoji_api.get_info()}")
```
### 4. 在Action中使用表情包
```python
from src.plugin_system.base import BaseAction
class EmojiAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 从action_data获取场景描述或关键词
scene_keyword = action_data.get("scene", "")
scene_description = action_data.get("description", "")
emoji_result = None
# 优先使用具体的场景描述
if scene_description:
emoji_result = await emoji_api.get_by_description(scene_description)
# 其次使用场景关键词
elif scene_keyword:
emoji_result = await emoji_api.get_by_emotion(scene_keyword)
# 最后随机选择
else:
emoji_result = await emoji_api.get_random()
if emoji_result:
emoji_base64, description, scene = emoji_result
return {
"success": True,
"emoji_base64": emoji_base64,
"description": description,
"scene": scene
}
return {"success": False, "message": "未找到合适的表情包"}
```
## 场景描述说明 ## 场景描述说明
### 常用场景描述 ### 常用场景描述
表情包系统支持多种具体的场景描述,常见的包括 表情包系统支持多种具体的场景描述,举例如下
- **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈 - **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈
- **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头 - **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头
@@ -248,8 +117,8 @@ class EmojiAction(BaseAction):
- **惊讶类场景**:震惊的表情、意外的发现、困惑的思考 - **惊讶类场景**:震惊的表情、意外的发现、困惑的思考
- **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子 - **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子
### 场景关键词示例 ### 情感关键词示例
系统支持的场景关键词包括 系统支持的情感关键词举例如下
- 大笑、微笑、兴奋、手舞足蹈 - 大笑、微笑、兴奋、手舞足蹈
- 无奈、沮丧、讽刺、无语、摇头 - 无奈、沮丧、讽刺、无语、摇头
- 愤怒、不满、生气、瞪视、抓狂 - 愤怒、不满、生气、瞪视、抓狂
@@ -263,9 +132,9 @@ class EmojiAction(BaseAction):
## 注意事项 ## 注意事项
1. **异步函数**获取表情包的函数是异步的,需要使用 `await` 1. **异步函数**部分函数是异步的,需要使用 `await`
2. **返回格式**表情包以base64编码返回可直接用于发送 2. **返回格式**表情包以base64编码返回可直接用于发送
3. **错误处理**所有函数都有错误处理失败时返回None或默认值 3. **错误处理**所有函数都有错误处理失败时返回None,空列表或默认值
4. **使用统计**:系统会记录表情包的使用次数 4. **使用统计**:系统会记录表情包的使用次数
5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在 5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在
6. **编码格式**返回的是base64编码的图片数据可直接用于网络传输 6. **编码格式**返回的是base64编码的图片数据可直接用于网络传输

View File

@@ -6,241 +6,151 @@
```python ```python
from src.plugin_system.apis import generator_api from src.plugin_system.apis import generator_api
# 或者
from src.plugin_system import generator_api
``` ```
## 主要功能 ## 主要功能
### 1. 回复器获取 ### 1. 回复器获取
```python
#### `get_replyer(chat_stream=None, platform=None, chat_id=None, is_group=True)` def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
```
获取回复器对象 获取回复器对象
**参数:** 优先使用chat_stream如果没有则使用chat_id直接查找。
- `chat_stream`:聊天流对象(优先)
- `platform`:平台名称,如"qq"
- `chat_id`聊天ID群ID或用户ID
- `is_group`:是否为群聊
**返回:** 使用 ReplyerManager 来管理实例,避免重复创建。
- `DefaultReplyer`回复器对象如果获取失败则返回None
**示例:** **Args:**
- `chat_stream`: 聊天流对象
- `chat_id`: 聊天ID实际上就是`stream_id`
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
- `request_type`: 请求类型用于记录LLM使用情况可以不写
**Returns:**
- `DefaultReplyer`: 回复器对象如果获取失败则返回None
#### 示例
```python ```python
# 使用聊天流获取回复器 # 使用聊天流获取回复器
replyer = generator_api.get_replyer(chat_stream=chat_stream) replyer = generator_api.get_replyer(chat_stream=chat_stream)
# 使用平台和ID获取回复器 # 使用平台和ID获取回复器
replyer = generator_api.get_replyer( replyer = generator_api.get_replyer(chat_id="123456789")
platform="qq",
chat_id="123456789",
is_group=True
)
``` ```
### 2. 回复生成 ### 2. 回复生成
```python
#### `generate_reply(chat_stream=None, action_data=None, platform=None, chat_id=None, is_group=True)` async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
```
生成回复 生成回复
**参数:** 优先使用chat_stream如果没有则使用chat_id直接查找。
- `chat_stream`:聊天流对象(优先)
- `action_data`:动作数据
- `platform`:平台名称(备用)
- `chat_id`聊天ID备用
- `is_group`:是否为群聊(备用)
**返回:** **Args:**
- `Tuple[bool, List[Tuple[str, Any]]]`(是否成功, 回复集合) - `chat_stream`: 聊天流对象
- `chat_id`: 聊天ID实际上就是`stream_id`
- `action_data`: 动作数据(向下兼容,包含`reply_to``extra_info`
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
- `extra_info`: 附加信息
- `available_actions`: 可用动作字典,格式为 `{"action_name": ActionInfo}`
- `enable_tool`: 是否启用工具
- `enable_splitter`: 是否启用分割器
- `enable_chinese_typo`: 是否启用中文错别字
- `return_prompt`: 是否返回提示词
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
- `request_type`: 请求类型可选记录LLM使用
- `request_type`: 请求类型用于记录LLM使用情况
**示例:** **Returns:**
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
#### 示例
```python ```python
success, reply_set = await generator_api.generate_reply( success, reply_set, prompt = await generator_api.generate_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
action_data={"message": "你好", "intent": "greeting"} action_data=action_data,
reply_to="麦麦:你好",
available_actions=action_info,
enable_tool=True,
return_prompt=True
) )
if success: if success:
for reply_type, reply_content in reply_set: for reply_type, reply_content in reply_set:
print(f"回复类型: {reply_type}, 内容: {reply_content}") print(f"回复类型: {reply_type}, 内容: {reply_content}")
if prompt:
print(f"使用的提示词: {prompt}")
``` ```
#### `rewrite_reply(chat_stream=None, reply_data=None, platform=None, chat_id=None, is_group=True)` ### 3. 回复重写
重写回复
**参数:**
- `chat_stream`:聊天流对象(优先)
- `reply_data`:回复数据
- `platform`:平台名称(备用)
- `chat_id`聊天ID备用
- `is_group`:是否为群聊(备用)
**返回:**
- `Tuple[bool, List[Tuple[str, Any]]]`(是否成功, 回复集合)
**示例:**
```python ```python
success, reply_set = await generator_api.rewrite_reply( async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
```
重写回复,使用新的内容替换旧的回复内容。
优先使用chat_stream如果没有则使用chat_id直接查找。
**Args:**
- `chat_stream`: 聊天流对象
- `reply_data`: 回复数据,包含`raw_reply`, `reason``reply_to`**(向下兼容备用,当其他参数缺失时从此获取)**
- `chat_id`: 聊天ID实际上就是`stream_id`
- `enable_splitter`: 是否启用分割器
- `enable_chinese_typo`: 是否启用中文错别字
- `model_set_with_weight`: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
- `raw_reply`: 原始回复内容
- `reason`: 重写原因
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
**Returns:**
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
#### 示例
```python
success, reply_set, prompt = await generator_api.rewrite_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
reply_data={"original_text": "原始回复", "style": "more_friendly"} raw_reply="原始回复内容",
reason="重写原因",
reply_to="麦麦:你好",
return_prompt=True
) )
if success:
for reply_type, reply_content in reply_set:
print(f"回复类型: {reply_type}, 内容: {reply_content}")
if prompt:
print(f"使用的提示词: {prompt}")
``` ```
## 使用示例 ## 回复集合`reply_set`格式
### 1. 基础回复生成
```python
from src.plugin_system.apis import generator_api
async def generate_greeting_reply(chat_stream, user_name):
"""生成问候回复"""
action_data = {
"intent": "greeting",
"user_name": user_name,
"context": "morning_greeting"
}
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
if success and reply_set:
# 获取第一个回复
reply_type, reply_content = reply_set[0]
return reply_content
return "你好!" # 默认回复
```
### 2. 在Action中使用回复生成器
```python
from src.plugin_system.base import BaseAction
class ChatAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 准备回复数据
reply_context = {
"message_type": "response",
"user_input": action_data.get("user_message", ""),
"intent": action_data.get("intent", ""),
"entities": action_data.get("entities", {}),
"context": self.get_conversation_context(chat_stream)
}
# 生成回复
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=reply_context
)
if success:
return {
"success": True,
"replies": reply_set,
"generated_count": len(reply_set)
}
return {
"success": False,
"error": "回复生成失败",
"fallback_reply": "抱歉,我现在无法理解您的消息。"
}
```
### 3. 多样化回复生成
```python
async def generate_diverse_replies(chat_stream, topic, count=3):
"""生成多个不同风格的回复"""
styles = ["formal", "casual", "humorous"]
all_replies = []
for i, style in enumerate(styles[:count]):
action_data = {
"topic": topic,
"style": style,
"variation": i
}
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
if success and reply_set:
all_replies.extend(reply_set)
return all_replies
```
### 4. 回复重写功能
```python
async def improve_reply(chat_stream, original_reply, improvement_type="more_friendly"):
"""改进原始回复"""
reply_data = {
"original_text": original_reply,
"improvement_type": improvement_type,
"target_audience": "young_users",
"tone": "positive"
}
success, improved_replies = await generator_api.rewrite_reply(
chat_stream=chat_stream,
reply_data=reply_data
)
if success and improved_replies:
# 返回改进后的第一个回复
_, improved_content = improved_replies[0]
return improved_content
return original_reply # 如果改进失败,返回原始回复
```
### 5. 条件回复生成
```python
async def conditional_reply_generation(chat_stream, user_message, user_emotion):
"""根据用户情感生成条件回复"""
# 根据情感调整回复策略
if user_emotion == "sad":
action_data = {
"intent": "comfort",
"tone": "empathetic",
"style": "supportive"
}
elif user_emotion == "angry":
action_data = {
"intent": "calm",
"tone": "peaceful",
"style": "understanding"
}
else:
action_data = {
"intent": "respond",
"tone": "neutral",
"style": "helpful"
}
action_data["user_message"] = user_message
action_data["user_emotion"] = user_emotion
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
return reply_set if success else []
```
## 回复集合格式
### 回复类型 ### 回复类型
生成的回复集合包含多种类型的回复: 生成的回复集合包含多种类型的回复:
@@ -260,82 +170,32 @@ reply_set = [
] ]
``` ```
## 高级用法 ### 4. 自定义提示词回复
### 1. 自定义回复器配置
```python ```python
async def generate_with_custom_config(chat_stream, action_data): async def generate_response_custom(
"""使用自定义配置生成回复""" chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
# 获取回复器 model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
replyer = generator_api.get_replyer(chat_stream=chat_stream) prompt: str = "",
) -> Optional[str]:
if replyer:
# 可以访问回复器的内部方法
success, reply_set = await replyer.generate_reply_with_context(
reply_data=action_data,
# 可以传递额外的配置参数
)
return success, reply_set
return False, []
``` ```
生成自定义提示词回复
### 2. 回复质量评估 优先使用chat_stream如果没有则使用chat_id直接查找。
```python **Args:**
async def generate_and_evaluate_replies(chat_stream, action_data): - `chat_stream`: 聊天流对象
"""生成回复并评估质量""" - `chat_id`: 聊天ID备用
- `model_set_with_weight`: 模型集合配置列表
success, reply_set = await generator_api.generate_reply( - `prompt`: 自定义提示词
chat_stream=chat_stream,
action_data=action_data
)
if success:
evaluated_replies = []
for reply_type, reply_content in reply_set:
# 简单的质量评估
quality_score = evaluate_reply_quality(reply_content)
evaluated_replies.append({
"type": reply_type,
"content": reply_content,
"quality": quality_score
})
# 按质量排序
evaluated_replies.sort(key=lambda x: x["quality"], reverse=True)
return evaluated_replies
return []
def evaluate_reply_quality(reply_content): **Returns:**
"""简单的回复质量评估""" - `Optional[str]`: 生成的自定义回复内容如果生成失败则返回None
if not reply_content:
return 0
score = 50 # 基础分
# 长度适中加分
if 5 <= len(reply_content) <= 100:
score += 20
# 包含积极词汇加分
positive_words = ["", "", "不错", "感谢", "开心"]
for word in positive_words:
if word in reply_content:
score += 10
break
return min(score, 100)
```
## 注意事项 ## 注意事项
1. **异步操作**所有生成函数是异步的,须使用`await` 1. **异步操作**部分函数是异步的,须使用`await`
2. **错误处理**函数内置错误处理失败时返回False和空列表 2. **聊天流依赖**需要有效的聊天流对象才能正常工作
3. **聊天流依赖**需要有效的聊天流对象才能正常工作 3. **性能考虑**回复生成可能需要一些时间特别是使用LLM时
4. **性能考虑**回复生成可能需要一些时间特别是使用LLM时 4. **回复格式**返回的回复集合是元组列表,包含类型和内容
5. **回复格式**返回的回复集合是元组列表,包含类型和内容 5. **上下文感知**生成器会考虑聊天上下文和历史消息,除非你用的是自定义提示词。
6. **上下文感知**:生成器会考虑聊天上下文和历史消息

View File

@@ -6,239 +6,60 @@ LLM API模块提供与大语言模型交互的功能让插件能够使用系
```python ```python
from src.plugin_system.apis import llm_api from src.plugin_system.apis import llm_api
# 或者
from src.plugin_system import llm_api
``` ```
## 主要功能 ## 主要功能
### 1. 模型管理 ### 1. 查询可用模型
#### `get_available_models() -> Dict[str, Any]`
获取所有可用的模型配置
**返回:**
- `Dict[str, Any]`模型配置字典key为模型名称value为模型配置
**示例:**
```python ```python
models = llm_api.get_available_models() def get_available_models() -> Dict[str, TaskConfig]:
for model_name, model_config in models.items():
print(f"模型: {model_name}")
print(f"配置: {model_config}")
``` ```
获取所有可用的模型配置。
### 2. 内容生成 **Return**
- `Dict[str, TaskConfig]`模型配置字典key为模型名称value为模型配置对象。
#### `generate_with_model(prompt, model_config, request_type="plugin.generate", **kwargs)` ### 2. 使用模型生成内容
使用指定模型生成内容
**参数:**
- `prompt`:提示词
- `model_config`:模型配置(从 get_available_models 获取)
- `request_type`:请求类型标识
- `**kwargs`其他模型特定参数如temperature、max_tokens等
**返回:**
- `Tuple[bool, str, str, str]`(是否成功, 生成的内容, 推理过程, 模型名称)
**示例:**
```python ```python
models = llm_api.get_available_models() async def generate_with_model(
default_model = models.get("default") prompt: str,
model_config: TaskConfig,
if default_model: request_type: str = "plugin.generate",
success, response, reasoning, model_name = await llm_api.generate_with_model( temperature: Optional[float] = None,
prompt="请写一首关于春天的诗", max_tokens: Optional[int] = None,
model_config=default_model, ) -> Tuple[bool, str, str, str]:
temperature=0.7,
max_tokens=200
)
if success:
print(f"生成内容: {response}")
print(f"使用模型: {model_name}")
``` ```
使用指定模型生成内容。
## 使用示例 **Args:**
- `prompt`:提示词。
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
- `request_type`:请求类型标识,默认为 `"plugin.generate"`
- `temperature`:生成内容的温度设置,影响输出的随机性。
- `max_tokens`生成内容的最大token数。
### 1. 基础文本生成 **Return**
- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
### 3. 有Tool情况下使用模型生成内容
```python ```python
from src.plugin_system.apis import llm_api async def generate_with_model_with_tools(
prompt: str,
async def generate_story(topic: str): model_config: TaskConfig,
"""生成故事""" tool_options: List[Dict[str, Any]] | None = None,
models = llm_api.get_available_models() request_type: str = "plugin.generate",
model = models.get("default") temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
if not model: ) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
return "未找到可用模型"
prompt = f"请写一个关于{topic}的短故事大约100字左右。"
success, story, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=model,
request_type="story.generate",
temperature=0.8,
max_tokens=150
)
return story if success else "故事生成失败"
``` ```
使用指定模型生成内容,并支持工具调用。
### 2. 在Action中使用LLM **Args:**
- `prompt`:提示词。
```python - `model_config`:模型配置对象(从 `get_available_models` 获取)。
from src.plugin_system.base import BaseAction - `tool_options`:工具选项列表,包含可用工具的配置,字典为每一个工具的定义,参见[tool-components.md](../tool-components.md#属性说明),可用`tool_api.get_llm_available_tool_definitions()`获取并选择。
- `request_type`:请求类型标识,默认为 `"plugin.generate"`
class LLMAction(BaseAction): - `temperature`:生成内容的温度设置,影响输出的随机性。
async def execute(self, action_data, chat_stream): - `max_tokens`生成内容的最大token数。
# 获取用户输入
user_input = action_data.get("user_message", "")
intent = action_data.get("intent", "chat")
# 获取模型配置
models = llm_api.get_available_models()
model = models.get("default")
if not model:
return {"success": False, "error": "未配置LLM模型"}
# 构建提示词
prompt = self.build_prompt(user_input, intent)
# 生成回复
success, response, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=model,
request_type=f"plugin.{self.plugin_name}",
temperature=0.7
)
if success:
return {
"success": True,
"response": response,
"model_used": model_name,
"reasoning": reasoning
}
return {"success": False, "error": response}
def build_prompt(self, user_input: str, intent: str) -> str:
"""构建提示词"""
base_prompt = "你是一个友善的AI助手。"
if intent == "question":
return f"{base_prompt}\n\n用户问题:{user_input}\n\n请提供准确、有用的回答:"
elif intent == "chat":
return f"{base_prompt}\n\n用户说:{user_input}\n\n请进行自然的对话:"
else:
return f"{base_prompt}\n\n用户输入:{user_input}\n\n请回复:"
```
### 3. 多模型对比
```python
async def compare_models(prompt: str):
"""使用多个模型生成内容并对比"""
models = llm_api.get_available_models()
results = {}
for model_name, model_config in models.items():
success, response, reasoning, actual_model = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="comparison.test"
)
results[model_name] = {
"success": success,
"response": response,
"model": actual_model,
"reasoning": reasoning
}
return results
```
### 4. 智能对话插件
```python
class ChatbotPlugin(BasePlugin):
async def handle_action(self, action_data, chat_stream):
user_message = action_data.get("message", "")
# 获取历史对话上下文
context = self.get_conversation_context(chat_stream)
# 构建对话提示词
prompt = self.build_conversation_prompt(user_message, context)
# 获取模型配置
models = llm_api.get_available_models()
chat_model = models.get("chat", models.get("default"))
if not chat_model:
return {"success": False, "message": "聊天模型未配置"}
# 生成回复
success, response, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=chat_model,
request_type="chat.conversation",
temperature=0.8,
max_tokens=500
)
if success:
# 保存对话历史
self.save_conversation(chat_stream, user_message, response)
return {
"success": True,
"reply": response,
"model": model_name
}
return {"success": False, "message": "回复生成失败"}
def build_conversation_prompt(self, user_message: str, context: list) -> str:
"""构建对话提示词"""
prompt = "你是一个有趣、友善的聊天机器人。请自然地回复用户的消息。\n\n"
# 添加历史对话
if context:
prompt += "对话历史:\n"
for msg in context[-5:]: # 只保留最近5条
prompt += f"用户: {msg['user']}\n机器人: {msg['bot']}\n"
prompt += "\n"
prompt += f"用户: {user_message}\n机器人: "
return prompt
```
## 模型配置说明
### 常用模型类型
- `default`:默认模型
- `chat`:聊天专用模型
- `creative`:创意生成模型
- `code`:代码生成模型
### 配置参数
LLM模型支持的常用参数
- `temperature`控制输出随机性0.0-1.0
- `max_tokens`:最大生成长度
- `top_p`:核采样参数
- `frequency_penalty`:频率惩罚
- `presence_penalty`:存在惩罚
## 注意事项
1. **异步操作**LLM生成是异步的必须使用`await`
2. **错误处理**生成失败时返回False和错误信息
3. **配置依赖**:需要正确配置模型才能使用
4. **请求类型**建议为不同用途设置不同的request_type
5. **性能考虑**LLM调用可能较慢考虑超时和缓存
6. **成本控制**注意控制max_tokens以控制成本

View File

@@ -0,0 +1,29 @@
# Logging API
Logging API模块提供了获取本体logger的功能允许插件记录日志信息。
## 导入方式
```python
from src.plugin_system.apis import get_logger
# 或者
from src.plugin_system import get_logger
```
## 主要功能
### 1. 获取本体logger
```python
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
```
获取本体logger实例。
**Args:**
- `name` (str): 日志记录器的名称。
**Returns:**
- 一个logger实例有以下方法:
- `debug`
- `info`
- `warning`
- `error`
- `critical`

View File

@@ -1,11 +1,13 @@
# 消息API # 消息API
> 消息API提供了强大的消息查询、计数和格式化功能让你轻松处理聊天消息数据。 消息API提供了强大的消息查询、计数和格式化功能让你轻松处理聊天消息数据。
## 导入方式 ## 导入方式
```python ```python
from src.plugin_system.apis import message_api from src.plugin_system.apis import message_api
# 或者
from src.plugin_system import message_api
``` ```
## 功能概述 ## 功能概述
@@ -15,297 +17,356 @@ from src.plugin_system.apis import message_api
- **消息计数** - 统计新消息数量 - **消息计数** - 统计新消息数量
- **消息格式化** - 将消息转换为可读格式 - **消息格式化** - 将消息转换为可读格式
--- ## 主要功能
## 消息查询API ### 1. 按照事件查询消息
```python
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
```
获取指定时间范围内的消息。
### 按时间查询消息 **Args:**
#### `get_messages_by_time(start_time, end_time, limit=0, limit_mode="latest")`
获取指定时间范围内的消息
**参数:**
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳 - `end_time` (float): 结束时间戳
- `limit` (int): 限制返回消息数量0为不限制 - `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 - `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**返回:** `List[Dict[str, Any]]` - 消息列表 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**示例:** 消息列表中包含的键与`Messages`类的属性一致。(位于`src.common.database.database_model`
### 2. 获取指定聊天中指定时间范围内的信息
```python ```python
import time def get_messages_by_time_in_chat(
chat_id: str,
# 获取最近24小时的消息 start_time: float,
now = time.time() end_time: float,
yesterday = now - 24 * 3600 limit: int = 0,
messages = message_api.get_messages_by_time(yesterday, now, limit=50) limit_mode: str = "latest",
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
``` ```
获取指定聊天中指定时间范围内的消息。
### 按聊天查询消息 **Args:**
#### `get_messages_by_time_in_chat(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
获取指定聊天中指定时间范围内的消息
**参数:**
- `chat_id` (str): 聊天ID
- 其他参数同上
**示例:**
```python
# 获取某个群聊最近的100条消息
messages = message_api.get_messages_by_time_in_chat(
chat_id="123456789",
start_time=yesterday,
end_time=now,
limit=100
)
```
#### `get_messages_by_time_in_chat_inclusive(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
获取指定聊天中指定时间范围内的消息(包含边界时间点)
`get_messages_by_time_in_chat` 类似,但包含边界时间戳的消息。
#### `get_recent_messages(chat_id, hours=24.0, limit=100, limit_mode="latest")`
获取指定聊天中最近一段时间的消息(便捷方法)
**参数:**
- `chat_id` (str): 聊天ID
- `hours` (float): 最近多少小时默认24小时
- `limit` (int): 限制返回消息数量默认100条
- `limit_mode` (str): 限制模式
**示例:**
```python
# 获取最近6小时的消息
recent_messages = message_api.get_recent_messages(
chat_id="123456789",
hours=6.0,
limit=50
)
```
### 按用户查询消息
#### `get_messages_by_time_in_chat_for_users(chat_id, start_time, end_time, person_ids, limit=0, limit_mode="latest")`
获取指定聊天中指定用户在指定时间范围内的消息
**参数:**
- `chat_id` (str): 聊天ID - `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳 - `end_time` (float): 结束时间戳
- `person_ids` (list): 用户ID列表 - `limit` (int): 限制返回消息数量0为不限制
- `limit` (int): 限制返回消息数量 - `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `limit_mode` (str): 限制模式 - `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**示例:** **Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 3. 获取指定聊天中指定时间范围内的信息(包含边界)
```python ```python
# 获取特定用户的消息 def get_messages_by_time_in_chat_inclusive(
user_messages = message_api.get_messages_by_time_in_chat_for_users( chat_id: str,
chat_id="123456789", start_time: float,
start_time=yesterday, end_time: float,
end_time=now, limit: int = 0,
person_ids=["user1", "user2"] limit_mode: str = "latest",
) filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
``` ```
获取指定聊天中指定时间范围内的消息(包含边界)。
#### `get_messages_by_time_for_users(start_time, end_time, person_ids, limit=0, limit_mode="latest")` **Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳(包含)
- `end_time` (float): 结束时间戳(包含)
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
- `filter_command` (bool): 是否过滤命令消息默认False
获取指定用户在所有聊天中指定时间范围内的消息 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 其他查询方法
#### `get_random_chat_messages(start_time, end_time, limit=0, limit_mode="latest")` ### 4. 获取指定聊天中指定用户在指定时间范围内的消息
```python
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
```
获取指定聊天中指定用户在指定时间范围内的消息。
随机选择一个聊天,返回该聊天在指定时间范围内的消息 **Args:**
#### `get_messages_before_time(timestamp, limit=0)`
获取指定时间戳之前的消息
#### `get_messages_before_time_in_chat(chat_id, timestamp, limit=0)`
获取指定聊天中指定时间戳之前的消息
#### `get_messages_before_time_for_users(timestamp, person_ids, limit=0)`
获取指定用户在指定时间戳之前的消息
---
## 消息计数API
### `count_new_messages(chat_id, start_time=0.0, end_time=None)`
计算指定聊天中从开始时间到结束时间的新消息数量
**参数:**
- `chat_id` (str): 聊天ID - `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳如果为None则使用当前时间 - `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
**返回:** `int` - 新消息数量 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**示例:**
### 5. 随机选择一个聊天,返回该聊天在指定时间范围内的消息
```python ```python
# 计算最近1小时的新消息数 def get_random_chat_messages(
import time start_time: float,
now = time.time() end_time: float,
hour_ago = now - 3600 limit: int = 0,
new_count = message_api.count_new_messages("123456789", hour_ago, now) limit_mode: str = "latest",
print(f"最近1小时有{new_count}条新消息") filter_mai: bool = False,
) -> List[Dict[str, Any]]:
``` ```
随机选择一个聊天,返回该聊天在指定时间范围内的消息。
### `count_new_messages_for_users(chat_id, start_time, end_time, person_ids)` **Args:**
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
计算指定聊天中指定用户从开始时间到结束时间的新消息数量 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
---
## 消息格式化API ### 6. 获取指定用户在所有聊天中指定时间范围内的消息
```python
def get_messages_by_time_for_users(
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
```
获取指定用户在所有聊天中指定时间范围内的消息。
### `build_readable_messages_to_str(messages, **options)` **Args:**
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
将消息列表构建成可读的字符串 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**参数:**
### 7. 获取指定时间戳之前的消息
```python
def get_messages_before_time(
timestamp: float,
limit: int = 0,
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定时间戳之前的消息。
**Args:**
- `timestamp` (float): 时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 8. 获取指定聊天中指定时间戳之前的消息
```python
def get_messages_before_time_in_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定聊天中指定时间戳之前的消息。
**Args:**
- `chat_id` (str): 聊天ID
- `timestamp` (float): 时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 9. 获取指定用户在指定时间戳之前的消息
```python
def get_messages_before_time_for_users(
timestamp: float,
person_ids: List[str],
limit: int = 0,
) -> List[Dict[str, Any]]:
```
获取指定用户在指定时间戳之前的消息。
**Args:**
- `timestamp` (float): 时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 10. 获取指定聊天中最近一段时间的消息
```python
def get_recent_messages(
chat_id: str,
hours: float = 24.0,
limit: int = 100,
limit_mode: str = "latest",
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定聊天中最近一段时间的消息。
**Args:**
- `chat_id` (str): 聊天ID
- `hours` (float): 最近多少小时默认24小时
- `limit` (int): 限制返回消息数量默认100条
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 11. 计算指定聊天中从开始时间到结束时间的新消息数量
```python
def count_new_messages(
chat_id: str,
start_time: float = 0.0,
end_time: Optional[float] = None,
) -> int:
```
计算指定聊天中从开始时间到结束时间的新消息数量。
**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
- `end_time` (Optional[float]): 结束时间戳如果为None则使用当前时间
**Returns:**
- `int` - 新消息数量
### 12. 计算指定聊天中指定用户从开始时间到结束时间的新消息数量
```python
def count_new_messages_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
) -> int:
```
计算指定聊天中指定用户从开始时间到结束时间的新消息数量。
**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
**Returns:**
- `int` - 新消息数量
### 13. 将消息列表构建成可读的字符串
```python
def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
```
将消息列表构建成可读的字符串。
**Args:**
- `messages` (List[Dict[str, Any]]): 消息列表 - `messages` (List[Dict[str, Any]]): 消息列表
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"默认True - `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
- `merge_messages` (bool): 是否合并连续消息默认False - `merge_messages` (bool): 是否合并连续消息
- `timestamp_mode` (str): 时间戳显示模式,`"relative"``"absolute"`,默认`"relative"` - `timestamp_mode` (str): 时间戳显示模式,`"relative"``"absolute"`
- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息默认0.0 - `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息
- `truncate` (bool): 是否截断长消息默认False - `truncate` (bool): 是否截断长消息
- `show_actions` (bool): 是否显示动作记录默认False - `show_actions` (bool): 是否显示动作记录
**返回:** `str` - 格式化后的可读字符串 **Returns:**
- `str` - 格式化后的可读字符串
**示例:**
### 14. 将消息列表构建成可读的字符串,并返回详细信息
```python ```python
# 获取消息并格式化为可读文本 async def build_readable_messages_with_details(
messages = message_api.get_recent_messages("123456789", hours=2) messages: List[Dict[str, Any]],
readable_text = message_api.build_readable_messages_to_str( replace_bot_name: bool = True,
messages, merge_messages: bool = False,
replace_bot_name=True, timestamp_mode: str = "relative",
merge_messages=True, truncate: bool = False,
timestamp_mode="relative" ) -> Tuple[str, List[Tuple[float, str, str]]]:
)
print(readable_text)
``` ```
将消息列表构建成可读的字符串,并返回详细信息。
### `build_readable_messages_with_details(messages, **options)` 异步 **Args:**
- `messages` (List[Dict[str, Any]]): 消息列表
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
- `merge_messages` (bool): 是否合并连续消息
- `timestamp_mode` (str): 时间戳显示模式,`"relative"``"absolute"`
- `truncate` (bool): 是否截断长消息
将消息列表构建成可读的字符串,并返回详细信息 **Returns:**
- `Tuple[str, List[Tuple[float, str, str]]]` - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
**参数:**`build_readable_messages_to_str` 类似,但不包含 `read_mark``show_actions`
**返回:** `Tuple[str, List[Tuple[float, str, str]]]` - 格式化字符串和详细信息元组列表(时间戳, 昵称, 内容) ### 15. 从消息列表中提取不重复的用户ID列表
**示例:**
```python ```python
# 异步获取详细格式化信息 async def get_person_ids_from_messages(
readable_text, details = await message_api.build_readable_messages_with_details( messages: List[Dict[str, Any]],
messages, ) -> List[str]:
timestamp_mode="absolute"
)
for timestamp, nickname, content in details:
print(f"{timestamp}: {nickname} 说: {content}")
``` ```
从消息列表中提取不重复的用户ID列表。
### `get_person_ids_from_messages(messages)` 异步 **Args:**
从消息列表中提取不重复的用户ID列表
**参数:**
- `messages` (List[Dict[str, Any]]): 消息列表 - `messages` (List[Dict[str, Any]]): 消息列表
**返回:** `List[str]` - 用户ID列表 **Returns:**
- `List[str]` - 用户ID列表
**示例:**
### 16. 从消息列表中移除机器人的消息
```python ```python
# 获取参与对话的所有用户ID def filter_mai_messages(
messages = message_api.get_recent_messages("123456789") messages: List[Dict[str, Any]],
person_ids = await message_api.get_person_ids_from_messages(messages) ) -> List[Dict[str, Any]]:
print(f"参与对话的用户: {person_ids}")
``` ```
从消息列表中移除机器人的消息。
--- **Args:**
- `messages` (List[Dict[str, Any]]): 消息列表,每个元素是消息字典
## 完整使用示例 **Returns:**
- `List[Dict[str, Any]]` - 过滤后的消息列表
### 场景1统计活跃度
```python
import time
from src.plugin_system.apis import message_api
async def analyze_chat_activity(chat_id: str):
"""分析聊天活跃度"""
now = time.time()
day_ago = now - 24 * 3600
# 获取最近24小时的消息
messages = message_api.get_recent_messages(chat_id, hours=24)
# 统计消息数量
total_count = len(messages)
# 获取参与用户
person_ids = await message_api.get_person_ids_from_messages(messages)
# 格式化消息内容
readable_text = message_api.build_readable_messages_to_str(
messages[-10:], # 最后10条消息
merge_messages=True,
timestamp_mode="relative"
)
return {
"total_messages": total_count,
"active_users": len(person_ids),
"recent_chat": readable_text
}
```
### 场景2查看特定用户的历史消息
```python
def get_user_history(chat_id: str, user_id: str, days: int = 7):
"""获取用户最近N天的消息历史"""
now = time.time()
start_time = now - days * 24 * 3600
# 获取特定用户的消息
user_messages = message_api.get_messages_by_time_in_chat_for_users(
chat_id=chat_id,
start_time=start_time,
end_time=now,
person_ids=[user_id],
limit=100
)
# 格式化为可读文本
readable_history = message_api.build_readable_messages_to_str(
user_messages,
replace_bot_name=False,
timestamp_mode="absolute"
)
return readable_history
```
---
## 注意事项 ## 注意事项
1. **时间戳格式**所有时间参数都使用Unix时间戳float类型 1. **时间戳格式**所有时间参数都使用Unix时间戳float类型
2. **异步函数**`build_readable_messages_with_details``get_person_ids_from_messages` 是异步函数,需要使用 `await` 2. **异步函数**部分函数是异步函数,需要使用 `await`
3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数 3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数
4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息 4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息
5. **用户ID**`person_ids` 参数接受字符串列表,用于筛选特定用户的消息 5. **用户ID**`person_ids` 参数接受字符串列表,用于筛选特定用户的消息

View File

@@ -6,59 +6,65 @@
```python ```python
from src.plugin_system.apis import person_api from src.plugin_system.apis import person_api
# 或者
from src.plugin_system import person_api
``` ```
## 主要功能 ## 主要功能
### 1. Person ID管理 ### 1. Person ID 获取
```python
#### `get_person_id(platform: str, user_id: int) -> str` def get_person_id(platform: str, user_id: int) -> str:
```
根据平台和用户ID获取person_id 根据平台和用户ID获取person_id
**参数:** **Args:**
- `platform`:平台名称,如 "qq", "telegram" 等 - `platform`:平台名称,如 "qq", "telegram" 等
- `user_id`用户ID - `user_id`用户ID
**返回:** **Returns:**
- `str`唯一的person_idMD5哈希值 - `str`唯一的person_idMD5哈希值
**示例:** #### 示例
```python ```python
person_id = person_api.get_person_id("qq", 123456) person_id = person_api.get_person_id("qq", 123456)
print(f"Person ID: {person_id}")
``` ```
### 2. 用户信息查询 ### 2. 用户信息查询
```python
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
```
查询单个用户信息字段值
#### `get_person_value(person_id: str, field_name: str, default: Any = None) -> Any` **Args:**
根据person_id和字段名获取某个值
**参数:**
- `person_id`用户的唯一标识ID - `person_id`用户的唯一标识ID
- `field_name`:要获取的字段名,如 "nickname", "impression" 等 - `field_name`:要获取的字段名
- `default`字段不存在或获取失败时返回的默认值 - `default`:字段不存在的默认值
**返回:** **Returns:**
- `Any`:字段值或默认值 - `Any`:字段值或默认值
**示例:** #### 示例
```python ```python
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression") impression = await person_api.get_person_value(person_id, "impression")
``` ```
#### `get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict` ### 3. 批量用户信息查询
```python
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
```
批量获取用户信息字段值 批量获取用户信息字段值
**参数:** **Args:**
- `person_id`用户的唯一标识ID - `person_id`用户的唯一标识ID
- `field_names`:要获取的字段名列表 - `field_names`:要获取的字段名列表
- `default_dict`:默认值字典,键为字段名,值为默认值 - `default_dict`:默认值字典,键为字段名,值为默认值
**返回:** **Returns:**
- `dict`:字段名到值的映射字典 - `dict`:字段名到值的映射字典
**示例:** #### 示例
```python ```python
values = await person_api.get_person_values( values = await person_api.get_person_values(
person_id, person_id,
@@ -67,204 +73,31 @@ values = await person_api.get_person_values(
) )
``` ```
### 3. 用户状态查询 ### 4. 判断用户是否已知
```python
#### `is_person_known(platform: str, user_id: int) -> bool` async def is_person_known(platform: str, user_id: int) -> bool:
```
判断是否认识某个用户 判断是否认识某个用户
**参数:** **Args:**
- `platform`:平台名称 - `platform`:平台名称
- `user_id`用户ID - `user_id`用户ID
**返回:** **Returns:**
- `bool`:是否认识该用户 - `bool`:是否认识该用户
**示例:** ### 5. 根据用户名获取Person ID
```python ```python
known = await person_api.is_person_known("qq", 123456) def get_person_id_by_name(person_name: str) -> str:
if known:
print("这个用户我认识")
``` ```
### 4. 用户名查询
#### `get_person_id_by_name(person_name: str) -> str`
根据用户名获取person_id 根据用户名获取person_id
**参数:** **Args:**
- `person_name`:用户名 - `person_name`:用户名
**返回:** **Returns:**
- `str`person_id如果未找到返回空字符串 - `str`person_id如果未找到返回空字符串
**示例:**
```python
person_id = person_api.get_person_id_by_name("张三")
if person_id:
print(f"找到用户: {person_id}")
```
## 使用示例
### 1. 基础用户信息获取
```python
from src.plugin_system.apis import person_api
async def get_user_info(platform: str, user_id: int):
"""获取用户基本信息"""
# 获取person_id
person_id = person_api.get_person_id(platform, user_id)
# 获取用户信息
user_info = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times", "last_seen"],
{
"nickname": "未知用户",
"impression": "",
"know_times": 0,
"last_seen": 0
}
)
return {
"person_id": person_id,
"nickname": user_info["nickname"],
"impression": user_info["impression"],
"know_times": user_info["know_times"],
"last_seen": user_info["last_seen"]
}
```
### 2. 在Action中使用用户信息
```python
from src.plugin_system.base import BaseAction
class PersonalizedAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 获取发送者信息
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
# 获取person_id
person_id = person_api.get_person_id(platform, user_id)
# 获取用户昵称和印象
nickname = await person_api.get_person_value(person_id, "nickname", "朋友")
impression = await person_api.get_person_value(person_id, "impression", "")
# 根据用户信息个性化回复
if impression:
response = f"你好 {nickname}!根据我对你的了解:{impression}"
else:
response = f"你好 {nickname}!很高兴见到你。"
return {
"success": True,
"response": response,
"user_info": {
"nickname": nickname,
"impression": impression
}
}
```
### 3. 用户识别和欢迎
```python
async def welcome_user(chat_stream):
"""欢迎用户,区分新老用户"""
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
# 检查是否认识这个用户
is_known = await person_api.is_person_known(platform, user_id)
if is_known:
# 老用户,获取详细信息
person_id = person_api.get_person_id(platform, user_id)
nickname = await person_api.get_person_value(person_id, "nickname", "老朋友")
know_times = await person_api.get_person_value(person_id, "know_times", 0)
welcome_msg = f"欢迎回来,{nickname}!我们已经聊过 {know_times} 次了。"
else:
# 新用户
welcome_msg = "你好很高兴认识你我是MaiBot。"
return welcome_msg
```
### 4. 用户搜索功能
```python
async def find_user_by_name(name: str):
"""根据名字查找用户"""
person_id = person_api.get_person_id_by_name(name)
if not person_id:
return {"found": False, "message": f"未找到名为 '{name}' 的用户"}
# 获取用户详细信息
user_info = await person_api.get_person_values(
person_id,
["nickname", "platform", "user_id", "impression", "know_times"],
{}
)
return {
"found": True,
"person_id": person_id,
"info": user_info
}
```
### 5. 用户印象分析
```python
async def analyze_user_relationship(chat_stream):
"""分析用户关系"""
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
person_id = person_api.get_person_id(platform, user_id)
# 获取关系相关信息
relationship_info = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times", "relationship_level", "last_interaction"],
{
"nickname": "未知",
"impression": "",
"know_times": 0,
"relationship_level": "stranger",
"last_interaction": 0
}
)
# 分析关系程度
know_times = relationship_info["know_times"]
if know_times == 0:
relationship = "陌生人"
elif know_times < 5:
relationship = "新朋友"
elif know_times < 20:
relationship = "熟人"
else:
relationship = "老朋友"
return {
"nickname": relationship_info["nickname"],
"relationship": relationship,
"impression": relationship_info["impression"],
"interaction_count": know_times
}
```
## 常用字段说明 ## 常用字段说明
### 基础信息字段 ### 基础信息字段
@@ -274,69 +107,13 @@ async def analyze_user_relationship(chat_stream):
### 关系信息字段 ### 关系信息字段
- `impression`:对用户的印象 - `impression`:对用户的印象
- `know_times`:交互次数 - `points`: 用户特征点
- `relationship_level`:关系等级
- `last_seen`:最后见面时间
- `last_interaction`:最后交互时间
### 个性化字段 其他字段可以参考`PersonInfo`类的属性(位于`src.common.database.database_model`
- `preferences`:用户偏好
- `interests`:兴趣爱好
- `mood_history`:情绪历史
- `topic_interests`:话题兴趣
## 最佳实践
### 1. 错误处理
```python
async def safe_get_user_info(person_id: str, field: str):
"""安全获取用户信息"""
try:
value = await person_api.get_person_value(person_id, field)
return value if value is not None else "未设置"
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return "获取失败"
```
### 2. 批量操作
```python
async def get_complete_user_profile(person_id: str):
"""获取完整用户档案"""
# 一次性获取所有需要的字段
fields = [
"nickname", "impression", "know_times",
"preferences", "interests", "relationship_level"
]
defaults = {
"nickname": "用户",
"impression": "",
"know_times": 0,
"preferences": "{}",
"interests": "[]",
"relationship_level": "stranger"
}
profile = await person_api.get_person_values(person_id, fields, defaults)
# 处理JSON字段
try:
profile["preferences"] = json.loads(profile["preferences"])
profile["interests"] = json.loads(profile["interests"])
except:
profile["preferences"] = {}
profile["interests"] = []
return profile
```
## 注意事项 ## 注意事项
1. **异步操作**部分查询函数都是异步的,需要使用`await` 1. **异步操作**:部分查询函数都是异步的,需要使用`await`
2. **错误处理**所有函数都有错误处理,失败时记录日志并返回默认值 2. **性能考虑**批量查询优于单个查询
3. **数据类型**返回的数据可能是字符串、数字或JSON需要适当处理 3. **隐私保护**确保用户信息的使用符合隐私政策
4. **性能考虑**批量查询优于单个查询 4. **数据一致**person_id是用户的唯一标识应妥善保存和使用
5. **隐私保护**:确保用户信息的使用符合隐私政策
6. **数据一致性**person_id是用户的唯一标识应妥善保存和使用

View File

@@ -0,0 +1,105 @@
# 插件管理API
插件管理API模块提供了对插件的加载、卸载、重新加载以及目录管理功能。
## 导入方式
```python
from src.plugin_system.apis import plugin_manage_api
# 或者
from src.plugin_system import plugin_manage_api
```
## 功能概述
插件管理API主要提供以下功能
- **插件查询** - 列出当前加载的插件或已注册的插件。
- **插件管理** - 加载、卸载、重新加载插件。
- **插件目录管理** - 添加插件目录并重新扫描。
## 主要功能
### 1. 列出当前加载的插件
```python
def list_loaded_plugins() -> List[str]:
```
列出所有当前加载的插件。
**Returns:**
- `List[str]` - 当前加载的插件名称列表。
### 2. 列出所有已注册的插件
```python
def list_registered_plugins() -> List[str]:
```
列出所有已注册的插件。
**Returns:**
- `List[str]` - 已注册的插件名称列表。
### 3. 获取插件路径
```python
def get_plugin_path(plugin_name: str) -> str:
```
获取指定插件的路径。
**Args:**
- `plugin_name` (str): 要查询的插件名称。
**Returns:**
- `str` - 插件的路径,如果插件不存在则 raise ValueError。
### 4. 卸载指定的插件
```python
async def remove_plugin(plugin_name: str) -> bool:
```
卸载指定的插件。
**Args:**
- `plugin_name` (str): 要卸载的插件名称。
**Returns:**
- `bool` - 卸载是否成功。
### 5. 重新加载指定的插件
```python
async def reload_plugin(plugin_name: str) -> bool:
```
重新加载指定的插件。
**Args:**
- `plugin_name` (str): 要重新加载的插件名称。
**Returns:**
- `bool` - 重新加载是否成功。
### 6. 加载指定的插件
```python
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
```
加载指定的插件。
**Args:**
- `plugin_name` (str): 要加载的插件名称。
**Returns:**
- `Tuple[bool, int]` - 加载是否成功,成功或失败的个数。
### 7. 添加插件目录
```python
def add_plugin_directory(plugin_directory: str) -> bool:
```
添加插件目录。
**Args:**
- `plugin_directory` (str): 要添加的插件目录路径。
**Returns:**
- `bool` - 添加是否成功。
### 8. 重新扫描插件目录
```python
def rescan_plugin_directory() -> Tuple[int, int]:
```
重新扫描插件目录,加载新插件。
**Returns:**
- `Tuple[int, int]` - 成功加载的插件数量和失败的插件数量。

View File

@@ -6,86 +6,108 @@
```python ```python
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
# 或者
from src.plugin_system import send_api
``` ```
## 主要功能 ## 主要功能
### 1. 文本消息发送 ### 1. 发送文本消息
```python
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
```
发送文本消息到指定的流
#### `text_to_group(text, group_id, platform="qq", typing=False, reply_to="", storage_message=True)` **Args:**
向群聊发送文本消息 - `text` (str): 要发送文本内容
- `stream_id` (str): 聊天流ID
- `typing` (bool): 是否显示正在输入
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
- `storage_message` (bool): 是否存储消息到数据库
**参数:** **Returns:**
- `text`:要发送的文本内容 - `bool` - 是否发送成功
- `group_id`群聊ID
- `platform`:平台,默认为"qq"
- `typing`:是否显示正在输入
- `reply_to`:回复消息的格式,如"发送者:消息内容"
- `storage_message`:是否存储到数据库
**返回:** ### 2. 发送表情包
- `bool`:是否发送成功 ```python
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
```
向指定流发送表情包。
#### `text_to_user(text, user_id, platform="qq", typing=False, reply_to="", storage_message=True)` **Args:**
向用户发送私聊文本消息 - `emoji_base64` (str): 表情包的base64编码
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
**参数与返回值同上** **Returns:**
- `bool` - 是否发送成功
### 2. 表情包发送 ### 3. 发送图片
```python
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
```
向指定流发送图片。
#### `emoji_to_group(emoji_base64, group_id, platform="qq", storage_message=True)` **Args:**
向群聊发送表情包 - `image_base64` (str): 图片的base64编码
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
**参数:** **Returns:**
- `emoji_base64`表情包的base64编码 - `bool` - 是否发送成功
- `group_id`群聊ID
- `platform`:平台,默认为"qq"
- `storage_message`:是否存储到数据库
#### `emoji_to_user(emoji_base64, user_id, platform="qq", storage_message=True)` ### 4. 发送命令
向用户发送表情包 ```python
async def command_to_stream(command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "") -> bool:
```
向指定流发送命令。
### 3. 图片发送 **Args:**
- `command` (Union[str, dict]): 命令内容
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
- `display_message` (str): 显示消息
#### `image_to_group(image_base64, group_id, platform="qq", storage_message=True)` **Returns:**
向群聊发送图片 - `bool` - 是否发送成功
#### `image_to_user(image_base64, user_id, platform="qq", storage_message=True)` ### 5. 发送自定义类型消息
向用户发送图片 ```python
async def custom_to_stream(
message_type: str,
content: str,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
show_log: bool = True,
) -> bool:
```
向指定流发送自定义类型消息。
### 4. 命令发送 **Args:**
- `message_type` (str): 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
- `content` (str): 消息内容通常是base64编码或文本
- `stream_id` (str): 聊天流ID
- `display_message` (str): 显示消息
- `typing` (bool): 是否显示正在输入
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
- `storage_message` (bool): 是否存储消息到数据库
- `show_log` (bool): 是否显示日志
#### `command_to_group(command, group_id, platform="qq", storage_message=True)` **Returns:**
向群聊发送命令 - `bool` - 是否发送成功
#### `command_to_user(command, user_id, platform="qq", storage_message=True)`
向用户发送命令
### 5. 自定义消息发送
#### `custom_to_group(message_type, content, group_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
向群聊发送自定义类型消息
#### `custom_to_user(message_type, content, user_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
向用户发送自定义类型消息
#### `custom_message(message_type, content, target_id, is_group=True, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
通用的自定义消息发送
**参数:**
- `message_type`:消息类型,如"text"、"image"、"emoji"等
- `content`:消息内容
- `target_id`目标ID群ID或用户ID
- `is_group`:是否为群聊
- `platform`:平台
- `display_message`:显示消息
- `typing`:是否显示正在输入
- `reply_to`:回复消息
- `storage_message`:是否存储
## 使用示例 ## 使用示例
### 1. 基础文本发送 ### 1. 基础文本发送,并回复消息
```python ```python
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@@ -93,57 +115,23 @@ from src.plugin_system.apis import send_api
async def send_hello(chat_stream): async def send_hello(chat_stream):
"""发送问候消息""" """发送问候消息"""
if chat_stream.group_info: success = await send_api.text_to_stream(
# 群聊 text="Hello, world!",
success = await send_api.text_to_group( stream_id=chat_stream.stream_id,
text="大家好!", typing=True,
group_id=chat_stream.group_info.group_id, reply_to="User:How are you?",
typing=True storage_message=True
) )
else:
# 私聊
success = await send_api.text_to_user(
text="你好!",
user_id=chat_stream.user_info.user_id,
typing=True
)
return success return success
``` ```
### 2. 回复特定消息 ### 2. 发送表情包
```python
async def reply_to_message(chat_stream, reply_text, original_sender, original_message):
"""回复特定消息"""
# 构建回复格式
reply_to = f"{original_sender}:{original_message}"
if chat_stream.group_info:
success = await send_api.text_to_group(
text=reply_text,
group_id=chat_stream.group_info.group_id,
reply_to=reply_to
)
else:
success = await send_api.text_to_user(
text=reply_text,
user_id=chat_stream.user_info.user_id,
reply_to=reply_to
)
return success
```
### 3. 发送表情包
```python ```python
from src.plugin_system.apis import emoji_api
async def send_emoji_reaction(chat_stream, emotion): async def send_emoji_reaction(chat_stream, emotion):
"""根据情感发送表情包""" """根据情感发送表情包"""
from src.plugin_system.apis import emoji_api
# 获取表情包 # 获取表情包
emoji_result = await emoji_api.get_by_emotion(emotion) emoji_result = await emoji_api.get_by_emotion(emotion)
if not emoji_result: if not emoji_result:
@@ -152,107 +140,10 @@ async def send_emoji_reaction(chat_stream, emotion):
emoji_base64, description, matched_emotion = emoji_result emoji_base64, description, matched_emotion = emoji_result
# 发送表情包 # 发送表情包
if chat_stream.group_info: success = await send_api.emoji_to_stream(
success = await send_api.emoji_to_group( emoji_base64=emoji_base64,
emoji_base64=emoji_base64, stream_id=chat_stream.stream_id,
group_id=chat_stream.group_info.group_id storage_message=False # 不存储到数据库
)
else:
success = await send_api.emoji_to_user(
emoji_base64=emoji_base64,
user_id=chat_stream.user_info.user_id
)
return success
```
### 4. 在Action中发送消息
```python
from src.plugin_system.base import BaseAction
class MessageAction(BaseAction):
async def execute(self, action_data, chat_stream):
message_type = action_data.get("type", "text")
content = action_data.get("content", "")
if message_type == "text":
success = await self.send_text(chat_stream, content)
elif message_type == "emoji":
success = await self.send_emoji(chat_stream, content)
elif message_type == "image":
success = await self.send_image(chat_stream, content)
else:
success = False
return {"success": success}
async def send_text(self, chat_stream, text):
if chat_stream.group_info:
return await send_api.text_to_group(text, chat_stream.group_info.group_id)
else:
return await send_api.text_to_user(text, chat_stream.user_info.user_id)
async def send_emoji(self, chat_stream, emoji_base64):
if chat_stream.group_info:
return await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
else:
return await send_api.emoji_to_user(emoji_base64, chat_stream.user_info.user_id)
async def send_image(self, chat_stream, image_base64):
if chat_stream.group_info:
return await send_api.image_to_group(image_base64, chat_stream.group_info.group_id)
else:
return await send_api.image_to_user(image_base64, chat_stream.user_info.user_id)
```
### 5. 批量发送消息
```python
async def broadcast_message(message: str, target_groups: list):
"""向多个群组广播消息"""
results = {}
for group_id in target_groups:
try:
success = await send_api.text_to_group(
text=message,
group_id=group_id,
typing=True
)
results[group_id] = success
except Exception as e:
results[group_id] = False
print(f"发送到群 {group_id} 失败: {e}")
return results
```
### 6. 智能消息发送
```python
async def smart_send(chat_stream, message_data):
"""智能发送不同类型的消息"""
message_type = message_data.get("type", "text")
content = message_data.get("content", "")
options = message_data.get("options", {})
# 根据聊天流类型选择发送方法
target_id = (chat_stream.group_info.group_id if chat_stream.group_info
else chat_stream.user_info.user_id)
is_group = chat_stream.group_info is not None
# 使用通用发送方法
success = await send_api.custom_message(
message_type=message_type,
content=content,
target_id=target_id,
is_group=is_group,
typing=options.get("typing", False),
reply_to=options.get("reply_to", ""),
display_message=options.get("display_message", "")
) )
return success return success
@@ -273,90 +164,6 @@ async def smart_send(chat_stream, message_data):
系统会自动查找匹配的原始消息并进行回复。 系统会自动查找匹配的原始消息并进行回复。
## 高级用法
### 1. 消息发送队列
```python
import asyncio
class MessageQueue:
def __init__(self):
self.queue = asyncio.Queue()
self.running = False
async def add_message(self, chat_stream, message_type, content, options=None):
"""添加消息到队列"""
message_item = {
"chat_stream": chat_stream,
"type": message_type,
"content": content,
"options": options or {}
}
await self.queue.put(message_item)
async def process_queue(self):
"""处理消息队列"""
self.running = True
while self.running:
try:
message_item = await asyncio.wait_for(self.queue.get(), timeout=1.0)
# 发送消息
success = await smart_send(
message_item["chat_stream"],
{
"type": message_item["type"],
"content": message_item["content"],
"options": message_item["options"]
}
)
# 标记任务完成
self.queue.task_done()
# 发送间隔
await asyncio.sleep(0.5)
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"处理消息队列出错: {e}")
```
### 2. 消息模板系统
```python
class MessageTemplate:
def __init__(self):
self.templates = {
"welcome": "欢迎 {nickname} 加入群聊!",
"goodbye": "{nickname} 离开了群聊。",
"notification": "🔔 通知:{message}",
"error": "❌ 错误:{error_message}",
"success": "✅ 成功:{message}"
}
def format_message(self, template_name: str, **kwargs) -> str:
"""格式化消息模板"""
template = self.templates.get(template_name, "{message}")
return template.format(**kwargs)
async def send_template(self, chat_stream, template_name: str, **kwargs):
"""发送模板消息"""
message = self.format_message(template_name, **kwargs)
if chat_stream.group_info:
return await send_api.text_to_group(message, chat_stream.group_info.group_id)
else:
return await send_api.text_to_user(message, chat_stream.user_info.user_id)
# 使用示例
template_system = MessageTemplate()
await template_system.send_template(chat_stream, "welcome", nickname="张三")
```
## 注意事项 ## 注意事项
1. **异步操作**:所有发送函数都是异步的,必须使用`await` 1. **异步操作**:所有发送函数都是异步的,必须使用`await`

View File

@@ -0,0 +1,55 @@
# 工具API
工具API模块提供了获取和管理工具实例的功能让插件能够访问系统中注册的工具。
## 导入方式
```python
from src.plugin_system.apis import tool_api
# 或者
from src.plugin_system import tool_api
```
## 主要功能
### 1. 获取工具实例
```python
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
```
获取指定名称的工具实例。
**Args**:
- `tool_name`: 工具名称字符串
**Returns**:
- `Optional[BaseTool]`: 工具实例,如果工具不存在则返回 None
### 2. 获取LLM可用的工具定义
```python
def get_llm_available_tool_definitions():
```
获取所有LLM可用的工具定义列表。
**Returns**:
- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
- 其具体定义请参照[tool-components.md](../tool-components.md#属性说明)中的工具定义格式。
#### 示例:
```python
# 获取所有LLM可用的工具定义
tools = tool_api.get_llm_available_tool_definitions()
for tool_name, tool_definition in tools:
print(f"工具: {tool_name}")
print(f"定义: {tool_definition}")
```
## 注意事项
1. **工具存在性检查**:使用前请检查工具实例是否为 None
2. **权限控制**:某些工具可能有使用权限限制
3. **异步调用**:大多数工具方法是异步的,需要使用 await
4. **错误处理**:调用工具时请做好异常处理

View File

@@ -1,435 +0,0 @@
# 工具API
工具API模块提供了各种辅助功能包括文件操作、时间处理、唯一ID生成等常用工具函数。
## 导入方式
```python
from src.plugin_system.apis import utils_api
```
## 主要功能
### 1. 文件操作
#### `get_plugin_path(caller_frame=None) -> str`
获取调用者插件的路径
**参数:**
- `caller_frame`调用者的栈帧默认为None自动获取
**返回:**
- `str`:插件目录的绝对路径
**示例:**
```python
plugin_path = utils_api.get_plugin_path()
print(f"插件路径: {plugin_path}")
```
#### `read_json_file(file_path: str, default: Any = None) -> Any`
读取JSON文件
**参数:**
- `file_path`:文件路径,可以是相对于插件目录的路径
- `default`:如果文件不存在或读取失败时返回的默认值
**返回:**
- `Any`JSON数据或默认值
**示例:**
```python
# 读取插件配置文件
config = utils_api.read_json_file("config.json", {})
settings = utils_api.read_json_file("data/settings.json", {"enabled": True})
```
#### `write_json_file(file_path: str, data: Any, indent: int = 2) -> bool`
写入JSON文件
**参数:**
- `file_path`:文件路径,可以是相对于插件目录的路径
- `data`:要写入的数据
- `indent`JSON缩进
**返回:**
- `bool`:是否写入成功
**示例:**
```python
data = {"name": "test", "value": 123}
success = utils_api.write_json_file("output.json", data)
```
### 2. 时间相关
#### `get_timestamp() -> int`
获取当前时间戳
**返回:**
- `int`:当前时间戳(秒)
#### `format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str`
格式化时间
**参数:**
- `timestamp`时间戳如果为None则使用当前时间
- `format_str`:时间格式字符串
**返回:**
- `str`:格式化后的时间字符串
#### `parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int`
解析时间字符串为时间戳
**参数:**
- `time_str`:时间字符串
- `format_str`:时间格式字符串
**返回:**
- `int`:时间戳(秒)
### 3. 其他工具
#### `generate_unique_id() -> str`
生成唯一ID
**返回:**
- `str`唯一ID
## 使用示例
### 1. 插件数据管理
```python
from src.plugin_system.apis import utils_api
class DataPlugin(BasePlugin):
def __init__(self):
self.plugin_path = utils_api.get_plugin_path()
self.data_file = "plugin_data.json"
self.load_data()
def load_data(self):
"""加载插件数据"""
default_data = {
"users": {},
"settings": {"enabled": True},
"stats": {"message_count": 0}
}
self.data = utils_api.read_json_file(self.data_file, default_data)
def save_data(self):
"""保存插件数据"""
return utils_api.write_json_file(self.data_file, self.data)
async def handle_action(self, action_data, chat_stream):
# 更新统计信息
self.data["stats"]["message_count"] += 1
self.data["stats"]["last_update"] = utils_api.get_timestamp()
# 保存数据
if self.save_data():
return {"success": True, "message": "数据已保存"}
else:
return {"success": False, "message": "数据保存失败"}
```
### 2. 日志记录系统
```python
class PluginLogger:
def __init__(self, plugin_name: str):
self.plugin_name = plugin_name
self.log_file = f"{plugin_name}_log.json"
self.logs = utils_api.read_json_file(self.log_file, [])
def log_event(self, event_type: str, message: str, data: dict = None):
"""记录事件"""
log_entry = {
"id": utils_api.generate_unique_id(),
"timestamp": utils_api.get_timestamp(),
"formatted_time": utils_api.format_time(),
"event_type": event_type,
"message": message,
"data": data or {}
}
self.logs.append(log_entry)
# 保持最新的100条记录
if len(self.logs) > 100:
self.logs = self.logs[-100:]
# 保存到文件
utils_api.write_json_file(self.log_file, self.logs)
def get_logs_by_type(self, event_type: str) -> list:
"""获取指定类型的日志"""
return [log for log in self.logs if log["event_type"] == event_type]
def get_recent_logs(self, count: int = 10) -> list:
"""获取最近的日志"""
return self.logs[-count:]
# 使用示例
logger = PluginLogger("my_plugin")
logger.log_event("user_action", "用户发送了消息", {"user_id": "123", "message": "hello"})
```
### 3. 配置管理系统
```python
class ConfigManager:
def __init__(self, config_file: str = "plugin_config.json"):
self.config_file = config_file
self.default_config = {
"enabled": True,
"debug": False,
"max_users": 100,
"response_delay": 1.0,
"features": {
"auto_reply": True,
"logging": True
}
}
self.config = self.load_config()
def load_config(self) -> dict:
"""加载配置"""
return utils_api.read_json_file(self.config_file, self.default_config)
def save_config(self) -> bool:
"""保存配置"""
return utils_api.write_json_file(self.config_file, self.config, indent=4)
def get(self, key: str, default=None):
"""获取配置值,支持嵌套访问"""
keys = key.split('.')
value = self.config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def set(self, key: str, value):
"""设置配置值,支持嵌套设置"""
keys = key.split('.')
config = self.config
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
config[keys[-1]] = value
def update_config(self, updates: dict):
"""批量更新配置"""
def deep_update(base, updates):
for key, value in updates.items():
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
deep_update(base[key], value)
else:
base[key] = value
deep_update(self.config, updates)
# 使用示例
config = ConfigManager()
print(f"调试模式: {config.get('debug', False)}")
print(f"自动回复: {config.get('features.auto_reply', True)}")
config.set('features.new_feature', True)
config.save_config()
```
### 4. 缓存系统
```python
class PluginCache:
def __init__(self, cache_file: str = "plugin_cache.json", ttl: int = 3600):
self.cache_file = cache_file
self.ttl = ttl # 缓存过期时间(秒)
self.cache = self.load_cache()
def load_cache(self) -> dict:
"""加载缓存"""
return utils_api.read_json_file(self.cache_file, {})
def save_cache(self):
"""保存缓存"""
return utils_api.write_json_file(self.cache_file, self.cache)
def get(self, key: str):
"""获取缓存值"""
if key not in self.cache:
return None
item = self.cache[key]
current_time = utils_api.get_timestamp()
# 检查是否过期
if current_time - item["timestamp"] > self.ttl:
del self.cache[key]
return None
return item["value"]
def set(self, key: str, value):
"""设置缓存值"""
self.cache[key] = {
"value": value,
"timestamp": utils_api.get_timestamp()
}
self.save_cache()
def clear_expired(self):
"""清理过期缓存"""
current_time = utils_api.get_timestamp()
expired_keys = []
for key, item in self.cache.items():
if current_time - item["timestamp"] > self.ttl:
expired_keys.append(key)
for key in expired_keys:
del self.cache[key]
if expired_keys:
self.save_cache()
return len(expired_keys)
# 使用示例
cache = PluginCache(ttl=1800) # 30分钟过期
cache.set("user_data_123", {"name": "张三", "score": 100})
user_data = cache.get("user_data_123")
```
### 5. 时间处理工具
```python
class TimeHelper:
@staticmethod
def get_time_info():
"""获取当前时间的详细信息"""
timestamp = utils_api.get_timestamp()
return {
"timestamp": timestamp,
"datetime": utils_api.format_time(timestamp),
"date": utils_api.format_time(timestamp, "%Y-%m-%d"),
"time": utils_api.format_time(timestamp, "%H:%M:%S"),
"year": utils_api.format_time(timestamp, "%Y"),
"month": utils_api.format_time(timestamp, "%m"),
"day": utils_api.format_time(timestamp, "%d"),
"weekday": utils_api.format_time(timestamp, "%A")
}
@staticmethod
def time_ago(timestamp: int) -> str:
"""计算时间差"""
current = utils_api.get_timestamp()
diff = current - timestamp
if diff < 60:
return f"{diff}秒前"
elif diff < 3600:
return f"{diff // 60}分钟前"
elif diff < 86400:
return f"{diff // 3600}小时前"
else:
return f"{diff // 86400}天前"
@staticmethod
def parse_duration(duration_str: str) -> int:
"""解析时间段字符串,返回秒数"""
import re
pattern = r'(\d+)([smhd])'
matches = re.findall(pattern, duration_str.lower())
total_seconds = 0
for value, unit in matches:
value = int(value)
if unit == 's':
total_seconds += value
elif unit == 'm':
total_seconds += value * 60
elif unit == 'h':
total_seconds += value * 3600
elif unit == 'd':
total_seconds += value * 86400
return total_seconds
# 使用示例
time_info = TimeHelper.get_time_info()
print(f"当前时间: {time_info['datetime']}")
last_seen = 1699000000
print(f"最后见面: {TimeHelper.time_ago(last_seen)}")
duration = TimeHelper.parse_duration("1h30m") # 1小时30分钟 = 5400秒
```
## 最佳实践
### 1. 错误处理
```python
def safe_file_operation(file_path: str, data: dict):
"""安全的文件操作"""
try:
success = utils_api.write_json_file(file_path, data)
if not success:
logger.warning(f"文件写入失败: {file_path}")
return success
except Exception as e:
logger.error(f"文件操作出错: {e}")
return False
```
### 2. 路径处理
```python
import os
def get_data_path(filename: str) -> str:
"""获取数据文件的完整路径"""
plugin_path = utils_api.get_plugin_path()
data_dir = os.path.join(plugin_path, "data")
# 确保数据目录存在
os.makedirs(data_dir, exist_ok=True)
return os.path.join(data_dir, filename)
```
### 3. 定期清理
```python
async def cleanup_old_files():
"""清理旧文件"""
plugin_path = utils_api.get_plugin_path()
current_time = utils_api.get_timestamp()
for filename in os.listdir(plugin_path):
if filename.endswith('.tmp'):
file_path = os.path.join(plugin_path, filename)
file_time = os.path.getmtime(file_path)
# 删除超过24小时的临时文件
if current_time - file_time > 86400:
os.remove(file_path)
```
## 注意事项
1. **相对路径**:文件路径支持相对于插件目录的路径
2. **自动创建目录**:写入文件时会自动创建必要的目录
3. **错误处理**:所有函数都有错误处理,失败时返回默认值
4. **编码格式**文件读写使用UTF-8编码
5. **时间格式**:时间戳使用秒为单位
6. **JSON格式**JSON文件使用可读性好的缩进格式

View File

@@ -10,6 +10,7 @@
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件 - [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件 - [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件 - [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构 - [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
@@ -43,24 +44,24 @@ Command vs Action 选择指南
- [LLM API](api/llm-api.md) - 大语言模型交互接口可以使用内置LLM生成内容 - [LLM API](api/llm-api.md) - 大语言模型交互接口可以使用内置LLM生成内容
- [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器 - [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器
### 表情包api ### 表情包API
- [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口 - [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口
### 关系系统api ### 关系系统API
- [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口 - [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口
### 数据与配置API ### 数据与配置API
- [🗄️ 数据库API](api/database-api.md) - 数据库操作接口 - [🗄️ 数据库API](api/database-api.md) - 数据库操作接口
- [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口 - [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口
### 插件和组件管理API
- [🔌 插件API](api/plugin-manage-api.md) - 插件加载和管理接口
- [🧩 组件API](api/component-manage-api.md) - 组件注册和管理接口
### 日志API
- [📜 日志API](api/logging-api.md) - logger实例获取接口
### 工具API ### 工具API
- [工具API](api/utils-api.md) - 文件操作、时间处理等工具函数 - [🔧 工具API](api/tool-api.md) - tool获取接口
## 实验性
这些功能将在未来重构或移除
- [🔧 工具系统详解](tool-system.md) - 工具系统的使用和开发

View File

@@ -1,10 +1,10 @@
# 🔧 工具系统详解 # 🔧 工具组件详解
## 📖 什么是工具系统 ## 📖 什么是工具
工具系统是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。 工具是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
### 🎯 工具系统的特点 ### 🎯 工具的特点
- 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力 - 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力
- 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据 - 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据
@@ -20,14 +20,11 @@
| **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 | | **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 |
| **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 | | **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 |
## 🏗️ 工具基本结构 ## 🏗️ Tool组件的基本结构
### 必要组件
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法: 每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
```python ```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool from src.plugin_system import BaseTool, ToolParamType
class MyTool(BaseTool): class MyTool(BaseTool):
# 工具名称,必须唯一 # 工具名称,必须唯一
@@ -36,21 +33,29 @@ class MyTool(BaseTool):
# 工具描述告诉LLM这个工具的用途 # 工具描述告诉LLM这个工具的用途
description = "这个工具用于获取特定类型的信息" description = "这个工具用于获取特定类型的信息"
# 参数定义,遵循JSONSchema格式 # 参数定义,仅定义参数
parameters = { # 比如想要定义一个类似下面的openai格式的参数表则可以这么定义:
"type": "object", # {
"properties": { # "type": "object",
"query": { # "properties": {
"type": "string", # "query": {
"description": "查询参数" # "type": "string",
}, # "description": "查询参数"
"limit": { # },
"type": "integer", # "limit": {
"description": "结果数量限制" # "type": "integer",
} # "description": "结果数量限制"
}, # "enum": [10, 20, 50] # 可选值
"required": ["query"] # }
} # },
# "required": ["query"]
# }
parameters = [
("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数
("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数
]
available_for_llm = True # 是否对LLM可用
async def execute(self, function_args: Dict[str, Any]): async def execute(self, function_args: Dict[str, Any]):
"""执行工具逻辑""" """执行工具逻辑"""
@@ -69,7 +74,12 @@ class MyTool(BaseTool):
|-----|------|------| |-----|------|------|
| `name` | str | 工具的唯一标识名称 | | `name` | str | 工具的唯一标识名称 |
| `description` | str | 工具功能描述帮助LLM理解用途 | | `description` | str | 工具功能描述帮助LLM理解用途 |
| `parameters` | dict | JSONSchema格式的参数定义 | | `parameters` | list[tuple] | 参数定义 |
其构造而成的工具定义为:
```python
definition: Dict[str, Any] = {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
```
### 方法说明 ### 方法说明
@@ -77,15 +87,6 @@ class MyTool(BaseTool):
|-----|------|--------|------| |-----|------|--------|------|
| `execute` | `function_args` | `dict` | 执行工具核心逻辑 | | `execute` | `function_args` | `dict` | 执行工具核心逻辑 |
## 🔄 自动注册机制
工具系统采用自动发现和注册机制:
1. **文件扫描**:系统自动遍历 `tool_can_use` 目录中的所有Python文件
2. **类识别**:寻找继承自 `BaseTool` 的工具类
3. **自动注册**:只需要实现对应的类并把文件放在正确文件夹中就可自动注册
4. **即用即加载**:工具在需要时被实例化和调用
--- ---
## 🎨 完整工具示例 ## 🎨 完整工具示例
@@ -93,7 +94,7 @@ class MyTool(BaseTool):
完成一个天气查询工具 完成一个天气查询工具
```python ```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool from src.plugin_system import BaseTool
import aiohttp import aiohttp
import json import json
@@ -102,23 +103,13 @@ class WeatherTool(BaseTool):
name = "weather_query" name = "weather_query"
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等" description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
available_for_llm = True # 允许LLM调用此工具
parameters = [
("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None),
("country", ToolParamType.STRING, "国家代码CN、US可选参数", False, None)
]
parameters = { async def execute(self, function_args: dict):
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "要查询天气的城市名称,如:北京、上海、纽约"
},
"country": {
"type": "string",
"description": "国家代码CN、US可选参数"
}
},
"required": ["city"]
}
async def execute(self, function_args, message_txt=""):
"""执行天气查询""" """执行天气查询"""
try: try:
city = function_args.get("city") city = function_args.get("city")
@@ -177,55 +168,12 @@ class WeatherTool(BaseTool):
--- ---
## 📊 工具开发步骤
### 1. 创建工具文件
`src/tools/tool_can_use/` 目录下创建新的Python文件
```bash
# 例如创建 my_new_tool.py
touch src/tools/tool_can_use/my_new_tool.py
```
### 2. 实现工具类
```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
class MyNewTool(BaseTool):
name = "my_new_tool"
description = "新工具的功能描述"
parameters = {
"type": "object",
"properties": {
# 定义参数
},
"required": []
}
async def execute(self, function_args, message_txt=""):
# 实现工具逻辑
return {
"name": self.name,
"content": "执行结果"
}
```
### 3. 系统集成
工具创建完成后,系统会自动发现和注册,无需额外配置。
---
## 🚨 注意事项和限制 ## 🚨 注意事项和限制
### 当前限制 ### 当前限制
1. **独立开发**需要单独编写,暂未完全融入插件系统 1. **适用范围**主要适用于信息获取场景
2. **适用范围**主要适用于信息获取场景 2. **配置要求**必须开启工具处理器
3. **配置要求**:必须开启工具处理器
### 开发建议 ### 开发建议
@@ -238,66 +186,49 @@ class MyNewTool(BaseTool):
## 🎯 最佳实践 ## 🎯 最佳实践
### 1. 工具命名规范 ### 1. 工具命名规范
#### ✅ 好的命名
```python ```python
# ✅ 好的命名
name = "weather_query" # 清晰表达功能 name = "weather_query" # 清晰表达功能
name = "knowledge_search" # 描述性强 name = "knowledge_search" # 描述性强
name = "stock_price_check" # 功能明确 name = "stock_price_check" # 功能明确
```
# ❌ 避免的命名 #### ❌ 避免的命名
```python
name = "tool1" # 无意义 name = "tool1" # 无意义
name = "wq" # 过于简短 name = "wq" # 过于简短
name = "weather_and_news" # 功能过于复杂 name = "weather_and_news" # 功能过于复杂
``` ```
### 2. 描述规范 ### 2. 描述规范
#### ✅ 良好的描述
```python ```python
# ✅ 好的描述
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况" description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况"
```
# ❌ 避免的描述 #### ❌ 避免的描述
```python
description = "天气" # 过于简单 description = "天气" # 过于简单
description = "获取信息" # 不够具体 description = "获取信息" # 不够具体
``` ```
### 3. 参数设计 ### 3. 参数设计
#### ✅ 合理的参数设计
```python ```python
# ✅ 合理的参数设计 parameters = [
parameters = { ("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None),
"type": "object", ("unit", ToolParamType.STRING, "温度单位celsius 或 fahrenheit", False, ["celsius", "fahrenheit"])
"properties": { ]
"city": { ```
"type": "string", #### ❌ 避免的参数设计
"description": "城市名称,如:北京、上海" ```python
}, parameters = [
"unit": { ("data", "string", "数据", True) # 参数过于模糊
"type": "string", ]
"description": "温度单位celsius(摄氏度) 或 fahrenheit(华氏度)",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city"]
}
# ❌ 避免的参数设计
parameters = {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "数据" # 描述不清晰
}
}
}
``` ```
### 4. 结果格式化 ### 4. 结果格式化
#### ✅ 良好的结果格式
```python ```python
# ✅ 良好的结果格式
def _format_result(self, data): def _format_result(self, data):
return f""" return f"""
🔍 查询结果 🔍 查询结果
@@ -307,12 +238,9 @@ def _format_result(self, data):
📝 说明: {data['description']} 📝 说明: {data['description']}
━━━━━━━━━━━━ ━━━━━━━━━━━━
""".strip() """.strip()
```
# ❌ 避免的结果格式 #### ❌ 避免的结果格式
```python
def _format_result(self, data): def _format_result(self, data):
return str(data) # 直接返回原始数据 return str(data) # 直接返回原始数据
``` ```
---
🎉 **工具系统为麦麦提供了强大的信息获取能力!合理使用工具可以让麦麦变得更加智能和博学。**

View File

@@ -1,18 +1,55 @@
from typing import List, Tuple, Type from typing import List, Tuple, Type, Any
from src.plugin_system import ( from src.plugin_system import (
BasePlugin, BasePlugin,
register_plugin, register_plugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ComponentInfo, ComponentInfo,
ActionActivationType, ActionActivationType,
ConfigField, ConfigField,
BaseEventHandler, BaseEventHandler,
EventType, EventType,
MaiMessages, MaiMessages,
ToolParamType
) )
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"name": self.name, "content": result}
except Exception as e:
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
# ===== Action组件 ===== # ===== Action组件 =====
class HelloAction(BaseAction): class HelloAction(BaseAction):
"""问候Action - 简单的问候动作""" """问候Action - 简单的问候动作"""
@@ -132,7 +169,9 @@ class HelloWorldPlugin(BasePlugin):
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
}, },
"greeting": { "greeting": {
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), "message": ConfigField(
type=list, default=["嗨!很开心见到你!😊", "Ciallo(∠・ω< )⌒★"], description="默认问候消息"
),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
}, },
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
@@ -142,6 +181,7 @@ class HelloWorldPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [ return [
(HelloAction.get_action_info(), HelloAction), (HelloAction.get_action_info(), HelloAction),
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action (ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand), (TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage), (PrintMessage.get_handler_info(), PrintMessage),

View File

@@ -15,6 +15,7 @@ matplotlib
networkx networkx
numpy numpy
openai openai
google-genai
pandas pandas
peewee peewee
pyarrow pyarrow

View File

@@ -1,2 +0,0 @@
@echo off
start "Voice Adapter" cmd /k "call conda activate maicore && cd /d C:\GitHub\maimbot_tts_adapter && echo Running Napcat Adapter... && python maimbot_pipeline.py"

View File

@@ -14,8 +14,6 @@ from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256 from src.chat.knowledge.utils.hash import get_sha256
from src.manager.local_store_manager import local_storage
from dotenv import load_dotenv
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
@@ -24,46 +22,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入") logger = get_logger("OpenIE导入")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_openie_dir(): def ensure_openie_dir():
"""确保OpenIE数据目录存在""" """确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR): if not os.path.exists(OPENIE_DIR):
@@ -101,7 +59,9 @@ def hash_deduplicate(
): ):
# 段落hash # 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: # 使用与EmbeddingStore中一致的命名空间格式namespace-hash
paragraph_key = f"paragraph-{paragraph_hash}"
if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
@@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
def main(): # sourcery skip: dict-comprehension def main(): # sourcery skip: dict-comprehension
# 新增确认提示 # 新增确认提示
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型") print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
print("同之前样例在本地模型下在70分钟内我们发送了约8万条请求在网络允许下速度会更快") print("同之前样例在本地模型下在70分钟内我们发送了约8万条请求在网络允许下速度会更快")
@@ -264,7 +222,8 @@ def main(): # sourcery skip: dict-comprehension
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{local_storage['pg_namespace']}-{pg_hash}" # 使用与EmbeddingStore中一致的命名空间格式namespace-hash
key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")

View File

@@ -25,9 +25,8 @@ from rich.progress import (
TextColumn, TextColumn,
) )
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from dotenv import load_dotenv
logger = get_logger("LPMM知识库-信息提取") logger = get_logger("LPMM知识库-信息提取")
@@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp") TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_dirs(): def ensure_dirs():
"""确保临时目录和输出目录存在""" """确保临时目录和输出目录存在"""
@@ -96,11 +56,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event() shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest( lpmm_entity_extract_llm = LLMRequest(
model=global_config.model.lpmm_entity_extract, model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract" request_type="lpmm.entity_extract"
) )
lpmm_rdf_build_llm = LLMRequest( lpmm_rdf_build_llm = LLMRequest(
model=global_config.model.lpmm_rdf_build, model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build" request_type="lpmm.rdf_build"
) )
def process_single_text(pg_hash, raw_data): def process_single_text(pg_hash, raw_data):
@@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器 # 设置信号处理器
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在 ensure_dirs() # 确保目录存在
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 新增用户确认提示 # 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。") print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

File diff suppressed because it is too large Load Diff

View File

@@ -8,15 +8,15 @@ import traceback
import io import io
import re import re
import binascii import binascii
from typing import Optional, Tuple, List, Any from typing import Optional, Tuple, List, Any
from PIL import Image from PIL import Image
from rich.traceback import install from rich.traceback import install
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db from src.common.database.database import db as peewee_db
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -379,9 +379,9 @@ class EmojiManager:
self._scan_task = None self._scan_task = None
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
self.llm_emotion_judge = LLMRequest( self.llm_emotion_judge = LLMRequest(
model=global_config.model.utils, max_tokens=600, request_type="emoji" model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0 self.emoji_num = 0
@@ -492,6 +492,7 @@ class EmojiManager:
return None return None
def _levenshtein_distance(self, s1: str, s2: str) -> int: def _levenshtein_distance(self, s1: str, s2: str) -> int:
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
"""计算两个字符串的编辑距离 """计算两个字符串的编辑距离
Args: Args:
@@ -629,11 +630,11 @@ class EmojiManager:
if success: if success:
# 注册成功则跳出循环 # 注册成功则跳出循环
break break
else:
# 注册失败则删除对应文件 # 注册失败则删除对应文件
file_path = os.path.join(EMOJI_DIR, filename) file_path = os.path.join(EMOJI_DIR, filename)
os.remove(file_path) os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e: except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
@@ -694,6 +695,7 @@ class EmojiManager:
return [] return []
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
# sourcery skip: use-next
"""从内存中的 emoji_objects 列表获取表情包 """从内存中的 emoji_objects 列表获取表情包
参数: 参数:
@@ -706,13 +708,45 @@ class EmojiManager:
if not emoji.is_deleted and emoji.hash == emoji_hash: if not emoji.is_deleted and emoji.hash == emoji_hash:
return emoji return emoji
return None # 如果循环结束还没找到,则返回 None return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
Args:
emoji_hash: 表情包的哈希值
Returns:
Optional[str]: 表情包描述如果未找到则返回None
"""
try:
# 先从内存中查找
emoji = await self.get_emoji_from_manager(emoji_hash)
if emoji and emoji.emotion:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
return ",".join(emoji.emotion)
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
return emoji_record.emotion
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述 """根据哈希值获取已注册表情包的描述
Args: Args:
emoji_hash: 表情包的哈希值 emoji_hash: 表情包的哈希值
Returns: Returns:
Optional[str]: 表情包描述如果未找到则返回None Optional[str]: 表情包描述如果未找到则返回None
""" """
@@ -722,7 +756,7 @@ class EmojiManager:
if emoji and emoji.description: if emoji and emoji.description:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...") logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
return emoji.description return emoji.description
# 如果内存中没有,从数据库查找 # 如果内存中没有,从数据库查找
self._ensure_db() self._ensure_db()
try: try:
@@ -732,9 +766,9 @@ class EmojiManager:
return emoji_record.description return emoji_record.description
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}") logger.error(f"从数据库查询表情包描述时出错: {e}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None return None
@@ -779,6 +813,7 @@ class EmojiManager:
return False return False
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
# sourcery skip: use-getitem-for-re-match-groups
"""替换一个表情包 """替换一个表情包
Args: Args:
@@ -820,7 +855,7 @@ class EmojiManager:
) )
# 调用大模型进行决策 # 调用大模型进行决策
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8) decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
logger.info(f"[决策] 结果: {decision}") logger.info(f"[决策] 结果: {decision}")
# 解析决策结果 # 解析决策结果
@@ -828,9 +863,7 @@ class EmojiManager:
logger.info("[决策] 不删除任何表情包") logger.info("[决策] 不删除任何表情包")
return False return False
# 尝试从决策中提取表情包编号 if match := re.search(r"删除编号(\d+)", decision):
match = re.search(r"删除编号(\d+)", decision)
if match:
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
# 检查索引是否有效 # 检查索引是否有效
@@ -889,6 +922,7 @@ class EmojiManager:
existing_description = None existing_description = None
try: try:
from src.common.database.database_model import Images from src.common.database.database_model import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji")) existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
@@ -902,15 +936,21 @@ class EmojiManager:
logger.info("[优化] 复用已有的详细描述跳过VLM调用") logger.info("[优化] 复用已有的详细描述跳过VLM调用")
else: else:
logger.info("[VLM分析] 生成新的详细描述") logger.info("[VLM分析] 生成新的详细描述")
if image_format == "gif" or image_format == "GIF": if image_format in ["gif", "GIF"]:
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64: if not image_base64:
raise RuntimeError("GIF表情包转换失败") raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
)
else: else:
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" prompt = (
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
# 审核表情包 # 审核表情包
if global_config.emoji.content_filtration: if global_config.emoji.content_filtration:
@@ -922,7 +962,9 @@ class EmojiManager:
4. 不要出现5个以上文字 4. 不要出现5个以上文字
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
''' '''
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) content, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
if content == "": if content == "":
return "", [] return "", []
@@ -933,7 +975,9 @@ class EmojiManager:
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
""" """
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
emotion_prompt, temperature=0.7, max_tokens=600
)
# 处理情感列表 # 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]

View File

@@ -7,12 +7,12 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.database.database_model import Expression
MAX_EXPRESSION_COUNT = 300 MAX_EXPRESSION_COUNT = 300
@@ -38,10 +38,9 @@ def init_prompt() -> None:
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格 请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字,不要考虑表情包和图片 1. 只考虑文字,不要考虑表情包和图片
2. 不要涉及具体的人名,只考虑语言风格 2. 不要涉及具体的人名,但是可以涉及具体名词
3. 语言风格包含特殊内容和情感 3. 思考有没有特殊的梗,一并总结成语言风格
4. 思考有没有特殊的梗,一并总结成语言风格 4. 例子仅供参考,请严格根据群聊内容总结!!!
5. 例子仅供参考,请严格根据群聊内容总结!!!
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性: 注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。 例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。
@@ -51,302 +50,150 @@ def init_prompt() -> None:
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂" "想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" "当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
请注意不要总结你自己SELF的发言 请注意不要总结你自己SELF的发言,尽量保证总结内容的逻辑性
现在请你概括 现在请你概括
""" """
Prompt(learn_style_prompt, "learn_style_prompt") Prompt(learn_style_prompt, "learn_style_prompt")
learn_grammar_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语法和句法特点,只考虑纯文字,不要考虑表情包和图片
1.不要总结【图片】,【动画表情】,[图片][动画表情],不总结 表情符号 at @ 回复 和[回复]
2.不要涉及具体的人名,只考虑语法和句法特点,
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
4. 例子仅供参考,请严格根据群聊内容总结!!!
总结成如下格式的规律,总结的内容要简洁,不浮夸:
"xxx"时,可以"xxx"
例如:
"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
注意不要总结你自己SELF的发言
现在请你概括
"""
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
class ExpressionLearner: class ExpressionLearner:
def __init__(self) -> None: def __init__(self, chat_id: str) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.model.replyer_1, model_set=model_config.model_task_config.replyer, request_type="expression.learner"
temperature=0.3,
request_type="expressor.learner",
) )
self.llm_model = None self.chat_id = chat_id
self._ensure_expression_directories() self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self._auto_migrate_json_to_db()
self._migrate_old_data_create_date()
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create: # 维护每个chat的上次学习时间
try: self.last_learning_time: float = time.time()
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}") # 学习参数
except Exception as e: self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
logger.error(f"创建目录失败 {directory}: {e}") self.min_learning_interval = 300 # 最短学习时间间隔(秒)
def _auto_migrate_json_to_db(self):
def can_learn_for_chat(self) -> bool:
""" """
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 检查指定聊天流是否允许学习表达
迁移完成后在/data/expression/done.done写入标记文件存在则跳过。
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
# 确保基础目录存在 Args:
try: chat_id: 聊天流ID
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"确保目录存在: {base_dir}")
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
return
logger.info("开始迁移表达方式JSON到数据库...") Returns:
migrated_count = 0 bool: 是否允许学习
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
def _migrate_old_data_create_date(self):
"""
为没有create_date的老数据设置创建日期
使用last_active_time作为create_date的默认值
""" """
try: try:
# 查找所有create_date为空的表达方式 use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
old_expressions = Expression.select().where(Expression.create_date.is_null()) return enable_learning
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e: except Exception as e:
logger.error(f"迁移老数据创建日期失败: {e}") logger.error(f"检查学习权限失败: {e}")
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query:
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style",
"create_date": create_date,
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query:
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions
def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
获取指定chat_id的表达方式创建信息按创建日期排序
"""
try:
expressions = (Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.create_date.desc())
.limit(limit))
result = []
for expr in expressions:
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
result.append({
"situation": expr.situation,
"style": expr.style,
"type": expr.type,
"count": expr.count,
"create_date": create_date,
"create_date_formatted": format_create_date(create_date),
"last_active_time": expr.last_active_time,
"last_active_formatted": format_create_date(expr.last_active_time),
})
return result
except Exception as e:
logger.error(f"获取表达方式创建信息失败: {e}")
return []
def is_similar(self, s1: str, s2: str) -> bool:
"""
判断两个字符串是否相似只考虑长度大于5且有80%以上重合,不考虑子串)
"""
if not s1 or not s2:
return False return False
min_len = min(len(s1), len(s2))
if min_len < 5:
return False
same = sum(a == b for a, b in zip(s1, s2, strict=False))
return same / min_len > 0.8
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]: def should_trigger_learning(self) -> bool:
""" """
学习并存储表达方式,分别学习语言风格和句法特点 检查是否应该触发学习
同时对所有已存储的表达方式进行全局衰减
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
""" """
current_time = time.time() current_time = time.time()
# 全局衰减所有已存储的表达方式(直接操作数据库) # 获取该聊天流的学习强度
self._apply_global_decay_to_database(current_time) try:
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
except Exception as e:
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
return False
# 检查是否允许学习
if not enable_learning:
return False
# 根据学习强度计算最短学习时间间隔
min_interval = self.min_learning_interval / learning_intensity
# 检查时间间隔
time_diff = current_time - self.last_learning_time
if time_diff < min_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
learnt_style: Optional[List[Tuple[str, str, str]]] = [] if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
learnt_grammar: Optional[List[Tuple[str, str, str]]] = [] return False
# 学习新的表达方式(这里会进行局部衰减)
for _ in range(3): return True
learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style: async def trigger_learning_for_chat(self) -> bool:
return [], [] """
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
# 更新学习时间
self.last_learning_time = time.time()
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# """
# 获取指定chat_id的style表达方式已禁用grammar的获取
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
# """
# learnt_style_expressions = []
# # 直接从数据库查询
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
# for expr in style_query:
# # 确保create_date存在如果不存在则使用last_active_time
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# learnt_style_expressions.append(
# {
# "situation": expr.situation,
# "style": expr.style,
# "count": expr.count,
# "last_active_time": expr.last_active_time,
# "source_id": self.chat_id,
# "type": "style",
# "create_date": create_date,
# }
# )
# return learnt_style_expressions
for _ in range(1):
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return [], []
return learnt_style, learnt_grammar
def _apply_global_decay_to_database(self, current_time: float) -> None: def _apply_global_decay_to_database(self, current_time: float) -> None:
""" """
@@ -355,19 +202,19 @@ class ExpressionLearner:
try: try:
# 获取所有表达方式 # 获取所有表达方式
all_expressions = Expression.select() all_expressions = Expression.select()
updated_count = 0 updated_count = 0
deleted_count = 0 deleted_count = 0
for expr in all_expressions: for expr in all_expressions:
# 计算时间差 # 计算时间差
last_active = expr.last_active_time last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值 # 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days) decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value) new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01: if new_count <= 0.01:
# 如果count太小删除这个表达方式 # 如果count太小删除这个表达方式
expr.delete_instance() expr.delete_instance()
@@ -377,10 +224,10 @@ class ExpressionLearner:
expr.count = new_count expr.count = new_count
expr.save() expr.save()
updated_count += 1 updated_count += 1
if updated_count > 0 or deleted_count > 0: if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e: except Exception as e:
logger.error(f"数据库全局衰减失败: {e}") logger.error(f"数据库全局衰减失败: {e}")
@@ -406,20 +253,16 @@ class ExpressionLearner:
return min(0.01, decay) return min(0.01, decay)
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
""" """
选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式 学习并存储表达方式
type: "style" or "grammar"
""" """
if type == "style": # 检查是否允许在此聊天流中学习(在函数最前面检查)
type_str = "语言风格" if not self.can_learn_for_chat():
elif type == "grammar": logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
type_str = "句法特点" return []
else:
raise ValueError(f"Invalid type: {type}")
res = await self.learn_expression(type, num) res = await self.learn_expression(num)
if res is None: if res is None:
return [] return []
@@ -435,10 +278,10 @@ class ExpressionLearner:
learnt_expressions_str = "" learnt_expressions_str = ""
for _chat_id, situation, style in learnt_expressions: for _chat_id, situation, style in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n" learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{group_name} 学习到{type_str}:\n{learnt_expressions_str}") logger.info(f"{group_name} 学习到表达风格:\n{learnt_expressions_str}")
if not learnt_expressions: if not learnt_expressions:
logger.info(f"没有学习到{type_str}") logger.info("没有学习到表达风格")
return [] return []
# 按chat_id分组 # 按chat_id分组
@@ -456,7 +299,7 @@ class ExpressionLearner:
# 查找是否已存在相似表达方式 # 查找是否已存在相似表达方式
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id)
& (Expression.type == type) & (Expression.type == "style")
& (Expression.situation == new_expr["situation"]) & (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"]) & (Expression.style == new_expr["style"])
) )
@@ -476,13 +319,13 @@ class ExpressionLearner:
count=1, count=1,
last_active_time=current_time, last_active_time=current_time,
chat_id=chat_id, chat_id=chat_id,
type=type, type="style",
create_date=current_time, # 手动设置创建日期 create_date=current_time, # 手动设置创建日期
) )
# 限制最大数量 # 限制最大数量
exprs = list( exprs = list(
Expression.select() Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type)) .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc()) .order_by(Expression.count.asc())
) )
if len(exprs) > MAX_EXPRESSION_COUNT: if len(exprs) > MAX_EXPRESSION_COUNT:
@@ -491,25 +334,25 @@ class ExpressionLearner:
expr.delete_instance() expr.delete_instance()
return learnt_expressions return learnt_expressions
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
"""选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式 """从指定聊天流学习表达方式
Args: Args:
type: "style" or "grammar" num: 学习数量
""" """
if type == "style": type_str = "语言风格"
type_str = "语言风格" prompt = "learn_style_prompt"
prompt = "learn_style_prompt"
elif type == "grammar":
type_str = "句法特点"
prompt = "learn_grammar_prompt"
else:
raise ValueError(f"Invalid type: {type}")
current_time = time.time() current_time = time.time()
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_random(
current_time - 3600 * 24, current_time, limit=num # 获取上次学习时间
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
) )
# print(random_msg) # print(random_msg)
if not random_msg or random_msg == []: if not random_msg or random_msg == []:
return None return None
@@ -527,7 +370,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的prompt: {prompt}") logger.debug(f"学习{type_str}的prompt: {prompt}")
try: try:
response, _ = await self.express_learn_model.generate_response_async(prompt) response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e: except Exception as e:
logger.error(f"学习{type_str}失败: {e}") logger.error(f"学习{type_str}失败: {e}")
return None return None
@@ -571,12 +414,221 @@ class ExpressionLearner:
init_prompt() init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
self._auto_migrate_json_to_db()
self._migrate_old_data_create_date()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
expression_learner = None for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
def get_expression_learner(): def _auto_migrate_json_to_db(self):
global expression_learner """
if expression_learner is None: 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
expression_learner = ExpressionLearner() 迁移完成后在/data/expression/done.done写入标记文件存在则跳过。
return expression_learner 然后检查done.done2如果没有就删除所有grammar表达并创建该标记文件。
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
done_flag2 = os.path.join(base_dir, "done.done2")
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"确保目录存在: {base_dir}")
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
else:
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
logger.info("开始删除所有grammar类型的表达...")
try:
deleted_count = self.delete_all_grammar_expressions()
logger.info(f"grammar表达删除完成共删除 {deleted_count} 个表达")
# 创建done.done2标记文件
with open(done_flag2, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info("已创建done.done2标记文件grammar表达删除标记完成")
except Exception as e:
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
else:
logger.info("grammar表达已删除跳过重复删除")
def _migrate_old_data_create_date(self):
"""
为没有create_date的老数据设置创建日期
使用last_active_time作为create_date的默认值
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:
"""
检查expression库中所有type为"grammar"的表达并全部删除
Returns:
int: 删除的grammar表达数量
"""
try:
# 查询所有type为"grammar"的表达
grammar_expressions = Expression.select().where(Expression.type == "grammar")
grammar_count = grammar_expressions.count()
if grammar_count == 0:
logger.info("expression库中没有找到grammar类型的表达")
return 0
logger.info(f"找到 {grammar_count} 个grammar类型的表达开始删除...")
# 删除所有grammar类型的表达
deleted_count = 0
for expr in grammar_expressions:
try:
expr.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"删除grammar表达失败: {e}")
continue
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
return deleted_count
except Exception as e:
logger.error(f"删除grammar表达过程中发生错误: {e}")
return 0
expression_learner_manager = ExpressionLearnerManager()

View File

@@ -1,16 +1,16 @@
import json import json
import time import time
import random import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@@ -25,7 +25,7 @@ def init_prompt():
以下是可选的表达情境: 以下是可选的表达情境:
{all_situations} {all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。 请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括: 考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等) 1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等) 2. 话题类型(日常、技术、游戏、情感等)
@@ -35,11 +35,7 @@ def init_prompt():
请以JSON格式输出只需要输出选中的情境编号 请以JSON格式输出只需要输出选中的情境编号
例如: 例如:
{{ {{
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48 , 64] "selected_situations": [2, 3, 5, 7, 19]
}}
例如:
{{
"selected_situations": [1, 4, 7, 9, 23, 38, 44]
}} }}
请严格按照JSON格式输出不要包含其他内容 请严格按照JSON格式输出不要包含其他内容
@@ -74,13 +70,27 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
class ExpressionSelector: class ExpressionSelector:
def __init__(self): def __init__(self):
self.expression_learner = get_expression_learner()
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
request_type="expression.selector",
) )
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod @staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致""" """解析'platform:id:type'为chat_id与get_stream_id一致"""
@@ -92,7 +102,6 @@ class ExpressionSelector:
id_str = parts[1] id_str = parts[1]
stream_type = parts[2] stream_type = parts[2]
is_group = stream_type == "group" is_group = stream_type == "group"
import hashlib
if is_group: if is_group:
components = [platform, str(id_str)] components = [platform, str(id_str)]
else: else:
@@ -108,29 +117,27 @@ class ExpressionSelector:
for group in groups: for group in groups:
group_chat_ids = [] group_chat_ids = []
for stream_config_str in group: for stream_config_str in group:
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str) if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
if chat_id_candidate:
group_chat_ids.append(chat_id_candidate) group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids: if chat_id in group_chat_ids:
return group_chat_ids return group_chat_ids
return [chat_id] return [chat_id]
def get_random_expressions( def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float self, chat_id: str, total_num: int
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选 # 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式 # 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where( style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
) )
grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)
style_exprs = [ style_exprs = [
{ {
"id": expr.id,
"situation": expr.situation, "situation": expr.situation,
"style": expr.style, "style": expr.style,
"count": expr.count, "count": expr.count,
@@ -138,35 +145,17 @@ class ExpressionSelector:
"source_id": expr.chat_id, "source_id": expr.chat_id,
"type": "style", "type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query }
for expr in style_query
] ]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重 # 按权重抽样使用count作为权重
if style_exprs: if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs] style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, style_num) selected_style = weighted_sample(style_exprs, style_weights, total_num)
else: else:
selected_style = [] selected_style = []
if grammar_exprs: return selected_style
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库""" """对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
@@ -174,22 +163,22 @@ class ExpressionSelector:
return return
updates_by_key = {} updates_by_key = {}
for expr in expressions_to_update: for expr in expressions_to_update:
source_id = expr.get("source_id") source_id: str = expr.get("source_id") # type: ignore
expr_type = expr.get("type", "style") expr_type: str = expr.get("type", "style")
situation = expr.get("situation") situation: str = expr.get("situation") # type: ignore
style = expr.get("style") style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style: if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue continue
key = (source_id, expr_type, situation, style) key = (source_id, expr_type, situation, style)
if key not in updates_by_key: if key not in updates_by_key:
updates_by_key[key] = expr updates_by_key[key] = expr
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items(): for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) & (Expression.chat_id == chat_id)
(Expression.type == expr_type) & & (Expression.type == expr_type)
(Expression.situation == situation) & & (Expression.situation == situation)
(Expression.style == style) & (Expression.style == style)
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
@@ -207,38 +196,36 @@ class ExpressionSelector:
chat_id: str, chat_id: str,
chat_info: str, chat_info: str,
max_num: int = 10, max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None, target_message: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension # sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式""" """使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取35个随机表达方式(现在按权重抽取) # 1. 获取20个随机表达方式(现在按权重抽取)
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 50, 0.5, 0.5) style_exprs = self.get_random_expressions(chat_id, 10)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表 # 2. 构建所有表达方式的索引和情境列表
all_expressions = [] all_expressions: List[Dict[str, Any]] = []
all_situations = [] all_situations: List[str] = []
# 添加style表达方式 # 添加style表达方式
for expr in style_exprs: for expr in style_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr: expr = expr.copy()
expr_with_type = expr.copy() all_expressions.append(expr)
expr_with_type["type"] = "style" all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加grammar表达方式
for expr in grammar_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "grammar"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
if not all_expressions: if not all_expressions:
logger.warning("没有找到可用的表达方式") logger.warning("没有找到可用的表达方式")
return [] return [], []
all_situations_str = "\n".join(all_situations) all_situations_str = "\n".join(all_situations)
@@ -254,23 +241,28 @@ class ExpressionSelector:
bot_name=global_config.bot.nickname, bot_name=global_config.bot.nickname,
chat_observe_info=chat_info, chat_observe_info=chat_info,
all_situations=all_situations_str, all_situations=all_situations_str,
min_num=min_num,
max_num=max_num, max_num=max_num,
target_message=target_message_str, target_message=target_message_str,
target_message_extra_block=target_message_extra_block, target_message_extra_block=target_message_extra_block,
) )
# print(prompt)
# 4. 调用LLM # 4. 调用LLM
try: try:
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"{self.log_prefix} LLM返回结果: {content}") # logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content: if not content:
logger.warning("LLM返回空结果") logger.warning("LLM返回空结果")
return [] return [], []
# 5. 解析结果 # 5. 解析结果
result = repair_json(content) result = repair_json(content)
@@ -280,15 +272,17 @@ class ExpressionSelector:
if not isinstance(result, dict) or "selected_situations" not in result: if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误") logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}") logger.info(f"LLM返回结果: \n{content}")
return [] return [], []
selected_indices = result["selected_situations"] selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式 # 根据索引获取完整的表达方式
valid_expressions = [] valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices: for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions): if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始 expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression) valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数 # 对选中的所有表达方式一次性更新count数
@@ -296,11 +290,12 @@ class ExpressionSelector:
self.update_expressions_count_batch(valid_expressions, 0.006) self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个") # logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions return valid_expressions, selected_ids
except Exception as e: except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}") logger.error(f"LLM处理表达方式选择时出错: {e}")
return [] return [], []
init_prompt() init_prompt()

View File

@@ -0,0 +1,143 @@
from typing import Optional
from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class FocusValueControl:
def __init__(self,chat_id:str):
self.chat_id = chat_id
self.focus_value_adjust = 1
def get_current_focus_value(self) -> float:
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
class FocusValueControlManager:
def __init__(self):
self.focus_value_controls = {}
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
if chat_id not in self.focus_value_controls:
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
return self.focus_value_controls[chat_id]
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 focus_value
"""
if not global_config.chat.focus_value_adjust:
return global_config.chat.focus_value
if chat_id:
stream_focus_value = get_stream_specific_focus_value(chat_id)
if stream_focus_value is not None:
return stream_focus_value
global_focus_value = get_global_focus_value()
if global_focus_value is not None:
return global_focus_value
return global_config.chat.focus_value
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
"""
获取特定聊天流在当前时间的专注度
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 专注度值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in global_config.chat.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_id:
continue
# 使用通用的时间专注度解析方法
return get_time_based_focus_value(config_item[1:])
return None
def get_time_based_focus_value(time_focus_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的专注度
Args:
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
Returns:
float: 专注度值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间专注度配置
time_focus_pairs = []
for time_focus_str in time_focus_list:
try:
time_str, focus_str = time_focus_str.split(",")
hour, minute = map(int, time_str.split(":"))
focus_value = float(focus_str)
minutes = hour * 60 + minute
time_focus_pairs.append((minutes, focus_value))
except (ValueError, IndexError):
continue
if not time_focus_pairs:
return None
# 按时间排序
time_focus_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的专注度
current_focus_value = None
for minutes, focus_value in time_focus_pairs:
if current_minutes >= minutes:
current_focus_value = focus_value
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
if current_focus_value is None and time_focus_pairs:
current_focus_value = time_focus_pairs[-1][1]
return current_focus_value
def get_global_focus_value() -> Optional[float]:
"""
获取全局默认专注度配置
Returns:
float: 专注度值,如果没有配置则返回 None
"""
for config_item in global_config.chat.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return get_time_based_focus_value(config_item[1:])
return None
focus_value_control = FocusValueControlManager()

View File

@@ -0,0 +1,144 @@
from typing import Optional
from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class TalkFrequencyControl:
def __init__(self,chat_id:str):
self.chat_id = chat_id
self.talk_frequency_adjust = 1
def get_current_talk_frequency(self) -> float:
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
class TalkFrequencyControlManager:
def __init__(self):
self.talk_frequency_controls = {}
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
if chat_id not in self.talk_frequency_controls:
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
return self.talk_frequency_controls[chat_id]
def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
Args:
chat_stream_id: 聊天流ID格式为 "platform:chat_id:type"
Returns:
float: 对应的频率值
"""
if not global_config.chat.talk_frequency_adjust:
return global_config.chat.talk_frequency
# 优先检查聊天流特定的配置
if chat_id:
stream_frequency = get_stream_specific_frequency(chat_id)
if stream_frequency is not None:
return stream_frequency
# 检查全局时段配置(第一个元素为空字符串的配置)
global_frequency = get_global_frequency()
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
Args:
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
Returns:
float: 频率值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间频率配置
time_freq_pairs = []
for time_freq_str in time_freq_list:
try:
time_str, freq_str = time_freq_str.split(",")
hour, minute = map(int, time_str.split(":"))
frequency = float(freq_str)
minutes = hour * 60 + minute
time_freq_pairs.append((minutes, frequency))
except (ValueError, IndexError):
continue
if not time_freq_pairs:
return None
# 按时间排序
time_freq_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的频率
current_frequency = None
for minutes, frequency in time_freq_pairs:
if current_minutes >= minutes:
current_frequency = frequency
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
if current_frequency is None and time_freq_pairs:
current_frequency = time_freq_pairs[-1][1]
return current_frequency
def get_stream_specific_frequency(chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 频率值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in global_config.chat.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 使用通用的时间频率解析方法
return get_time_based_frequency(config_item[1:])
return None
def get_global_frequency() -> Optional[float]:
"""
获取全局默认频率配置
Returns:
float: 频率值,如果没有配置则返回 None
"""
for config_item in global_config.chat.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return get_time_based_frequency(config_item[1:])
return None
talk_frequency_control = TalkFrequencyControlManager()

View File

@@ -0,0 +1,37 @@
from typing import Optional
import hashlib
def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
Args:
stream_config_str: 格式为 "platform:id:type" 的字符串
Returns:
str: 生成的 chat_id如果解析失败则返回 None
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
# 判断是否为群聊
is_group = stream_type == "group"
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None

View File

@@ -3,7 +3,6 @@ from typing import Any, Optional, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.heart_flow.sub_heartflow import SubHeartflow from src.chat.heart_flow.sub_heartflow import SubHeartflow
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("heartflow") logger = get_logger("heartflow")
@@ -27,8 +26,6 @@ class Heartflow:
# 注册子心流 # 注册子心流
self.subheartflows[subheartflow_id] = new_subflow self.subheartflows[subheartflow_id] = new_subflow
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
logger.info(f"[{heartflow_name}] 开始接收消息")
return new_subflow return new_subflow
except Exception as e: except Exception as e:

View File

@@ -14,53 +14,35 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.chat_message_builder import replace_user_references_sync from src.chat.utils.chat_message_builder import replace_user_references_sync
from src.common.logger import get_logger from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.heart_flow.sub_heartflow import SubHeartflow from src.chat.heart_flow.sub_heartflow import SubHeartflow
logger = get_logger("chat") logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]:
async def _process_relationship(message: MessageRecv) -> None:
"""处理用户关系逻辑
Args:
message: 消息对象,包含用户信息
"""
platform = message.message_info.platform
user_id = message.message_info.user_info.user_id # type: ignore
nickname = message.message_info.user_info.user_nickname # type: ignore
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
relationship_manager = get_relationship_manager()
is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known:
logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
"""计算消息的兴趣度 """计算消息的兴趣度
Args: Args:
message: 待处理的消息对象 message: 待处理的消息对象
Returns: Returns:
Tuple[float, bool]: (兴趣度, 是否被提及) Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
""" """
is_mentioned, _ = is_mentioned_bot_in_message(message) is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0 interested_rate = 0.0
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate = await hippocampus_manager.get_activate_from_text( interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
max_depth= 5, max_depth= 4,
fast_retrieval=False, fast_retrieval=False,
) )
logger.debug(f"记忆激活率: {interested_rate:.2f}") message.key_words = keywords
message.key_words_lite = keywords_lite
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
text_len = len(message.processed_plain_text) text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
@@ -99,7 +81,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
interest_increase_on_mention = 1 interest_increase_on_mention = 1
interested_rate += interest_increase_on_mention interested_rate += interest_increase_on_mention
return interested_rate, is_mentioned return interested_rate, is_mentioned, keywords
class HeartFCMessageReceiver: class HeartFCMessageReceiver:
@@ -128,7 +110,7 @@ class HeartFCMessageReceiver:
chat = message.chat_stream chat = message.chat_stream
# 2. 兴趣度计算与更新 # 2. 兴趣度计算与更新
interested_rate, is_mentioned = await _calculate_interest(message) interested_rate, is_mentioned, keywords = await _calculate_interest(message)
message.interest_value = interested_rate message.interest_value = interested_rate
message.is_mentioned = is_mentioned message.is_mentioned = is_mentioned
@@ -143,8 +125,6 @@ class HeartFCMessageReceiver:
# 3. 日志记录 # 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊" mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片] # 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
picid_pattern = r"\[picid:([^\]]+)\]" picid_pattern = r"\[picid:([^\]]+)\]"
@@ -157,13 +137,12 @@ class HeartFCMessageReceiver:
replace_bot_name=True replace_bot_name=True
) )
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore if keywords:
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]") # type: ignore
else:
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") _ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
# 4. 关系处理
if global_config.relationship.enable_relationship:
await _process_relationship(message)
except Exception as e: except Exception as e:
logger.error(f"消息处理失败: {e}") logger.error(f"消息处理失败: {e}")

View File

@@ -3,6 +3,7 @@ import json
import os import os
import math import math
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
@@ -11,8 +12,6 @@ import pandas as pd
# import tqdm # import tqdm
import faiss import faiss
# from .llm_client import LLMClient
# from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install
@@ -26,12 +25,20 @@ from rich.progress import (
SpinnerColumn, SpinnerColumn,
TextColumn, TextColumn,
) )
from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config
install(extra_lines=3) install(extra_lines=3)
# 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
@@ -87,13 +94,23 @@ class EmbeddingStoreItem:
class EmbeddingStore: class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str): def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
self.namespace = namespace self.namespace = namespace
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index" self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
# 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
# 如果配置值被调整,记录日志
if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
self.store = {} self.store = {}
self.faiss_index = None self.faiss_index = None
@@ -125,16 +142,134 @@ class EmbeddingStore:
return [] return []
return result return result
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数
progress_callback: 进度回调函数,接收一个参数表示完成的数量
Returns:
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
"""
if not strs:
return []
# 分块
chunks = []
for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
try:
# 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 直接使用异步函数
embedding = asyncio.run(llm.get_embedding(s))
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
for i, s in enumerate(chunk_strs):
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
return chunk_results
# 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk):
try:
chunk_results = future.result()
for idx, s, embedding in chunk_results:
results[idx] = (s, embedding)
except Exception as e:
chunk = future_to_chunk[future]
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
# 为失败的块添加空结果
start_idx, chunk_strs = chunk
for i, s in enumerate(chunk_strs):
results[start_idx + i] = (s, [])
# 按原始顺序返回结果
ordered_results = []
for i in range(len(strs)):
if i in results:
ordered_results.append(results[i])
else:
# 防止遗漏
ordered_results.append((strs[i], []))
return ordered_results
def get_test_file_path(self): def get_test_file_path(self):
return EMBEDDING_TEST_FILE return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self): def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地""" """保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...")
# 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
)
# 构建测试向量字典
test_vectors = {} test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS): for idx, (s, embedding) in enumerate(embedding_results):
test_vectors[str(idx)] = self._get_embedding(s) if embedding:
test_vectors[str(idx)] = embedding
else:
logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f: with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2) json.dump(test_vectors, f, ensure_ascii=False, indent=2)
logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self): def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入""" """加载本地保存的测试字符串嵌入"""
@@ -145,29 +280,64 @@ class EmbeddingStore:
return json.load(f) return json.load(f)
def check_embedding_model_consistency(self): def check_embedding_model_consistency(self):
"""校验当前模型与本地嵌入模型是否一致""" """校验当前模型与本地嵌入模型是否一致(使用多线程优化)"""
local_vectors = self.load_embedding_test_vectors() local_vectors = self.load_embedding_test_vectors()
if local_vectors is None: if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
local_emb = local_vectors.get(str(idx)) # 检查本地向量完整性
if local_emb is None: for idx in range(len(EMBEDDING_TEST_STRINGS)):
if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
new_emb = self._get_embedding(s)
logger.info("开始检验嵌入模型一致性...")
# 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
)
# 检查一致性
for idx, (s, new_emb) in enumerate(embedding_results):
local_emb = local_vectors.get(str(idx))
if not new_emb:
logger.error(f"获取测试字符串嵌入失败: {s}")
return False
sim = cosine_similarity(local_emb, new_emb) sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD: if sim < EMBEDDING_SIM_THRESHOLD:
logger.error("嵌入模型一致性校验失败") logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False return False
logger.info("嵌入模型一致性校验通过。") logger.info("嵌入模型一致性校验通过。")
return True return True
def batch_insert_strs(self, strs: List[str], times: int) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串""" """向库中存入字符串(使用多线程优化)"""
if not strs:
return
total = len(strs) total = len(strs)
# 过滤已存在的字符串
new_strs = []
for s in strs:
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash not in self.store:
new_strs.append(s)
if not new_strs:
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
return
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
@@ -181,19 +351,38 @@ class EmbeddingStore:
transient=False, transient=False,
) as progress: ) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重 # 首先更新已存在项的进度
item_hash = self.namespace + "-" + get_sha256(s) already_processed = total - len(new_strs)
if item_hash in self.store: if already_processed > 0:
progress.update(task, advance=1) progress.update(task, advance=already_processed)
continue
if new_strs:
# 获取embedding # 使用实例配置的参数,智能调整分块和线程数
embedding = self._get_embedding(s) optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
progress.update(task, advance=1)
# 定义进度更新回调函数
def update_progress(count):
progress.update(task, advance=count)
# 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded(
new_strs,
chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers,
progress_callback=update_progress
)
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s)
if embedding: # 只有成功获取到嵌入才存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
else:
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""保存到文件""" """保存到文件"""
@@ -316,31 +505,37 @@ class EmbeddingStore:
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
"""
初始化EmbeddingManager
Args:
max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小
"""
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "paragraph", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "entity", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "relation", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self): def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验""" """对所有嵌入库做模型一致性校验"""
for store in [ return self.paragraphs_embedding_store.check_embedding_model_consistency()
self.paragraphs_embedding_store,
self.entities_embedding_store,
self.relation_embedding_store,
]:
if not store.check_embedding_model_consistency():
return False
return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库""" """将段落编码存入Embedding库"""

View File

@@ -8,12 +8,15 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json from json_repair import repair_json
def _extract_json_from_text(text: str): def _extract_json_from_text(text: str):
# sourcery skip: assign-if-exp, extract-method
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
if text is None: if text is None:
logger.error("输入文本为None") logger.error("输入文本为None")
return [] return []
try: try:
fixed_json = repair_json(text) fixed_json = repair_json(text)
if isinstance(fixed_json, str): if isinstance(fixed_json, str):
@@ -24,7 +27,7 @@ def _extract_json_from_text(text: str):
# 如果是列表,直接返回 # 如果是列表,直接返回
if isinstance(parsed_json, list): if isinstance(parsed_json, list):
return parsed_json return parsed_json
# 如果是字典且只有一个项目,可能包装了列表 # 如果是字典且只有一个项目,可能包装了列表
if isinstance(parsed_json, dict): if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表 # 如果字典只有一个键,并且值是列表,返回那个列表
@@ -33,7 +36,7 @@ def _extract_json_from_text(text: str):
if isinstance(value, list): if isinstance(value, list):
return value return value
return parsed_json return parsed_json
# 其他情况,尝试转换为列表 # 其他情况,尝试转换为列表
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}") logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
return [] return []
@@ -42,44 +45,40 @@ def _extract_json_from_text(text: str):
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return [] return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph) entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
# 使用 asyncio.run 来运行异步方法 # 使用 asyncio.run 来运行异步方法
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
llm_req.generate_response_async(entity_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
llm_req.generate_response_async(entity_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}") logger.debug(f"LLM返回的原始响应: {response}")
entity_extract_result = _extract_json_from_text(response) entity_extract_result = _extract_json_from_text(response)
# 检查返回的是否为有效的实体列表 # 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list): if not isinstance(entity_extract_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(entity_extract_result, dict):
if isinstance(entity_extract_result, dict): raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 尝试常见的键名
for key in ['entities', 'result', 'data', 'items']: # 尝试常见的键名
if key in entity_extract_result and isinstance(entity_extract_result[key], list): for key in ["entities", "result", "data", "items"]:
entity_extract_result = entity_extract_result[key] if key in entity_extract_result and isinstance(entity_extract_result[key], list):
break entity_extract_result = entity_extract_result[key]
else: break
# 如果找不到合适的列表,抛出异常
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
else: else:
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") # 如果找不到合适的列表,抛出异常
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体 # 过滤无效实体
entity_extract_result = [ entity_extract_result = [
entity entity
@@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY) if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
] ]
if len(entity_extract_result) == 0: if not entity_extract_result:
raise Exception("实体提取结果为空") raise ValueError("实体提取结果为空")
return entity_extract_result return entity_extract_result
@@ -98,45 +97,44 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context( rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False) paragraph, entities=json.dumps(entities, ensure_ascii=False)
) )
# 使用 asyncio.run 来运行异步方法 # 使用 asyncio.run 来运行异步方法
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
llm_req.generate_response_async(rdf_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
llm_req.generate_response_async(rdf_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}") logger.debug(f"RDF LLM返回的原始响应: {response}")
rdf_triple_result = _extract_json_from_text(response) rdf_triple_result = _extract_json_from_text(response)
# 检查返回的是否为有效的三元组列表 # 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list): if not isinstance(rdf_triple_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(rdf_triple_result, dict):
if isinstance(rdf_triple_result, dict): raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 尝试常见的键名
for key in ['triples', 'result', 'data', 'items']: # 尝试常见的键名
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): for key in ["triples", "result", "data", "items"]:
rdf_triple_result = rdf_triple_result[key] if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
break rdf_triple_result = rdf_triple_result[key]
else: break
# 如果找不到合适的列表,抛出异常
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
else: else:
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}") # 如果找不到合适的列表,抛出异常
raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式 # 验证三元组格式
for triple in rdf_triple_result: for triple in rdf_triple_result:
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: if (
raise Exception("RDF提取结果格式错误") not isinstance(triple, list)
or len(triple) != 3
or (triple[0] is None or triple[1] is None or triple[2] is None)
or "" in triple
):
raise ValueError("RDF提取结果格式错误")
return rdf_triple_result return rdf_triple_result

View File

@@ -20,8 +20,7 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import global_config from src.config.config import global_config
from src.manager.local_store_manager import local_storage
from .global_logger import logger from .global_logger import logger
@@ -30,19 +29,9 @@ def _get_kg_dir():
""" """
安全地获取KG数据目录路径 安全地获取KG数据目录路径
""" """
root_path: str = local_storage["root_path"] current_dir = os.path.dirname(os.path.abspath(__file__))
if root_path is None: root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用 kg_dir = os.path.join(root_path, "data/rag")
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag")
else:
kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/") return str(kg_dir).replace("\\", "/")
@@ -65,9 +54,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径 # 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str() self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@@ -122,8 +111,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) hash_key1 = "entity" + "-" + get_sha256(triple[0])
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) hash_key2 = "entity" + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@@ -141,8 +130,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) ent_hash_key = "entity" + "-" + get_sha256(triple[0])
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) pg_hash_key = "paragraph" + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@@ -157,12 +146,12 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
synonym_result = dict() synonym_result = {}
# rich 进度条 # rich 进度条
total = len(ent_hash_list) total = len(ent_hash_list)
@@ -190,14 +179,14 @@ class KGManager:
assert isinstance(ent, EmbeddingStoreItem) assert isinstance(ent, EmbeddingStoreItem)
# 查询相似实体 # 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k( similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k
) )
res_ent = [] # Debug res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents: for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash: if res_ent_hash == ent_hash:
# 避免自连接 # 避免自连接
continue continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]: if similarity < global_config.lpmm_knowledge.rag_synonym_threshold:
# 相似度阈值 # 相似度阈值
continue continue
node_to_node[(res_ent_hash, ent_hash)] = similarity node_to_node[(res_ent_hash, ent_hash)] = similarity
@@ -263,7 +252,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(local_storage["ent_namespace"]): if node_hash.startswith("entity"):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash) node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -275,7 +264,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(local_storage["pg_namespace"]): elif node_hash.startswith("paragraph"):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -359,7 +348,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) ent_hash = "entity" + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@@ -380,7 +369,7 @@ class KGManager:
for ent_hash in ent_weights.keys(): for ent_hash in ent_weights.keys():
ent_weights[ent_hash] = 1.0 ent_weights[ent_hash] = 1.0
else: else:
down_edge = global_config["qa"]["params"]["paragraph_node_weight"] down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight
# 缩放取值区间至[down_edge, 1] # 缩放取值区间至[down_edge, 1]
for ent_hash, score in ent_weights.items(): for ent_hash, score in ent_weights.items():
# 缩放相似度 # 缩放相似度
@@ -389,7 +378,7 @@ class KGManager:
) + down_edge ) + down_edge
# 取平均相似度的top_k实体 # 取平均相似度的top_k实体
top_k = global_config["qa"]["params"]["ent_filter_top_k"] top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
if len(ent_mean_scores) > top_k: if len(ent_mean_scores) > top_k:
# 从大到小排序取后len - k个 # 从大到小排序取后len - k个
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
@@ -418,7 +407,7 @@ class KGManager:
for pg_hash, score in pg_sim_scores.items(): for pg_hash, score in pg_sim_scores.items():
pg_weights[pg_hash] = ( pg_weights[pg_hash] = (
score * global_config["qa"]["params"]["paragraph_node_weight"] score * global_config.lpmm_knowledge.qa_paragraph_node_weight
) # 文段权重 = 归一化相似度 * 文段节点权重参数 ) # 文段权重 = 归一化相似度 * 文段节点权重参数
del pg_sim_scores del pg_sim_scores
@@ -431,7 +420,7 @@ class KGManager:
self.graph, self.graph,
personalization=ppr_node_weights, personalization=ppr_node_weights,
max_iter=100, max_iter=100,
alpha=global_config["qa"]["params"]["ppr_damping"], alpha=global_config.lpmm_knowledge.qa_ppr_damping,
) )
# 获取最终结果 # 获取最终结果
@@ -439,7 +428,7 @@ class KGManager:
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score)
for node_key, score in ppr_res.items() for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"]) if node_key.startswith("paragraph")
] ]
del ppr_res del ppr_res

View File

@@ -1,12 +1,8 @@
from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger from src.chat.knowledge.global_logger import logger
from src.config.config import global_config as bot_global_config from src.config.config import global_config
from src.manager.local_store_manager import local_storage
import os import os
INVALID_ENTITY = [ INVALID_ENTITY = [
@@ -21,9 +17,6 @@ INVALID_ENTITY = [
"她们", "她们",
"它们", "它们",
] ]
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph" RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
@@ -34,67 +27,13 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..",
DATA_PATH = os.path.join(ROOT_PATH, "data") DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage():
"""
初始化知识库相关的本地存储配置
使用字典批量设置避免重复的if判断
"""
# 定义所有需要初始化的配置项
default_configs = {
# 路径配置
"root_path": ROOT_PATH,
"data_path": f"{ROOT_PATH}/data",
# 实体和命名空间配置
"lpmm_invalid_entity": INVALID_ENTITY,
"pg_namespace": PG_NAMESPACE,
"ent_namespace": ENT_NAMESPACE,
"rel_namespace": REL_NAMESPACE,
# RAG相关命名空间配置
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
}
# 日志级别映射重要配置用info其他用debug
important_configs = {"root_path", "data_path"}
# 批量设置配置项
initialized_count = 0
for key, default_value in default_configs.items():
if local_storage[key] is None:
local_storage[key] = default_value
# 根据重要性选择日志级别
if key in important_configs:
logger.info(f"设置{key}: {default_value}")
else:
logger.debug(f"设置{key}: {default_value}")
initialized_count += 1
if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else:
logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径
# sourcery skip: dict-comprehension
_initialize_knowledge_local_storage()
qa_manager = None qa_manager = None
inspire_manager = None inspire_manager = None
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable: if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM") logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = {}
for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], # type: ignore
global_config["llm_providers"][key]["api_key"], # type: ignore
)
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager() embed_manager = EmbeddingManager()
@@ -120,7 +59,8 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{PG_NAMESPACE}-{pg_hash}" # 使用与EmbeddingStore中一致的命名空间格式
key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@@ -130,11 +70,11 @@ if bot_global_config.lpmm_knowledge.enable:
kg_manager, kg_manager,
) )
# 记忆激活(用于记忆库) # # 记忆激活(用于记忆库)
inspire_manager = MemoryActiveManager( # inspire_manager = MemoryActiveManager(
embed_manager, # embed_manager,
llm_client_list[global_config["embedding"]["provider"]], # llm_client_list[global_config["embedding"]["provider"]],
) # )
else: else:
logger.info("LPMM知识库已禁用跳过初始化") logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误 # 创建空的占位符对象,避免导入错误

View File

@@ -1,45 +0,0 @@
from openai import OpenAI
class LLMMessage:
def __init__(self, role, content):
self.role = role
self.content = content
def to_dict(self):
return {"role": self.role, "content": self.content}
class LLMClient:
"""LLM客户端对应一个API服务商"""
def __init__(self, url, api_key):
self.client = OpenAI(
base_url=url,
api_key=api_key,
)
def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果"""
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
else:
# 无单独的推理内容块
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
# 如果有推理内容,则分割推理内容和内容
if len(response) == 2:
reasoning_content = response[0]
content = response[1]
else:
reasoning_content = None
content = response[0]
return reasoning_content, content
def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding

View File

@@ -1,137 +0,0 @@
import os
import toml
import sys
# import argparse
from .global_logger import logger
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
# 无效实体
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
def _load_config(config, config_file_path):
"""读取TOML格式的配置文件"""
if not os.path.exists(config_file_path):
return
with open(config_file_path, "r", encoding="utf-8") as f:
file_config = toml.load(f)
# Check if all top-level keys from default config exist in the file config
for key in config.keys():
if key not in file_config:
logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
logger.critical("请通过template/lpmm_config_template.toml文件进行更新")
sys.exit(1)
if "llm_providers" in file_config:
for provider in file_config["llm_providers"]:
if provider["name"] not in config["llm_providers"]:
config["llm_providers"][provider["name"]] = {}
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
if "entity_extract" in file_config:
config["entity_extract"] = file_config["entity_extract"]
if "rdf_build" in file_config:
config["rdf_build"] = file_config["rdf_build"]
if "embedding" in file_config:
config["embedding"] = file_config["embedding"]
if "rag" in file_config:
config["rag"] = file_config["rag"]
if "qa" in file_config:
config["qa"] = file_config["qa"]
if "persistence" in file_config:
config["persistence"] = file_config["persistence"]
# print(config)
logger.info(f"从文件中读取配置: {config_file_path}")
global_config = dict(
{
"lpmm": {
"version": "0.1.0",
},
"llm_providers": {
"localhost": {
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sk-ospynxadyorf",
}
},
"entity_extract": {
"llm": {
"provider": "localhost",
"model": "Pro/deepseek-ai/DeepSeek-V3",
}
},
"rdf_build": {
"llm": {
"provider": "localhost",
"model": "Pro/deepseek-ai/DeepSeek-V3",
}
},
"embedding": {
"provider": "localhost",
"model": "Pro/BAAI/bge-m3",
"dimension": 1024,
},
"rag": {
"params": {
"synonym_search_top_k": 10,
"synonym_threshold": 0.75,
}
},
"qa": {
"params": {
"relation_search_top_k": 10,
"relation_threshold": 0.75,
"paragraph_search_top_k": 10,
"paragraph_node_weight": 0.05,
"ent_filter_top_k": 10,
"ppr_damping": 0.8,
"res_top_k": 10,
},
"llm": {
"provider": "localhost",
"model": "qa",
},
},
"persistence": {
"data_root_path": "data",
"raw_data_path": "data/raw.json",
"openie_data_path": "data/openie.json",
"embedding_data_dir": "data/embedding",
"rag_data_dir": "data/rag",
},
"info_extraction": {
"workers": 10,
},
}
)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
_load_config(global_config, config_path)

View File

@@ -1,3 +1,4 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
from .llm_client import LLMClient from .llm_client import LLMClient

View File

@@ -2,16 +2,14 @@ import time
from typing import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Optional
from .global_logger import logger from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
# from .llm_client import LLMClient
from .kg_manager import KGManager from .kg_manager import KGManager
# from .lpmmconfig import global_config # from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -21,17 +19,14 @@ class QAManager:
self, self,
embed_manager: EmbeddingManager, embed_manager: EmbeddingManager,
kg_manager: KGManager, kg_manager: KGManager,
): ):
self.embed_manager = embed_manager self.embed_manager = embed_manager
self.kg_manager = kg_manager self.kg_manager = kg_manager
# TODO: API-Adapter修改标记 self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
self.qa_model = LLMRequest(
model=global_config.model.lpmm_qa,
request_type="lpmm.qa"
)
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: async def process_query(
self, question: str
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
"""处理查询""" """处理查询"""
# 生成问题的Embedding # 生成问题的Embedding
@@ -49,66 +44,71 @@ class QAManager:
question_embedding, question_embedding,
global_config.lpmm_knowledge.qa_relation_search_top_k, global_config.lpmm_knowledge.qa_relation_search_top_k,
) )
if relation_search_res is not None: if relation_search_res is None:
# 过滤阈值 return None
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 过滤阈值
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
# 未找到相关关系 if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
logger.debug("未找到相关关系,跳过关系检索") # 未找到相关关系
relation_search_res = [] logger.debug("未找到相关关系,跳过关系检索")
relation_search_res = []
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
for res in relation_search_res: for res in relation_search_res:
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") rel_str = store_item.str
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
# TODO: 使用LLM过滤三元组结果 # TODO: 使用LLM过滤三元组结果
# logger.info(f"LLM过滤三元组用时{time.time() - part_start_time:.2f}s") # logger.info(f"LLM过滤三元组用时{time.time() - part_start_time:.2f}s")
# part_start_time = time.time() # part_start_time = time.time()
# 根据问题Embedding查询Paragraph Embedding库 # 根据问题Embedding查询Paragraph Embedding库
part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
question_embedding,
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
)
part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
if len(relation_search_res) != 0:
logger.info("找到相关关系将使用RAG进行检索")
# 使用KG检索
part_start_time = time.perf_counter() part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( result, ppr_node_weights = self.kg_manager.kg_search(
question_embedding, relation_search_res, paragraph_search_res, self.embed_manager
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
) )
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
if len(relation_search_res) != 0:
logger.info("找到相关关系将使用RAG进行检索")
# 使用KG检索
part_start_time = time.perf_counter()
result, ppr_node_weights = self.kg_manager.kg_search(
relation_search_res, paragraph_search_res, self.embed_manager
)
part_end_time = time.perf_counter()
logger.info(f"RAG检索用时{part_end_time - part_start_time:.5f}s")
else:
logger.info("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res
ppr_node_weights = None
# 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
else: else:
return None logger.info("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res
ppr_node_weights = None
async def get_knowledge(self, question: str) -> str: # 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[str]:
"""获取知识""" """获取知识"""
# 处理查询 # 处理查询
processed_result = await self.process_query(question) processed_result = await self.process_query(question)
if processed_result is not None: if processed_result is not None:
query_res = processed_result[0] query_res = processed_result[0]
# 检查查询结果是否为空
if not query_res:
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
return None
knowledge = [ knowledge = [
( (
self.embed_manager.paragraphs_embedding_store.store[res[0]].str, self.embed_manager.paragraphs_embedding_store.store[res[0]].str,

View File

@@ -1,48 +0,0 @@
import json
import os
from .global_logger import logger
from .lpmmconfig import global_config
from src.chat.knowledge.utils.hash import get_sha256
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件
读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns:
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
# 读取指定路径或默认路径的json文件
json_path = path if path else global_config["persistence"]["raw_data_path"]
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read())
else:
raise Exception(f"原始数据文件读取失败: {json_path}")
"""
import_json 内容示例:
import_json = ["The capital of China is Beijing. The capital of France is Paris.",]
"""
raw_data = []
sha256_list = []
sha256_set = set()
for item in import_json:
if not isinstance(item, str):
logger.warning("数据类型错误:{}".format(item))
continue
pg_hash = get_sha256(item)
if pg_hash in sha256_set:
logger.warning("重复数据:{}".format(item))
continue
sha256_set.add(pg_hash)
sha256_list.append(pg_hash)
raw_data.append(item)
logger.info("共读取到{}条数据".format(len(raw_data)))
return sha256_list, raw_data

View File

@@ -5,6 +5,10 @@ def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]: ) -> List[Tuple[Any, float, float]]:
"""动态TopK选择""" """动态TopK选择"""
# 检查输入列表是否为空
if not score:
return []
# 按照分数排序(降序) # 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True) sorted_score = sorted(score, key=lambda x: x[1], reverse=True)

View File

@@ -1,17 +0,0 @@
import networkx as nx
from matplotlib import pyplot as plt
def draw_graph_and_show(graph):
"""绘制图并显示画布大小1280*1280"""
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
nx.draw_networkx(
graph,
node_size=100,
width=0.5,
with_labels=True,
labels=nx.get_node_attributes(graph, "content"),
font_family="Sarasa Mono SC",
font_size=8,
)
fig.show()

File diff suppressed because it is too large Load Diff

View File

@@ -3,13 +3,16 @@ import time
import re import re
import json import json
import ast import ast
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
import traceback import traceback
from src.config.config import global_config from json_repair import repair_json
from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入 from src.common.database.database_model import Memory # Peewee Models导入
from src.config.config import model_config
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -35,8 +38,7 @@ class InstantMemory:
self.chat_id = chat_id self.chat_id = chat_id
self.last_view_time = time.time() self.last_view_time = time.time()
self.summary_model = LLMRequest( self.summary_model = LLMRequest(
model=global_config.model.memory, model_set=model_config.model_task_config.utils,
temperature=0.5,
request_type="memory.summary", request_type="memory.summary",
) )
@@ -48,14 +50,11 @@ class InstantMemory:
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) print(prompt)
print(response) print(response)
if "1" in response: return "1" in response
return True
else:
return False
except Exception as e: except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False return False
@@ -71,9 +70,9 @@ class InstantMemory:
}} }}
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) # print(prompt)
print(response) # print(response)
if not response: if not response:
return None return None
try: try:
@@ -142,7 +141,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容 请只输出json格式不要输出其他多余内容
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) print(prompt)
print(response) print(response)
if not response: if not response:
@@ -177,7 +176,7 @@ class InstantMemory:
for mem in query: for mem in query:
# 对每条记忆 # 对每条记忆
mem_keywords = mem.keywords or [] mem_keywords = mem.keywords or ""
parsed = ast.literal_eval(mem_keywords) parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list): if isinstance(parsed, list):
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()] mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
@@ -201,6 +200,7 @@ class InstantMemory:
return None return None
def _parse_time_range(self, time_str): def _parse_time_range(self, time_str):
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
""" """
支持解析如下格式: 支持解析如下格式:
- 具体日期时间YYYY-MM-DD HH:MM:SS - 具体日期时间YYYY-MM-DD HH:MM:SS
@@ -208,8 +208,6 @@ class InstantMemory:
- 相对时间今天昨天前天N天前N个月前 - 相对时间今天昨天前天N天前N个月前
- 空字符串:返回(None, None) - 空字符串:返回(None, None)
""" """
from datetime import datetime, timedelta
now = datetime.now() now = datetime.now()
if not time_str: if not time_str:
return 0, now return 0, now
@@ -239,14 +237,12 @@ class InstantMemory:
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) end = start + timedelta(days=1)
return start, end return start, end
m = re.match(r"(\d+)天前", time_str) if m := re.match(r"(\d+)天前", time_str):
if m:
days = int(m.group(1)) days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) end = start + timedelta(days=1)
return start, end return start, end
m = re.match(r"(\d+)个月前", time_str) if m := re.match(r"(\d+)个月前", time_str):
if m:
months = int(m.group(1)) months = int(m.group(1))
# 近似每月30天 # 近似每月30天
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -1,13 +1,17 @@
import json
from json_repair import repair_json
from typing import List, Tuple
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from datetime import datetime
from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.memory_system.Hippocampus import hippocampus_manager
from typing import List, Dict from src.chat.utils.utils import parse_keywords_string
import difflib from src.chat.utils.chat_message_builder import build_readable_messages
import json import random
from json_repair import repair_json
logger = get_logger("memory_activator") logger = get_logger("memory_activator")
@@ -38,20 +42,20 @@ def get_keywords_from_json(json_str) -> List:
def init_prompt(): def init_prompt():
# --- Group Chat Prompt --- # --- Group Chat Prompt ---
memory_activator_prompt = """ memory_activator_prompt = """
是一个记忆分析器,你需要根据以下信息来进行回忆 需要根据以下信息来挑选合适的记忆编号
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词 以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
聊天记录: 聊天记录:
{obs_info_text} {obs_info_text}
你想要回复的消息: 你想要回复的消息:
{target_message} {target_message}
历史关键词(请避免重复提取这些关键词) 记忆
{cached_keywords} {memory_info}
请输出一个json格式包含以下字段 请输出一个json格式包含以下字段
{{ {{
"keywords": ["关键词1", "关键词2", "关键词3",......] "memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
}} }}
不要输出其他多余内容只输出json格式就好 不要输出其他多余内容只输出json格式就好
""" """
@@ -61,83 +65,197 @@ def init_prompt():
class MemoryActivator: class MemoryActivator:
def __init__(self): def __init__(self):
# TODO: API-Adapter修改标记
self.key_words_model = LLMRequest( self.key_words_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small,
temperature=0.5,
request_type="memory.activator", request_type="memory.activator",
) )
# 用于记忆选择的 LLM 模型
self.memory_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.selection",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
""" """
激活记忆 激活记忆
""" """
# 如果记忆系统被禁用,直接返回空列表 # 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return [] return []
# 将缓存的关键词转换为字符串用于prompt keywords_list = set()
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
for msg in chat_history_prompt:
prompt = await global_prompt_manager.format_prompt( keywords = parse_keywords_string(msg.get("key_words", ""))
"memory_activator_prompt", if keywords:
obs_info_text=chat_history_prompt, if len(keywords_list) < 30:
target_message=target_message, # 最多容纳30个关键词
cached_keywords=cached_keywords_str, keywords_list.update(keywords)
) logger.debug(f"提取关键词: {keywords_list}")
else:
# logger.debug(f"prompt: {prompt}") break
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt) if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
keywords = list(get_keywords_from_json(response)) return []
# 更新关键词缓存 # 从海马体获取相关记忆
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
# 调用记忆系统获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic( related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3 valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
) )
logger.debug(f"当前记忆关键词: {self.cached_keywords} ") # logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}") logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
# 激活时所有已有记忆的duration+1达到3则移除 used_ids = set()
for m in self.running_memory[:]: candidate_memories = []
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
if related_memory: # 为每个记忆分配随机ID并过滤相关记忆
for topic, memory in related_memory: for memory in related_memory:
# 检查是否已存在相同topic或相似内容相似度>=0.7)的记忆 keyword, content = memory
exists = any( found = False
m["topic"] == topic or difflib.SequenceMatcher(None, m["content"], memory).ratio() >= 0.7 for kw in keywords_list:
for m in self.running_memory if kw in content:
) found = True
if not exists: break
self.running_memory.append(
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat(), "duration": 1} if found:
) # 随机分配一个不重复的2位数id
logger.debug(f"添加新记忆: {topic} - {memory}") while True:
random_id = "{:02d}".format(random.randint(0, 99))
if random_id not in used_ids:
used_ids.add(random_id)
break
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
# 限制同时加载的记忆条数最多保留最后3条 if not candidate_memories:
if len(self.running_memory) > 3: logger.info("没有找到相关的候选记忆")
self.running_memory = self.running_memory[-3:] return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
# 使用 LLM 选择合适的记忆
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
return selected_memories
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
Returns:
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history_prompt,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text,
target_message=target_message,
memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt,
temperature=0.3,
max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段
memory_ids_str = result.get("memory_ids", "")
# 解析逗号分隔的编号
if memory_ids_str:
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
selected_memory_ids = valid_memory_ids
else:
selected_memory_ids = []
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
for memory_id in selected_memory_ids:
if memory_id in memory_id_to_memory:
selected_memories.append(memory_id_to_memory[memory_id])
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
return self.running_memory
init_prompt() init_prompt()

View File

@@ -1,126 +0,0 @@
import numpy as np
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
"""
初始化记忆构建调度器
参数:
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
std_hours1 (float): 第一个分布的标准差(小时)
weight1 (float): 第一个分布的权重
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
std_hours2 (float): 第二个分布的标准差(小时)
weight2 (float): 第二个分布的权重
total_samples (int): 要生成的总时间点数量
"""
# 验证参数
if total_samples <= 0:
raise ValueError("total_samples 必须大于0")
if weight1 < 0 or weight2 < 0:
raise ValueError("权重必须为非负数")
if std_hours1 < 0 or std_hours2 < 0:
raise ValueError("标准差必须为非负数")
# 归一化权重
total_weight = weight1 + weight2
if total_weight == 0:
raise ValueError("权重总和不能为0")
self.weight1 = weight1 / total_weight
self.weight2 = weight2 / total_weight
self.n_hours1 = n_hours1
self.std_hours1 = std_hours1
self.n_hours2 = n_hours2
self.std_hours2 = std_hours2
self.total_samples = total_samples
self.base_time = datetime.now()
def generate_time_samples(self):
"""生成混合分布的时间采样点"""
# 根据权重计算每个分布的样本数
samples1 = max(1, int(self.total_samples * self.weight1))
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
# 生成两个正态分布的小时偏移
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
# 合并两个分布的偏移
hours_offset = np.concatenate([hours_offset1, hours_offset2])
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
# 按时间排序(从最早到最近)
return sorted(timestamps)
def get_timestamp_array(self):
"""返回时间戳数组"""
timestamps = self.generate_time_samples()
return [int(t.timestamp()) for t in timestamps]
# def print_time_samples(timestamps, show_distribution=True):
# """打印时间样本和分布信息"""
# print(f"\n生成的{len(timestamps)}个时间点分布:")
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
# print("-" * 50)
# now = datetime.now()
# time_diffs = []
# for i, timestamp in enumerate(timestamps, 1):
# hours_diff = (now - timestamp).total_seconds() / 3600
# time_diffs.append(hours_diff)
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# # 打印统计信息
# print("\n统计信息")
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
# print(f"标准差:{np.std(time_diffs):.2f}小时")
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
# if show_distribution:
# # 计算时间分布的直方图
# hist, bins = np.histogram(time_diffs, bins=40)
# print("\n时间分布每个*代表一个时间点):")
# for i in range(len(hist)):
# if hist[i] > 0:
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# # 使用示例
# if __name__ == "__main__":
# # 创建一个双峰分布的记忆调度器
# scheduler = MemoryBuildScheduler(
# n_hours1=12, # 第一个分布均值12小时前
# std_hours1=8, # 第一个分布标准差
# weight1=0.7, # 第一个分布权重 70%
# n_hours2=36, # 第二个分布均值36小时前
# std_hours2=24, # 第二个分布标准差
# weight2=0.3, # 第二个分布权重 30%
# total_samples=50, # 总共生成50个时间点
# )
# # 生成时间分布
# timestamps = scheduler.generate_time_samples()
# # 打印结果,包含分布可视化
# print_time_samples(timestamps, show_distribution=True)
# # 打印时间戳数组
# timestamp_array = scheduler.get_timestamp_array()
# print("\n时间戳数组Unix时间戳")
# print("[", end="")
# for i, ts in enumerate(timestamp_array):
# if i > 0:
# print(", ", end="")
# print(ts, end="")
# print("]")

View File

@@ -16,6 +16,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.person_info.person_info import Person
# 定义日志配置 # 定义日志配置
@@ -146,7 +147,10 @@ class ChatBot:
async def hanle_notice_message(self, message: MessageRecv): async def hanle_notice_message(self, message: MessageRecv):
if message.message_info.message_id == "notice": if message.message_info.message_id == "notice":
logger.info("收到notice消息暂时不支持处理") message.is_notify = True
logger.info("notice消息")
# print(message)
return True return True
async def do_s4u(self, message_data: Dict[str, Any]): async def do_s4u(self, message_data: Dict[str, Any]):
@@ -165,6 +169,8 @@ class ChatBot:
# 处理消息内容 # 处理消息内容
await message.process() await message.process()
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
await self.s4u_message_processor.process_message(message) await self.s4u_message_processor.process_message(message)
@@ -207,7 +213,8 @@ class ChatBot:
message = MessageRecv(message_data) message = MessageRecv(message_data)
if await self.hanle_notice_message(message): if await self.hanle_notice_message(message):
return # return
pass
group_info = message.message_info.group_info group_info = message.message_info.group_info
user_info = message.message_info.user_info user_info = message.message_info.user_info

View File

@@ -217,7 +217,8 @@ class ChatManager:
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info if user_info and user_info.platform and user_info.user_id:
stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用 from .message import MessageRecv # 延迟导入,避免循环引用

View File

@@ -4,7 +4,7 @@ import urllib3
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from rich.traceback import install from rich.traceback import install
from typing import Optional, Any from typing import Optional, Any, List
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -29,7 +29,6 @@ class Message(MessageBase):
chat_stream: "ChatStream" = None # type: ignore chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None reply: Optional["Message"] = None
processed_plain_text: str = "" processed_plain_text: str = ""
memorized_times: int = 0
def __init__( def __init__(
self, self,
@@ -109,12 +108,16 @@ class MessageRecv(Message):
self.has_picid = False self.has_picid = False
self.is_voice = False self.is_voice = False
self.is_mentioned = None self.is_mentioned = None
self.is_notify = False
self.is_command = False self.is_command = False
self.priority_mode = "interest" self.priority_mode = "interest"
self.priority_info = None self.priority_info = None
self.interest_value: float = None # type: ignore self.interest_value: float = None # type: ignore
self.key_words = []
self.key_words_lite = []
def update_chat_stream(self, chat_stream: "ChatStream"): def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream self.chat_stream = chat_stream
@@ -203,7 +206,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False self.is_superchat = False
self.gift_info = None self.gift_info = None
self.gift_name = None self.gift_name = None
self.gift_count = None self.gift_count: Optional[str] = None
self.superchat_info = None self.superchat_info = None
self.superchat_price = None self.superchat_price = None
self.superchat_message_text = None self.superchat_message_text = None
@@ -369,7 +372,7 @@ class MessageProcessBase(Message):
return "[图片,网卡了加载不出来]" return "[图片,网卡了加载不出来]"
elif seg.type == "emoji": elif seg.type == "emoji":
if isinstance(seg.data, str): if isinstance(seg.data, str):
return await get_image_manager().get_emoji_description(seg.data) return await get_image_manager().get_emoji_tag(seg.data)
return "[表情,网卡了加载不出来]" return "[表情,网卡了加载不出来]"
elif seg.type == "voice": elif seg.type == "voice":
if isinstance(seg.data, str): if isinstance(seg.data, str):
@@ -399,34 +402,6 @@ class MessageProcessBase(Message):
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n" return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply,
thinking_start_time=thinking_start_time,
timestamp=timestamp,
)
# 思考状态特有属性
self.interrupt = False
@dataclass @dataclass
class MessageSending(MessageProcessBase): class MessageSending(MessageProcessBase):
"""发送状态的消息类""" """发送状态的消息类"""
@@ -444,7 +419,8 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False, is_emoji: bool = False,
thinking_start_time: float = 0, thinking_start_time: float = 0,
apply_set_reply_logic: bool = False, apply_set_reply_logic: bool = False,
reply_to: str = None, # type: ignore reply_to: Optional[str] = None,
selected_expressions:List[int] = None,
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
@@ -469,6 +445,8 @@ class MessageSending(MessageProcessBase):
self.display_message = display_message self.display_message = display_message
self.interest_value = 0.0 self.interest_value = 0.0
self.selected_expressions = selected_expressions
def build_reply(self): def build_reply(self):
"""设置回复消息""" """设置回复消息"""
@@ -487,26 +465,6 @@ class MessageSending(MessageProcessBase):
if self.message_segment: if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
# @classmethod
# def from_thinking(
# cls,
# thinking: MessageThinking,
# message_segment: Seg,
# is_head: bool = False,
# is_emoji: bool = False,
# ) -> "MessageSending":
# """从思考状态消息创建发送状态消息"""
# return cls(
# message_id=thinking.message_info.message_id,
# chat_stream=thinking.chat_stream,
# message_segment=message_segment,
# bot_user_info=thinking.message_info.user_info,
# reply=thinking.reply,
# is_head=is_head,
# is_emoji=is_emoji,
# sender_info=None,
# )
def to_dict(self): def to_dict(self):
ret = super().to_dict() ret = super().to_dict()
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict() ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()

View File

@@ -1,4 +1,5 @@
import re import re
import json
import traceback import traceback
from typing import Union from typing import Union
@@ -11,6 +12,23 @@ logger = get_logger("message_storage")
class MessageStorage: class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return json.loads(keywords_str)
except (json.JSONDecodeError, TypeError):
return []
@staticmethod @staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
@@ -43,7 +61,11 @@ class MessageStorage:
priority_info = {} priority_info = {}
is_emoji = False is_emoji = False
is_picid = False is_picid = False
is_notify = False
is_command = False is_command = False
key_words = ""
key_words_lite = ""
selected_expressions = message.selected_expressions
else: else:
filtered_display_message = "" filtered_display_message = ""
interest_value = message.interest_value interest_value = message.interest_value
@@ -53,8 +75,13 @@ class MessageStorage:
priority_info = message.priority_info priority_info = message.priority_info
is_emoji = message.is_emoji is_emoji = message.is_emoji
is_picid = message.is_picid is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
selected_expressions = ""
chat_info_dict = chat_stream.to_dict() chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore user_info_dict = message.message_info.user_info.to_dict() # type: ignore
@@ -92,13 +119,16 @@ class MessageStorage:
# Text content # Text content
processed_plain_text=filtered_processed_plain_text, processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message, display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value, interest_value=interest_value,
priority_mode=priority_mode, priority_mode=priority_mode,
priority_info=priority_info, priority_info=priority_info,
is_emoji=is_emoji, is_emoji=is_emoji,
is_picid=is_picid, is_picid=is_picid,
is_notify=is_notify,
is_command=is_command, is_command=is_command,
key_words=key_words,
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
) )
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -1,9 +1,10 @@
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction
logger = get_logger("action_manager") logger = get_logger("action_manager")

View File

@@ -5,7 +5,7 @@ import time
from typing import List, Any, Dict, TYPE_CHECKING, Tuple from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
@@ -36,10 +36,7 @@ class ActionModifier:
self.action_manager = action_manager self.action_manager = action_manager
# 用于LLM判定的小模型 # 用于LLM判定的小模型
self.llm_judge = LLMRequest( self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
model=global_config.model.utils_small,
request_type="action.judge",
)
# 缓存相关属性 # 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果 self._llm_judge_cache = {} # 缓存LLM判定结果
@@ -130,8 +127,10 @@ class ActionModifier:
if all_removals: if all_removals:
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals]) removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else ""
logger.info( logger.info(
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}" f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
) )
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
@@ -438,4 +437,4 @@ class ActionModifier:
return True return True
else: else:
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}") logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
return False return False

View File

@@ -1,13 +1,13 @@
import json import json
import time import time
import traceback import traceback
from typing import Dict, Any, Optional, Tuple from typing import Dict, Any, Optional, Tuple, List
from rich.traceback import install from rich.traceback import install
from datetime import datetime from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
@@ -36,17 +36,28 @@ def init_prompt():
{chat_context_description},以下是具体的聊天内容 {chat_context_description},以下是具体的聊天内容
{chat_content_block} {chat_content_block}
{moderation_prompt} {moderation_prompt}
现在请你根据{by_what}选择合适的action和触发action的消息: 现在请你根据聊天内容和用户的最新消息选择合适的action和触发action的消息:
{actions_before_now_block} {actions_before_now_block}
{no_action_block} {no_action_block}
动作reply
动作描述:参与聊天回复,发送文本进行表达
- 你想要闲聊或者随便附和
- 有人提到了你,但是你还没有回应
- {mentioned_bonus}
- 如果你刚刚进行了回复,不要对同一个话题重复回应
{{
"action": "reply",
"target_message_id":"想要回复的消息id",
"reason":"回复的原因"
}}
{action_options_text} {action_options_text}
你必须从上面列出的可用action中选择一个并说明触发action的消息id不是消息原文和选择该action的原因。 你必须从上面列出的可用action中选择一个并说明触发action的消息id不是消息原文和选择该action的原因。消息id格式:m+数字
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容: 请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
""", """,
@@ -59,7 +70,8 @@ def init_prompt():
动作描述:{action_description} 动作描述:{action_description}
{action_require} {action_require}
{{ {{
"action": "{action_name}",{action_parameters}{target_prompt} "action": "{action_name}",{action_parameters},
"target_message_id":"触发action的消息id",
"reason":"触发action的原因" "reason":"触发action的原因"
}} }}
""", """,
@@ -74,14 +86,15 @@ class ActionPlanner:
self.action_manager = action_manager self.action_manager = action_manager
# LLM规划器配置 # LLM规划器配置
self.planner_llm = LLMRequest( self.planner_llm = LLMRequest(
model=global_config.model.planner, model_set=model_config.model_task_config.planner, request_type="planner"
request_type="planner", # 用于动作规划 ) # 用于动作规划
)
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
# 添加重试计数器
self.plan_retry_count = 0
self.max_plan_retries = 3
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
# sourcery skip: use-next
""" """
根据message_id从message_id_list中查找对应的原始消息 根据message_id从message_id_list中查找对应的原始消息
@@ -97,37 +110,41 @@ class ActionPlanner:
return item.get("message") return item.get("message")
return None return None
def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
"""
获取消息列表中的最新消息
Args:
message_id_list: 消息ID列表格式为[{'id': str, 'message': dict}, ...]
Returns:
最新的消息字典如果列表为空则返回None
"""
if not message_id_list:
return None
# 假设消息列表是按时间顺序排列的,最后一个是最新的
return message_id_list[-1].get("message")
async def plan( async def plan(
self, mode: ChatMode = ChatMode.FOCUS self,
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: mode: ChatMode = ChatMode.FOCUS,
loop_start_time:float = 0.0,
available_actions: Optional[Dict[str, ActionInfo]] = None,
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """
action = "no_reply" # 默认动作 action = "no_action" # 默认动作
reasoning = "规划器初始化默认" reasoning = "规划器初始化默认"
action_data = {} action_data = {}
current_available_actions: Dict[str, ActionInfo] = {} current_available_actions: Dict[str, ActionInfo] = {}
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量 target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
prompt: str = "" prompt: str = ""
message_id_list: list = []
try: try:
is_group_chat = True is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[action_name]
else:
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- # --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt, message_id_list = await self.build_planner_prompt( prompt, message_id_list = await self.build_planner_prompt(
@@ -135,12 +152,13 @@ class ActionPlanner:
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息 chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
current_available_actions=current_available_actions, # <-- Pass determined actions current_available_actions=current_available_actions, # <-- Pass determined actions
mode=mode, mode=mode,
refresh_time=True,
) )
# --- 调用 LLM (普通文本生成) --- # --- 调用 LLM (普通文本生成) ---
llm_content = None llm_content = None
try: try:
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
@@ -156,7 +174,7 @@ class ActionPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
reasoning = f"LLM 请求失败,模型出现问题: {req_e}" reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
action = "no_reply" action = "no_action"
if llm_content: if llm_content:
try: try:
@@ -173,68 +191,94 @@ class ActionPlanner:
logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}") logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}")
parsed_json = {} parsed_json = {}
action = parsed_json.get("action", "no_reply") action = parsed_json.get("action", "no_action")
reasoning = parsed_json.get("reasoning", "未提供原因") reasoning = parsed_json.get("reason", "未提供原因")
# 将所有其他属性添加到action_data # 将所有其他属性添加到action_data
for key, value in parsed_json.items(): for key, value in parsed_json.items():
if key not in ["action", "reasoning"]: if key not in ["action", "reasoning"]:
action_data[key] = value action_data[key] = value
# 在FOCUS模式下非no_reply动作需要target_message_id # 非no_action动作需要target_message_id
if mode == ChatMode.FOCUS and action != "no_reply": if action != "no_action":
if target_message_id := parsed_json.get("target_message_id"): if target_message_id := parsed_json.get("target_message_id"):
# 根据target_message_id查找原始消息 # 根据target_message_id查找原始消息
target_message = self.find_message_by_id(target_message_id, message_id_list) target_message = self.find_message_by_id(target_message_id, message_id_list)
# 如果获取的target_message为None输出warning并重新plan
if target_message is None:
self.plan_retry_count += 1
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
# 如果连续三次plan均为None输出error并选取最新消息
if self.plan_retry_count >= self.max_plan_retries:
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message")
target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 # 重置计数器
else:
# 递归重新plan
return await self.plan(mode, loop_start_time, available_actions)
else:
# 成功获取到target_message重置计数器
self.plan_retry_count = 0
else: else:
logger.warning(f"{self.log_prefix}FOCUS模式下动作'{action}'缺少target_message_id") logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
if action == "no_action": if action != "no_action" and action != "reply" and action not in current_available_actions:
reasoning = "normal决定不使用额外动作"
elif action != "no_reply" and action != "reply" and action not in current_available_actions:
logger.warning( logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'" f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
) )
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}" reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
action = "no_reply" action = "no_action"
except Exception as json_e: except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'") logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
traceback.print_exc() traceback.print_exc()
reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_reply'." reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_action'."
action = "no_reply" action = "no_action"
except Exception as outer_e: except Exception as outer_e:
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}") logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_action: {outer_e}")
traceback.print_exc() traceback.print_exc()
action = "no_reply" action = "no_action"
reasoning = f"Planner 内部处理错误: {outer_e}" reasoning = f"Planner 内部处理错误: {outer_e}"
is_parallel = False is_parallel = False
if mode == ChatMode.NORMAL and action in current_available_actions: if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action is_parallel = current_available_actions[action].parallel_action
action_result = {
action_data["loop_start_time"] = loop_start_time
actions = []
# 1. 添加Planner取得的动作
actions.append({
"action_type": action, "action_type": action,
"action_data": action_data,
"reasoning": reasoning, "reasoning": reasoning,
"timestamp": time.time(), "action_data": action_data,
"is_parallel": is_parallel, "action_message": target_message,
} "available_actions": available_actions # 添加这个字段
})
return (
{ if action != "reply" and is_parallel:
"action_result": action_result, actions.append({
"action_prompt": prompt, "action_type": "reply",
}, "action_message": target_message,
target_message, "available_actions": available_actions
) })
return actions,target_message
async def build_planner_prompt( async def build_planner_prompt(
self, self,
is_group_chat: bool, # Now passed as argument is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument
current_available_actions: Dict[str, ActionInfo], current_available_actions: Dict[str, ActionInfo],
refresh_time :bool = False,
mode: ChatMode = ChatMode.FOCUS, mode: ChatMode = ChatMode.FOCUS,
) -> tuple[str, list]: # sourcery skip: use-join ) -> tuple[str, list]: # sourcery skip: use-join
"""构建 Planner LLM 的提示词 (获取模板并填充数据)""" """构建 Planner LLM 的提示词 (获取模板并填充数据)"""
@@ -265,43 +309,36 @@ class ActionPlanner:
) )
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}" actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
if refresh_time:
self.last_obs_time_mark = time.time() self.last_obs_time_mark = time.time()
mentioned_bonus = ""
if global_config.chat.mentioned_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你"
if global_config.chat.at_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你或者at你"
if mode == ChatMode.FOCUS: if mode == ChatMode.FOCUS:
mentioned_bonus = "" no_action_block = """
if global_config.chat.mentioned_bot_inevitable_reply: 动作no_action
mentioned_bonus = "\n- 有人提到你" 动作描述:不进行动作,等待合适的时机
if global_config.chat.at_bot_inevitable_reply: - 当你刚刚发送了消息没有人回复时选择no_action
mentioned_bonus = "\n- 有人提到你或者at你" - 如果有别的动作非回复满足条件可以不用no_action
- 当你一次发送了太多消息为了避免打扰聊天节奏选择no_action
by_what = "聊天内容" {
target_prompt = '\n "target_message_id":"触发action的消息id"' "action": "no_action",
no_action_block = f"""重要说明: "reason":"不动作的原因"
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机 }
- 当你刚刚发送了消息没有人回复时选择no_reply
- 当你一次发送了太多消息为了避免打扰聊天节奏选择no_reply
动作reply
动作描述:参与聊天回复,发送文本进行表达
- 你想要闲聊或者随便附和{mentioned_bonus}
- 如果你刚刚进行了回复,不要对同一个话题重复回应
{{
"action": "reply",
"target_message_id":"触发action的消息id",
"reason":"回复的原因"
}}
""" """
else: else:
by_what = "聊天内容和用户的最新消息"
target_prompt = ""
no_action_block = """重要说明: no_action_block = """重要说明:
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作 - 'reply' 表示只进行普通聊天回复,不执行任何额外动作
- 其他action表示在普通回复的基础上执行相应的额外动作""" - 其他action表示在普通回复的基础上执行相应的额外动作
"""
chat_context_description = "你现在正在一个群聊中" chat_context_description = "你现在正在一个群聊中"
chat_target_name = None # Only relevant for private chat_target_name = None
if not is_group_chat and chat_target_info: if not is_group_chat and chat_target_info:
chat_target_name = ( chat_target_name = (
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方" chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
@@ -330,7 +367,6 @@ class ActionPlanner:
action_description=using_actions_info.description, action_description=using_actions_info.description,
action_parameters=param_text, action_parameters=param_text,
action_require=require_text, action_require=require_text,
target_prompt=target_prompt,
) )
action_options_block += using_action_prompt action_options_block += using_action_prompt
@@ -350,11 +386,11 @@ class ActionPlanner:
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format( prompt = planner_prompt_template.format(
time_block=time_block, time_block=time_block,
by_what=by_what,
chat_context_description=chat_context_description, chat_context_description=chat_context_description,
chat_content_block=chat_content_block, chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block, actions_before_now_block=actions_before_now_block,
no_action_block=no_action_block, no_action_block=no_action_block,
mentioned_bonus=mentioned_bonus,
action_options_text=action_options_block, action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
identity_block=identity_block, identity_block=identity_block,
@@ -365,5 +401,28 @@ class ActionPlanner:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return "构建 Planner Prompt 时出错", [] return "构建 Planner Prompt 时出错", []
def get_necessary_info(self) -> Tuple[bool, Optional[dict], Dict[str, ActionInfo]]:
"""
获取 Planner 需要的必要信息
"""
is_group_chat = True
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
current_available_actions_dict = self.action_manager.get_using_actions()
# 获取完整的动作信息
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[action_name]
else:
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
return is_group_chat, chat_target_info, current_available_actions
init_prompt() init_prompt()

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
from typing import Dict, Any, Optional, List from typing import Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
@@ -15,7 +15,6 @@ class ReplyerManager:
self, self,
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer]: ) -> Optional[DefaultReplyer]:
""" """
@@ -49,7 +48,6 @@ class ReplyerManager:
# model_configs 只在此时(初始化时)生效 # model_configs 只在此时(初始化时)生效
replyer = DefaultReplyer( replyer = DefaultReplyer(
chat_stream=target_stream, chat_stream=target_stream,
model_configs=model_configs, # 可以是None此时使用默认模型
request_type=request_type, request_type=request_type,
) )
self._repliers[stream_id] = replyer self._repliers[stream_id] = replyer

View File

@@ -9,7 +9,7 @@ from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images from src.common.database.database_model import Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import Person,get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
install(extra_lines=3) install(extra_lines=3)
@@ -35,14 +35,12 @@ def replace_user_references_sync(
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
if name_resolver is None: if name_resolver is None:
person_info_manager = get_person_info_manager()
def default_resolver(platform: str, user_id: str) -> str: def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)" return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id) person = Person(platform=platform, user_id=user_id)
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore return person.person_name or user_id # type: ignore
name_resolver = default_resolver name_resolver = default_resolver
@@ -110,14 +108,12 @@ async def replace_user_references_async(
str: 处理后的内容字符串 str: 处理后的内容字符串
""" """
if name_resolver is None: if name_resolver is None:
person_info_manager = get_person_info_manager()
async def default_resolver(platform: str, user_id: str) -> str: async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己 # 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)" return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id) person = Person(platform=platform, user_id=user_id)
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore return person.person_name or user_id # type: ignore
name_resolver = default_resolver name_resolver = default_resolver
@@ -506,14 +502,13 @@ def _build_readable_messages_internal(
if not all([platform, user_id, timestamp is not None]): if not all([platform, user_id, timestamp is not None]):
continue continue
person_id = PersonInfoManager.get_person_id(platform, user_id) person = Person(platform=platform, user_id=user_id)
person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称 # 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)" person_name = f"{global_config.bot.nickname}(你)"
else: else:
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore person_name = person.person_name or user_id # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name: if not person_name:
@@ -740,7 +735,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions: for action in actions:
action_time = action.get("time", current_time) action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作") action_name = action.get("action_name", "未知动作")
if action_name in ["no_action", "no_reply"]: if action_name in ["no_action", "no_action"]:
continue continue
action_prompt_display = action.get("action_prompt_display", "无具体内容") action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -1009,7 +1004,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# print("SELF11111111111111") # print("SELF11111111111111")
return "SELF" return "SELF"
try: try:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = get_person_id(platform, user_id)
except Exception as _e: except Exception as _e:
person_id = None person_id = None
if not person_id: if not person_id:
@@ -1098,7 +1093,11 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account: if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue continue
if person_id := PersonInfoManager.get_person_id(platform, user_id): # 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
if person_id := get_person_id(platform, user_id):
person_ids_set.add(person_id) person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回 return list(person_ids_set) # 将集合转换为列表返回

View File

@@ -1,223 +0,0 @@
import ast
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
# 获取logger
logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
"""
安全地解析JSON字符串出错时返回默认值
现在尝试处理单引号和标准JSON
参数:
json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值
返回:
解析后的Python对象或在解析失败时返回default_value
"""
if not json_str or not isinstance(json_str, str):
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
return default_value
try:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
if isinstance(result, dict):
return result
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
参数:
tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值
返回:
解析后的参数字典或在解析失败时返回default_value
"""
default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}")
return default_result
try:
# 提取function参数
function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
if arguments_str := function_data.get("arguments", "{}"):
# 解析JSON
return safe_json_loads(arguments_str, default_result)
else:
return default_result
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
安全地将Python对象序列化为JSON字符串
参数:
obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON
返回:
序列化后的JSON字符串或在序列化失败时返回default_value
"""
try:
indent = 2 if pretty else None
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
except TypeError as e:
logger.error(f"JSON序列化失败(类型错误): {e}")
return default_value
except Exception as e:
logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
"""
标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数:
response: 原始LLM响应
log_prefix: 日志前缀
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
# 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表
if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response)
# 确保是列表类型
if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在)
if len(response) == 3:
content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls)
response[2] = tool_calls
return True, response, ""
def process_llm_tool_calls(
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并验证LLM响应中的工具调用列表
参数:
tool_calls: 从LLM响应中直接获取的工具调用列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 验证后的工具调用列表, 错误消息)
"""
# 如果列表为空,表示没有工具调用,这不是错误
if not tool_calls:
return True, [], "工具调用列表为空"
# 验证每个工具调用的格式
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
continue
# 检查基本结构
if tool_call.get("type") != "function":
logger.warning(
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
)
continue
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
continue
func_details = tool_call["function"]
if "name" not in func_details or not isinstance(func_details.get("name"), str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
continue
# 验证参数 'arguments'
args_value = func_details.get("arguments")
# 1. 检查 arguments 是否存在且是字符串
if args_value is None or not isinstance(args_value, str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
continue
# 2. 尝试安全地解析 arguments 字符串
parsed_args = safe_json_loads(args_value, None)
# 3. 检查解析结果是否为字典
if parsed_args is None or not isinstance(parsed_args, dict):
logger.warning(
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
)
continue
# 如果检查通过,将原始的 tool_call 加入有效列表
valid_tool_calls.append(tool_call)
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
return False, [], "所有工具调用格式均无效"
return True, valid_tool_calls, ""

View File

@@ -36,6 +36,18 @@ COST_BY_TYPE = "costs_by_type"
COST_BY_USER = "costs_by_user" COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model" COST_BY_MODEL = "costs_by_model"
COST_BY_MODULE = "costs_by_module" COST_BY_MODULE = "costs_by_module"
TIME_COST_BY_TYPE = "time_costs_by_type"
TIME_COST_BY_USER = "time_costs_by_user"
TIME_COST_BY_MODEL = "time_costs_by_model"
TIME_COST_BY_MODULE = "time_costs_by_module"
AVG_TIME_COST_BY_TYPE = "avg_time_costs_by_type"
AVG_TIME_COST_BY_USER = "avg_time_costs_by_user"
AVG_TIME_COST_BY_MODEL = "avg_time_costs_by_model"
AVG_TIME_COST_BY_MODULE = "avg_time_costs_by_module"
STD_TIME_COST_BY_TYPE = "std_time_costs_by_type"
STD_TIME_COST_BY_USER = "std_time_costs_by_user"
STD_TIME_COST_BY_MODEL = "std_time_costs_by_model"
STD_TIME_COST_BY_MODULE = "std_time_costs_by_module"
ONLINE_TIME = "online_time" ONLINE_TIME = "online_time"
TOTAL_MSG_CNT = "total_messages" TOTAL_MSG_CNT = "total_messages"
MSG_CNT_BY_CHAT = "messages_by_chat" MSG_CNT_BY_CHAT = "messages_by_chat"
@@ -293,6 +305,18 @@ class StatisticOutputTask(AsyncTask):
COST_BY_USER: defaultdict(float), COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float), COST_BY_MODEL: defaultdict(float),
COST_BY_MODULE: defaultdict(float), COST_BY_MODULE: defaultdict(float),
TIME_COST_BY_TYPE: defaultdict(list),
TIME_COST_BY_USER: defaultdict(list),
TIME_COST_BY_MODEL: defaultdict(list),
TIME_COST_BY_MODULE: defaultdict(list),
AVG_TIME_COST_BY_TYPE: defaultdict(float),
AVG_TIME_COST_BY_USER: defaultdict(float),
AVG_TIME_COST_BY_MODEL: defaultdict(float),
AVG_TIME_COST_BY_MODULE: defaultdict(float),
STD_TIME_COST_BY_TYPE: defaultdict(float),
STD_TIME_COST_BY_USER: defaultdict(float),
STD_TIME_COST_BY_MODEL: defaultdict(float),
STD_TIME_COST_BY_MODULE: defaultdict(float),
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
@@ -344,7 +368,41 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost stats[period_key][COST_BY_MODEL][model_name] += cost
stats[period_key][COST_BY_MODULE][module_name] += cost stats[period_key][COST_BY_MODULE][module_name] += cost
# 收集time_cost数据
time_cost = record.time_cost or 0.0
if time_cost > 0: # 只记录有效的time_cost
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
break break
# 计算平均耗时和标准差
for period_key in stats:
for category in [REQ_CNT_BY_TYPE, REQ_CNT_BY_USER, REQ_CNT_BY_MODEL, REQ_CNT_BY_MODULE]:
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
for item_name in stats[period_key][category]:
time_costs = stats[period_key][time_cost_key].get(item_name, [])
if time_costs:
# 计算平均耗时
avg_time_cost = sum(time_costs) / len(time_costs)
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
# 计算标准差
if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else:
stats[period_key][std_key][item_name] = 0.0
else:
stats[period_key][avg_key][item_name] = 0.0
stats[period_key][std_key][item_name] = 0.0
return stats return stats
@staticmethod @staticmethod
@@ -566,11 +624,11 @@ class StatisticOutputTask(AsyncTask):
""" """
if stats[TOTAL_REQ_CNT] <= 0: if stats[TOTAL_REQ_CNT] <= 0:
return "" return ""
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥" data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥ {:>10} {:>10}"
output = [ output = [
"按模型分类统计:", "按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费", " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒)",
] ]
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()): for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
@@ -578,7 +636,9 @@ class StatisticOutputTask(AsyncTask):
out_tokens = stats[OUT_TOK_BY_MODEL][model_name] out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
tokens = stats[TOTAL_TOK_BY_MODEL][model_name] tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
cost = stats[COST_BY_MODEL][model_name] cost = stats[COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost)) avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
output.append("") output.append("")
return "\n".join(output) return "\n".join(output)
@@ -663,6 +723,8 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>" f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>" f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>" f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>"
f"<td>{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.3f} 秒</td>"
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.3f} 秒</td>"
f"</tr>" f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
] ]
@@ -677,6 +739,8 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>" f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>" f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>" f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>"
f"<td>{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.3f} 秒</td>"
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.3f} 秒</td>"
f"</tr>" f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
] ]
@@ -691,6 +755,8 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[OUT_TOK_BY_MODULE][module_name]}</td>" f"<td>{stat_data[OUT_TOK_BY_MODULE][module_name]}</td>"
f"<td>{stat_data[TOTAL_TOK_BY_MODULE][module_name]}</td>" f"<td>{stat_data[TOTAL_TOK_BY_MODULE][module_name]}</td>"
f"<td>{stat_data[COST_BY_MODULE][module_name]:.4f} ¥</td>" f"<td>{stat_data[COST_BY_MODULE][module_name]:.4f} ¥</td>"
f"<td>{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.3f} 秒</td>"
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.3f} 秒</td>"
f"</tr>" f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items()) for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
] ]
@@ -717,7 +783,7 @@ class StatisticOutputTask(AsyncTask):
<h2>按模型分类统计</h2> <h2>按模型分类统计</h2>
<table> <table>
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr></thead> <thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
<tbody> <tbody>
{model_rows} {model_rows}
</tbody> </tbody>
@@ -726,7 +792,7 @@ class StatisticOutputTask(AsyncTask):
<h2>按模块分类统计</h2> <h2>按模块分类统计</h2>
<table> <table>
<thead> <thead>
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr> <tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
</thead> </thead>
<tbody> <tbody>
{module_rows} {module_rows}
@@ -736,7 +802,7 @@ class StatisticOutputTask(AsyncTask):
<h2>按请求类型分类统计</h2> <h2>按请求类型分类统计</h2>
<table> <table>
<thead> <thead>
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr> <tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
</thead> </thead>
<tbody> <tbody>
{type_rows} {type_rows}

View File

@@ -11,11 +11,11 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import Person
from .typo_generator import ChineseTypoGenerator from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils") logger = get_logger("chat_utils")
@@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
return is_mentioned, reply_probability return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding"): async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
# TODO: API-Adapter修改标记 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try: try:
embedding = await llm.get_embedding(text) embedding, _ = await llm.get_embedding(text)
except Exception as e: except Exception as e:
logger.error(f"获取embedding失败: {str(e)}") logger.error(f"获取embedding失败: {str(e)}")
embedding = None embedding = None
@@ -641,12 +639,16 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
# Try to fetch person info # Try to fetch person info
try: try:
# Assume get_person_id is sync (as per original code), keep using to_thread # Assume get_person_id is sync (as per original code), keep using to_thread
person_id = PersonInfoManager.get_person_id(platform, user_id) person = Person(platform=platform, user_id=user_id)
if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None
return False, None
person_id = person.person_id
person_name = None person_name = None
if person_id: if person_id:
# get_value is async, so await it directly # get_value is async, so await it directly
person_info_manager = get_person_info_manager() person_name = person.person_name
person_name = person_info_manager.get_value_sync(person_id, "person_name")
target_info["person_id"] = person_id target_info["person_id"] = person_id
target_info["person_name"] = person_name target_info["person_name"] = person_name
@@ -767,3 +769,68 @@ def assign_message_ids_flexible(
# # 增强版本 - 使用时间戳 # # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True) # result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}] # # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
"""
统一的关键词解析函数,支持多种格式的关键词字符串解析
支持的格式:
1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]'
2. 斜杠分隔格式:'utils.py/修改/代码/动作'
3. 逗号分隔格式:'utils.py,修改,代码,动作'
4. 空格分隔格式:'utils.py 修改 代码 动作'
5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"]
6. JSON格式字符串'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
Args:
keywords_input: 关键词输入,可以是字符串或列表
Returns:
list[str]: 解析后的关键词列表,去除空白项
"""
if not keywords_input:
return []
# 如果已经是列表,直接处理
if isinstance(keywords_input, list):
return [str(k).strip() for k in keywords_input if str(k).strip()]
# 转换为字符串处理
keywords_str = str(keywords_input).strip()
if not keywords_str:
return []
try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
import json
json_data = json.loads(keywords_str)
if isinstance(json_data, dict) and "keywords" in json_data:
keywords_list = json_data["keywords"]
if isinstance(keywords_list, list):
return [str(k).strip() for k in keywords_list if str(k).strip()]
elif isinstance(json_data, list):
# 直接是JSON数组格式
return [str(k).strip() for k in json_data if str(k).strip()]
except (json.JSONDecodeError, ValueError):
pass
try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式
import ast
parsed = ast.literal_eval(keywords_str)
if isinstance(parsed, list):
return [str(k).strip() for k in parsed if str(k).strip()]
except (ValueError, SyntaxError):
pass
# 尝试不同的分隔符
separators = ['/', ',', ' ', '|', ';']
for separator in separators:
if separator in keywords_str:
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
if len(keywords_list) > 1: # 确保分割有效
return keywords_list
# 如果没有分隔符,返回单个关键词
return [keywords_str] if keywords_str else []

View File

@@ -14,7 +14,7 @@ from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
install(extra_lines=3) install(extra_lines=3)
@@ -37,7 +37,7 @@ class ImageManager:
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
@@ -92,6 +92,20 @@ class ImageManager:
desc_obj.save() desc_obj.save()
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
tag_str = ",".join(emotion_list)
return f"[表情包:{tag_str}]"
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述优先使用Emoji表中的缓存数据""" """获取表情包描述优先使用Emoji表中的缓存数据"""
@@ -108,21 +122,21 @@ class ImageManager:
try: try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash) tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
if cached_emoji_description: if tags:
logger.info(f"[缓存命中] 使用已注册表情包描述: {cached_emoji_description[:50]}...") tag_str = ",".join(tags)
return cached_emoji_description logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
return f"[表情包:{tag_str}]"
except Exception as e: except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}") logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述 # 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description:
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
# === 二步走识别流程 === # === 二步走识别流程 ===
# 第一步VLM视觉分析 - 生成详细描述 # 第一步VLM视觉分析 - 生成详细描述
if image_format in ["gif", "GIF"]: if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
@@ -130,10 +144,16 @@ class ImageManager:
logger.warning("GIF转换失败无法获取描述") logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]" return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg") detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
)
else: else:
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" vlm_prompt = (
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format) "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if detailed_description is None: if detailed_description is None:
logger.warning("VLM未能生成表情包详细描述") logger.warning("VLM未能生成表情包详细描述")
@@ -150,31 +170,32 @@ class ImageManager:
3. 输出简短精准,不要解释 3. 输出简短精准,不要解释
4. 如果有多个词用逗号分隔 4. 如果有多个词用逗号分隔
""" """
# 使用较低温度确保输出稳定 # 使用较低温度确保输出稳定
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji") emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt) emotion_result, _ = await emotion_llm.generate_response_async(
emotion_prompt, temperature=0.3, max_tokens=50
)
if emotion_result is None: if emotion_result is None:
logger.warning("LLM未能生成情感标签使用详细描述的前几个词") logger.warning("LLM未能生成情感标签使用详细描述的前几个词")
# 降级处理:从详细描述中提取关键词 # 降级处理:从详细描述中提取关键词
import jieba import jieba
words = list(jieba.cut(detailed_description)) words = list(jieba.cut(detailed_description))
emotion_result = "".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") emotion_result = "".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
# 处理情感结果取前1-2个最重要的标签 # 处理情感结果取前1-2个最重要的标签
emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()] emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()]
final_emotion = emotions[0] if emotions else "表情" final_emotion = emotions[0] if emotions else "表情"
# 如果有第二个情感且不重复,也包含进来 # 如果有第二个情感且不重复,也包含进来
if len(emotions) > 1 and emotions[1] != emotions[0]: if len(emotions) > 1 and emotions[1] != emotions[0]:
final_emotion = f"{emotions[0]}{emotions[1]}" final_emotion = f"{emotions[0]}{emotions[1]}"
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存,防止并发写入时重复生成 if cached_description := self._get_description_from_db(image_hash, "emoji"):
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
@@ -242,9 +263,7 @@ class ImageManager:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...") logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]" return f"[图片:{existing_image.description}]"
# 查询ImageDescriptions表的缓存描述 if cached_description := self._get_description_from_db(image_hash, "image"):
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]" return f"[图片:{cached_description}]"
@@ -252,7 +271,9 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None: if description is None:
logger.warning("AI未能生成图片描述") logger.warning("AI未能生成图片描述")
@@ -445,10 +466,7 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查图片是否已存在 if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录 # 检查是否缺少必要字段,如果缺少则创建新记录
if ( if (
not hasattr(existing_image, "image_id") not hasattr(existing_image, "image_id")
@@ -524,9 +542,7 @@ class ImageManager:
# 优先检查是否已有其他相同哈希的图片记录包含描述 # 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none( existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
(Images.description.is_null(False)) &
(Images.description != "")
) )
if existing_with_description and existing_with_description.id != image.id: if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
@@ -538,8 +554,7 @@ class ImageManager:
return return
# 检查ImageDescriptions表的缓存描述 # 检查ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image") if cached_description := self._get_description_from_db(image_hash, "image"):
if cached_description:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description image.description = cached_description
image.vlm_processed = True image.vlm_processed = True
@@ -554,15 +569,15 @@ class ImageManager:
# 获取VLM描述 # 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)") logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None: if description is None:
logger.warning("VLM未能生成图片描述") logger.warning("VLM未能生成图片描述")
description = "无法生成描述" description = "无法生成描述"
# 再次检查缓存,防止并发写入时重复生成 if cached_description := self._get_description_from_db(image_hash, "image"):
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description description = cached_description
@@ -606,7 +621,7 @@ def image_path_to_base64(image_path: str) -> str:
raise FileNotFoundError(f"图片文件不存在: {image_path}") raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f: with open(image_path, "rb") as f:
image_data = f.read() if image_data := f.read():
if not image_data: return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}") raise IOError(f"读取图片文件失败: {image_path}")
return base64.b64encode(image_data).decode("utf-8")

View File

@@ -1,35 +1,29 @@
import base64 from src.config.config import global_config, model_config
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("chat_voice") logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str: async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件描述""" """获取音频文件转录文本"""
if not global_config.voice.enable_asr: if not global_config.voice.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息") logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]" return "[语音]"
try: try:
# 解码base64音频数据 _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
# 确保base64字符串只包含ASCII字符 text = await _llm.generate_response_for_voice(voice_base64)
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None: if text is None:
logger.warning("未能生成语音文本") logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]" return "[语音(文本生成失败)]"
logger.debug(f"描述是{text}") logger.debug(f"描述是{text}")
return f"[语音:{text}]" return f"[语音:{text}]"
except Exception as e: except Exception as e:
logger.error(f"语音转文字失败: {str(e)}") logger.error(f"语音转文字失败: {str(e)}")
return "[语音]" return "[语音]"

View File

@@ -1,62 +0,0 @@
import asyncio
from src.config.config import global_config
from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager):
def __init__(self):
super().__init__()
self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(1)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self):
if self._decay_task is None:
self._decay_task = asyncio.create_task(self._decay_reply_willing())
async def get_reply_probability(self, message_id):
# sourcery skip: inline-immediately-returned-variable
willing_info = self.ongoing_messages[message_id]
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
# print(f"[{chat_id}] 回复意愿: {current_willing}")
interested_rate = willing_info.interested_rate
# print(f"[{chat_id}] 兴趣值: {interested_rate}")
if interested_rate > 0.2:
current_willing += interested_rate - 0.2
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
current_willing += 1 if current_willing < 1.0 else 0.2
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1.5)
# print(f"[{chat_id}] 回复概率: {reply_probability}")
return reply_probability
async def before_generate_reply_handle(self, message_id):
pass
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
return
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.3)
async def not_reply_handle(self, message_id):
return await super().not_reply_handle(message_id)

View File

@@ -1,23 +0,0 @@
from .willing_manager import BaseWillingManager
NOT_IMPLEMENTED_MESSAGE = "\ncustom模式你实现了吗没自行实现不要选custom。给你退了快点给你麦爹配置\n以上内容由gemini生成如有不满请投诉gemini"
class CustomWillingManager(BaseWillingManager):
async def async_task_starter(self) -> None:
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
async def before_generate_reply_handle(self, message_id: str):
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
async def after_generate_reply_handle(self, message_id: str):
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
async def not_reply_handle(self, message_id: str):
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
async def get_reply_probability(self, message_id: str):
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
def __init__(self):
super().__init__()
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)

View File

@@ -1,296 +0,0 @@
"""
Mxp 模式:梦溪畔独家赞助
此模式的一些参数不会在配置文件中显示,要修改请在可变参数下修改
同时一些全局设置对此模式无效
此模式的可变参数暂时比较草率,需要调参仙人的大手
此模式的特点:
1.每个聊天流的每个用户的意愿是独立的
2.接入关系系统,关系会影响意愿值(已移除,因为关系系统重构)
3.会根据群聊的热度来调整基础意愿值
4.限制同时思考的消息数量,防止喷射
5.拥有单聊增益无论在群里还是私聊只要bot一直和你聊就会增加意愿值
6.意愿分为衰减意愿+临时意愿
7.疲劳机制
如果你发现本模式出现了bug
上上策是询问智慧的小草神()
上策是询问万能的千石可乐
中策是发issue
下下策是询问一个菜鸟(@梦溪畔)
"""
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
import time
import math
from src.chat.message_receive.chat_stream import ChatStream
class MxpWillingManager(BaseWillingManager):
"""Mxp意愿管理器"""
def __init__(self):
super().__init__()
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
self.temporary_willing: float = 0 # 临时意愿值
self.chat_bot_message_time: Dict[str, list[float]] = {} # 聊天流ID: bot已回复消息时间
self.chat_fatigue_punishment_list: Dict[
str, list[tuple[float, float]]
] = {} # 聊天流疲劳惩罚列, 聊天流ID: 惩罚时间列(开始时间,持续时间)
self.chat_fatigue_willing_attenuation: Dict[str, float] = {} # 聊天流疲劳意愿衰减值
# 可变参数
self.intention_decay_rate = 0.93 # 意愿衰减率
self.number_of_message_storage = 12 # 消息存储数量
self.expected_replies_per_min = 3 # 每分钟预期回复数
self.basic_maximum_willing = 0.5 # 基础最大意愿值
self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
self.fatigue_coefficient = 1.0 # 疲劳系数
self.is_debug = False # 是否开启调试模式
async def async_task_starter(self) -> None:
"""异步任务启动器"""
asyncio.create_task(self._return_to_basic_willing())
asyncio.create_task(self._chat_new_message_to_change_basic_willing())
asyncio.create_task(self._fatigue_attenuation())
async def before_generate_reply_handle(self, message_id: str):
"""回复前处理"""
current_time = time.time()
async with self.lock:
w_info = self.ongoing_messages[message_id]
if w_info.chat_id not in self.chat_bot_message_time:
self.chat_bot_message_time[w_info.chat_id] = []
self.chat_bot_message_time[w_info.chat_id] = [
t for t in self.chat_bot_message_time[w_info.chat_id] if current_time - t < 60
]
self.chat_bot_message_time[w_info.chat_id].append(current_time)
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2))
async def after_generate_reply_handle(self, message_id: str):
"""回复后处理"""
async with self.lock:
w_info = self.ongoing_messages[message_id]
# 移除关系值相关代码
# rel_value = await w_info.person_info_manager.get_value(w_info.person_id, "relationship_value")
# rel_level = self._get_relationship_level_num(rel_value)
# self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0))
if now_chat_new_person[0] == w_info.person_id:
if now_chat_new_person[1] < 3:
tmp_list = list(now_chat_new_person)
tmp_list[1] += 1 # type: ignore
self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore
else:
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def not_reply_handle(self, message_id: str):
"""不回复处理"""
async with self.lock:
w_info = self.ongoing_messages[message_id]
if w_info.is_mentioned_bot:
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.mention_willing_gain / 2.5
if (
w_info.chat_id in self.last_response_person
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
and self.last_response_person[w_info.chat_id][1]
):
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
2 * self.last_response_person[w_info.chat_id][1] - 1
)
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0))
if now_chat_new_person[0] != w_info.person_id:
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def get_reply_probability(self, message_id: str):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
"""获取回复概率"""
async with self.lock:
w_info = self.ongoing_messages[message_id]
current_willing = self.chat_person_reply_willing[w_info.chat_id][w_info.person_id]
if self.is_debug:
self.logger.debug(f"基础意愿值:{current_willing}")
if w_info.is_mentioned_bot:
willing_gain = self.mention_willing_gain / (int(current_willing) + 1)
current_willing += willing_gain
if self.is_debug:
self.logger.debug(f"提及增益:{willing_gain}")
if w_info.interested_rate > 0:
willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
current_willing += willing_gain
if self.is_debug:
self.logger.debug(f"兴趣增益:{willing_gain}")
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
# 添加单聊增益
if (
w_info.chat_id in self.last_response_person
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
and self.last_response_person[w_info.chat_id][1]
):
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
if self.is_debug:
self.logger.debug(
f"单聊增益:{self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)}"
)
current_willing += self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)
if self.is_debug:
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
if len(chat_person_ongoing_messages) >= 2:
current_willing = 0
if self.is_debug:
self.logger.debug("进行中消息惩罚归0")
elif len(chat_ongoing_messages) == 2:
current_willing -= 0.5
if self.is_debug:
self.logger.debug("进行中消息惩罚:-0.5")
elif len(chat_ongoing_messages) == 3:
current_willing -= 1.5
if self.is_debug:
self.logger.debug("进行中消息惩罚:-1.5")
elif len(chat_ongoing_messages) >= 4:
current_willing = 0
if self.is_debug:
self.logger.debug("进行中消息惩罚归0")
probability = self._willing_to_probability(current_willing)
self.temporary_willing = current_willing
return probability
async def _return_to_basic_willing(self):
"""使每个人的意愿恢复到chat基础意愿"""
while True:
await asyncio.sleep(3)
async with self.lock:
for chat_id, person_willing in self.chat_person_reply_willing.items():
for person_id, willing in person_willing.items():
if chat_id not in self.chat_reply_willing:
self.logger.debug(f"聊天流{chat_id}不存在,错误")
continue
basic_willing = self.chat_reply_willing[chat_id]
person_willing[person_id] = (
basic_willing + (willing - basic_willing) * self.intention_decay_rate
)
def setup(self, message: dict, chat_stream: ChatStream):
super().setup(message, chat_stream)
stream_id = chat_stream.stream_id
self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing)
self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {})
self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = (
self.chat_person_reply_willing[stream_id].get(
self.ongoing_messages[message.get("message_id", "")].person_id,
self.chat_reply_willing[stream_id],
)
)
current_time = time.time()
if stream_id not in self.chat_new_message_time:
self.chat_new_message_time[stream_id] = []
self.chat_new_message_time[stream_id].append(current_time)
if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage:
self.chat_new_message_time[stream_id].pop(0)
if stream_id not in self.chat_fatigue_punishment_list:
self.chat_fatigue_punishment_list[stream_id] = [
(
current_time,
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
)
]
self.chat_fatigue_willing_attenuation[stream_id] = (
-2 * self.basic_maximum_willing * self.fatigue_coefficient
)
@staticmethod
def _willing_to_probability(willing: float) -> float:
"""意愿值转化为概率"""
willing = max(0, willing)
if willing < 2:
return math.atan(willing * 2) / math.pi * 2
elif willing < 2.5:
return math.atan(willing * 4) / math.pi * 2
else:
return 1
async def _chat_new_message_to_change_basic_willing(self):
"""聊天流新消息改变基础意愿"""
update_time = 20
while True:
await asyncio.sleep(update_time)
async with self.lock:
for chat_id, message_times in self.chat_new_message_time.items():
# 清理过期消息
current_time = time.time()
message_times = [
msg_time
for msg_time in message_times
if current_time - msg_time
< self.number_of_message_storage
* self.basic_maximum_willing
/ self.expected_replies_per_min
* 60
]
self.chat_new_message_time[chat_id] = message_times
if len(message_times) < self.number_of_message_storage:
self.chat_reply_willing[chat_id] = self.basic_maximum_willing
update_time = 20
elif len(message_times) == self.number_of_message_storage:
time_interval = current_time - message_times[0]
basic_willing = self._basic_willing_calculate(time_interval)
self.chat_reply_willing[chat_id] = basic_willing
update_time = 17 * basic_willing / self.basic_maximum_willing + 3
else:
self.logger.debug(f"聊天流{chat_id}消息时间数量异常,数量:{len(message_times)}")
self.chat_reply_willing[chat_id] = 0
if self.is_debug:
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
def _basic_willing_calculate(self, t: float) -> float:
"""基础意愿值计算"""
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2
async def _fatigue_attenuation(self):
"""疲劳衰减"""
while True:
await asyncio.sleep(1)
current_time = time.time()
async with self.lock:
for chat_id, fatigue_list in self.chat_fatigue_punishment_list.items():
fatigue_list = [z for z in fatigue_list if current_time - z[0] < z[1]]
self.chat_fatigue_willing_attenuation[chat_id] = 0
for start_time, duration in fatigue_list:
self.chat_fatigue_willing_attenuation[chat_id] += (
self.chat_reply_willing[chat_id]
* 2
/ math.pi
* math.asin(2 * (current_time - start_time) / duration - 1)
- self.chat_reply_willing[chat_id]
) * self.fatigue_coefficient
async def get_willing(self, chat_id):
return self.temporary_willing

View File

@@ -1,180 +0,0 @@
import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3)
"""
基类方法概览:
以下8个方法是你必须在子类重写的哪怕什么都不干
async_task_starter 在程序启动时执行在其中用asyncio.create_task启动你想要执行的异步任务
before_generate_reply_handle 确定要回复后,在生成回复前的处理
after_generate_reply_handle 确定要回复后,在生成回复后的处理
not_reply_handle 确定不回复后的处理
get_reply_probability 获取回复概率
get_variable_parameters 暂不确定
set_variable_parameters 暂不确定
以下2个方法根据你的实现可以做调整
get_willing 获取某聊天流意愿
set_willing 设置某聊天流意愿
规范说明:
模块文件命名: `mode_{manager_type}.py`
示例: 若 `manager_type="aggressive"`,则模块文件应为 `mode_aggressive.py`
类命名: `{manager_type}WillingManager` (首字母大写)
示例: 在 `mode_aggressive.py` 中,类名应为 `AggressiveWillingManager`
"""
logger = get_logger("willing")
@dataclass
class WillingInfo:
"""此类保存意愿模块常用的参数
Attributes:
message (MessageRecv): 原始消息对象
chat (ChatStream): 聊天流对象
person_info_manager (PersonInfoManager): 用户信息管理对象
chat_id (str): 当前聊天流的标识符
person_id (str): 发送者的个人信息的标识符
group_id (str): 群组ID如果是私聊则为空
is_mentioned_bot (bool): 是否提及了bot
is_emoji (bool): 是否为表情包
interested_rate (float): 兴趣度
"""
message: Dict[str, Any] # 原始消息数据
chat: ChatStream
person_info_manager: PersonInfoManager
chat_id: str
person_id: str
group_info: Optional[GroupInfo]
is_mentioned_bot: bool
is_emoji: bool
is_picid: bool
interested_rate: float
# current_mood: float 当前心情?
class BaseWillingManager(ABC):
"""回复意愿管理基类"""
@classmethod
def create(cls, manager_type: str) -> "BaseWillingManager":
try:
module = importlib.import_module(f".mode_{manager_type}", __package__)
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
if not issubclass(manager_class, cls):
raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}")
else:
logger.info(f"普通回复模式:{manager_type}")
return manager_class()
except (ImportError, AttributeError, TypeError) as e:
module = importlib.import_module(".mode_classical", __package__)
manager_class = module.ClassicalWillingManager
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}")
return manager_class()
def __init__(self):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock()
self.logger = logger
def setup(self, message: dict, chat: ChatStream):
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
self.ongoing_messages[message.get("message_id", "")] = WillingInfo(
message=message,
chat=chat,
person_info_manager=get_person_info_manager(),
chat_id=chat.stream_id,
person_id=person_id,
group_info=chat.group_info,
is_mentioned_bot=message.get("is_mentioned", False),
is_emoji=message.get("is_emoji", False),
is_picid=message.get("is_picid", False),
interested_rate = message.get("interest_value") or 0.0,
)
def delete(self, message_id: str):
del_message = self.ongoing_messages.pop(message_id, None)
if not del_message:
logger.debug(f"尝试删除不存在的消息 ID: {message_id},可能已被其他流程处理,喵~")
@abstractmethod
async def async_task_starter(self) -> None:
"""抽象方法:异步任务启动器"""
pass
@abstractmethod
async def before_generate_reply_handle(self, message_id: str):
"""抽象方法:回复前处理"""
pass
@abstractmethod
async def after_generate_reply_handle(self, message_id: str):
"""抽象方法:回复后处理"""
pass
@abstractmethod
async def not_reply_handle(self, message_id: str):
"""抽象方法:不回复处理"""
pass
@abstractmethod
async def get_reply_probability(self, message_id: str):
"""抽象方法:获取回复概率"""
raise NotImplementedError
async def get_willing(self, chat_id: str):
"""获取指定聊天流的回复意愿"""
async with self.lock:
return self.chat_reply_willing.get(chat_id, 0)
async def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
async with self.lock:
self.chat_reply_willing[chat_id] = willing
# @abstractmethod
# async def get_variable_parameters(self) -> Dict[str, str]:
# """抽象方法:获取可变参数"""
# pass
# @abstractmethod
# async def set_variable_parameters(self, parameters: Dict[str, any]):
# """抽象方法:设置可变参数"""
# pass
def init_willing_manager() -> BaseWillingManager:
"""
根据配置初始化并返回对应的WillingManager实例
Returns:
对应mode的WillingManager实例
"""
mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode)
# 全局willing_manager对象
willing_manager = None
def get_willing_manager():
global willing_manager
if willing_manager is None:
willing_manager = init_willing_manager()
return willing_manager

View File

@@ -79,6 +79,8 @@ class LLMUsage(BaseModel):
""" """
model_name = TextField(index=True) # 添加索引 model_name = TextField(index=True) # 添加索引
model_assign_name = TextField(null=True) # 添加索引
model_api_provider = TextField(null=True) # 添加索引
user_id = TextField(index=True) # 添加索引 user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引 request_type = TextField(index=True) # 添加索引
endpoint = TextField() endpoint = TextField()
@@ -86,6 +88,7 @@ class LLMUsage(BaseModel):
completion_tokens = IntegerField() completion_tokens = IntegerField()
total_tokens = IntegerField() total_tokens = IntegerField()
cost = DoubleField() cost = DoubleField()
time_cost = DoubleField(null=True)
status = TextField() status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
@@ -130,6 +133,9 @@ class Messages(BaseModel):
reply_to = TextField(null=True) reply_to = TextField(null=True)
interest_value = DoubleField(null=True) interest_value = DoubleField(null=True)
key_words = TextField(null=True)
key_words_lite = TextField(null=True)
is_mentioned = BooleanField(null=True) is_mentioned = BooleanField(null=True)
# 从 chat_info 扁平化而来的字段 # 从 chat_info 扁平化而来的字段
@@ -146,14 +152,13 @@ class Messages(BaseModel):
chat_info_last_active_time = DoubleField() chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息) # 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField() user_platform = TextField(null=True)
user_id = TextField() user_id = TextField(null=True)
user_nickname = TextField() user_nickname = TextField(null=True)
user_cardname = TextField(null=True) user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息 processed_plain_text = TextField(null=True) # 处理后的纯文本消息
display_message = TextField(null=True) # 显示的消息 display_message = TextField(null=True) # 显示的消息
memorized_times = IntegerField(default=0) # 被记忆的次数
priority_mode = TextField(null=True) priority_mode = TextField(null=True)
priority_info = TextField(null=True) priority_info = TextField(null=True)
@@ -162,6 +167,9 @@ class Messages(BaseModel):
is_emoji = BooleanField(default=False) is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False) is_picid = BooleanField(default=False)
is_command = BooleanField(default=False) is_command = BooleanField(default=False)
is_notify = BooleanField(default=False)
selected_expressions = TextField(null=True)
class Meta: class Meta:
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
@@ -247,28 +255,60 @@ class PersonInfo(BaseModel):
用于存储个人信息数据的模型。 用于存储个人信息数据的模型。
""" """
is_known = BooleanField(default=False) # 是否已认识
person_id = TextField(unique=True, index=True) # 个人唯一ID person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空) person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因 name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台 platform = TextField() # 平台
user_id = TextField(index=True) # 用户ID user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称 nickname = TextField(null=True) # 用户昵称
impression = TextField(null=True) # 个人印象 memory_points = TextField(null=True) # 个人印象的点
short_impression = TextField(null=True) # 个人印象的简短描述
points = TextField(null=True) # 个人印象的点
forgotten_points = TextField(null=True) # 被遗忘的点
info_list = TextField(null=True) # 与Bot的互动
know_times = FloatField(null=True) # 认识时间 (时间戳) know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间 know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间 last_know = FloatField(null=True) # 最后一次印象总结时间
attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢
attitude_to_me = TextField(null=True) # 对bot的态度
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
friendly_value = FloatField(null=True) # 对bot的友好程度
friendly_value_confidence = FloatField(null=True) # 对bot的友好程度置信度
rudeness = TextField(null=True) # 对bot的冒犯程度
rudeness_confidence = FloatField(null=True) # 对bot的冒犯程度置信度
neuroticism = TextField(null=True) # 对bot的神经质程度
neuroticism_confidence = FloatField(null=True) # 对bot的神经质程度置信度
conscientiousness = TextField(null=True) # 对bot的尽责程度
conscientiousness_confidence = FloatField(null=True) # 对bot的尽责程度置信度
likeness = TextField(null=True) # 对bot的相似程度
likeness_confidence = FloatField(null=True) # 对bot的相似程度置信度
class Meta: class Meta:
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "person_info" table_name = "person_info"
class GroupInfo(BaseModel):
"""
用于存储群组信息数据的模型。
"""
group_id = TextField(unique=True, index=True) # 群组唯一ID
group_name = TextField(null=True) # 群组名称 (允许为空)
platform = TextField() # 平台
group_impression = TextField(null=True) # 群组印象
member_list = TextField(null=True) # 群成员列表 (JSON格式)
topic = TextField(null=True) # 群组基本信息
create_time = FloatField(null=True) # 创建时间 (时间戳)
last_active = FloatField(null=True) # 最后活跃时间
member_count = IntegerField(null=True, default=0) # 成员数量
class Meta:
# database = db # 继承自 BaseModel
table_name = "group_info"
class Memory(BaseModel): class Memory(BaseModel):
memory_id = TextField(index=True) memory_id = TextField(index=True)
chat_id = TextField(null=True) chat_id = TextField(null=True)
@@ -281,20 +321,6 @@ class Memory(BaseModel):
table_name = "memory" table_name = "memory"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型。 用于存储表达风格的模型。
@@ -311,31 +337,6 @@ class Expression(BaseModel):
class Meta: class Meta:
table_name = "expression" table_name = "expression"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = "thinking_logs"
class GraphNodes(BaseModel): class GraphNodes(BaseModel):
""" """
用于存储记忆图节点的模型 用于存储记忆图节点的模型
@@ -343,6 +344,7 @@ class GraphNodes(BaseModel):
concept = TextField(unique=True, index=True) # 节点概念 concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表 memory_items = TextField() # JSON格式存储的记忆列表
weight = FloatField(default=0.0) # 节点权重
hash = TextField() # 节点哈希值 hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳 created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳 last_modified = FloatField() # 最后修改时间戳
@@ -382,9 +384,7 @@ def create_tables():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
ThinkingLog,
GraphNodes, # 添加图节点表 GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表 GraphEdges, # 添加图边表
Memory, Memory,
@@ -393,10 +393,14 @@ def create_tables():
) )
def initialize_database(): def initialize_database(sync_constraints=False):
""" """
检查所有定义的表是否存在,如果不存在则创建它们。 检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则自动添加。 检查所有表的所有字段是否存在,如果缺失则自动添加。
Args:
sync_constraints (bool): 是否同步字段约束。默认为 False。
如果为 True会检查并修复字段的 NULL 约束不一致问题。
""" """
models = [ models = [
@@ -408,10 +412,8 @@ def initialize_database():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
Memory, Memory,
ThinkingLog,
GraphNodes, GraphNodes,
GraphEdges, GraphEdges,
ActionRecords, # 添加 ActionRecords 到初始化列表 ActionRecords, # 添加 ActionRecords 到初始化列表
@@ -478,6 +480,13 @@ def initialize_database():
logger.info(f"字段 '{field_name}' 删除成功") logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e: except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}") logger.error(f"删除字段 '{field_name}' 失败: {e}")
# 如果启用了约束同步,执行约束检查和修复
if sync_constraints:
logger.debug("开始同步数据库字段约束...")
sync_field_constraints()
logger.debug("数据库字段约束同步完成")
except Exception as e: except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}") logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出 # 如果检查失败(例如数据库不可用),则退出
@@ -486,5 +495,261 @@ def initialize_database():
logger.info("数据库初始化完成") logger.info("数据库初始化完成")
def sync_field_constraints():
"""
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
如果发现不一致,会自动修复字段约束。
"""
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Expression,
Memory,
GraphNodes,
GraphEdges,
ActionRecords,
]
try:
with db:
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 不存在,跳过约束检查")
continue
logger.debug(f"检查表 '{table_name}' 的字段约束...")
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
# 检查每个模型字段的约束
constraints_to_fix = []
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue # 字段不存在,跳过
current_notnull = current_schema[field_name]['notnull']
model_allows_null = field_obj.null
# 如果模型允许 null 但数据库字段不允许 null需要修复
if model_allows_null and current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'allow_null',
'current_constraint': 'NOT NULL',
'target_constraint': 'NULL'
})
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL")
# 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心
elif not model_allows_null and not current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'disallow_null',
'current_constraint': 'NULL',
'target_constraint': 'NOT NULL'
})
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL")
# 修复约束不一致的字段
if constraints_to_fix:
logger.info(f"'{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
_fix_table_constraints(table_name, model, constraints_to_fix)
else:
logger.debug(f"'{table_name}' 的字段约束已同步")
except Exception as e:
logger.exception(f"同步字段约束时出错: {e}")
def _fix_table_constraints(table_name, model, constraints_to_fix):
"""
修复表的字段约束。
对于 SQLite由于不支持直接修改列约束需要重建表。
"""
try:
# 备份表名
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
# 1. 创建备份表
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
logger.info(f"已创建备份表 '{backup_table}'")
# 2. 删除原表
db.execute_sql(f"DROP TABLE {table_name}")
logger.info(f"已删除原表 '{table_name}'")
# 3. 重新创建表(使用当前模型定义)
db.create_tables([model])
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
# 4. 从备份表恢复数据
# 获取字段列表
fields = list(model._meta.fields.keys())
fields_str = ', '.join(fields)
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
# 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [
constraint['field_name'] for constraint in constraints_to_fix
if constraint['action'] == 'disallow_null'
]
if null_to_notnull_fields:
# 需要处理 NULL 值,为这些字段设置默认值
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL需要处理现有的NULL值")
# 构建更复杂的 SELECT 语句来处理 NULL 值
select_fields = []
for field_name in fields:
if field_name in null_to_notnull_fields:
field_obj = model._meta.fields[field_name]
# 根据字段类型设置默认值
if isinstance(field_obj, (TextField,)):
default_value = "''"
elif isinstance(field_obj, (IntegerField, FloatField, DoubleField)):
default_value = "0"
elif isinstance(field_obj, BooleanField):
default_value = "0"
elif isinstance(field_obj, DateTimeField):
default_value = f"'{datetime.datetime.now()}'"
else:
default_value = "''"
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
else:
select_fields.append(field_name)
select_str = ', '.join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
db.execute_sql(insert_sql)
logger.info(f"已从备份表恢复数据到 '{table_name}'")
# 5. 验证数据完整性
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
if original_count == new_count:
logger.info(f"数据完整性验证通过: {original_count} 行数据")
# 删除备份表
db.execute_sql(f"DROP TABLE {backup_table}")
logger.info(f"已删除备份表 '{backup_table}'")
else:
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count}")
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
# 记录修复的约束
for constraint in constraints_to_fix:
logger.info(f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
except Exception as e:
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
# 尝试恢复
try:
if db.table_exists(backup_table):
logger.info(f"尝试从备份表 '{backup_table}' 恢复...")
db.execute_sql(f"DROP TABLE IF EXISTS {table_name}")
db.execute_sql(f"ALTER TABLE {backup_table} RENAME TO {table_name}")
logger.info(f"已从备份恢复表 '{table_name}'")
except Exception as restore_error:
logger.exception(f"恢复表失败: {restore_error}")
def check_field_constraints():
"""
检查但不修复字段约束,返回不一致的字段信息。
用于在修复前预览需要修复的内容。
"""
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Expression,
Memory,
GraphNodes,
GraphEdges,
ActionRecords,
]
inconsistencies = {}
try:
with db:
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
continue
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
table_inconsistencies = []
# 检查每个模型字段的约束
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue
current_notnull = current_schema[field_name]['notnull']
model_allows_null = field_obj.null
if model_allows_null and current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_allows_null_but_db_not_null',
'model_constraint': 'NULL',
'db_constraint': 'NOT NULL',
'recommended_action': 'allow_null'
})
elif not model_allows_null and not current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_not_null_but_db_allows_null',
'model_constraint': 'NOT NULL',
'db_constraint': 'NULL',
'recommended_action': 'disallow_null'
})
if table_inconsistencies:
inconsistencies[table_name] = table_inconsistencies
except Exception as e:
logger.exception(f"检查字段约束时出错: {e}")
return inconsistencies
# 模块加载时调用初始化函数 # 模块加载时调用初始化函数
initialize_database() initialize_database(sync_constraints=True)

View File

@@ -5,7 +5,7 @@ import json
import threading import threading
import time import time
import structlog import structlog
import toml import tomlkit
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
@@ -188,24 +188,35 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
"""从配置文件加载日志设置""" """从配置文件加载日志设置"""
config_path = Path("config/bot_config.toml") config_path = Path("config/bot_config.toml")
default_config = { default_config = {
"date_style": "Y-m-d H:i:s", "date_style": "m-d H:i:s",
"log_level_style": "lite", "log_level_style": "lite",
"color_text": "title", "color_text": "full",
"log_level": "INFO", # 全局日志级别(向下兼容) "log_level": "INFO", # 全局日志级别(向下兼容)
"console_log_level": "INFO", # 控制台日志级别 "console_log_level": "INFO", # 控制台日志级别
"file_log_level": "DEBUG", # 文件日志级别 "file_log_level": "DEBUG", # 文件日志级别
"suppress_libraries": [], "suppress_libraries": [
"library_log_levels": {}, "faiss",
"httpx",
"urllib3",
"asyncio",
"websockets",
"httpcore",
"requests",
"peewee",
"openai",
"uvicorn",
"jieba",
],
"library_log_levels": {"aiohttp": "WARNING"},
} }
try: try:
if config_path.exists(): if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config = toml.load(f) config = tomlkit.load(f)
return config.get("log", default_config) return config.get("log", default_config)
except Exception: except Exception as e:
pass print(f"[日志系统] 加载日志配置失败: {e}")
return default_config return default_config
@@ -334,7 +345,7 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色 "llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼 "remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m", "planner": "\033[36m",
"memory": "\033[34m", "memory": "\033[38;5;117m", # 天蓝色
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读 "hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
"action_manager": "\033[38;5;208m", # 橙色不与replyer重复 "action_manager": "\033[38;5;208m", # 橙色不与replyer重复
# 关系系统 # 关系系统
@@ -352,7 +363,7 @@ MODULE_COLORS = {
"expressor": "\033[38;5;166m", # 橙色 "expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块 # 专注聊天模块
"replyer": "\033[38;5;166m", # 橙色 "replyer": "\033[38;5;166m", # 橙色
"memory_activator": "\033[34m", # 绿 "memory_activator": "\033[38;5;117m", # 天蓝
# 插件系统 # 插件系统
"plugins": "\033[31m", # 红色 "plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色 "plugin_api": "\033[33m", # 黄色
@@ -390,7 +401,7 @@ MODULE_COLORS = {
"tts_action": "\033[38;5;58m", # 深黄色 "tts_action": "\033[38;5;58m", # 深黄色
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色 "doubao_pic_plugin": "\033[38;5;64m", # 深绿色
# Action组件 # Action组件
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告 "no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"reply_action": "\033[38;5;46m", # 亮绿色 "reply_action": "\033[38;5;46m", # 亮绿色
"base_action": "\033[38;5;250m", # 浅灰色 "base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息 # 数据库和消息
@@ -403,8 +414,7 @@ MODULE_COLORS = {
"model_utils": "\033[38;5;164m", # 紫红色 "model_utils": "\033[38;5;164m", # 紫红色
"relationship_fetcher": "\033[38;5;170m", # 浅紫色 "relationship_fetcher": "\033[38;5;170m", # 浅紫色
"relationship_builder": "\033[38;5;93m", # 浅蓝色 "relationship_builder": "\033[38;5;93m", # 浅蓝色
# s4u
#s4u
"context_web_api": "\033[38;5;240m", # 深灰色 "context_web_api": "\033[38;5;240m", # 深灰色
"S4U_chat": "\033[92m", # 深灰色 "S4U_chat": "\033[92m", # 深灰色
} }
@@ -414,7 +424,7 @@ MODULE_ALIASES = {
# 示例映射 # 示例映射
"individuality": "人格特质", "individuality": "人格特质",
"emoji": "表情包", "emoji": "表情包",
"no_reply_action": "摸鱼", "no_action_action": "摸鱼",
"reply_action": "回复", "reply_action": "回复",
"action_manager": "动作", "action_manager": "动作",
"memory_activator": "记忆", "memory_activator": "记忆",
@@ -440,6 +450,37 @@ MODULE_ALIASES = {
RESET_COLOR = "\033[0m" RESET_COLOR = "\033[0m"
def convert_pathname_to_module(logger, method_name, event_dict):
# sourcery skip: extract-method, use-string-remove-affix
"""将 pathname 转换为模块风格的路径"""
if "pathname" in event_dict:
pathname = event_dict["pathname"]
try:
# 获取项目根目录 - 使用绝对路径确保准确性
logger_file = Path(__file__).resolve()
project_root = logger_file.parent.parent.parent
pathname_path = Path(pathname).resolve()
rel_path = pathname_path.relative_to(project_root)
# 转换为模块风格:移除 .py 扩展名,将路径分隔符替换为点
module_path = str(rel_path).replace("\\", ".").replace("/", ".")
if module_path.endswith(".py"):
module_path = module_path[:-3]
# 使用转换后的模块路径替换 module 字段
event_dict["module"] = module_path
# 移除原始的 pathname 字段
del event_dict["pathname"]
except Exception:
# 如果转换失败,删除 pathname 但保留原始的 module如果有的话
del event_dict["pathname"]
# 如果没有 module 字段,使用文件名作为备选
if "module" not in event_dict:
event_dict["module"] = Path(pathname).stem
return event_dict
class ModuleColoredConsoleRenderer: class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,为不同模块提供不同颜色""" """自定义控制台渲染器,为不同模块提供不同颜色"""
@@ -451,7 +492,7 @@ class ModuleColoredConsoleRenderer:
# 日志级别颜色 # 日志级别颜色
self._level_colors = { self._level_colors = {
"debug": "\033[38;5;208m", # 橙色 "debug": "\033[38;5;208m", # 橙色
"info": "\033[34m", # 蓝色 "info": "\033[38;5;117m", # 蓝色
"success": "\033[32m", # 绿色 "success": "\033[32m", # 绿色
"warning": "\033[33m", # 黄色 "warning": "\033[33m", # 黄色
"error": "\033[31m", # 红色 "error": "\033[31m", # 红色
@@ -529,7 +570,7 @@ class ModuleColoredConsoleRenderer:
if logger_name: if logger_name:
# 获取别名,如果没有别名则使用原名称 # 获取别名,如果没有别名则使用原名称
display_name = MODULE_ALIASES.get(logger_name, logger_name) display_name = MODULE_ALIASES.get(logger_name, logger_name)
if self._colors and self._enable_module_colors: if self._colors and self._enable_module_colors:
if module_color: if module_color:
module_part = f"{module_color}[{display_name}]{RESET_COLOR}" module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
@@ -562,7 +603,7 @@ class ModuleColoredConsoleRenderer:
# 处理其他字段 # 处理其他字段
extras = [] extras = []
for key, value in event_dict.items(): for key, value in event_dict.items():
if key not in ("timestamp", "level", "logger_name", "event"): if key not in ("timestamp", "level", "logger_name", "event", "module", "lineno", "pathname"):
# 确保值也转换为字符串 # 确保值也转换为字符串
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
try: try:
@@ -603,6 +644,13 @@ def configure_structlog():
processors=[ processors=[
structlog.contextvars.merge_contextvars, structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level, structlog.processors.add_log_level,
structlog.processors.CallsiteParameterAdder(
parameters=[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.LINENO,
]
),
convert_pathname_to_module,
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info, structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False), structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
@@ -627,6 +675,10 @@ file_formatter = structlog.stdlib.ProcessorFormatter(
structlog.stdlib.add_log_level, structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(), structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"), structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.CallsiteParameterAdder(
parameters=[structlog.processors.CallsiteParameter.MODULE, structlog.processors.CallsiteParameter.LINENO]
),
convert_pathname_to_module,
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info, structlog.processors.format_exc_info,
], ],
@@ -706,181 +758,6 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
return logger return logger
def configure_logging(
level: str = "INFO",
console_level: Optional[str] = None,
file_level: Optional[str] = None,
max_bytes: int = 5 * 1024 * 1024,
backup_count: int = 30,
log_dir: str = "logs",
):
"""动态配置日志参数"""
log_path = Path(log_dir)
log_path.mkdir(exist_ok=True)
# 更新文件handler配置
file_handler = get_file_handler()
if file_handler and isinstance(file_handler, TimestampedFileHandler):
file_handler.max_bytes = max_bytes
file_handler.backup_count = backup_count
file_handler.log_dir = Path(log_dir)
# 更新文件handler日志级别
if file_level:
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
# 更新控制台handler日志级别
console_handler = get_console_handler()
if console_handler and console_level:
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
# 设置根logger日志级别为最低级别
if console_level or file_level:
console_level_num = getattr(logging, (console_level or level).upper(), logging.INFO)
file_level_num = getattr(logging, (file_level or level).upper(), logging.INFO)
min_level = min(console_level_num, file_level_num)
root_logger = logging.getLogger()
root_logger.setLevel(min_level)
else:
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, level.upper()))
def reload_log_config():
"""重新加载日志配置"""
global LOG_CONFIG
LOG_CONFIG = load_log_config()
if file_handler := get_file_handler():
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
if console_handler := get_console_handler():
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
# 重新配置console渲染器
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if isinstance(handler, logging.StreamHandler):
# 这是控制台处理器,更新其格式化器
handler.setFormatter(
structlog.stdlib.ProcessorFormatter(
processor=ModuleColoredConsoleRenderer(colors=True),
foreign_pre_chain=[
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
],
)
)
# 重新配置第三方库日志
configure_third_party_loggers()
# 重新配置所有已存在的logger
reconfigure_existing_loggers()
def get_log_config():
"""获取当前日志配置"""
return LOG_CONFIG.copy()
def set_console_log_level(level: str):
"""设置控制台日志级别
Args:
level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
"""
global LOG_CONFIG
LOG_CONFIG["console_log_level"] = level.upper()
if console_handler := get_console_handler():
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
configure_third_party_loggers()
logger = get_logger("logger")
logger.info(f"控制台日志级别已设置为: {level.upper()}")
def set_file_log_level(level: str):
"""设置文件日志级别
Args:
level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
"""
global LOG_CONFIG
LOG_CONFIG["file_log_level"] = level.upper()
if file_handler := get_file_handler():
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
# 重新设置root logger级别
configure_third_party_loggers()
logger = get_logger("logger")
logger.info(f"文件日志级别已设置为: {level.upper()}")
def get_current_log_levels():
"""获取当前的日志级别设置"""
file_handler = get_file_handler()
console_handler = get_console_handler()
file_level = logging.getLevelName(file_handler.level) if file_handler else "UNKNOWN"
console_level = logging.getLevelName(console_handler.level) if console_handler else "UNKNOWN"
return {
"console_level": console_level,
"file_level": file_level,
"root_level": logging.getLevelName(logging.getLogger().level),
}
def force_reset_all_loggers():
"""强制重置所有logger解决格式不一致问题"""
# 先关闭现有的handler
close_handlers()
# 清除所有现有的logger配置
logging.getLogger().manager.loggerDict.clear()
# 重新配置根logger
root_logger = logging.getLogger()
root_logger.handlers.clear()
# 使用单例handler避免重复创建
file_handler = get_file_handler()
console_handler = get_console_handler()
# 重新添加我们的handler
root_logger.addHandler(file_handler)
root_logger.addHandler(console_handler)
# 设置格式化器
file_handler.setFormatter(file_formatter)
console_handler.setFormatter(console_formatter)
# 设置根logger级别为所有handler中最低的级别
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
console_level_num = getattr(logging, console_level.upper(), logging.INFO)
file_level_num = getattr(logging, file_level.upper(), logging.INFO)
min_level = min(console_level_num, file_level_num)
root_logger.setLevel(min_level)
def initialize_logging(): def initialize_logging():
"""手动初始化日志系统确保所有logger都使用正确的配置 """手动初始化日志系统确保所有logger都使用正确的配置
@@ -888,6 +765,7 @@ def initialize_logging():
""" """
global LOG_CONFIG global LOG_CONFIG
LOG_CONFIG = load_log_config() LOG_CONFIG = load_log_config()
# print(LOG_CONFIG)
configure_third_party_loggers() configure_third_party_loggers()
reconfigure_existing_loggers() reconfigure_existing_loggers()
@@ -899,77 +777,10 @@ def initialize_logging():
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
logger.info("日志系统已重新初始化:") logger.info("日志系统已初始化:")
logger.info(f" - 控制台级别: {console_level}") logger.info(f" - 控制台级别: {console_level}")
logger.info(f" - 文件级别: {file_level}") logger.info(f" - 文件级别: {file_level}")
logger.info(" - 轮转份数: 30个文件") logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
logger.info(" - 自动清理: 30天前的日志")
def force_initialize_logging():
"""强制重新初始化整个日志系统,解决格式不一致问题"""
global LOG_CONFIG
LOG_CONFIG = load_log_config()
# 强制重置所有logger
force_reset_all_loggers()
# 重新配置structlog
configure_structlog()
# 配置第三方库
configure_third_party_loggers()
# 输出初始化信息
logger = get_logger("logger")
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
logger.info(
f"日志系统已强制重新初始化,控制台级别: {console_level},文件级别: {file_level},轮转份数: 30个文件所有logger格式已统一"
)
def show_module_colors():
"""显示所有模块的颜色效果"""
get_logger("demo")
print("\n=== 模块颜色展示 ===")
for module_name, _color_code in MODULE_COLORS.items():
# 临时创建一个该模块的logger来展示颜色
demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name)
alias = MODULE_ALIASES.get(module_name, module_name)
if alias != module_name:
demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})")
else:
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
print("=== 颜色展示结束 ===\n")
# 显示别名映射表
if MODULE_ALIASES:
print("=== 当前别名映射 ===")
for module_name, alias in MODULE_ALIASES.items():
print(f" {module_name} -> {alias}")
print("=== 别名映射结束 ===\n")
def format_json_for_logging(data, indent=2, ensure_ascii=False):
"""将JSON数据格式化为可读字符串
Args:
data: 要格式化的数据(字典、列表等)
indent: 缩进空格数
ensure_ascii: 是否确保ASCII编码
Returns:
str: 格式化后的JSON字符串
"""
if not isinstance(data, str):
# 如果是对象,直接格式化
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
# 如果是JSON字符串先解析再格式化
parsed_data = json.loads(data)
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
def cleanup_old_logs(): def cleanup_old_logs():
@@ -1007,8 +818,8 @@ def start_log_cleanup_task():
def cleanup_task(): def cleanup_task():
while True: while True:
time.sleep(24 * 60 * 60) # 每24小时执行一次
cleanup_old_logs() cleanup_old_logs()
time.sleep(24 * 60 * 60) # 每24小时执行一次
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True) cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
cleanup_thread.start() cleanup_thread.start()
@@ -1017,35 +828,6 @@ def start_log_cleanup_task():
logger.info("已启动日志清理任务将自动清理30天前的日志文件轮转份数限制: 30个文件") logger.info("已启动日志清理任务将自动清理30天前的日志文件轮转份数限制: 30个文件")
def get_log_stats():
"""获取日志文件统计信息"""
stats = {"total_files": 0, "total_size": 0, "files": []}
try:
if not LOG_DIR.exists():
return stats
for log_file in LOG_DIR.glob("*.log*"):
file_info = {
"name": log_file.name,
"size": log_file.stat().st_size,
"modified": datetime.fromtimestamp(log_file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
}
stats["files"].append(file_info)
stats["total_files"] += 1
stats["total_size"] += file_info["size"]
# 按修改时间排序
stats["files"].sort(key=lambda x: x["modified"], reverse=True)
except Exception as e:
logger = get_logger("logger")
logger.error(f"获取日志统计信息时出错: {e}")
return stats
def shutdown_logging(): def shutdown_logging():
"""优雅关闭日志系统,释放所有文件句柄""" """优雅关闭日志系统,释放所有文件句柄"""
logger = get_logger("logger") logger = get_logger("logger")

View File

@@ -73,6 +73,9 @@ def find_messages(
if conditions: if conditions:
query = query.where(*conditions) query = query.where(*conditions)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
if filter_bot: if filter_bot:
query = query.where(Messages.user_id != global_config.bot.qq_account) query = query.where(Messages.user_id != global_config.bot.qq_account)
@@ -167,6 +170,9 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions: if conditions:
query = query.where(*conditions) query = query.where(*conditions)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
count = query.count() count = query.count()
return count return count
except Exception as e: except Exception as e:

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from .config_base import ConfigBase
@dataclass
class APIProvider(ConfigBase):
"""API提供商配置类"""
name: str
"""API提供商名称"""
base_url: str
"""API基础URL"""
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
def get_api_key(self) -> str:
return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url and self.client_type != "gemini":
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用"""
name: str
"""模型名称(用于模块调用)"""
api_provider: str
"""API提供商如OpenAI、Azure等"""
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@dataclass
class TaskConfig(ConfigBase):
"""任务配置类"""
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
max_tokens: int = 1024
"""任务最大输出token数"""
temperature: float = 0.3
"""模型温度"""
@dataclass
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer: TaskConfig
"""normal_chat首要回复模型模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")

View File

@@ -1,162 +0,0 @@
import shutil
import tomlkit
from tomlkit.items import Table, KeyType
from pathlib import Path
from datetime import datetime
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
return None
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
# 递归比较两个dict找出新增和删减项收集注释
if path is None:
path = []
if logs is None:
logs = []
if new_comments is None:
new_comments = {}
if old_comments is None:
old_comments = {}
# 新增项
for key in new:
if key == "version":
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
def update_config():
print("开始更新配置文件...")
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
# 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True)
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
# 读取旧配置文件
old_config = {}
if old_config_path.exists():
print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
print(f"从模板文件创建新配置: {template_path}")
shutil.copy2(template_path, new_config_path)
# 读取新配置文件
with open(new_config_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path) # type: ignore
return
else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
# 输出新增和删减项及注释
if old_config:
print("配置项变动如下:")
logs = compare_dicts(new_config, old_config)
if logs:
for log in logs:
print(log)
else:
print("无新增或删减项")
# 递归更新配置
def update_dict(target, source):
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
if not value:
target[key] = tomlkit.array()
else:
# 特殊处理正则表达式数组和包含正则表达式的结构
if key == "ban_msgs_regex":
# 直接使用原始值,不进行额外处理
target[key] = value
elif key == "regex_rules":
# 对于regex_rules需要特殊处理其中的regex字段
target[key] = value
else:
# 检查是否包含正则表达式相关的字典项
contains_regex = False
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
target[key] = value if contains_regex else tomlkit.array(str(value))
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
print("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
print("配置文件更新完成")
if __name__ == "__main__":
update_config()

View File

@@ -1,12 +1,14 @@
import os import os
import tomlkit import tomlkit
import shutil import shutil
import sys
from datetime import datetime from datetime import datetime
from tomlkit import TOMLDocument from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass from dataclasses import field, dataclass
from rich.traceback import install from rich.traceback import install
from typing import List, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
@@ -15,7 +17,6 @@ from src.config.official_configs import (
PersonalityConfig, PersonalityConfig,
ExpressionConfig, ExpressionConfig,
ChatConfig, ChatConfig,
NormalChatConfig,
EmojiConfig, EmojiConfig,
MemoryConfig, MemoryConfig,
MoodConfig, MoodConfig,
@@ -25,7 +26,6 @@ from src.config.official_configs import (
ResponseSplitterConfig, ResponseSplitterConfig,
TelemetryConfig, TelemetryConfig,
ExperimentalConfig, ExperimentalConfig,
ModelConfig,
MessageReceiveConfig, MessageReceiveConfig,
MaimMessageConfig, MaimMessageConfig,
LPMMKnowledgeConfig, LPMMKnowledgeConfig,
@@ -36,6 +36,13 @@ from src.config.official_configs import (
CustomPromptConfig, CustomPromptConfig,
) )
from .api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
install(extra_lines=3) install(extra_lines=3)
@@ -49,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/ # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.9.1" MMC_VERSION = "0.10.0"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):
@@ -62,8 +69,8 @@ def get_key_comment(toml_table, key):
return item.trivia.comment return item.trivia.comment
if hasattr(toml_table, "keys"): if hasattr(toml_table, "keys"):
for k in toml_table.keys(): for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key: if isinstance(k, KeyType) and k.key == key: # type: ignore
return k.trivia.comment return k.trivia.comment # type: ignore
return None return None
@@ -79,7 +86,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue continue
if key not in old: if key not in old:
comment = get_key_comment(new, key) comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}") logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs) compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项 # 删减项
@@ -88,7 +95,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue continue
if key not in new: if key not in new:
comment = get_key_comment(old, key) comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}") logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs return logs
@@ -102,11 +109,18 @@ def get_value_by_path(d, path):
def set_value_by_path(d, path, value): def set_value_by_path(d, path, value):
"""设置嵌套字典中指定路径的值"""
for k in path[:-1]: for k in path[:-1]:
if k not in d or not isinstance(d[k], dict): if k not in d or not isinstance(d[k], dict):
d[k] = {} d[k] = {}
d = d[k] d = d[k]
d[path[-1]] = value
# 使用 tomlkit.item 来保持 TOML 格式
try:
d[path[-1]] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None): def compare_default_values(new, old, path=None, logs=None, changes=None):
@@ -123,102 +137,140 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
if key in old: if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes) compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else: elif new[key] != old[key]:
# 只要值发生变化就记录 logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
if new[key] != old[key]: changes.append((path + [str(key)], old[key], new[key]))
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes return logs, changes
def update_config(): def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
def _version_tuple(v):
"""将版本字符串转换为元组以便比较"""
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
"""
# 获取根目录路径 # 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old") old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare") compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径 # 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml") template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
compare_path = os.path.join(compare_dir, "bot_config_template.toml") compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录如果不存在 # 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True) os.makedirs(compare_dir, exist_ok=True)
# 处理compare下的模板文件 template_version = _get_version_from_toml(template_path)
def get_version_from_toml(toml_path): compare_version = _get_version_from_toml(compare_path)
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path) # 检查配置文件是否存在
compare_version = get_version_from_toml(compare_path) if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
def version_tuple(v): compare_config = None
if v is None: new_config = None
return (0,) old_config = None
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
# 先读取 compare 下的模板(如果有),用于默认值变动检测 # 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path): if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f: with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f) compare_config = tomlkit.load(f)
else:
compare_config = None
# 读取当前模板 # 读取当前模板
with open(template_path, "r", encoding="utf-8") as f: with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f) new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做) # 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config is not None: if compare_config:
# 读取旧配置 # 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f: with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f) old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config) logs, changes = compare_default_values(new_config, compare_config)
if logs: if logs:
logger.info("检测到模板默认值变动如下:") logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs: for log in logs:
logger.info(log) logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值 # 检查旧配置是否等于旧默认值,如果是则更新为新默认值
config_updated = False
for path, old_default, new_default in changes: for path, old_default, new_default in changes:
old_value = get_value_by_path(old_config, path) old_value = get_value_by_path(old_config, path)
if old_value == old_default: if old_value == old_default:
set_value_by_path(old_config, path, new_default) set_value_by_path(old_config, path, new_default)
logger.info( logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
) )
config_updated = True
# 如果配置有更新,立即保存到文件
if config_updated:
with open(old_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(old_config))
logger.info(f"已保存更新后的{config_name}配置文件")
else: else:
logger.info("未检测到模板默认值变动") logger.info(f"未检测到{config_name}模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
# 检查 compare 下没有模板,或新模板版本更高,则复制 # 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path): if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path) shutil.copy2(template_path, compare_path)
logger.info(f"已将模板文件复制到: {compare_path}") logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else: else:
if version_tuple(template_version) > version_tuple(compare_version): logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
shutil.copy2(template_path, compare_path)
logger.info(f"模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的模板版本不低于当前模板无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
quit()
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次 # 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None: if old_config is None:
@@ -226,79 +278,60 @@ def update_config():
old_config = tomlkit.load(f) old_config = tomlkit.load(f)
# new_config 已经读取 # new_config 已经读取
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
# 检查version是否相同 # 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config: if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version: if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return return
else: else:
logger.info( logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
) )
else: else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在 # 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml") old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录 # 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path) shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}") logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录 # 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path) shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}") logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释 # 输出新增和删减项及注释
if old_config: if old_config:
logger.info("配置项变动如下:\n----------------------------------------") logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
logs = compare_dicts(new_config, old_config) if logs := compare_dicts(new_config, old_config):
if logs:
for log in logs: for log in logs:
logger.info(log) logger.info(log)
else: else:
logger.info("无新增或删减项") logger.info("无新增或删减项")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中 # 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...") logger.info(f"开始合并{config_name}新旧配置...")
update_dict(new_config, old_config) _update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式) # 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f: with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config)) f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
@dataclass @dataclass
@@ -312,7 +345,6 @@ class Config(ConfigBase):
relationship: RelationshipConfig relationship: RelationshipConfig
chat: ChatConfig chat: ChatConfig
message_receive: MessageReceiveConfig message_receive: MessageReceiveConfig
normal_chat: NormalChatConfig
emoji: EmojiConfig emoji: EmojiConfig
expression: ExpressionConfig expression: ExpressionConfig
memory: MemoryConfig memory: MemoryConfig
@@ -323,7 +355,6 @@ class Config(ConfigBase):
response_splitter: ResponseSplitterConfig response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig telemetry: TelemetryConfig
experimental: ExperimentalConfig experimental: ExperimentalConfig
model: ModelConfig
maim_message: MaimMessageConfig maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig tool: ToolConfig
@@ -331,11 +362,69 @@ class Config(ConfigBase):
custom_prompt: CustomPromptConfig custom_prompt: CustomPromptConfig
voice: VoiceConfig voice: VoiceConfig
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
# 检查模型名称是否重复
model_names = [model.name for model in self.models]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
for model in self.models:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
if not model.api_provider or model.api_provider not in self.api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
def load_config(config_path: str) -> Config: def load_config(config_path: str) -> Config:
""" """
加载配置文件 加载配置文件
:param config_path: 配置文件路径 Args:
:return: Config对象 config_path: 配置文件路径
Returns:
Config对象
""" """
# 读取配置文件 # 读取配置文件
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
@@ -349,18 +438,32 @@ def load_config(config_path: str) -> Config:
raise e raise e
def get_config_dir() -> str: def api_ada_load_config(config_path: str) -> APIAdapterConfig:
""" """
获取配置目录 加载API适配器配置文件
:return: 配置目录路径 Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
""" """
return CONFIG_DIR # 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 获取配置文件路径 # 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}") logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config() update_config()
update_model_config()
logger.info("正在品鉴配置文件...") logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!") logger.info("非常的新鲜,非常的美味!")

View File

@@ -1,7 +1,7 @@
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Optional from typing import Literal, Optional
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
@@ -17,7 +17,7 @@ from src.config.config_base import ConfigBase
@dataclass @dataclass
class BotConfig(ConfigBase): class BotConfig(ConfigBase):
"""QQ机器人配置类""" """QQ机器人配置类"""
platform: str platform: str
"""平台""" """平台"""
@@ -44,6 +44,9 @@ class PersonalityConfig(ConfigBase):
identity: str = "" identity: str = ""
"""身份特征""" """身份特征"""
reply_style: str = ""
"""表达风格"""
compress_personality: bool = True compress_personality: bool = True
"""是否压缩人格压缩后会精简人格信息节省token消耗并提高回复性能但是会丢失一些信息如果人设不长可以关闭""" """是否压缩人格压缩后会精简人格信息节省token消耗并提高回复性能但是会丢失一些信息如果人设不长可以关闭"""
@@ -68,155 +71,90 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18 max_context_size: int = 18
"""上下文长度""" """上下文长度"""
willing_amplifier: float = 1.0
replyer_random_probability: float = 0.5
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
thinking_timeout: int = 40
"""麦麦最长思考规划时间超过这个时间的思考会放弃往往是api反应太慢"""
talk_frequency: float = 1
"""回复频率阈值"""
mentioned_bot_inevitable_reply: bool = False mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复""" """提及 bot 必然回复"""
at_bot_inevitable_reply: bool = False at_bot_inevitable_reply: bool = False
"""@bot 必然回复""" """@bot 必然回复"""
talk_frequency: float = 0.5
"""回复频率阈值"""
# 修改:基于时段的回复频率配置,改为数组格式 # 合并后的时段频率配置
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
"""
基于时段的回复频率配置(全局)
格式:["HH:MM,frequency", "HH:MM,frequency", ...]
示例:["8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]
表示从该时间开始使用该频率,直到下一个时间点
"""
# 新增:基于聊天流的个性化时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: []) talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
"""
基于聊天流的个性化时段频率配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
示例:[
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"],
["qq:729957033:group", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"]
]
每个子列表的第一个元素是聊天流标识符,后续元素是"时间,频率"格式
表示从该时间开始使用该频率,直到下一个时间点
"""
focus_value: float = 1.0
focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多""" """麦麦的专注思考能力越低越容易专注消耗token也越多"""
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
"""
统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例:
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
]
说明:
- 当第一个元素为空字符串""时,表示全局默认配置
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
- 优先级:特定聊天流配置 > 全局配置 > 默认值
注意:
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
- focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多
"""
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
Args: @dataclass
chat_stream_id: 聊天流ID格式为 "platform:chat_id:type" class MessageReceiveConfig(ConfigBase):
"""消息接收配置类"""
Returns: ban_words: set[str] = field(default_factory=lambda: set())
float: 对应的频率值 """过滤词列表"""
"""
# 优先检查聊天流特定的配置
if chat_stream_id and self.talk_frequency_adjust:
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
if stream_frequency is not None:
return stream_frequency
# 如果没有聊天流特定配置,检查全局时段配置 ban_msgs_regex: set[str] = field(default_factory=lambda: set())
if self.time_based_talk_frequency: """过滤正则表达式列表"""
global_frequency = self._get_time_based_frequency(self.time_based_talk_frequency)
if global_frequency is not None:
return global_frequency
# 如果都没有匹配,返回默认值 @dataclass
return self.talk_frequency class ExpressionConfig(ConfigBase):
"""表达配置类"""
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]: learning_list: list[list] = field(default_factory=lambda: [])
""" """
根据时间配置列表获取当前时段的频率 表达学习配置列表,支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
示例:
[
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
]
说明:
- 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable")
- 第三位: 是否学习表达 ("enable"/"disable")
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
"""
Args: expression_groups: list[list[str]] = field(default_factory=list)
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...] """
表达学习互通组
Returns: 格式: [["qq:12345:group", "qq:67890:private"]]
float: 频率值,如果没有配置则返回 None """
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间频率配置
time_freq_pairs = []
for time_freq_str in time_freq_list:
try:
time_str, freq_str = time_freq_str.split(",")
hour, minute = map(int, time_str.split(":"))
frequency = float(freq_str)
minutes = hour * 60 + minute
time_freq_pairs.append((minutes, frequency))
except (ValueError, IndexError):
continue
if not time_freq_pairs:
return None
# 按时间排序
time_freq_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的频率
current_frequency = None
for minutes, frequency in time_freq_pairs:
if current_minutes >= minutes:
current_frequency = frequency
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
if current_frequency is None and time_freq_pairs:
current_frequency = time_freq_pairs[-1][1]
return current_frequency
def _get_stream_specific_frequency(self, chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 频率值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in self.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 使用通用的时间频率解析方法
return self._get_time_based_frequency(config_item[1:])
return None
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
""" """
@@ -253,46 +191,96 @@ class ChatConfig(ConfigBase):
except (ValueError, IndexError): except (ValueError, IndexError):
return None return None
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
"""
根据聊天流ID获取表达配置
@dataclass Args:
class MessageReceiveConfig(ConfigBase): chat_stream_id: 聊天流ID格式为哈希值
"""消息接收配置类"""
ban_words: set[str] = field(default_factory=lambda: set()) Returns:
"""过滤词列表""" tuple: (是否使用表达, 是否学习表达, 学习间隔)
"""
if not self.learning_list:
# 如果没有配置使用默认值启用表达启用学习300秒间隔
return True, True, 300
ban_msgs_regex: set[str] = field(default_factory=lambda: set()) # 优先检查聊天流特定的配置
"""过滤正则表达式列表""" if chat_stream_id:
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
if specific_expression_config is not None:
return specific_expression_config
# 检查全局配置(第一个元素为空字符串的配置)
global_expression_config = self._get_global_config()
if global_expression_config is not None:
return global_expression_config
@dataclass # 如果都没有匹配,返回默认值
class NormalChatConfig(ConfigBase): return True, True, 300
"""普通聊天配置类"""
willing_mode: str = "classical" def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
"""意愿模式""" """
获取特定聊天流的表达配置
@dataclass Args:
class ExpressionConfig(ConfigBase): chat_stream_id: 聊天流ID哈希值
"""表达配置类"""
enable_expression: bool = True Returns:
"""是否用表达方式""" tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
continue
expression_style: str = "" stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
"""表达风格"""
learning_interval: int = 300 # 如果是空字符串,跳过(这是全局配置)
"""学习间隔(秒)""" if stream_config_str == "":
continue
enable_expression_learning: bool = True # 解析配置字符串并生成对应的 chat_id
"""是否启用表达学习""" config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
expression_groups: list[list[str]] = field(default_factory=list) # 比较生成的 chat_id
""" if config_chat_id != chat_stream_id:
表达学习互通组 continue
格式: [["qq:12345:group", "qq:67890:private"]]
""" # 解析配置
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity: float = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
except (ValueError, IndexError):
continue
return None
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
"""
获取全局表达配置
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
continue
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
except (ValueError, IndexError):
continue
return None
@dataclass @dataclass
@@ -301,7 +289,8 @@ class ToolConfig(ConfigBase):
enable_tool: bool = False enable_tool: bool = False
"""是否在聊天中启用工具""" """是否在聊天中启用工具"""
@dataclass @dataclass
class VoiceConfig(ConfigBase): class VoiceConfig(ConfigBase):
"""语音识别配置类""" """语音识别配置类"""
@@ -317,9 +306,6 @@ class EmojiConfig(ConfigBase):
emoji_chance: float = 0.6 emoji_chance: float = 0.6
"""发送表情包的基础概率""" """发送表情包的基础概率"""
emoji_activate_type: str = "random"
"""表情包激活类型可选randomllmrandom下表情包动作随机启用llm下表情包动作根据llm判断是否启用"""
max_reg_num: int = 200 max_reg_num: int = 200
"""表情包最大注册数量""" """表情包最大注册数量"""
@@ -344,25 +330,10 @@ class MemoryConfig(ConfigBase):
"""记忆配置类""" """记忆配置类"""
enable_memory: bool = True enable_memory: bool = True
"""是否启用记忆系统"""
memory_build_interval: int = 600
"""记忆构建间隔(秒)""" memory_build_frequency: int = 1
"""记忆构建频率(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
memory_compress_rate: float = 0.1 memory_compress_rate: float = 0.1
"""记忆压缩率""" """记忆压缩率"""
@@ -376,18 +347,9 @@ class MemoryConfig(ConfigBase):
memory_forget_percentage: float = 0.01 memory_forget_percentage: float = 0.01
"""记忆遗忘比例""" """记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]) memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表""" """不允许记忆的词列表"""
enable_instant_memory: bool = True enable_instant_memory: bool = True
"""是否启用即时记忆""" """是否启用即时记忆"""
@@ -398,7 +360,7 @@ class MoodConfig(ConfigBase):
enable_mood: bool = False enable_mood: bool = False
"""是否启用情绪系统""" """是否启用情绪系统"""
mood_update_threshold: float = 1.0 mood_update_threshold: float = 1.0
"""情绪更新阈值,越高,更新越慢""" """情绪更新阈值,越高,更新越慢"""
@@ -449,6 +411,7 @@ class KeywordReactionConfig(ConfigBase):
if not isinstance(rule, KeywordRuleConfig): if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}") raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass @dataclass
class CustomPromptConfig(ConfigBase): class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类""" """自定义提示词配置类"""
@@ -597,52 +560,3 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024 embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致""" """嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
utils: dict[str, Any] = field(default_factory=lambda: {})
"""组件模型配置"""
utils_small: dict[str, Any] = field(default_factory=lambda: {})
"""组件小模型配置"""
replyer_1: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat首要回复模型模型配置"""
replyer_2: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat次要回复模型配置"""
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情绪模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
voice: dict[str, Any] = field(default_factory=lambda: {})
"""语音识别模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""专注工具使用模型配置"""
planner: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM实体提取模型配置"""
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM RDF构建模型配置"""
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM问答模型配置"""

View File

@@ -4,9 +4,8 @@ import hashlib
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -19,14 +18,10 @@ class Individuality:
def __init__(self): def __init__(self):
self.name = "" self.name = ""
self.bot_person_id = ""
self.meta_info_file_path = "data/personality/meta.json" self.meta_info_file_path = "data/personality/meta.json"
self.personality_data_file_path = "data/personality/personality_data.json" self.personality_data_file_path = "data/personality/personality_data.json"
self.model = LLMRequest( self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
model=global_config.model.utils,
request_type="individuality.compress",
)
async def initialize(self) -> None: async def initialize(self) -> None:
"""初始化个体特征""" """初始化个体特征"""
@@ -35,9 +30,6 @@ class Individuality:
personality_side = global_config.personality.personality_side personality_side = global_config.personality.personality_side
identity = global_config.personality.identity identity = global_config.personality.identity
person_info_manager = get_person_info_manager()
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
self.name = bot_nickname self.name = bot_nickname
# 检查配置变化,如果变化则清空 # 检查配置变化,如果变化则清空
@@ -68,16 +60,6 @@ class Individuality:
else: else:
logger.error("人设构建失败") logger.error("人设构建失败")
# 如果任何一个发生变化都需要清空数据库中的info_list因为这影响整体人设
if personality_changed or identity_changed:
logger.info("将清空数据库中原有的关键词缓存")
update_data = {
"platform": "system",
"user_id": "bot_id",
"person_name": self.name,
"nickname": self.name,
}
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
async def get_personality_block(self) -> str: async def get_personality_block(self) -> str:
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
@@ -85,16 +67,16 @@ class Individuality:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else: else:
bot_nickname = "" bot_nickname = ""
# 从文件获取 short_impression # 从文件获取 short_impression
personality, identity = self._get_personality_from_file() personality, identity = self._get_personality_from_file()
# 确保short_impression是列表格式且有足够的元素 # 确保short_impression是列表格式且有足够的元素
if not personality or not identity: if not personality or not identity:
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值") logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
personality = "友好活泼" personality = "友好活泼"
identity = "人类" identity = "人类"
prompt_personality = f"{personality}\n{identity}" prompt_personality = f"{personality}\n{identity}"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
@@ -134,7 +116,6 @@ class Individuality:
Returns: Returns:
tuple: (personality_changed, identity_changed) tuple: (personality_changed, identity_changed)
""" """
person_info_manager = get_person_info_manager()
current_personality_hash, current_identity_hash = self._get_config_hash( current_personality_hash, current_identity_hash = self._get_config_hash(
bot_nickname, personality_core, personality_side, identity bot_nickname, personality_core, personality_side, identity
) )
@@ -152,17 +133,6 @@ class Individuality:
if identity_changed: if identity_changed:
logger.info("检测到身份配置发生变化") logger.info("检测到身份配置发生变化")
# 如果任何一个发生变化都需要清空info_list因为这影响整体人设
if personality_changed or identity_changed:
logger.info("将清空原有的关键词缓存")
update_data = {
"platform": "system",
"user_id": "bot_id",
"person_name": self.name,
"nickname": self.name,
}
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
# 更新元信息文件 # 更新元信息文件
new_meta_info = { new_meta_info = {
"personality_hash": current_personality_hash, "personality_hash": current_personality_hash,
@@ -215,7 +185,7 @@ class Individuality:
def _get_personality_from_file(self) -> tuple[str, str]: def _get_personality_from_file(self) -> tuple[str, str]:
"""从文件获取personality数据 """从文件获取personality数据
Returns: Returns:
tuple: (personality, identity) tuple: (personality, identity)
""" """
@@ -226,7 +196,7 @@ class Individuality:
def _save_personality_to_file(self, personality: str, identity: str): def _save_personality_to_file(self, personality: str, identity: str):
"""保存personality数据到文件 """保存personality数据到文件
Args: Args:
personality: 压缩后的人格描述 personality: 压缩后的人格描述
identity: 压缩后的身份描述 identity: 压缩后的身份描述
@@ -235,7 +205,7 @@ class Individuality:
"personality": personality, "personality": personality,
"identity": identity, "identity": identity,
"bot_nickname": self.name, "bot_nickname": self.name,
"last_updated": int(time.time()) "last_updated": int(time.time()),
} }
self._save_personality_data(personality_data) self._save_personality_data(personality_data)
@@ -269,7 +239,7 @@ class Individuality:
2. 尽量简洁不超过30字 2. 尽量简洁不超过30字
3. 直接输出压缩后的内容,不要解释""" 3. 直接输出压缩后的内容,不要解释"""
response, (_, _) = await self.model.generate_response_async( response, _ = await self.model.generate_response_async(
prompt=prompt, prompt=prompt,
) )
@@ -281,7 +251,7 @@ class Individuality:
# 压缩失败时使用原始内容 # 压缩失败时使用原始内容
if personality_side: if personality_side:
personality_parts.append(personality_side) personality_parts.append(personality_side)
if personality_parts: if personality_parts:
personality_result = "".join(personality_parts) personality_result = "".join(personality_parts)
else: else:
@@ -308,7 +278,7 @@ class Individuality:
2. 尽量简洁不超过30字 2. 尽量简洁不超过30字
3. 直接输出压缩后的内容,不要解释""" 3. 直接输出压缩后的内容,不要解释"""
response, (_, _) = await self.model.generate_response_async( response, _ = await self.model.generate_response_async(
prompt=prompt, prompt=prompt,
) )

21
src/llm_models/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Mai.To.The.Gate
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

View File

@@ -0,0 +1,98 @@
from typing import Any
# 常见Error Code Mapping (以OpenAI API为例)
error_code_mapping = {
400: "参数不正确",
401: "API-Key错误认证失败请检查/config/model_list.toml中的配置是否正确",
402: "账号余额不足",
403: "模型拒绝访问,可能需要实名或余额不足",
404: "Not Found",
413: "请求体过大,请尝试压缩图片或减少输入内容",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class NetworkConnectionError(Exception):
"""连接异常,常见于网络问题或服务器不可用"""
def __init__(self):
super().__init__()
def __str__(self):
return "连接异常请检查网络连接状态或URL是否正确"
class ReqAbortException(Exception):
"""请求异常退出,常见于请求被中断或取消"""
def __init__(self, message: str | None = None):
super().__init__(message)
self.message = message
def __str__(self):
return self.message or "请求因未知原因异常终止"
class RespNotOkException(Exception):
"""请求响应异常,见于请求未能成功响应(非 '200 OK'"""
def __init__(self, status_code: int, message: str | None = None):
super().__init__(message)
self.status_code = status_code
self.message = message
def __str__(self):
if self.status_code in error_code_mapping:
return error_code_mapping[self.status_code]
elif self.message:
return self.message
else:
return f"未知的异常响应代码:{self.status_code}"
class RespParseException(Exception):
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
def __init__(self, ext_info: Any, message: str | None = None):
super().__init__(message)
self.ext_info = ext_info
self.message = message
def __str__(self):
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message

View File

@@ -0,0 +1,8 @@
from src.config.config import model_config
used_client_types = {provider.client_type for provider in model_config.api_providers}
if "openai" in used_client_types:
from . import openai_client # noqa: F401
if "gemini" in used_client_types:
from . import gemini_client # noqa: F401

View File

@@ -0,0 +1,178 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
@dataclass
class UsageRecord:
"""
使用记录类
"""
model_name: str
"""模型名称"""
provider_name: str
"""提供商名称"""
prompt_tokens: int
"""提示token数"""
completion_tokens: int
"""完成token数"""
total_tokens: int
"""总token数"""
@dataclass
class APIResponse:
"""
API响应类
"""
content: str | None = None
"""响应内容"""
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
raw_data: Any = None
"""响应原始数据"""
class BaseClient(ABC):
"""
基础客户端
"""
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
def register_client_class(self, client_type: str):
"""
注册API客户端类
Args:
client_class: API客户端类
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
return cls
return decorator
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
"""
获取注册的API客户端实例
Args:
api_provider: APIProvider实例
Returns:
BaseClient: 注册的API客户端实例
"""
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]
client_registry = ClientRegistry()

View File

@@ -0,0 +1,561 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from google import genai
from google.genai.types import (
Content,
Part,
FunctionDeclaration,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
ThinkingConfig,
Tool,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HarmCategory,
HarmBlockThreshold,
)
from google.genai.errors import (
ClientError,
ServerError,
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
)
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
gemini_safe_settings = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
]
def _convert_messages(
messages: list[Message],
) -> tuple[ContentListUnion, list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表(和可能存在的system消息)
"""
def _convert_message_item(message: Message) -> Content:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
:return: 转换后的消息字典
"""
# 将openai格式的角色重命名为gemini格式的角色
if message.role == RoleType.Assistant:
role = "model"
elif message.role == RoleType.User:
role = "user"
# 添加Content
if isinstance(message.content, str):
content = [Part.from_text(text=message.content)]
elif isinstance(message.content, list):
content: List[Part] = []
for item in message.content:
if isinstance(item, tuple):
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
content.append(
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")
)
elif isinstance(item, str):
content.append(Part.from_text(text=item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return Content(role=role, parts=content)
temp_list: list[ContentUnion] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise ValueError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
else:
temp_list.append(_convert_message_item(message))
if system_instructions:
# 如果有system消息就把它加上去
ret: tuple = (temp_list, system_instructions)
else:
# 如果没有system消息就直接返回
ret: tuple = (temp_list, None)
return ret
def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
"""
转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具对象列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return_dict: dict[str, Any] = {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
if tool_option_param.enum_values:
return_dict["enum"] = tool_option_param.enum_values
return return_dict
def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
ret1 = FunctionDeclaration(**ret)
return ret1
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
):
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.id or not call.name:
raise RespParseException(delta, "响应解析失败工具调用缺失id或name字段")
tool_calls_buffer.append(
(
call.id,
call.name,
call.args or {}, # 如果args是None则转换为一个空字典
)
)
except Exception as e:
raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if len(_tool_calls_buffer) > 0:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer is not None:
arguments = arguments_buffer
if not isinstance(arguments, dict):
raise RespParseException(
None,
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
)
else:
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncIterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in resp_stream:
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
raise ReqAbortException("请求被外部信号中断")
_process_delta(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
)
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count or 0,
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or not resp.candidates:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
try:
if resp.candidates[0].content and resp.candidates[0].content.parts:
for part in resp.candidates[0].content.parts:
if not part.text:
continue
if part.thought:
api_response.reasoning_content = (
api_response.reasoning_content + part.text if api_response.reasoning_content else part.text
)
except Exception as e:
logger.warning(f"解析思考内容时发生错误: {e},跳过解析")
if resp.text:
api_response.content = resp.text
if resp.function_calls:
api_response.tool_calls = []
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.name:
raise RespParseException(resp, "响应解析失败工具调用缺失name字段")
api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {}))
except Exception as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count or 0,
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
resp.usage_metadata.total_token_count or 0,
)
else:
_usage_record = None
api_response.raw_data = resp
return api_response, _usage_record
@client_registry.register_client_class("gemini")
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.4,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项可选默认为None
max_tokens: 最大token数可选默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的暂不支持其它相应格式输入
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
APIResponse对象包含响应内容、推理内容、工具调用等信息
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为Gemini API所需的格式
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
"max_output_tokens": max_tokens,
"temperature": temperature,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"]
if extra_params and "thinking_budget" in extra_params
else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复
),
),
"safety_settings": gemini_safe_settings, # 防止空回复问题
}
if tools:
generation_config_dict["tools"] = Tool(function_declarations=tools)
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
if response_format and response_format.format_type == RespFormatType.TEXT:
generation_config_dict["response_mime_type"] = "text/plain"
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = GenerateContentConfig(**generation_config_dict)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
req_task = asyncio.create_task(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=len(embedding_input),
completion_tokens=0,
total_tokens=len(embedding_input),
)
return response
def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: 音频文件的Base64编码字符串
:param extra_params: 额外参数(可选)
:return: 转录响应
"""
generation_config_dict = {
"max_output_tokens": 2048,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
),
"safety_settings": gemini_safe_settings,
}
generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
try:
raw_response: GenerateContentResponse = self.client.models.generate_content(
model=model_info.model_identifier,
contents=[
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
],
config=generate_content_config,
)
resp, usage_record = _default_normal_response_parser(raw_response)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]

View File

@@ -0,0 +1,591 @@
import asyncio
import io
import json
import re
import base64
from collections.abc import Iterable
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
from openai import (
AsyncOpenAI,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncStream,
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("OpenAI客户端")
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
"""
转换消息格式 - 将消息转换为OpenAI API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表
"""
def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
"""
转换单个消息格式
:param message: 消息对象
:return: 转换后的消息字典
"""
# 添加Content
content: str | list[dict[str, Any]]
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
}
)
elif isinstance(item, str):
content.append({"type": "text", "text": item})
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret = {
"role": message.role.value,
"content": content,
}
# 添加工具调用ID
if message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret # type: ignore
return [_convert_message_item(message) for message in messages]
def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
"""
转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具选项列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return_dict: dict[str, Any] = {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
if tool_option_param.enum_values:
return_dict["enum"] = tool_option_param.enum_values
return return_dict
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的工具选项字典
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
return ret
return [
{
"type": "function",
"function": _convert_tool_option_item(tool_option),
}
for tool_option in tool_options
]
def _process_delta(
delta: ChoiceDelta,
has_rc_attr_flag: bool,
in_rc_flag: bool,
rc_delta_buffer: io.StringIO,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> bool:
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str) # type: ignore
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
elif hasattr(delta, "content") and delta.content is not None:
# 没有独立的推理内容块,但有正式内容
if in_rc_flag:
# 当前在推理内容块中
if delta.content == "</think>":
# 如果当前内容是</think>,则将其视为推理内容的结束标记,退出推理内容块
in_rc_flag = False
else:
# 其他情况视为推理内容,加入推理内容缓冲区
rc_delta_buffer.write(delta.content)
elif delta.content == "<think>" and not fc_delta_buffer.getvalue():
# 如果当前内容是<think>,且正式内容缓冲区为空,说明<think>为输出的首个token
# 则将其视为推理内容的开始标记,进入推理内容块
in_rc_flag = True
else:
# 其他情况视为正式内容,加入正式内容缓冲区
fc_delta_buffer.write(delta.content)
# 接收tool_calls
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call_delta = delta.tool_calls[0]
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
)
)
else:
logger.warning("工具调用索引号大于等于缓冲区长度但缺少ID或函数信息。")
if tool_call_delta.function and tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
return in_rc_flag
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> APIResponse:
resp = APIResponse()
if _rc_delta_buffer.tell() > 0:
# 如果推理内容缓冲区不为空则将其写入APIResponse对象
resp.reasoning_content = _rc_delta_buffer.getvalue()
_rc_delta_buffer.close()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if _tool_calls_buffer:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer.tell() > 0:
# 如果参数串缓冲区不为空则解析为JSON对象
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(repair_json(raw_arg_data))
if not isinstance(arguments, dict):
raise RespParseException(
None,
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}",
)
except json.JSONDecodeError as e:
raise RespParseException(
None,
f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}",
) from e
else:
arguments_buffer.close()
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
:return: APIResponse对象
"""
_has_rc_attr_flag = False # 标记是否有独立的推理内容块
_in_rc_flag = False # 标记是否在推理内容块中
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
# 确保缓冲区被关闭
if _rc_delta_buffer and not _rc_delta_buffer.closed:
_rc_delta_buffer.close()
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
for _, _, buffer in _tool_calls_buffer:
if buffer and not buffer.closed:
buffer.close()
async for event in resp_stream:
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
_insure_buffer_closed()
raise ReqAbortException("请求被外部信号中断")
# 空 choices / usage-only 帧的防御
if not hasattr(event, "choices") or not event.choices:
if hasattr(event, "usage") and event.usage:
_usage_record = (
event.usage.prompt_tokens or 0,
event.usage.completion_tokens or 0,
event.usage.total_tokens or 0,
)
continue # 跳过本帧,避免访问 choices[0]
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
_in_rc_flag = _process_delta(
delta,
_has_rc_attr_flag,
_in_rc_flag,
_rc_delta_buffer,
_fc_delta_buffer,
_tool_calls_buffer,
)
if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
event.usage.prompt_tokens or 0,
event.usage.completion_tokens or 0,
event.usage.total_tokens or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
pattern = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
)
"""用于解析推理内容的正则表达式"""
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "choices") or len(resp.choices) == 0:
raise RespParseException(resp, "响应解析失败缺失choices字段")
message_part = resp.choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
# 有有效的推理字段
api_response.content = message_part.content
api_response.reasoning_content = message_part.reasoning_content # type: ignore
elif message_part.content:
# 提取推理和内容
match = pattern.match(message_part.content)
if not match:
raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
if match.group("think") is not None:
result = match.group("think").strip(), match.group("content").strip()
elif match.group("think_unclosed") is not None:
result = match.group("think_unclosed").strip(), None
else:
result = None, match.group("content_only").strip()
api_response.reasoning_content, api_response.content = result
# 提取工具调用
if message_part.tool_calls:
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(repair_json(call.function.arguments))
if not isinstance(arguments, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
# 提取Usage信息
if resp.usage:
_usage_record = (
resp.usage.prompt_tokens or 0,
resp.usage.completion_tokens or 0,
resp.usage.total_tokens or 0,
)
else:
_usage_record = None
# 将原始响应存储在原始数据中
api_response.raw_data = resp
return api_response, _usage_record
@client_registry.register_client_class("openai")
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
)
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项可选默认为None
max_tokens: 最大token数可选默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式(可选,默认为 NotGiven
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
(响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=NOT_GIVEN,
extra_body=extra_params,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
# 发送请求并获取响应
# start_time = time.time()
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=NOT_GIVEN,
extra_body=extra_params,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
resp, usage_record = async_response_parser(req_task.result())
except APIConnectionError as e:
# 重封装APIConnectionError为NetworkConnectionError
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
extra_body=extra_params,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析嵌入响应
if len(raw_response.data) > 0:
response.embedding = raw_response.data[0].embedding
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失嵌入数据。",
)
# 解析使用情况
if hasattr(raw_response, "usage"):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens or 0,
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
total_tokens=raw_response.usage.total_tokens or 0,
)
return response
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
try:
raw_response = await self.client.audio.transcriptions.create(
model=model_info.model_identifier,
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
extra_body=extra_params,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析转录响应
if hasattr(raw_response, "text"):
response.content = raw_response.text
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失转录文本。",
)
return response
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["jpg", "jpeg", "png", "webp", "gif"]

View File

@@ -0,0 +1,3 @@
from .tool_option import ToolCall
__all__ = ["ToolCall"]

View File

@@ -0,0 +1,107 @@
from enum import Enum
# 设计这系列类的目的是为未来可能的扩展做准备
class RoleType(Enum):
System = "system"
User = "user"
Assistant = "assistant"
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
class Message:
def __init__(
self,
role: RoleType,
content: str | list[tuple[str, str] | str],
tool_call_id: str | None = None,
):
"""
初始化消息对象
不应直接修改Message类而应使用MessageBuilder类来构建对象
"""
self.role: RoleType = role
self.content: str | list[tuple[str, str] | str] = content
self.tool_call_id: str | None = tool_call_id
class MessageBuilder:
def __init__(self):
self.__role: RoleType = RoleType.User
self.__content: list[tuple[str, str] | str] = []
self.__tool_call_id: str | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
"""
设置角色默认为User
:param role: 角色
:return: MessageBuilder对象
"""
self.__role = role
return self
def add_text_content(self, text: str) -> "MessageBuilder":
"""
添加文本内容
:param text: 文本内容
:return: MessageBuilder对象
"""
self.__content.append(text)
return self
def add_image_content(
self,
image_format: str,
image_base64: str,
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
) -> "MessageBuilder":
"""
添加图片内容
:param image_format: 图片格式
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")
self.__content.append((image_format, image_base64))
return self
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
"""
添加工具调用指令调用时请确保已设置为Tool角色
:param tool_call_id: 工具调用指令的id
:return: MessageBuilder对象
"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
if not tool_call_id:
raise ValueError("工具调用ID不能为空")
self.__tool_call_id = tool_call_id
return self
def build(self) -> Message:
"""
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0:
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")
return Message(
role=self.__role,
content=(
self.__content[0]
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
else self.__content
),
tool_call_id=self.__tool_call_id,
)

View File

@@ -0,0 +1,223 @@
from enum import Enum
from typing import Optional, Any
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
class RespFormatType(Enum):
TEXT = "text" # 文本
JSON_OBJ = "json_object" # JSON
JSON_SCHEMA = "json_schema" # JSON Schema
class JsonSchema(TypedDict, total=False):
name: Required[str]
"""
The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""
description: Optional[str]
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: dict[str, object]
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool]
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
learn more, read the
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
"""
def _json_schema_type_check(instance) -> str | None:
if "name" not in instance:
return "schema必须包含'name'字段"
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str)
or instance["description"].strip() == ""
):
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
return "schema必须包含'schema'字段"
elif not isinstance(instance["schema"], dict):
return "schema的'schema'字段必须是字典详见https://json-schema.org/"
if "strict" in instance and not isinstance(instance["strict"], bool):
return "schema的'strict'字段只能填入布尔值"
return None
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
"""
递归移除JSON Schema中的title字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
"""
链接JSON Schema中的definitions字段
"""
def link_definitions_recursive(
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
) -> dict[str, Any]:
"""
递归链接JSON Schema中的definitions字段
:param path: 当前路径
:param sub_schema: 子Schema
:param defs: Schema定义集
:return:
"""
if isinstance(sub_schema, list):
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(
f"{path}/{str(i)}", sub_schema[i], defs
)
else:
# 否则为字典
if "$defs" in sub_schema:
# 如果当前Schema有$def字段则将其添加到defs中
key_prefix = f"{path}/$defs/"
for key, value in sub_schema["$defs"].items():
def_key = key_prefix + key
if def_key not in defs:
defs[def_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
# 如果当前Schema有$ref字段则将其替换为defs中的定义
def_key = sub_schema["$ref"]
if def_key in defs:
sub_schema = defs[def_key]
else:
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
# 遍历键值对
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(
f"{path}/{key}", value, defs
)
return sub_schema
return link_definitions_recursive("#", schema, {})
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归移除JSON Schema中的$defs字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
class RespFormat:
"""
响应格式
"""
@staticmethod
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(
_link_definitions(_remove_title(schema.model_json_schema()))
),
"strict": False,
}
if schema.__doc__:
json_schema["description"] = schema.__doc__
return json_schema
def __init__(
self,
format_type: RespFormatType = RespFormatType.TEXT,
schema: type | JsonSchema | None = None,
):
"""
响应格式
:param format_type: 响应格式类型(默认为文本)
:param schema: 模板类或JsonSchema仅当format_type为JSON Schema时有效
"""
self.format_type: RespFormatType = format_type
if format_type == RespFormatType.JSON_SCHEMA:
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
self.schema = schema
elif issubclass(schema, BaseModel):
try:
json_schema = self._generate_schema_from_model(schema)
self.schema = json_schema
except Exception as e:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from e
else:
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
else:
self.schema = None
def to_dict(self):
"""
将响应格式转换为字典
:return: 字典
"""
if self.schema:
return {
"format_type": self.format_type.value,
"schema": self.schema,
}
else:
return {
"format_type": self.format_type.value,
}

View File

@@ -0,0 +1,163 @@
from enum import Enum
class ToolParamType(Enum):
"""
工具调用参数类型
"""
STRING = "string" # 字符串
INTEGER = "integer" # 整型
FLOAT = "float" # 浮点型
BOOLEAN = "bool" # 布尔型
class ToolParam:
"""
工具调用参数
"""
def __init__(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool,
enum_values: list[str] | None = None,
):
"""
初始化工具调用参数
不应直接修改ToolParam类而应使用ToolOptionBuilder类来构建对象
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填
"""
self.name: str = name
self.param_type: ToolParamType = param_type
self.description: str = description
self.required: bool = required
self.enum_values: list[str] | None = enum_values
class ToolOption:
"""
工具调用项
"""
def __init__(
self,
name: str,
description: str,
params: list[ToolParam] | None = None,
):
"""
初始化工具调用项
不应直接修改ToolOption类而应使用ToolOptionBuilder类来构建对象
:param name: 工具名称
:param description: 工具描述
:param params: 工具参数列表
"""
self.name: str = name
self.description: str = description
self.params: list[ToolParam] | None = params
class ToolOptionBuilder:
"""
工具调用项构建器
"""
def __init__(self):
self.__name: str = ""
self.__description: str = ""
self.__params: list[ToolParam] = []
def set_name(self, name: str) -> "ToolOptionBuilder":
"""
设置工具名称
:param name: 工具名称
:return: ToolBuilder实例
"""
if not name:
raise ValueError("工具名称不能为空")
self.__name = name
return self
def set_description(self, description: str) -> "ToolOptionBuilder":
"""
设置工具描述
:param description: 工具描述
:return: ToolBuilder实例
"""
if not description:
raise ValueError("工具描述不能为空")
self.__description = description
return self
def add_param(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool = False,
enum_values: list[str] | None = None,
) -> "ToolOptionBuilder":
"""
添加工具参数
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填默认为False
:return: ToolBuilder实例
"""
if not name or not description:
raise ValueError("参数名称/描述不能为空")
self.__params.append(
ToolParam(
name=name,
param_type=param_type,
description=description,
required=required,
enum_values=enum_values,
)
)
return self
def build(self):
"""
构建工具调用项
:return: 工具调用项
"""
if self.__name == "" or self.__description == "":
raise ValueError("工具名称/描述不能为空")
return ToolOption(
name=self.__name,
description=self.__description,
params=None if len(self.__params) == 0 else self.__params,
)
class ToolCall:
"""
来自模型反馈的工具调用
"""
def __init__(
self,
call_id: str,
func_name: str,
args: dict | None = None,
):
"""
初始化工具调用
:param call_id: 工具调用ID
:param func_name: 要调用的函数名称
:param args: 工具调用参数
"""
self.call_id: str = call_id
self.func_name: str = func_name
self.args: dict | None = args

189
src/llm_models/utils.py Normal file
View File

@@ -0,0 +1,189 @@
import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
logger = get_logger("消息压缩工具")
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
"""
压缩消息列表中的图片
:param messages: 消息列表
:param img_target_size: 图片目标大小默认1MB
:return: 压缩后的消息列表
"""
def reformat_static_image(image_data: bytes) -> bytes:
"""
将静态图片转换为JPEG格式
:param image_data: 图片数据
:return: 转换后的图片数据
"""
try:
image = Image.open(image_data)
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
# 静态图像转换为JPEG格式
reformated_image_data = io.BytesIO()
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
image_data = reformated_image_data.getvalue()
return image_data
except Exception as e:
logger.error(f"图片转换格式失败: {str(e)}")
return image_data
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
"""
缩放图片
:param image_data: 图片数据
:param scale: 缩放比例
:return: 缩放后的图片数据
"""
try:
image = Image.open(image_data)
# 原始尺寸
original_size = (image.width, image.height)
# 计算新的尺寸
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
output_buffer = io.BytesIO()
if getattr(image, "is_animated", False):
# 动态图片,处理所有帧
frames = []
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
for frame_idx in range(getattr(image, "n_frames", 1)):
image.seek(frame_idx)
new_frame = image.copy()
new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=True,
duration=image.info.get("duration", 100),
loop=image.info.get("loop", 0),
)
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
return output_buffer.getvalue(), original_size, new_size
except Exception as e:
logger.error(f"图片缩放失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return image_data, None, None
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
original_b64_data_size = len(base64_data) # 计算原始数据大小
image_data = base64.b64decode(base64_data)
# 先尝试转换格式为JPEG
image_data = reformat_static_image(image_data)
base64_data = base64.b64encode(image_data).decode("utf-8")
if len(base64_data) <= target_size:
# 如果转换后小于目标大小,直接返回
logger.info(f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB")
return base64_data
# 如果转换后仍然大于目标大小,进行尺寸压缩
scale = min(1.0, target_size / len(base64_data))
image_data, original_size, new_size = rescale_image(image_data, scale)
base64_data = base64.b64encode(image_data).decode("utf-8")
if original_size and new_size:
logger.info(
f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n"
f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB"
)
return base64_data
compressed_messages = []
for message in messages:
if isinstance(message.content, list):
# 检查content如有图片则压缩
message_builder = MessageBuilder()
for content_item in message.content:
if isinstance(content_item, tuple):
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(content_item[1], target_size=img_target_size),
)
else:
message_builder.add_text_content(content_item)
compressed_messages.append(message_builder.build())
else:
compressed_messages.append(message)
return compressed_messages
class LLMUsageRecorder:
"""
LLM使用情况记录器
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
model_api_provider=model_info.api_provider,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()

File diff suppressed because it is too large Load Diff

View File

@@ -2,12 +2,10 @@ import asyncio
import time import time
from maim_message import MessageServer from maim_message import MessageServer
from src.chat.express.expression_learner import get_expression_learner
from src.common.remote import TelemetryHeartBeatTask from src.common.remote import TelemetryHeartBeatTask
from src.manager.async_task_manager import async_task_manager from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.willing.willing_manager import get_willing_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.bot import chat_bot
@@ -16,6 +14,7 @@ from src.individuality.individuality import get_individuality, Individuality
from src.common.server import get_global_server, Server from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from rich.traceback import install from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server # from src.api.main import start_api_server
# 导入新的插件管理器 # 导入新的插件管理器
@@ -32,8 +31,6 @@ if global_config.memory.enable_memory:
install(extra_lines=3) install(extra_lines=3)
willing_manager = get_willing_manager()
logger = get_logger("main") logger = get_logger("main")
@@ -53,12 +50,22 @@ class MainSystem:
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
logger.debug(f"正在唤醒{global_config.bot.nickname}......") logger.info(f"正在唤醒{global_config.bot.nickname}......")
# 其他初始化任务 # 其他初始化任务
await asyncio.gather(self._init_components()) await asyncio.gather(self._init_components())
logger.debug("系统初始化完成") logger.info(f"""
--------------------------------
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
--------------------------------
如果想要自定义{global_config.bot.nickname}的功能,请查阅https://docs.mai-mai.org/manual/usage/
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
--------------------------------
如果你想要编写或了解插件相关内容请访问开发文档https://docs.mai-mai.org/develop/
--------------------------------
如果你需要查阅模型的消耗以及麦麦的统计数据请访问根目录的maibot_statistics.html文件
""")
async def _init_components(self): async def _init_components(self):
"""初始化其他组件""" """初始化其他组件"""
@@ -84,11 +91,6 @@ class MainSystem:
get_emoji_manager().initialize() get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功") logger.info("表情包管理器初始化成功")
# 启动愿望管理器
await willing_manager.async_task_starter()
logger.info("willing管理器初始化成功")
# 启动情绪管理器 # 启动情绪管理器
await mood_manager.start() await mood_manager.start()
logger.info("情绪管理器初始化成功") logger.info("情绪管理器初始化成功")
@@ -115,6 +117,9 @@ class MainSystem:
# 初始化个体特征 # 初始化个体特征
await self.individuality.initialize() await self.individuality.initialize()
await check_and_run_migrations()
try: try:
init_time = int(1000 * (time.time() - init_start_time)) init_time = int(1000 * (time.time() - init_start_time))
@@ -136,23 +141,14 @@ class MainSystem:
if global_config.memory.enable_memory and self.hippocampus_manager: if global_config.memory.enable_memory and self.hippocampus_manager:
tasks.extend( tasks.extend(
[ [
self.build_memory_task(), # 移除记忆构建的定期调用改为在heartFC_chat.py中调用
# self.build_memory_task(),
self.forget_memory_task(), self.forget_memory_task(),
self.consolidate_memory_task(),
] ]
) )
tasks.append(self.learn_and_store_expression_task())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def build_memory_task(self):
"""记忆构建任务"""
while True:
await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self): async def forget_memory_task(self):
"""记忆遗忘任务""" """记忆遗忘任务"""
while True: while True:
@@ -161,24 +157,7 @@ class MainSystem:
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成") logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self):
"""记忆整合任务"""
while True:
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
logger.info("[记忆整合] 开始整合记忆...")
await self.hippocampus_manager.consolidate_memory() # type: ignore
logger.info("[记忆整合] 记忆整合完成")
@staticmethod
async def learn_and_store_expression_task():
"""学习并存储表达方式任务"""
expression_learner = get_expression_learner()
while True:
await asyncio.sleep(global_config.expression.learning_interval)
if global_config.expression.enable_expression_learning and global_config.expression.enable_expression:
logger.info("[表达方式学习] 开始学习表达方式...")
await expression_learner.learn_and_store_expression()
logger.info("[表达方式学习] 表达方式学习完成")
async def main(): async def main():
@@ -192,3 +171,5 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,132 +0,0 @@
[inner]
version = "1.1.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 8 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示
enable_streaming_output = false # 是否启用流式输出false时全部生成后一次性发送
max_context_message_length = 30
max_core_message_length = 20
# 模型配置
[models]
# 主要对话模型配置
[models.chat]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 规划模型配置
[models.motion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 情感分析模型配置
[models.emotion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 记忆模型配置
[models.memory]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 工具使用模型配置
[models.tool_use]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 嵌入模型配置
[models.embedding]
name = "text-embedding-v1"
provider = "OPENAI"
dimension = 1024
# 视觉语言模型配置
[models.vlm]
name = "qwen-vl-plus"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 知识库模型配置
[models.knowledge]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 实体提取模型配置
[models.entity_extract]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 问答模型配置
[models.qa]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
# 兼容性配置已废弃请使用models.motion
[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
# 强烈建议使用免费的小模型
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false # 是否启用思考

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "1.1.0" version = "1.2.0"
#----以下是S4U聊天系统配置文件---- #----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块 # S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
@@ -12,6 +12,7 @@ version = "1.1.0"
#----S4U配置说明结束---- #----S4U配置说明结束----
[s4u] [s4u]
enable_s4u = false
# 消息管理配置 # 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃 message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除 recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除

View File

@@ -1 +0,0 @@
ENABLE_S4U = False

View File

@@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
import time import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def init_prompt(): def init_prompt():
Prompt( Prompt(
""" """
@@ -32,10 +34,8 @@ def init_prompt():
) )
class MaiThinking: class MaiThinking:
def __init__(self,chat_id): def __init__(self, chat_id):
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform self.platform = self.chat_stream.platform
@@ -44,11 +44,11 @@ class MaiThinking:
self.is_group = True self.is_group = True
else: else:
self.is_group = False self.is_group = False
self.s4u_message_processor = S4UMessageProcessor() self.s4u_message_processor = S4UMessageProcessor()
self.mind = "" self.mind = ""
self.memory_block = "" self.memory_block = ""
self.relation_info_block = "" self.relation_info_block = ""
self.time_block = "" self.time_block = ""
@@ -59,17 +59,13 @@ class MaiThinking:
self.identity = "" self.identity = ""
self.sender = "" self.sender = ""
self.target = "" self.target = ""
self.thinking_model = LLMRequest( self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
model=global_config.model.replyer_1,
request_type="thinking",
)
async def do_think_before_response(self): async def do_think_before_response(self):
pass pass
async def do_think_after_response(self,reponse:str): async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt", "after_response_think_prompt",
mind=self.mind, mind=self.mind,
@@ -85,47 +81,44 @@ class MaiThinking:
sender=self.sender, sender=self.sender,
target=self.target, target=self.target,
) )
result, _ = await self.thinking_model.generate_response_async(prompt) result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}") logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}") # logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}") logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind) msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv) await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind) internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self): async def do_think_when_receive_message(self):
pass pass
async def build_internal_message_recv(self,message_text:str): async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}" msg_id = f"internal_{time.time()}"
message_dict = { message_dict = {
"message_info": { "message_info": {
"message_id": msg_id, "message_id": msg_id,
"time": time.time(), "time": time.time(),
"user_info": { "user_info": {
"user_id": "internal", # 内部用户ID "user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称 "user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal "platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充 # 其他 user_info 字段按需补充
}, },
"platform": self.platform, # 平台 "platform": self.platform, # 平台
# 其他 message_info 字段按需补充 # 其他 message_info 字段按需补充
}, },
"message_segment": { "message_segment": {
"type": "text", # 消息类型 "type": "text", # 消息类型
"data": message_text, # 消息内容 "data": message_text, # 消息内容
# 其他 segment 字段按需补充 # 其他 segment 字段按需补充
}, },
"raw_message": message_text, # 原始消息内容 "raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本 "processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要 # 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False, "is_emoji": False,
"has_emoji": False, "has_emoji": False,
@@ -139,45 +132,36 @@ class MaiThinking:
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级 "priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0, "interest_value": 1.0,
} }
if self.is_group: if self.is_group:
message_dict["message_info"]["group_info"] = { message_dict["message_info"]["group_info"] = {
"platform": self.platform, "platform": self.platform,
"group_id": self.chat_stream.group_info.group_id, "group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name, "group_name": self.chat_stream.group_info.group_name,
} }
msg_recv = MessageRecvS4U(message_dict) msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True msg_recv.is_internal = True
return msg_recv return msg_recv
class MaiThinkingManager: class MaiThinkingManager:
def __init__(self): def __init__(self):
self.mai_think_list = [] self.mai_think_list = []
def get_mai_think(self,chat_id): def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list: for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id: if mai_think.chat_id == chat_id:
return mai_think return mai_think
mai_think = MaiThinking(chat_id) mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think) self.mai_think_list.append(mai_think)
return mai_think return mai_think
mai_thinking_manager = MaiThinkingManager() mai_thinking_manager = MaiThinkingManager()
init_prompt() init_prompt()

View File

@@ -1,19 +1,22 @@
import json import json
import time import time
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
from json_repair import repair_json
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
logger = get_logger("action") logger = get_logger("action")
HEAD_CODE = { # 使用字典作为默认值但通过Prompt来注册以便外部重载
DEFAULT_HEAD_CODE = {
"看向上方": "(0,0.5,0)", "看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)", "看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)", "看向左边": "(-1,0,0)",
@@ -24,7 +27,7 @@ HEAD_CODE = {
"看向正前方": "(0,0,0)", "看向正前方": "(0,0,0)",
} }
BODY_CODE = { DEFAULT_BODY_CODE = {
"双手背后向前弯腰": "010_0070", "双手背后向前弯腰": "010_0070",
"歪头双手合十": "010_0100", "歪头双手合十": "010_0100",
"标准文静站立": "010_0101", "标准文静站立": "010_0101",
@@ -32,7 +35,7 @@ BODY_CODE = {
"帅气的姿势": "010_0190", "帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191", "另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210", "手掌朝前可爱": "010_0210",
"平静,双手后放":"平静,双手后放", "平静,双手后放": "平静,双手后放",
"思考": "思考", "思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上", "优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般", "一般": "一般",
@@ -40,7 +43,44 @@ BODY_CODE = {
} }
def get_head_code() -> dict:
"""获取头部动作代码字典"""
head_code_str = global_prompt_manager.get_prompt("head_code_prompt")
if not head_code_str:
return DEFAULT_HEAD_CODE
try:
return json.loads(head_code_str)
except Exception as e:
logger.error(f"解析head_code_prompt失败使用默认值: {e}")
return DEFAULT_HEAD_CODE
def get_body_code() -> dict:
"""获取身体动作代码字典"""
body_code_str = global_prompt_manager.get_prompt("body_code_prompt")
if not body_code_str:
return DEFAULT_BODY_CODE
try:
return json.loads(body_code_str)
except Exception as e:
logger.error(f"解析body_code_prompt失败使用默认值: {e}")
return DEFAULT_BODY_CODE
def init_prompt(): def init_prompt():
# 注册头部动作代码
Prompt(
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
"head_code_prompt",
)
# 注册身体动作代码
Prompt(
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
"body_code_prompt",
)
# 注册原有提示模板
Prompt( Prompt(
""" """
{chat_talking_prompt} {chat_talking_prompt}
@@ -94,20 +134,16 @@ class ChatAction:
self.body_action_cooldown: dict[str, int] = {} self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion) print(s4u_config.models.motion)
print(global_config.model.emotion) print(model_config.model_task_config.emotion)
self.action_model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="motion",
)
self.last_change_time = 0 self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self): async def send_action_update(self):
"""发送动作更新到前端""" """发送动作更新到前端"""
body_code = BODY_CODE.get(self.body_action, "") body_code = get_body_code().get(self.body_action, "")
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="body_action", message_type="body_action",
content=body_code, content=body_code,
@@ -115,13 +151,11 @@ class ChatAction:
storage_message=False, storage_message=False,
show_log=True, show_log=True,
) )
async def update_action_by_message(self, message: MessageRecv): async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0 self.regression_count = 0
message_time = message.message_info.time message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_change_time, timestamp_start=self.last_change_time,
@@ -147,13 +181,13 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try: try:
# 冷却池处理:过滤掉冷却中的动作 # 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown() self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions) all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"change_action_prompt", "change_action_prompt",
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
@@ -163,19 +197,18 @@ class ChatAction:
) )
logger.info(f"prompt: {prompt}") logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}") logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}") logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response)) if action_data := json.loads(repair_json(response)):
if action_data:
# 记录原动作,切换后进入冷却 # 记录原动作,切换后进入冷却
prev_body_action = self.body_action prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action) new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action: if new_body_action != prev_body_action and prev_body_action:
if prev_body_action: self.body_action_cooldown[prev_body_action] = 3
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action) self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新 # 发送动作更新
@@ -213,10 +246,9 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try: try:
# 冷却池处理:过滤掉冷却中的动作 # 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown() self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions) all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
@@ -228,17 +260,17 @@ class ChatAction:
) )
logger.info(f"prompt: {prompt}") logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}") logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}") logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response)) if action_data := json.loads(repair_json(response)):
if action_data:
prev_body_action = self.body_action prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action) new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action: if new_body_action != prev_body_action and prev_body_action:
if prev_body_action: self.body_action_cooldown[prev_body_action] = 6
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action self.body_action = new_body_action
# 发送动作更新 # 发送动作更新
await self.send_action_update() await self.send_action_update()
@@ -306,9 +338,6 @@ class ActionManager:
return new_action_state return new_action_state
init_prompt() init_prompt()
action_manager = ActionManager() action_manager = ActionManager()

View File

@@ -16,10 +16,9 @@ import json
from .s4u_mood_manager import mood_manager from .s4u_mood_manager import mood_manager
from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import PersonInfoManager from src.person_info.person_info import get_person_id
from .super_chat_manager import get_super_chat_manager from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head from .yes_or_no import yes_or_no_head
from src.mais4u.constant_s4u import ENABLE_S4U
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -137,7 +136,7 @@ class MessageSenderContainer:
await self.storage.store_message(bot_message, self.chat_stream) await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally: finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved. # CRUCIAL: Always call task_done() for any item that was successfully retrieved.
@@ -166,7 +165,7 @@ class S4UChatManager:
return self.s4u_chats[chat_stream.stream_id] return self.s4u_chats[chat_stream.stream_id]
if not ENABLE_S4U: if not s4u_config.enable_s4u:
s4u_chat_manager = None s4u_chat_manager = None
else: else:
s4u_chat_manager = S4UChatManager() s4u_chat_manager = S4UChatManager()
@@ -262,7 +261,7 @@ class S4UChat:
"""根据VIP状态和中断逻辑将消息放入相应队列。""" """根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id
platform = message.message_info.platform platform = message.message_info.platform
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = get_person_id(platform, user_id)
try: try:
is_gift = message.is_gift is_gift = message.is_gift

Some files were not shown because too many files have changed in this diff Show More