7
.gitignore
vendored
7
.gitignore
vendored
@@ -7,6 +7,7 @@ logs/
|
|||||||
out/
|
out/
|
||||||
tool_call_benchmark.py
|
tool_call_benchmark.py
|
||||||
run_maibot_core.bat
|
run_maibot_core.bat
|
||||||
|
run_voice.bat
|
||||||
run_napcat_adapter.bat
|
run_napcat_adapter.bat
|
||||||
run_ad.bat
|
run_ad.bat
|
||||||
s4u.s4u
|
s4u.s4u
|
||||||
@@ -40,7 +41,10 @@ config/bot_config.toml
|
|||||||
config/bot_config.toml.bak
|
config/bot_config.toml.bak
|
||||||
config/lpmm_config.toml
|
config/lpmm_config.toml
|
||||||
config/lpmm_config.toml.bak
|
config/lpmm_config.toml.bak
|
||||||
|
src/mais4u/config/s4u_config.toml
|
||||||
|
src/mais4u/config/old
|
||||||
template/compare/bot_config_template.toml
|
template/compare/bot_config_template.toml
|
||||||
|
template/compare/model_config_template.toml
|
||||||
(测试版)麦麦生成人格.bat
|
(测试版)麦麦生成人格.bat
|
||||||
(临时版)麦麦开始学习.bat
|
(临时版)麦麦开始学习.bat
|
||||||
src/plugins/utils/statistic.py
|
src/plugins/utils/statistic.py
|
||||||
@@ -321,4 +325,5 @@ run_pet.bat
|
|||||||
|
|
||||||
config.toml
|
config.toml
|
||||||
|
|
||||||
interested_rates.txt
|
interested_rates.txt
|
||||||
|
MaiBot.code-workspace
|
||||||
|
|||||||
@@ -25,8 +25,8 @@
|
|||||||
|
|
||||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||||
|
|
||||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。
|
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||||
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。
|
- 🔌 **强大插件系统**:全面重构的插件架构,更多API。
|
||||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||||
@@ -46,7 +46,7 @@
|
|||||||
|
|
||||||
## 🔥 更新和安装
|
## 🔥 更新和安装
|
||||||
|
|
||||||
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md))
|
**最新版本: v0.10.0** ([更新日志](changelogs/changelog.md))
|
||||||
|
|
||||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||||
@@ -56,7 +56,6 @@
|
|||||||
- `classical`: 旧版本(停止维护)
|
- `classical`: 旧版本(停止维护)
|
||||||
|
|
||||||
### 最新版本部署教程
|
### 最新版本部署教程
|
||||||
- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html)
|
|
||||||
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
|
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
@@ -64,7 +63,6 @@
|
|||||||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||||
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
|
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
|
||||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||||
> - 由于持续迭代,可能存在一些已知或未知的 bug。
|
|
||||||
> - 由于程序处于开发中,可能消耗较多 token。
|
> - 由于程序处于开发中,可能消耗较多 token。
|
||||||
|
|
||||||
## 💬 讨论
|
## 💬 讨论
|
||||||
|
|||||||
41
bot.py
41
bot.py
@@ -20,11 +20,13 @@ from rich.traceback import install
|
|||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||||
from src.main import MainSystem
|
|
||||||
from src.manager.async_task_manager import async_task_manager
|
|
||||||
|
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
|
|
||||||
|
from src.main import MainSystem #noqa
|
||||||
|
from src.manager.async_task_manager import async_task_manager #noqa
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
|
||||||
@@ -74,36 +76,6 @@ def easter_egg():
|
|||||||
print(rainbow_text)
|
print(rainbow_text)
|
||||||
|
|
||||||
|
|
||||||
def scan_provider(env_config: dict):
|
|
||||||
provider = {}
|
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
|
||||||
# 提取 provider 名称
|
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
|
||||||
provider[provider_name]["url"] = env_config[key]
|
|
||||||
elif key.endswith("_KEY"):
|
|
||||||
provider[provider_name]["key"] = env_config[key]
|
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
|
||||||
if config["url"] is None or config["key"] is None:
|
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
|
||||||
|
|
||||||
|
|
||||||
async def graceful_shutdown():
|
async def graceful_shutdown():
|
||||||
try:
|
try:
|
||||||
@@ -229,9 +201,6 @@ def raw_main():
|
|||||||
|
|
||||||
easter_egg()
|
easter_egg()
|
||||||
|
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
|
||||||
scan_provider(env_config)
|
|
||||||
|
|
||||||
# 返回MainSystem实例
|
# 返回MainSystem实例
|
||||||
return MainSystem()
|
return MainSystem()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,76 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.10.0] - 2025-7-1
|
||||||
|
### 🌟 主要功能更改
|
||||||
|
- 优化的回复生成,现在的回复对上下文把控更加精准
|
||||||
|
- 新的回复逻辑控制,现在合并了normal和focus模式,更加统一
|
||||||
|
- 优化表达方式系统,现在学习和使用更加精准
|
||||||
|
- 新的关系系统,现在的关系构建更精准也更克制
|
||||||
|
- 工具系统重构,现在合并到了插件系统中
|
||||||
|
- 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数
|
||||||
|
- 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||||
|
- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API
|
||||||
|
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
|
||||||
|
|
||||||
|
#### 🔧 工具系统重构
|
||||||
|
- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力
|
||||||
|
- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式
|
||||||
|
- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性
|
||||||
|
|
||||||
|
#### 🚀 LLM系统全面重构
|
||||||
|
- **LLM Request重构**: 彻底重构了整个LLM Request系统,现在支持模型轮询和更多灵活的参数
|
||||||
|
- **模型配置升级**: 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||||
|
- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑
|
||||||
|
- **异常处理增强**: 增强LLMRequest类的异常处理,添加统一的模型异常处理方法
|
||||||
|
|
||||||
|
#### 🔌 插件系统稳定化
|
||||||
|
- **插件系统重构完成**: 随着LLM Request的重构,插件系统彻底重构完成,进入稳定状态
|
||||||
|
- **API扩展**: 仅增加新的API,保持向后兼容性
|
||||||
|
- **插件管理优化**: 让插件管理配置真正有用,提升管理体验
|
||||||
|
|
||||||
|
#### 💾 记忆系统优化
|
||||||
|
- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建
|
||||||
|
- **精确提取**: 记忆提取更精确,提升记忆质量
|
||||||
|
|
||||||
|
#### 🎭 表达方式系统
|
||||||
|
- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪
|
||||||
|
- **学习优化**: 优化表达方式提取,修复表达学习出错问题
|
||||||
|
- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性
|
||||||
|
|
||||||
|
#### 🔄 聊天系统统一
|
||||||
|
- **normal和focus合并**: 彻底合并normal和focus,完全基于planner决定target message
|
||||||
|
- **no_reply内置**: 将no_reply功能移动到主循环中,简化系统架构
|
||||||
|
- **回复优化**: 优化reply,填补缺失值,让麦麦可以回复自己的消息
|
||||||
|
- **频率控制API**: 加入聊天频率控制相关API,提供更精细的控制
|
||||||
|
|
||||||
|
#### 日志系统改进
|
||||||
|
- **日志颜色优化**: 修改了log的颜色,更加护眼
|
||||||
|
- **日志清理优化**: 修复了日志清理先等24h的问题,提升系统性能
|
||||||
|
- **计时定位**: 通过计时定位LLM异常延时,提升问题排查效率
|
||||||
|
|
||||||
|
### 🐛 问题修复
|
||||||
|
|
||||||
|
#### 代码质量提升
|
||||||
|
- **lint问题修复**: 修复了lint爆炸的问题,代码更加规范了
|
||||||
|
- **导入优化**: 修复导入爆炸和文档错误,优化代码结构
|
||||||
|
|
||||||
|
#### 系统稳定性
|
||||||
|
- **循环导入**: 修复了import时循环导入的问题
|
||||||
|
- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力
|
||||||
|
- **空响应处理**: 空响应就raise,避免系统异常
|
||||||
|
|
||||||
|
#### 功能修复
|
||||||
|
- **API问题**: 修复api问题,提升系统可用性
|
||||||
|
- **notice问题**: 为组件方法提供新参数,暂时解决notice问题
|
||||||
|
- **关系构建**: 修复不认识的用户构建关系问题
|
||||||
|
- **流式解析**: 修复流式解析越界问题,避免空choices的SSE帧错误
|
||||||
|
|
||||||
|
#### 配置和兼容性
|
||||||
|
- **默认值**: 添加默认值,提升配置灵活性
|
||||||
|
- **类型问题**: 修复类型问题,提升代码健壮性
|
||||||
|
- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性
|
||||||
|
|
||||||
|
|
||||||
## [0.9.1] - 2025-7-26
|
## [0.9.1] - 2025-7-26
|
||||||
|
|
||||||
### 主要修复和优化
|
### 主要修复和优化
|
||||||
@@ -81,7 +152,7 @@ MaiBot 0.9.0 重磅升级!本版本带来两大核心突破:**全面重构
|
|||||||
#### 问题修复与优化
|
#### 问题修复与优化
|
||||||
|
|
||||||
- 修复normal planner没有超时退出问题,添加回复超时检查
|
- 修复normal planner没有超时退出问题,添加回复超时检查
|
||||||
- 重构no_reply逻辑,不再使用小模型,采用激活度决定
|
- 重构no_action逻辑,不再使用小模型,采用激活度决定
|
||||||
- 修复图片与文字混合兴趣值为0的情况
|
- 修复图片与文字混合兴趣值为0的情况
|
||||||
- 适配无兴趣度消息处理
|
- 适配无兴趣度消息处理
|
||||||
- 优化Docker镜像构建流程,合并AMD64和ARM64构建步骤
|
- 优化Docker镜像构建流程,合并AMD64和ARM64构建步骤
|
||||||
@@ -149,7 +220,7 @@ MMC启动速度加快
|
|||||||
- 移除冗余处理器
|
- 移除冗余处理器
|
||||||
- 精简处理器上下文,减少不必要的处理
|
- 精简处理器上下文,减少不必要的处理
|
||||||
- 后置工具处理器,大大减少token消耗
|
- 后置工具处理器,大大减少token消耗
|
||||||
- **统计系统**: 提供focus统计功能,可查看详细的no_reply统计信息
|
- **统计系统**: 提供focus统计功能,可查看详细的no_action统计信息
|
||||||
|
|
||||||
|
|
||||||
### ⏰ 聊天频率精细控制
|
### ⏰ 聊天频率精细控制
|
||||||
|
|||||||
@@ -25,6 +25,7 @@
|
|||||||
- 这意味着你终于可以动态控制是否继续后续消息的处理了。
|
- 这意味着你终于可以动态控制是否继续后续消息的处理了。
|
||||||
8. 移除了dependency_manager,但是依然保留了`python_dependencies`属性,等待后续重构。
|
8. 移除了dependency_manager,但是依然保留了`python_dependencies`属性,等待后续重构。
|
||||||
- 一并移除了文档有关manager的内容。
|
- 一并移除了文档有关manager的内容。
|
||||||
|
9. 增加了工具的有关api
|
||||||
|
|
||||||
# 插件系统修改
|
# 插件系统修改
|
||||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||||
@@ -57,30 +58,12 @@
|
|||||||
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
|
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
|
||||||
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
|
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
|
||||||
- 同样不保存到配置文件~
|
- 同样不保存到配置文件~
|
||||||
|
16. 把`BaseTool`一并合并进入了插件系统
|
||||||
|
|
||||||
# 官方插件修改
|
# 官方插件修改
|
||||||
1. `HelloWorld`插件现在有一个样例的`EventHandler`。
|
1. `HelloWorld`插件现在有一个样例的`EventHandler`。
|
||||||
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。
|
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用)
|
||||||
|
3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`。
|
||||||
### TODO
|
|
||||||
把这个看起来就很别扭的config获取方式改一下
|
|
||||||
|
|
||||||
|
|
||||||
# 吐槽
|
|
||||||
```python
|
|
||||||
plugin_path = Path(plugin_file)
|
|
||||||
if plugin_path.parent.name != "plugins":
|
|
||||||
# 插件包格式:parent_dir.plugin
|
|
||||||
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
|
||||||
else:
|
|
||||||
# 单文件格式:plugins.filename
|
|
||||||
module_name = f"plugins.{plugin_path.stem}"
|
|
||||||
```
|
|
||||||
```python
|
|
||||||
plugin_path = Path(plugin_file)
|
|
||||||
module_name = ".".join(plugin_path.parent.parts)
|
|
||||||
```
|
|
||||||
这两个区别很大的。
|
|
||||||
|
|
||||||
### 执笔BGM
|
### 执笔BGM
|
||||||
塞壬唱片!
|
塞壬唱片!
|
||||||
BIN
docs/image-1.png
Normal file
BIN
docs/image-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
BIN
docs/image.png
Normal file
BIN
docs/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.9 KiB |
325
docs/model_configuration_guide.md
Normal file
325
docs/model_configuration_guide.md
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
# 模型配置指南
|
||||||
|
|
||||||
|
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
|
||||||
|
|
||||||
|
## 配置文件结构
|
||||||
|
|
||||||
|
配置文件主要包含以下几个部分:
|
||||||
|
- 版本信息
|
||||||
|
- API服务提供商配置
|
||||||
|
- 模型配置
|
||||||
|
- 模型任务配置
|
||||||
|
|
||||||
|
## 1. 版本信息
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[inner]
|
||||||
|
version = "1.1.1"
|
||||||
|
```
|
||||||
|
|
||||||
|
用于标识配置文件的版本,遵循语义化版本规则。
|
||||||
|
|
||||||
|
## 2. API服务提供商配置
|
||||||
|
|
||||||
|
### 2.1 基本配置
|
||||||
|
|
||||||
|
使用 `[[api_providers]]` 数组配置多个API服务提供商:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[api_providers]]
|
||||||
|
name = "DeepSeek" # 服务商名称(自定义)
|
||||||
|
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
|
||||||
|
api_key = "your-api-key-here" # API密钥
|
||||||
|
client_type = "openai" # 客户端类型
|
||||||
|
max_retry = 2 # 最大重试次数
|
||||||
|
timeout = 30 # 超时时间(秒)
|
||||||
|
retry_interval = 10 # 重试间隔(秒)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 配置参数说明
|
||||||
|
|
||||||
|
| 参数 | 必填 | 说明 | 默认值 |
|
||||||
|
|------|------|------|--------|
|
||||||
|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||||
|
| `base_url` | ✅ | API服务的基础URL | - |
|
||||||
|
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||||
|
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
|
||||||
|
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||||
|
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||||
|
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||||
|
|
||||||
|
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
||||||
|
### 2.3 支持的服务商示例
|
||||||
|
|
||||||
|
#### DeepSeek
|
||||||
|
```toml
|
||||||
|
[[api_providers]]
|
||||||
|
name = "DeepSeek"
|
||||||
|
base_url = "https://api.deepseek.cn/v1"
|
||||||
|
api_key = "your-deepseek-api-key"
|
||||||
|
client_type = "openai"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### SiliconFlow
|
||||||
|
```toml
|
||||||
|
[[api_providers]]
|
||||||
|
name = "SiliconFlow"
|
||||||
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
|
api_key = "your-siliconflow-api-key"
|
||||||
|
client_type = "openai"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Google Gemini
|
||||||
|
```toml
|
||||||
|
[[api_providers]]
|
||||||
|
name = "Google"
|
||||||
|
base_url = "https://api.google.com/v1"
|
||||||
|
api_key = "your-google-api-key"
|
||||||
|
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 模型配置
|
||||||
|
|
||||||
|
### 3.1 基本模型配置
|
||||||
|
|
||||||
|
使用 `[[models]]` 数组配置多个模型:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
|
||||||
|
name = "deepseek-v3" # 自定义模型名称
|
||||||
|
api_provider = "DeepSeek" # 引用的API服务商名称
|
||||||
|
price_in = 2.0 # 输入价格(元/M token)
|
||||||
|
price_out = 8.0 # 输出价格(元/M token)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 高级模型配置
|
||||||
|
|
||||||
|
#### 强制流式输出
|
||||||
|
对于不支持非流式输出的模型:
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
model_identifier = "some-model"
|
||||||
|
name = "custom-name"
|
||||||
|
api_provider = "Provider"
|
||||||
|
force_stream_mode = true # 启用强制流式输出
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 额外参数配置`extra_params`
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
model_identifier = "Qwen/Qwen3-8B"
|
||||||
|
name = "qwen3-8b"
|
||||||
|
api_provider = "SiliconFlow"
|
||||||
|
[models.extra_params]
|
||||||
|
enable_thinking = false # 禁用思考
|
||||||
|
```
|
||||||
|
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。
|
||||||
|
|
||||||
|
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
以豆包文档为另一个例子
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
# 你的模型
|
||||||
|
[models.extra_params]
|
||||||
|
thinking = {type = "disabled"} # 禁用思考
|
||||||
|
```
|
||||||
|
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||||
|
|
||||||
|
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
||||||
|
### 3.3 配置参数说明
|
||||||
|
|
||||||
|
| 参数 | 必填 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
|
||||||
|
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
|
||||||
|
| `api_provider` | ✅ | 对应的API服务商名称 |
|
||||||
|
| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 |
|
||||||
|
| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
|
||||||
|
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
|
||||||
|
| `extra_params` | ❌ | 额外的模型参数配置 |
|
||||||
|
|
||||||
|
## 4. 模型任务配置
|
||||||
|
|
||||||
|
### utils - 工具模型
|
||||||
|
用于表情包模块、取名模块、关系模块等核心功能:
|
||||||
|
```toml
|
||||||
|
[model_task_config.utils]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### utils_small - 小型工具模型
|
||||||
|
用于高频率调用的场景,建议使用速度快的小模型:
|
||||||
|
```toml
|
||||||
|
[model_task_config.utils_small]
|
||||||
|
model_list = ["qwen3-8b"]
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### replyer - 主要回复模型
|
||||||
|
首要回复模型,也用于表达器和表达方式学习:
|
||||||
|
```toml
|
||||||
|
[model_task_config.replyer]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### planner - 决策模型
|
||||||
|
负责决定MaiBot该做什么:
|
||||||
|
```toml
|
||||||
|
[model_task_config.planner]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### emotion - 情绪模型
|
||||||
|
负责MaiBot的情绪变化:
|
||||||
|
```toml
|
||||||
|
[model_task_config.emotion]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### memory - 记忆模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.memory]
|
||||||
|
model_list = ["qwen3-30b"]
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### vlm - 视觉语言模型
|
||||||
|
用于图像识别:
|
||||||
|
```toml
|
||||||
|
[model_task_config.vlm]
|
||||||
|
model_list = ["qwen2.5-vl-72b"]
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### voice - 语音识别模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.voice]
|
||||||
|
model_list = ["sensevoice-small"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### embedding - 嵌入模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.embedding]
|
||||||
|
model_list = ["bge-m3"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### tool_use - 工具调用模型
|
||||||
|
需要使用支持工具调用的模型:
|
||||||
|
```toml
|
||||||
|
[model_task_config.tool_use]
|
||||||
|
model_list = ["qwen3-14b"]
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### lpmm_entity_extract - 实体提取模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.lpmm_entity_extract]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### lpmm_rdf_build - RDF构建模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.lpmm_rdf_build]
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.2
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
### lpmm_qa - 问答模型
|
||||||
|
```toml
|
||||||
|
[model_task_config.lpmm_qa]
|
||||||
|
model_list = ["deepseek-r1-distill-qwen-32b"]
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 800
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. 配置建议
|
||||||
|
|
||||||
|
### 5.1 Temperature 参数选择
|
||||||
|
|
||||||
|
| 任务类型 | 推荐温度 | 说明 |
|
||||||
|
|----------|----------|------|
|
||||||
|
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
|
||||||
|
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
|
||||||
|
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
|
||||||
|
|
||||||
|
### 5.2 模型选择建议
|
||||||
|
|
||||||
|
| 任务类型 | 推荐模型类型 | 示例 |
|
||||||
|
|----------|--------------|------|
|
||||||
|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
|
||||||
|
| 高频率任务 | 小模型 | Qwen3-8B |
|
||||||
|
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
|
||||||
|
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
|
||||||
|
|
||||||
|
### 5.3 成本优化
|
||||||
|
|
||||||
|
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
|
||||||
|
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
|
||||||
|
3. **选择免费模型**:对于测试环境,优先使用price为0的模型
|
||||||
|
|
||||||
|
## 6. 配置验证
|
||||||
|
|
||||||
|
### 6.1 必要检查项
|
||||||
|
|
||||||
|
1. ✅ API密钥是否正确配置
|
||||||
|
2. ✅ 模型标识符是否与API服务商提供的一致
|
||||||
|
3. ✅ 任务配置中引用的模型名称是否在models中定义
|
||||||
|
4. ✅ 多模态任务是否配置了对应的专用模型
|
||||||
|
|
||||||
|
### 6.2 测试配置
|
||||||
|
|
||||||
|
建议在正式使用前:
|
||||||
|
1. 使用少量测试数据验证配置
|
||||||
|
2. 检查API调用是否正常
|
||||||
|
3. 确认成本统计功能正常工作
|
||||||
|
|
||||||
|
## 7. 故障排除
|
||||||
|
|
||||||
|
### 7.1 常见问题
|
||||||
|
|
||||||
|
**问题1**: API调用失败
|
||||||
|
- 检查API密钥是否正确
|
||||||
|
- 确认base_url是否可访问
|
||||||
|
- 检查模型标识符是否正确
|
||||||
|
|
||||||
|
**问题2**: 模型未找到
|
||||||
|
- 确认模型名称在任务配置和模型定义中一致
|
||||||
|
- 检查api_provider名称是否匹配
|
||||||
|
|
||||||
|
**问题3**: 响应异常
|
||||||
|
- 检查温度参数是否合理(0-1之间)
|
||||||
|
- 确认max_tokens设置是否合适
|
||||||
|
- 验证模型是否支持所需功能
|
||||||
|
|
||||||
|
### 7.2 日志查看
|
||||||
|
|
||||||
|
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
|
||||||
|
|
||||||
|
## 8. 更新和维护
|
||||||
|
|
||||||
|
1. **定期更新**: 关注API服务商的模型更新,及时调整配置
|
||||||
|
2. **性能监控**: 监控模型调用的成本和性能
|
||||||
|
3. **备份配置**: 在修改前备份当前配置文件
|
||||||
|
|
||||||
@@ -22,7 +22,6 @@ class ExampleAction(BaseAction):
|
|||||||
action_name = "example_action" # 动作的唯一标识符
|
action_name = "example_action" # 动作的唯一标识符
|
||||||
action_description = "这是一个示例动作" # 动作描述
|
action_description = "这是一个示例动作" # 动作描述
|
||||||
activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例
|
activation_type = ActionActivationType.ALWAYS # 这里以 ALWAYS 为例
|
||||||
mode_enable = ChatMode.ALL # 一般取ALL,表示在所有聊天模式下都可用
|
|
||||||
associated_types = ["text", "emoji", ...] # 关联类型
|
associated_types = ["text", "emoji", ...] # 关联类型
|
||||||
parallel_action = False # 是否允许与其他Action并行执行
|
parallel_action = False # 是否允许与其他Action并行执行
|
||||||
action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...}
|
action_parameters = {"param1": "参数1的说明", "param2": "参数2的说明", ...}
|
||||||
|
|||||||
194
docs/plugins/api/component-manage-api.md
Normal file
194
docs/plugins/api/component-manage-api.md
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# 组件管理API
|
||||||
|
|
||||||
|
组件管理API模块提供了对插件组件的查询和管理功能,使得插件能够获取和使用组件相关的信息。
|
||||||
|
|
||||||
|
## 导入方式
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis import component_manage_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import component_manage_api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 功能概述
|
||||||
|
|
||||||
|
组件管理API主要提供以下功能:
|
||||||
|
- **插件信息查询** - 获取所有插件或指定插件的信息。
|
||||||
|
- **组件查询** - 按名称或类型查询组件信息。
|
||||||
|
- **组件管理** - 启用或禁用组件,支持全局和局部操作。
|
||||||
|
|
||||||
|
## 主要功能
|
||||||
|
|
||||||
|
### 1. 获取所有插件信息
|
||||||
|
```python
|
||||||
|
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||||
|
```
|
||||||
|
获取所有插件的信息。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Dict[str, PluginInfo]` - 包含所有插件信息的字典,键为插件名称,值为 `PluginInfo` 对象。
|
||||||
|
|
||||||
|
### 2. 获取指定插件信息
|
||||||
|
```python
|
||||||
|
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||||
|
```
|
||||||
|
获取指定插件的信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_name` (str): 插件名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[PluginInfo]`: 插件信息对象,如果插件不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 3. 获取指定组件信息
|
||||||
|
```python
|
||||||
|
def get_component_info(component_name: str, component_type: ComponentType) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
```
|
||||||
|
获取指定组件的信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_name` (str): 组件名称。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 组件信息对象,如果组件不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 4. 获取指定类型的所有组件信息
|
||||||
|
```python
|
||||||
|
def get_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
```
|
||||||
|
获取指定类型的所有组件信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||||
|
|
||||||
|
### 5. 获取指定类型的所有启用的组件信息
|
||||||
|
```python
|
||||||
|
def get_enabled_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||||
|
```
|
||||||
|
获取指定类型的所有启用的组件信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||||
|
|
||||||
|
### 6. 获取指定 Action 的注册信息
|
||||||
|
```python
|
||||||
|
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||||
|
```
|
||||||
|
获取指定 Action 的注册信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `action_name` (str): Action 名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[ActionInfo]` - Action 信息对象,如果 Action 不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 7. 获取指定 Command 的注册信息
|
||||||
|
```python
|
||||||
|
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||||
|
```
|
||||||
|
获取指定 Command 的注册信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `command_name` (str): Command 名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[CommandInfo]` - Command 信息对象,如果 Command 不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 8. 获取指定 Tool 的注册信息
|
||||||
|
```python
|
||||||
|
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||||
|
```
|
||||||
|
获取指定 Tool 的注册信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `tool_name` (str): Tool 名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[ToolInfo]` - Tool 信息对象,如果 Tool 不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 9. 获取指定 EventHandler 的注册信息
|
||||||
|
```python
|
||||||
|
def get_registered_event_handler_info(event_handler_name: str) -> Optional[EventHandlerInfo]:
|
||||||
|
```
|
||||||
|
获取指定 EventHandler 的注册信息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `event_handler_name` (str): EventHandler 名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[EventHandlerInfo]` - EventHandler 信息对象,如果 EventHandler 不存在则返回 `None`。
|
||||||
|
|
||||||
|
### 10. 全局启用指定组件
|
||||||
|
```python
|
||||||
|
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
```
|
||||||
|
全局启用指定组件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_name` (str): 组件名称。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 启用成功返回 `True`,否则返回 `False`。
|
||||||
|
|
||||||
|
### 11. 全局禁用指定组件
|
||||||
|
```python
|
||||||
|
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||||
|
```
|
||||||
|
全局禁用指定组件。
|
||||||
|
|
||||||
|
**此函数是异步的,确保在异步环境中调用。**
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_name` (str): 组件名称。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 禁用成功返回 `True`,否则返回 `False`。
|
||||||
|
|
||||||
|
### 12. 局部启用指定组件
|
||||||
|
```python
|
||||||
|
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||||
|
```
|
||||||
|
局部启用指定组件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_name` (str): 组件名称。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
- `stream_id` (str): 消息流 ID。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 启用成功返回 `True`,否则返回 `False`。
|
||||||
|
|
||||||
|
### 13. 局部禁用指定组件
|
||||||
|
```python
|
||||||
|
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||||
|
```
|
||||||
|
局部禁用指定组件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `component_name` (str): 组件名称。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
- `stream_id` (str): 消息流 ID。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 禁用成功返回 `True`,否则返回 `False`。
|
||||||
|
|
||||||
|
### 14. 获取指定消息流中禁用的组件列表
|
||||||
|
```python
|
||||||
|
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||||
|
```
|
||||||
|
获取指定消息流中禁用的组件列表。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `stream_id` (str): 消息流 ID。
|
||||||
|
- `component_type` (ComponentType): 组件类型。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `list[str]` - 禁用的组件名称列表。
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# 配置API
|
# 配置API
|
||||||
|
|
||||||
配置API模块提供了配置读取和用户信息获取等功能,让插件能够安全地访问全局配置和用户信息。
|
配置API模块提供了配置读取功能,让插件能够安全地访问全局配置和插件配置。
|
||||||
|
|
||||||
## 导入方式
|
## 导入方式
|
||||||
|
|
||||||
|
|||||||
@@ -6,72 +6,51 @@
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import database_api
|
from src.plugin_system.apis import database_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import database_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 通用数据库查询
|
### 1. 通用数据库操作
|
||||||
|
|
||||||
#### `db_query(model_class, query_type="get", filters=None, data=None, limit=None, order_by=None, single_result=False)`
|
|
||||||
执行数据库查询操作的通用接口
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `model_class`:Peewee模型类,如ActionRecords、Messages等
|
|
||||||
- `query_type`:查询类型,可选值: "get", "create", "update", "delete", "count"
|
|
||||||
- `filters`:过滤条件字典,键为字段名,值为要匹配的值
|
|
||||||
- `data`:用于创建或更新的数据字典
|
|
||||||
- `limit`:限制结果数量
|
|
||||||
- `order_by`:排序字段列表,使用字段名,前缀'-'表示降序
|
|
||||||
- `single_result`:是否只返回单个结果
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
根据查询类型返回不同的结果:
|
|
||||||
- "get":返回查询结果列表或单个结果
|
|
||||||
- "create":返回创建的记录
|
|
||||||
- "update":返回受影响的行数
|
|
||||||
- "delete":返回受影响的行数
|
|
||||||
- "count":返回记录数量
|
|
||||||
|
|
||||||
### 2. 便捷查询函数
|
|
||||||
|
|
||||||
#### `db_save(model_class, data, key_field=None, key_value=None)`
|
|
||||||
保存数据到数据库(创建或更新)
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `model_class`:Peewee模型类
|
|
||||||
- `data`:要保存的数据字典
|
|
||||||
- `key_field`:用于查找现有记录的字段名
|
|
||||||
- `key_value`:用于查找现有记录的字段值
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Dict[str, Any]`:保存后的记录数据,失败时返回None
|
|
||||||
|
|
||||||
#### `db_get(model_class, filters=None, order_by=None, limit=None)`
|
|
||||||
简化的查询函数
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `model_class`:Peewee模型类
|
|
||||||
- `filters`:过滤条件字典
|
|
||||||
- `order_by`:排序字段
|
|
||||||
- `limit`:限制结果数量
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Union[List[Dict], Dict, None]`:查询结果
|
|
||||||
|
|
||||||
### 3. 专用函数
|
|
||||||
|
|
||||||
#### `store_action_info(...)`
|
|
||||||
存储动作信息的专用函数
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
### 1. 基本查询操作
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import database_api
|
async def db_query(
|
||||||
from src.common.database.database_model import Messages, ActionRecords
|
model_class: Type[Model],
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
query_type: Optional[str] = "get",
|
||||||
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
single_result: Optional[bool] = False,
|
||||||
|
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||||
|
```
|
||||||
|
执行数据库查询操作的通用接口。
|
||||||
|
|
||||||
# 查询最近10条消息
|
**Args:**
|
||||||
|
- `model_class`: Peewee模型类。
|
||||||
|
- Peewee模型类可以在`src.common.database.database_model`模块中找到,如`ActionRecords`、`Messages`等。
|
||||||
|
- `data`: 用于创建或更新的数据
|
||||||
|
- `query_type`: 查询类型
|
||||||
|
- 可选值: `get`, `create`, `update`, `delete`, `count`。
|
||||||
|
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
|
||||||
|
- `limit`: 限制结果数量。
|
||||||
|
- `order_by`: 排序字段列表,使用字段名,前缀'-'表示降序。
|
||||||
|
- 排序字段,前缀`-`表示降序,例如`-time`表示按时间字段(即`time`字段)降序
|
||||||
|
- `single_result`: 是否只返回单个结果。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- 根据查询类型返回不同的结果:
|
||||||
|
- `get`: 返回查询结果列表或单个结果。(如果 `single_result=True`)
|
||||||
|
- `create`: 返回创建的记录。
|
||||||
|
- `update`: 返回受影响的行数。
|
||||||
|
- `delete`: 返回受影响的行数。
|
||||||
|
- `count`: 返回记录数量。
|
||||||
|
|
||||||
|
#### 示例
|
||||||
|
|
||||||
|
1. 查询最近10条消息
|
||||||
|
```python
|
||||||
messages = await database_api.db_query(
|
messages = await database_api.db_query(
|
||||||
Messages,
|
Messages,
|
||||||
query_type="get",
|
query_type="get",
|
||||||
@@ -79,180 +58,159 @@ messages = await database_api.db_query(
|
|||||||
limit=10,
|
limit=10,
|
||||||
order_by=["-time"]
|
order_by=["-time"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 查询单条记录
|
|
||||||
message = await database_api.db_query(
|
|
||||||
Messages,
|
|
||||||
query_type="get",
|
|
||||||
filters={"message_id": "msg_123"},
|
|
||||||
single_result=True
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
2. 创建一条记录
|
||||||
### 2. 创建记录
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 创建新的动作记录
|
|
||||||
new_record = await database_api.db_query(
|
new_record = await database_api.db_query(
|
||||||
ActionRecords,
|
ActionRecords,
|
||||||
|
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
|
||||||
query_type="create",
|
query_type="create",
|
||||||
data={
|
|
||||||
"action_id": "action_123",
|
|
||||||
"time": time.time(),
|
|
||||||
"action_name": "TestAction",
|
|
||||||
"action_done": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"创建了记录: {new_record['id']}")
|
|
||||||
```
|
```
|
||||||
|
3. 更新记录
|
||||||
### 3. 更新记录
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 更新动作状态
|
|
||||||
updated_count = await database_api.db_query(
|
updated_count = await database_api.db_query(
|
||||||
ActionRecords,
|
ActionRecords,
|
||||||
|
data={"action_done": True},
|
||||||
query_type="update",
|
query_type="update",
|
||||||
filters={"action_id": "action_123"},
|
filters={"action_id": "123"},
|
||||||
data={"action_done": True, "completion_time": time.time()}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"更新了 {updated_count} 条记录")
|
|
||||||
```
|
```
|
||||||
|
4. 删除记录
|
||||||
### 4. 删除记录
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 删除过期记录
|
|
||||||
deleted_count = await database_api.db_query(
|
deleted_count = await database_api.db_query(
|
||||||
ActionRecords,
|
ActionRecords,
|
||||||
query_type="delete",
|
query_type="delete",
|
||||||
filters={"time__lt": time.time() - 86400} # 删除24小时前的记录
|
filters={"action_id": "123"}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"删除了 {deleted_count} 条过期记录")
|
|
||||||
```
|
```
|
||||||
|
5. 计数
|
||||||
### 5. 统计查询
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 统计消息数量
|
count = await database_api.db_query(
|
||||||
message_count = await database_api.db_query(
|
|
||||||
Messages,
|
Messages,
|
||||||
query_type="count",
|
query_type="count",
|
||||||
filters={"chat_id": chat_stream.stream_id}
|
filters={"chat_id": chat_stream.stream_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"该聊天有 {message_count} 条消息")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. 使用便捷函数
|
### 2. 数据库保存
|
||||||
|
```python
|
||||||
|
async def db_save(
|
||||||
|
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
保存数据到数据库(创建或更新)
|
||||||
|
|
||||||
|
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||||
|
|
||||||
|
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `model_class`: Peewee模型类。
|
||||||
|
- `data`: 要保存的数据字典。
|
||||||
|
- `key_field`: 用于查找现有记录的字段名,例如"action_id"。
|
||||||
|
- `key_value`: 用于查找现有记录的字段值。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[Dict[str, Any]]`: 保存后的记录数据,失败时返回None。
|
||||||
|
|
||||||
|
#### 示例
|
||||||
|
创建或更新一条记录
|
||||||
```python
|
```python
|
||||||
# 使用db_save进行创建或更新
|
|
||||||
record = await database_api.db_save(
|
record = await database_api.db_save(
|
||||||
ActionRecords,
|
ActionRecords,
|
||||||
{
|
{
|
||||||
"action_id": "action_123",
|
"action_id": "123",
|
||||||
"time": time.time(),
|
"time": time.time(),
|
||||||
"action_name": "TestAction",
|
"action_name": "TestAction",
|
||||||
"action_done": True
|
"action_done": True
|
||||||
},
|
},
|
||||||
key_field="action_id",
|
key_field="action_id",
|
||||||
key_value="action_123"
|
key_value="123"
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
# 使用db_get进行简单查询
|
### 3. 数据库获取
|
||||||
recent_messages = await database_api.db_get(
|
```python
|
||||||
|
async def db_get(
|
||||||
|
model_class: Type[Model],
|
||||||
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
single_result: Optional[bool] = False,
|
||||||
|
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||||
|
```
|
||||||
|
|
||||||
|
从数据库获取记录
|
||||||
|
|
||||||
|
这是db_query方法的简化版本,专注于数据检索操作。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `model_class`: Peewee模型类。
|
||||||
|
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
|
||||||
|
- `limit`: 限制结果数量。
|
||||||
|
- `order_by`: 排序字段,使用字段名,前缀'-'表示降序。
|
||||||
|
- `single_result`: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Union[List[Dict], Dict, None]`: 查询结果列表或单个结果(如果`single_result=True`),失败时返回None。
|
||||||
|
|
||||||
|
#### 示例
|
||||||
|
1. 获取单个记录
|
||||||
|
```python
|
||||||
|
record = await database_api.db_get(
|
||||||
|
ActionRecords,
|
||||||
|
filters={"action_id": "123"},
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
```
|
||||||
|
2. 获取最近10条记录
|
||||||
|
```python
|
||||||
|
records = await database_api.db_get(
|
||||||
Messages,
|
Messages,
|
||||||
filters={"chat_id": chat_stream.stream_id},
|
filters={"chat_id": chat_stream.stream_id},
|
||||||
|
limit=10,
|
||||||
order_by="-time",
|
order_by="-time",
|
||||||
limit=5
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 高级用法
|
### 4. 动作信息存储
|
||||||
|
|
||||||
### 复杂查询示例
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 查询特定用户在特定时间段的消息
|
async def store_action_info(
|
||||||
user_messages = await database_api.db_query(
|
chat_stream=None,
|
||||||
Messages,
|
action_build_into_prompt: bool = False,
|
||||||
query_type="get",
|
action_prompt_display: str = "",
|
||||||
filters={
|
action_done: bool = True,
|
||||||
"user_id": "123456",
|
thinking_id: str = "",
|
||||||
"time__gte": start_time, # 大于等于开始时间
|
action_data: Optional[dict] = None,
|
||||||
"time__lt": end_time # 小于结束时间
|
action_name: str = "",
|
||||||
},
|
) -> Optional[Dict[str, Any]]:
|
||||||
order_by=["-time"],
|
```
|
||||||
limit=50
|
存储动作信息到数据库,是一种针对 Action 的 `db_save()` 的封装函数。
|
||||||
|
|
||||||
|
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_stream`: 聊天流对象,包含聊天ID等信息。
|
||||||
|
- `action_build_into_prompt`: 是否将动作信息构建到提示中。
|
||||||
|
- `action_prompt_display`: 动作提示的显示文本。
|
||||||
|
- `action_done`: 动作是否完成。
|
||||||
|
- `thinking_id`: 思考过程的ID。
|
||||||
|
- `action_data`: 动作的数据字典。
|
||||||
|
- `action_name`: 动作的名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Optional[Dict[str, Any]]`: 存储后的记录数据,失败时返回None。
|
||||||
|
|
||||||
|
#### 示例
|
||||||
|
```python
|
||||||
|
record = await database_api.store_action_info(
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
action_build_into_prompt=True,
|
||||||
|
action_prompt_display="执行了回复动作",
|
||||||
|
action_done=True,
|
||||||
|
thinking_id="thinking_123",
|
||||||
|
action_data={"content": "Hello"},
|
||||||
|
action_name="reply_action"
|
||||||
)
|
)
|
||||||
|
```
|
||||||
# 批量处理
|
|
||||||
for message in user_messages:
|
|
||||||
print(f"消息内容: {message['plain_text']}")
|
|
||||||
print(f"发送时间: {message['time']}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 插件中的数据持久化
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.base import BasePlugin
|
|
||||||
from src.plugin_system.apis import database_api
|
|
||||||
|
|
||||||
class DataPlugin(BasePlugin):
|
|
||||||
async def handle_action(self, action_data, chat_stream):
|
|
||||||
# 保存插件数据
|
|
||||||
plugin_data = {
|
|
||||||
"plugin_name": self.plugin_name,
|
|
||||||
"chat_id": chat_stream.stream_id,
|
|
||||||
"data": json.dumps(action_data),
|
|
||||||
"created_time": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 使用自定义表模型(需要先定义)
|
|
||||||
record = await database_api.db_save(
|
|
||||||
PluginData, # 假设的插件数据模型
|
|
||||||
plugin_data,
|
|
||||||
key_field="plugin_name",
|
|
||||||
key_value=self.plugin_name
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"success": True, "record_id": record["id"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 数据模型
|
|
||||||
|
|
||||||
### 常用模型类
|
|
||||||
系统提供了以下常用的数据模型:
|
|
||||||
|
|
||||||
- `Messages`:消息记录
|
|
||||||
- `ActionRecords`:动作记录
|
|
||||||
- `UserInfo`:用户信息
|
|
||||||
- `GroupInfo`:群组信息
|
|
||||||
|
|
||||||
### 字段说明
|
|
||||||
|
|
||||||
#### Messages模型主要字段
|
|
||||||
- `message_id`:消息ID
|
|
||||||
- `chat_id`:聊天ID
|
|
||||||
- `user_id`:用户ID
|
|
||||||
- `plain_text`:纯文本内容
|
|
||||||
- `time`:时间戳
|
|
||||||
|
|
||||||
#### ActionRecords模型主要字段
|
|
||||||
- `action_id`:动作ID
|
|
||||||
- `action_name`:动作名称
|
|
||||||
- `action_done`:是否完成
|
|
||||||
- `time`:创建时间
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **异步操作**:所有数据库API都是异步的,必须使用`await`
|
|
||||||
2. **错误处理**:函数内置错误处理,失败时返回None或空列表
|
|
||||||
3. **数据类型**:返回的都是字典格式的数据,不是模型对象
|
|
||||||
4. **性能考虑**:使用`limit`参数避免查询大量数据
|
|
||||||
5. **过滤条件**:支持简单的等值过滤,复杂查询需要使用原生Peewee语法
|
|
||||||
6. **事务**:如需事务支持,建议直接使用Peewee的事务功能
|
|
||||||
@@ -6,11 +6,13 @@
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import emoji_api
|
from src.plugin_system.apis import emoji_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import emoji_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🆕 **二步走识别优化**
|
## 二步走识别优化
|
||||||
|
|
||||||
从最新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
|
从新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
|
||||||
|
|
||||||
### **收到表情包时的识别流程**
|
### **收到表情包时的识别流程**
|
||||||
1. **第一步**:VLM视觉分析 - 生成详细描述
|
1. **第一步**:VLM视觉分析 - 生成详细描述
|
||||||
@@ -30,217 +32,84 @@ from src.plugin_system.apis import emoji_api
|
|||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 表情包获取
|
### 1. 表情包获取
|
||||||
|
```python
|
||||||
#### `get_by_description(description: str) -> Optional[Tuple[str, str, str]]`
|
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||||
|
```
|
||||||
根据场景描述选择表情包
|
根据场景描述选择表情包
|
||||||
|
|
||||||
**参数:**
|
**Args:**
|
||||||
- `description`:场景描述文本,例如"开心的大笑"、"轻微的讽刺"、"表示无奈和沮丧"等
|
- `description`:表情包的描述文本,例如"开心"、"难过"、"愤怒"等
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None
|
- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到匹配的表情包则返回None
|
||||||
|
|
||||||
**示例:**
|
#### 示例
|
||||||
```python
|
```python
|
||||||
emoji_result = await emoji_api.get_by_description("开心的大笑")
|
emoji_result = await emoji_api.get_by_description("大笑")
|
||||||
if emoji_result:
|
if emoji_result:
|
||||||
emoji_base64, description, matched_scene = emoji_result
|
emoji_base64, description, matched_scene = emoji_result
|
||||||
print(f"获取到表情包: {description}, 场景: {matched_scene}")
|
print(f"获取到表情包: {description}, 场景: {matched_scene}")
|
||||||
# 可以将emoji_base64用于发送表情包
|
# 可以将emoji_base64用于发送表情包
|
||||||
```
|
```
|
||||||
|
|
||||||
#### `get_random() -> Optional[Tuple[str, str, str]]`
|
### 2. 随机获取表情包
|
||||||
随机获取表情包
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 随机场景) 或 None
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
random_emoji = await emoji_api.get_random()
|
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||||
if random_emoji:
|
|
||||||
emoji_base64, description, scene = random_emoji
|
|
||||||
print(f"随机表情包: {description}")
|
|
||||||
```
|
```
|
||||||
|
随机获取指定数量的表情包
|
||||||
|
|
||||||
#### `get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]`
|
**Args:**
|
||||||
根据场景关键词获取表情包
|
- `count`:要获取的表情包数量,默认为1
|
||||||
|
|
||||||
**参数:**
|
**Returns:**
|
||||||
- `emotion`:场景关键词,如"大笑"、"讽刺"、"无奈"等
|
- `List[Tuple[str, str, str]]`:一个包含多个表情包的列表,每个元素是一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到或出错则返回空列表
|
||||||
|
|
||||||
**返回:**
|
### 3. 根据情感获取表情包
|
||||||
- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
emoji_result = await emoji_api.get_by_emotion("讽刺")
|
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||||
if emoji_result:
|
|
||||||
emoji_base64, description, scene = emoji_result
|
|
||||||
# 发送讽刺表情包
|
|
||||||
```
|
```
|
||||||
|
根据情感标签获取表情包
|
||||||
|
|
||||||
### 2. 表情包信息查询
|
**Args:**
|
||||||
|
- `emotion`:情感标签,例如"开心"、"悲伤"、"愤怒"等
|
||||||
|
|
||||||
#### `get_count() -> int`
|
**Returns:**
|
||||||
获取表情包数量
|
- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到则返回None
|
||||||
|
|
||||||
**返回:**
|
### 4. 获取表情包数量
|
||||||
- `int`:当前可用的表情包数量
|
```python
|
||||||
|
def get_count() -> int:
|
||||||
|
```
|
||||||
|
获取当前可用表情包的数量
|
||||||
|
|
||||||
#### `get_info() -> dict`
|
### 5. 获取表情包系统信息
|
||||||
获取表情包系统信息
|
```python
|
||||||
|
def get_info() -> Dict[str, Any]:
|
||||||
|
```
|
||||||
|
获取表情包系统的基本信息
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `dict`:包含表情包数量、最大数量等信息
|
- `Dict[str, Any]`:包含表情包数量、描述等信息的字典,包含以下键:
|
||||||
|
- `current_count`:当前表情包数量
|
||||||
|
- `max_count`:最大表情包数量
|
||||||
|
- `available_emojis`:当前可用的表情包数量
|
||||||
|
|
||||||
**返回字典包含:**
|
### 6. 获取所有可用的情感标签
|
||||||
- `current_count`:当前表情包数量
|
```python
|
||||||
- `max_count`:最大表情包数量
|
def get_emotions() -> List[str]:
|
||||||
- `available_emojis`:可用表情包数量
|
```
|
||||||
|
获取所有可用的情感标签 **(已经去重)**
|
||||||
|
|
||||||
#### `get_emotions() -> list`
|
### 7. 获取所有表情包描述
|
||||||
获取所有可用的场景关键词
|
```python
|
||||||
|
def get_descriptions() -> List[str]:
|
||||||
**返回:**
|
```
|
||||||
- `list`:所有表情包的场景关键词列表(去重)
|
|
||||||
|
|
||||||
#### `get_descriptions() -> list`
|
|
||||||
获取所有表情包的描述列表
|
获取所有表情包的描述列表
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `list`:所有表情包的描述文本列表
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
### 1. 智能表情包选择
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.apis import emoji_api
|
|
||||||
|
|
||||||
async def send_emotion_response(message_text: str, chat_stream):
|
|
||||||
"""根据消息内容智能选择表情包回复"""
|
|
||||||
|
|
||||||
# 分析消息场景
|
|
||||||
if "哈哈" in message_text or "好笑" in message_text:
|
|
||||||
emoji_result = await emoji_api.get_by_description("开心的大笑")
|
|
||||||
elif "无语" in message_text or "算了" in message_text:
|
|
||||||
emoji_result = await emoji_api.get_by_description("表示无奈和沮丧")
|
|
||||||
elif "呵呵" in message_text or "是吗" in message_text:
|
|
||||||
emoji_result = await emoji_api.get_by_description("轻微的讽刺")
|
|
||||||
elif "生气" in message_text or "愤怒" in message_text:
|
|
||||||
emoji_result = await emoji_api.get_by_description("愤怒和不满")
|
|
||||||
else:
|
|
||||||
# 随机选择一个表情包
|
|
||||||
emoji_result = await emoji_api.get_random()
|
|
||||||
|
|
||||||
if emoji_result:
|
|
||||||
emoji_base64, description, scene = emoji_result
|
|
||||||
# 使用send_api发送表情包
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
success = await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
|
|
||||||
return success
|
|
||||||
|
|
||||||
return False
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 表情包管理功能
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def show_emoji_stats():
|
|
||||||
"""显示表情包统计信息"""
|
|
||||||
|
|
||||||
# 获取基本信息
|
|
||||||
count = emoji_api.get_count()
|
|
||||||
info = emoji_api.get_info()
|
|
||||||
scenes = emoji_api.get_emotions() # 实际返回的是场景关键词
|
|
||||||
|
|
||||||
stats = f"""
|
|
||||||
📊 表情包统计信息:
|
|
||||||
- 总数量: {count}
|
|
||||||
- 可用数量: {info['available_emojis']}
|
|
||||||
- 最大容量: {info['max_count']}
|
|
||||||
- 支持场景: {len(scenes)}种
|
|
||||||
|
|
||||||
🎭 支持的场景关键词: {', '.join(scenes[:10])}{'...' if len(scenes) > 10 else ''}
|
|
||||||
"""
|
|
||||||
|
|
||||||
return stats
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 表情包测试功能
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def test_emoji_system():
|
|
||||||
"""测试表情包系统的各种功能"""
|
|
||||||
|
|
||||||
print("=== 表情包系统测试 ===")
|
|
||||||
|
|
||||||
# 测试场景描述查找
|
|
||||||
test_descriptions = ["开心的大笑", "轻微的讽刺", "表示无奈和沮丧", "愤怒和不满"]
|
|
||||||
for desc in test_descriptions:
|
|
||||||
result = await emoji_api.get_by_description(desc)
|
|
||||||
if result:
|
|
||||||
_, description, scene = result
|
|
||||||
print(f"✅ 场景'{desc}' -> {description} ({scene})")
|
|
||||||
else:
|
|
||||||
print(f"❌ 场景'{desc}' -> 未找到")
|
|
||||||
|
|
||||||
# 测试关键词查找
|
|
||||||
scenes = emoji_api.get_emotions()
|
|
||||||
if scenes:
|
|
||||||
test_scene = scenes[0]
|
|
||||||
result = await emoji_api.get_by_emotion(test_scene)
|
|
||||||
if result:
|
|
||||||
print(f"✅ 关键词'{test_scene}' -> 找到匹配表情包")
|
|
||||||
|
|
||||||
# 测试随机获取
|
|
||||||
random_result = await emoji_api.get_random()
|
|
||||||
if random_result:
|
|
||||||
print("✅ 随机获取 -> 成功")
|
|
||||||
|
|
||||||
print(f"📊 系统信息: {emoji_api.get_info()}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 在Action中使用表情包
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.base import BaseAction
|
|
||||||
|
|
||||||
class EmojiAction(BaseAction):
|
|
||||||
async def execute(self, action_data, chat_stream):
|
|
||||||
# 从action_data获取场景描述或关键词
|
|
||||||
scene_keyword = action_data.get("scene", "")
|
|
||||||
scene_description = action_data.get("description", "")
|
|
||||||
|
|
||||||
emoji_result = None
|
|
||||||
|
|
||||||
# 优先使用具体的场景描述
|
|
||||||
if scene_description:
|
|
||||||
emoji_result = await emoji_api.get_by_description(scene_description)
|
|
||||||
# 其次使用场景关键词
|
|
||||||
elif scene_keyword:
|
|
||||||
emoji_result = await emoji_api.get_by_emotion(scene_keyword)
|
|
||||||
# 最后随机选择
|
|
||||||
else:
|
|
||||||
emoji_result = await emoji_api.get_random()
|
|
||||||
|
|
||||||
if emoji_result:
|
|
||||||
emoji_base64, description, scene = emoji_result
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"emoji_base64": emoji_base64,
|
|
||||||
"description": description,
|
|
||||||
"scene": scene
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"success": False, "message": "未找到合适的表情包"}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 场景描述说明
|
## 场景描述说明
|
||||||
|
|
||||||
### 常用场景描述
|
### 常用场景描述
|
||||||
表情包系统支持多种具体的场景描述,常见的包括:
|
表情包系统支持多种具体的场景描述,举例如下:
|
||||||
|
|
||||||
- **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈
|
- **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈
|
||||||
- **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头
|
- **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头
|
||||||
@@ -248,8 +117,8 @@ class EmojiAction(BaseAction):
|
|||||||
- **惊讶类场景**:震惊的表情、意外的发现、困惑的思考
|
- **惊讶类场景**:震惊的表情、意外的发现、困惑的思考
|
||||||
- **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子
|
- **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子
|
||||||
|
|
||||||
### 场景关键词示例
|
### 情感关键词示例
|
||||||
系统支持的场景关键词包括:
|
系统支持的情感关键词举例如下:
|
||||||
- 大笑、微笑、兴奋、手舞足蹈
|
- 大笑、微笑、兴奋、手舞足蹈
|
||||||
- 无奈、沮丧、讽刺、无语、摇头
|
- 无奈、沮丧、讽刺、无语、摇头
|
||||||
- 愤怒、不满、生气、瞪视、抓狂
|
- 愤怒、不满、生气、瞪视、抓狂
|
||||||
@@ -263,9 +132,9 @@ class EmojiAction(BaseAction):
|
|||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. **异步函数**:获取表情包的函数都是异步的,需要使用 `await`
|
1. **异步函数**:部分函数是异步的,需要使用 `await`
|
||||||
2. **返回格式**:表情包以base64编码返回,可直接用于发送
|
2. **返回格式**:表情包以base64编码返回,可直接用于发送
|
||||||
3. **错误处理**:所有函数都有错误处理,失败时返回None或默认值
|
3. **错误处理**:所有函数都有错误处理,失败时返回None,空列表或默认值
|
||||||
4. **使用统计**:系统会记录表情包的使用次数
|
4. **使用统计**:系统会记录表情包的使用次数
|
||||||
5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在
|
5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在
|
||||||
6. **编码格式**:返回的是base64编码的图片数据,可直接用于网络传输
|
6. **编码格式**:返回的是base64编码的图片数据,可直接用于网络传输
|
||||||
|
|||||||
@@ -6,241 +6,151 @@
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import generator_api
|
from src.plugin_system.apis import generator_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import generator_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 回复器获取
|
### 1. 回复器获取
|
||||||
|
```python
|
||||||
#### `get_replyer(chat_stream=None, platform=None, chat_id=None, is_group=True)`
|
def get_replyer(
|
||||||
|
chat_stream: Optional[ChatStream] = None,
|
||||||
|
chat_id: Optional[str] = None,
|
||||||
|
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||||
|
request_type: str = "replyer",
|
||||||
|
) -> Optional[DefaultReplyer]:
|
||||||
|
```
|
||||||
获取回复器对象
|
获取回复器对象
|
||||||
|
|
||||||
**参数:**
|
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||||
- `chat_stream`:聊天流对象(优先)
|
|
||||||
- `platform`:平台名称,如"qq"
|
|
||||||
- `chat_id`:聊天ID(群ID或用户ID)
|
|
||||||
- `is_group`:是否为群聊
|
|
||||||
|
|
||||||
**返回:**
|
使用 ReplyerManager 来管理实例,避免重复创建。
|
||||||
- `DefaultReplyer`:回复器对象,如果获取失败则返回None
|
|
||||||
|
|
||||||
**示例:**
|
**Args:**
|
||||||
|
- `chat_stream`: 聊天流对象
|
||||||
|
- `chat_id`: 聊天ID(实际上就是`stream_id`)
|
||||||
|
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
|
||||||
|
- `request_type`: 请求类型,用于记录LLM使用情况,可以不写
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `DefaultReplyer`: 回复器对象,如果获取失败则返回None
|
||||||
|
|
||||||
|
#### 示例
|
||||||
```python
|
```python
|
||||||
# 使用聊天流获取回复器
|
# 使用聊天流获取回复器
|
||||||
replyer = generator_api.get_replyer(chat_stream=chat_stream)
|
replyer = generator_api.get_replyer(chat_stream=chat_stream)
|
||||||
|
|
||||||
# 使用平台和ID获取回复器
|
# 使用平台和ID获取回复器
|
||||||
replyer = generator_api.get_replyer(
|
replyer = generator_api.get_replyer(chat_id="123456789")
|
||||||
platform="qq",
|
|
||||||
chat_id="123456789",
|
|
||||||
is_group=True
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 回复生成
|
### 2. 回复生成
|
||||||
|
```python
|
||||||
#### `generate_reply(chat_stream=None, action_data=None, platform=None, chat_id=None, is_group=True)`
|
async def generate_reply(
|
||||||
|
chat_stream: Optional[ChatStream] = None,
|
||||||
|
chat_id: Optional[str] = None,
|
||||||
|
action_data: Optional[Dict[str, Any]] = None,
|
||||||
|
reply_to: str = "",
|
||||||
|
extra_info: str = "",
|
||||||
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
|
enable_tool: bool = False,
|
||||||
|
enable_splitter: bool = True,
|
||||||
|
enable_chinese_typo: bool = True,
|
||||||
|
return_prompt: bool = False,
|
||||||
|
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||||
|
request_type: str = "generator_api",
|
||||||
|
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||||
|
```
|
||||||
生成回复
|
生成回复
|
||||||
|
|
||||||
**参数:**
|
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||||
- `chat_stream`:聊天流对象(优先)
|
|
||||||
- `action_data`:动作数据
|
|
||||||
- `platform`:平台名称(备用)
|
|
||||||
- `chat_id`:聊天ID(备用)
|
|
||||||
- `is_group`:是否为群聊(备用)
|
|
||||||
|
|
||||||
**返回:**
|
**Args:**
|
||||||
- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合)
|
- `chat_stream`: 聊天流对象
|
||||||
|
- `chat_id`: 聊天ID(实际上就是`stream_id`)
|
||||||
|
- `action_data`: 动作数据(向下兼容,包含`reply_to`和`extra_info`)
|
||||||
|
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
|
||||||
|
- `extra_info`: 附加信息
|
||||||
|
- `available_actions`: 可用动作字典,格式为 `{"action_name": ActionInfo}`
|
||||||
|
- `enable_tool`: 是否启用工具
|
||||||
|
- `enable_splitter`: 是否启用分割器
|
||||||
|
- `enable_chinese_typo`: 是否启用中文错别字
|
||||||
|
- `return_prompt`: 是否返回提示词
|
||||||
|
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
|
||||||
|
- `request_type`: 请求类型(可选,记录LLM使用)
|
||||||
|
- `request_type`: 请求类型,用于记录LLM使用情况
|
||||||
|
|
||||||
**示例:**
|
**Returns:**
|
||||||
|
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
|
||||||
|
|
||||||
|
#### 示例
|
||||||
```python
|
```python
|
||||||
success, reply_set = await generator_api.generate_reply(
|
success, reply_set, prompt = await generator_api.generate_reply(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
action_data={"message": "你好", "intent": "greeting"}
|
action_data=action_data,
|
||||||
|
reply_to="麦麦:你好",
|
||||||
|
available_actions=action_info,
|
||||||
|
enable_tool=True,
|
||||||
|
return_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
for reply_type, reply_content in reply_set:
|
for reply_type, reply_content in reply_set:
|
||||||
print(f"回复类型: {reply_type}, 内容: {reply_content}")
|
print(f"回复类型: {reply_type}, 内容: {reply_content}")
|
||||||
|
if prompt:
|
||||||
|
print(f"使用的提示词: {prompt}")
|
||||||
```
|
```
|
||||||
|
|
||||||
#### `rewrite_reply(chat_stream=None, reply_data=None, platform=None, chat_id=None, is_group=True)`
|
### 3. 回复重写
|
||||||
重写回复
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `chat_stream`:聊天流对象(优先)
|
|
||||||
- `reply_data`:回复数据
|
|
||||||
- `platform`:平台名称(备用)
|
|
||||||
- `chat_id`:聊天ID(备用)
|
|
||||||
- `is_group`:是否为群聊(备用)
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合)
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
success, reply_set = await generator_api.rewrite_reply(
|
async def rewrite_reply(
|
||||||
|
chat_stream: Optional[ChatStream] = None,
|
||||||
|
reply_data: Optional[Dict[str, Any]] = None,
|
||||||
|
chat_id: Optional[str] = None,
|
||||||
|
enable_splitter: bool = True,
|
||||||
|
enable_chinese_typo: bool = True,
|
||||||
|
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||||
|
raw_reply: str = "",
|
||||||
|
reason: str = "",
|
||||||
|
reply_to: str = "",
|
||||||
|
return_prompt: bool = False,
|
||||||
|
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||||
|
```
|
||||||
|
重写回复,使用新的内容替换旧的回复内容。
|
||||||
|
|
||||||
|
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_stream`: 聊天流对象
|
||||||
|
- `reply_data`: 回复数据,包含`raw_reply`, `reason`和`reply_to`,**(向下兼容备用,当其他参数缺失时从此获取)**
|
||||||
|
- `chat_id`: 聊天ID(实际上就是`stream_id`)
|
||||||
|
- `enable_splitter`: 是否启用分割器
|
||||||
|
- `enable_chinese_typo`: 是否启用中文错别字
|
||||||
|
- `model_set_with_weight`: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||||
|
- `raw_reply`: 原始回复内容
|
||||||
|
- `reason`: 重写原因
|
||||||
|
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
|
||||||
|
|
||||||
|
#### 示例
|
||||||
|
```python
|
||||||
|
success, reply_set, prompt = await generator_api.rewrite_reply(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
reply_data={"original_text": "原始回复", "style": "more_friendly"}
|
raw_reply="原始回复内容",
|
||||||
|
reason="重写原因",
|
||||||
|
reply_to="麦麦:你好",
|
||||||
|
return_prompt=True
|
||||||
)
|
)
|
||||||
|
if success:
|
||||||
|
for reply_type, reply_content in reply_set:
|
||||||
|
print(f"回复类型: {reply_type}, 内容: {reply_content}")
|
||||||
|
if prompt:
|
||||||
|
print(f"使用的提示词: {prompt}")
|
||||||
```
|
```
|
||||||
|
|
||||||
## 使用示例
|
## 回复集合`reply_set`格式
|
||||||
|
|
||||||
### 1. 基础回复生成
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
|
|
||||||
async def generate_greeting_reply(chat_stream, user_name):
|
|
||||||
"""生成问候回复"""
|
|
||||||
|
|
||||||
action_data = {
|
|
||||||
"intent": "greeting",
|
|
||||||
"user_name": user_name,
|
|
||||||
"context": "morning_greeting"
|
|
||||||
}
|
|
||||||
|
|
||||||
success, reply_set = await generator_api.generate_reply(
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
action_data=action_data
|
|
||||||
)
|
|
||||||
|
|
||||||
if success and reply_set:
|
|
||||||
# 获取第一个回复
|
|
||||||
reply_type, reply_content = reply_set[0]
|
|
||||||
return reply_content
|
|
||||||
|
|
||||||
return "你好!" # 默认回复
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 在Action中使用回复生成器
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.base import BaseAction
|
|
||||||
|
|
||||||
class ChatAction(BaseAction):
|
|
||||||
async def execute(self, action_data, chat_stream):
|
|
||||||
# 准备回复数据
|
|
||||||
reply_context = {
|
|
||||||
"message_type": "response",
|
|
||||||
"user_input": action_data.get("user_message", ""),
|
|
||||||
"intent": action_data.get("intent", ""),
|
|
||||||
"entities": action_data.get("entities", {}),
|
|
||||||
"context": self.get_conversation_context(chat_stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 生成回复
|
|
||||||
success, reply_set = await generator_api.generate_reply(
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
action_data=reply_context
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"replies": reply_set,
|
|
||||||
"generated_count": len(reply_set)
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "回复生成失败",
|
|
||||||
"fallback_reply": "抱歉,我现在无法理解您的消息。"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 多样化回复生成
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def generate_diverse_replies(chat_stream, topic, count=3):
|
|
||||||
"""生成多个不同风格的回复"""
|
|
||||||
|
|
||||||
styles = ["formal", "casual", "humorous"]
|
|
||||||
all_replies = []
|
|
||||||
|
|
||||||
for i, style in enumerate(styles[:count]):
|
|
||||||
action_data = {
|
|
||||||
"topic": topic,
|
|
||||||
"style": style,
|
|
||||||
"variation": i
|
|
||||||
}
|
|
||||||
|
|
||||||
success, reply_set = await generator_api.generate_reply(
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
action_data=action_data
|
|
||||||
)
|
|
||||||
|
|
||||||
if success and reply_set:
|
|
||||||
all_replies.extend(reply_set)
|
|
||||||
|
|
||||||
return all_replies
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 回复重写功能
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def improve_reply(chat_stream, original_reply, improvement_type="more_friendly"):
|
|
||||||
"""改进原始回复"""
|
|
||||||
|
|
||||||
reply_data = {
|
|
||||||
"original_text": original_reply,
|
|
||||||
"improvement_type": improvement_type,
|
|
||||||
"target_audience": "young_users",
|
|
||||||
"tone": "positive"
|
|
||||||
}
|
|
||||||
|
|
||||||
success, improved_replies = await generator_api.rewrite_reply(
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
reply_data=reply_data
|
|
||||||
)
|
|
||||||
|
|
||||||
if success and improved_replies:
|
|
||||||
# 返回改进后的第一个回复
|
|
||||||
_, improved_content = improved_replies[0]
|
|
||||||
return improved_content
|
|
||||||
|
|
||||||
return original_reply # 如果改进失败,返回原始回复
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 条件回复生成
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def conditional_reply_generation(chat_stream, user_message, user_emotion):
|
|
||||||
"""根据用户情感生成条件回复"""
|
|
||||||
|
|
||||||
# 根据情感调整回复策略
|
|
||||||
if user_emotion == "sad":
|
|
||||||
action_data = {
|
|
||||||
"intent": "comfort",
|
|
||||||
"tone": "empathetic",
|
|
||||||
"style": "supportive"
|
|
||||||
}
|
|
||||||
elif user_emotion == "angry":
|
|
||||||
action_data = {
|
|
||||||
"intent": "calm",
|
|
||||||
"tone": "peaceful",
|
|
||||||
"style": "understanding"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
action_data = {
|
|
||||||
"intent": "respond",
|
|
||||||
"tone": "neutral",
|
|
||||||
"style": "helpful"
|
|
||||||
}
|
|
||||||
|
|
||||||
action_data["user_message"] = user_message
|
|
||||||
action_data["user_emotion"] = user_emotion
|
|
||||||
|
|
||||||
success, reply_set = await generator_api.generate_reply(
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
action_data=action_data
|
|
||||||
)
|
|
||||||
|
|
||||||
return reply_set if success else []
|
|
||||||
```
|
|
||||||
|
|
||||||
## 回复集合格式
|
|
||||||
|
|
||||||
### 回复类型
|
### 回复类型
|
||||||
生成的回复集合包含多种类型的回复:
|
生成的回复集合包含多种类型的回复:
|
||||||
@@ -260,82 +170,32 @@ reply_set = [
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
## 高级用法
|
### 4. 自定义提示词回复
|
||||||
|
|
||||||
### 1. 自定义回复器配置
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
async def generate_with_custom_config(chat_stream, action_data):
|
async def generate_response_custom(
|
||||||
"""使用自定义配置生成回复"""
|
chat_stream: Optional[ChatStream] = None,
|
||||||
|
chat_id: Optional[str] = None,
|
||||||
# 获取回复器
|
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||||
replyer = generator_api.get_replyer(chat_stream=chat_stream)
|
prompt: str = "",
|
||||||
|
) -> Optional[str]:
|
||||||
if replyer:
|
|
||||||
# 可以访问回复器的内部方法
|
|
||||||
success, reply_set = await replyer.generate_reply_with_context(
|
|
||||||
reply_data=action_data,
|
|
||||||
# 可以传递额外的配置参数
|
|
||||||
)
|
|
||||||
return success, reply_set
|
|
||||||
|
|
||||||
return False, []
|
|
||||||
```
|
```
|
||||||
|
生成自定义提示词回复
|
||||||
|
|
||||||
### 2. 回复质量评估
|
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||||
|
|
||||||
```python
|
**Args:**
|
||||||
async def generate_and_evaluate_replies(chat_stream, action_data):
|
- `chat_stream`: 聊天流对象
|
||||||
"""生成回复并评估质量"""
|
- `chat_id`: 聊天ID(备用)
|
||||||
|
- `model_set_with_weight`: 模型集合配置列表
|
||||||
success, reply_set = await generator_api.generate_reply(
|
- `prompt`: 自定义提示词
|
||||||
chat_stream=chat_stream,
|
|
||||||
action_data=action_data
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
evaluated_replies = []
|
|
||||||
for reply_type, reply_content in reply_set:
|
|
||||||
# 简单的质量评估
|
|
||||||
quality_score = evaluate_reply_quality(reply_content)
|
|
||||||
evaluated_replies.append({
|
|
||||||
"type": reply_type,
|
|
||||||
"content": reply_content,
|
|
||||||
"quality": quality_score
|
|
||||||
})
|
|
||||||
|
|
||||||
# 按质量排序
|
|
||||||
evaluated_replies.sort(key=lambda x: x["quality"], reverse=True)
|
|
||||||
return evaluated_replies
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
def evaluate_reply_quality(reply_content):
|
**Returns:**
|
||||||
"""简单的回复质量评估"""
|
- `Optional[str]`: 生成的自定义回复内容,如果生成失败则返回None
|
||||||
if not reply_content:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
score = 50 # 基础分
|
|
||||||
|
|
||||||
# 长度适中加分
|
|
||||||
if 5 <= len(reply_content) <= 100:
|
|
||||||
score += 20
|
|
||||||
|
|
||||||
# 包含积极词汇加分
|
|
||||||
positive_words = ["好", "棒", "不错", "感谢", "开心"]
|
|
||||||
for word in positive_words:
|
|
||||||
if word in reply_content:
|
|
||||||
score += 10
|
|
||||||
break
|
|
||||||
|
|
||||||
return min(score, 100)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. **异步操作**:所有生成函数都是异步的,必须使用`await`
|
1. **异步操作**:部分函数是异步的,须使用`await`
|
||||||
2. **错误处理**:函数内置错误处理,失败时返回False和空列表
|
2. **聊天流依赖**:需要有效的聊天流对象才能正常工作
|
||||||
3. **聊天流依赖**:需要有效的聊天流对象才能正常工作
|
3. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时
|
||||||
4. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时
|
4. **回复格式**:返回的回复集合是元组列表,包含类型和内容
|
||||||
5. **回复格式**:返回的回复集合是元组列表,包含类型和内容
|
5. **上下文感知**:生成器会考虑聊天上下文和历史消息,除非你用的是自定义提示词。
|
||||||
6. **上下文感知**:生成器会考虑聊天上下文和历史消息
|
|
||||||
@@ -6,239 +6,60 @@ LLM API模块提供与大语言模型交互的功能,让插件能够使用系
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import llm_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 模型管理
|
### 1. 查询可用模型
|
||||||
|
|
||||||
#### `get_available_models() -> Dict[str, Any]`
|
|
||||||
获取所有可用的模型配置
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Dict[str, Any]`:模型配置字典,key为模型名称,value为模型配置
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
models = llm_api.get_available_models()
|
def get_available_models() -> Dict[str, TaskConfig]:
|
||||||
for model_name, model_config in models.items():
|
|
||||||
print(f"模型: {model_name}")
|
|
||||||
print(f"配置: {model_config}")
|
|
||||||
```
|
```
|
||||||
|
获取所有可用的模型配置。
|
||||||
|
|
||||||
### 2. 内容生成
|
**Return:**
|
||||||
|
- `Dict[str, TaskConfig]`:模型配置字典,key为模型名称,value为模型配置对象。
|
||||||
|
|
||||||
#### `generate_with_model(prompt, model_config, request_type="plugin.generate", **kwargs)`
|
### 2. 使用模型生成内容
|
||||||
使用指定模型生成内容
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `prompt`:提示词
|
|
||||||
- `model_config`:模型配置(从 get_available_models 获取)
|
|
||||||
- `request_type`:请求类型标识
|
|
||||||
- `**kwargs`:其他模型特定参数,如temperature、max_tokens等
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Tuple[bool, str, str, str]`:(是否成功, 生成的内容, 推理过程, 模型名称)
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
models = llm_api.get_available_models()
|
async def generate_with_model(
|
||||||
default_model = models.get("default")
|
prompt: str,
|
||||||
|
model_config: TaskConfig,
|
||||||
if default_model:
|
request_type: str = "plugin.generate",
|
||||||
success, response, reasoning, model_name = await llm_api.generate_with_model(
|
temperature: Optional[float] = None,
|
||||||
prompt="请写一首关于春天的诗",
|
max_tokens: Optional[int] = None,
|
||||||
model_config=default_model,
|
) -> Tuple[bool, str, str, str]:
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=200
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print(f"生成内容: {response}")
|
|
||||||
print(f"使用模型: {model_name}")
|
|
||||||
```
|
```
|
||||||
|
使用指定模型生成内容。
|
||||||
|
|
||||||
## 使用示例
|
**Args:**
|
||||||
|
- `prompt`:提示词。
|
||||||
|
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
|
||||||
|
- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
|
||||||
|
- `temperature`:生成内容的温度设置,影响输出的随机性。
|
||||||
|
- `max_tokens`:生成内容的最大token数。
|
||||||
|
|
||||||
### 1. 基础文本生成
|
**Return:**
|
||||||
|
- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
|
||||||
|
|
||||||
|
### 3. 有Tool情况下使用模型生成内容
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import llm_api
|
async def generate_with_model_with_tools(
|
||||||
|
prompt: str,
|
||||||
async def generate_story(topic: str):
|
model_config: TaskConfig,
|
||||||
"""生成故事"""
|
tool_options: List[Dict[str, Any]] | None = None,
|
||||||
models = llm_api.get_available_models()
|
request_type: str = "plugin.generate",
|
||||||
model = models.get("default")
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
if not model:
|
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||||
return "未找到可用模型"
|
|
||||||
|
|
||||||
prompt = f"请写一个关于{topic}的短故事,大约100字左右。"
|
|
||||||
|
|
||||||
success, story, reasoning, model_name = await llm_api.generate_with_model(
|
|
||||||
prompt=prompt,
|
|
||||||
model_config=model,
|
|
||||||
request_type="story.generate",
|
|
||||||
temperature=0.8,
|
|
||||||
max_tokens=150
|
|
||||||
)
|
|
||||||
|
|
||||||
return story if success else "故事生成失败"
|
|
||||||
```
|
```
|
||||||
|
使用指定模型生成内容,并支持工具调用。
|
||||||
|
|
||||||
### 2. 在Action中使用LLM
|
**Args:**
|
||||||
|
- `prompt`:提示词。
|
||||||
```python
|
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
|
||||||
from src.plugin_system.base import BaseAction
|
- `tool_options`:工具选项列表,包含可用工具的配置,字典为每一个工具的定义,参见[tool-components.md](../tool-components.md#属性说明),可用`tool_api.get_llm_available_tool_definitions()`获取并选择。
|
||||||
|
- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
|
||||||
class LLMAction(BaseAction):
|
- `temperature`:生成内容的温度设置,影响输出的随机性。
|
||||||
async def execute(self, action_data, chat_stream):
|
- `max_tokens`:生成内容的最大token数。
|
||||||
# 获取用户输入
|
|
||||||
user_input = action_data.get("user_message", "")
|
|
||||||
intent = action_data.get("intent", "chat")
|
|
||||||
|
|
||||||
# 获取模型配置
|
|
||||||
models = llm_api.get_available_models()
|
|
||||||
model = models.get("default")
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
return {"success": False, "error": "未配置LLM模型"}
|
|
||||||
|
|
||||||
# 构建提示词
|
|
||||||
prompt = self.build_prompt(user_input, intent)
|
|
||||||
|
|
||||||
# 生成回复
|
|
||||||
success, response, reasoning, model_name = await llm_api.generate_with_model(
|
|
||||||
prompt=prompt,
|
|
||||||
model_config=model,
|
|
||||||
request_type=f"plugin.{self.plugin_name}",
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"response": response,
|
|
||||||
"model_used": model_name,
|
|
||||||
"reasoning": reasoning
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"success": False, "error": response}
|
|
||||||
|
|
||||||
def build_prompt(self, user_input: str, intent: str) -> str:
|
|
||||||
"""构建提示词"""
|
|
||||||
base_prompt = "你是一个友善的AI助手。"
|
|
||||||
|
|
||||||
if intent == "question":
|
|
||||||
return f"{base_prompt}\n\n用户问题:{user_input}\n\n请提供准确、有用的回答:"
|
|
||||||
elif intent == "chat":
|
|
||||||
return f"{base_prompt}\n\n用户说:{user_input}\n\n请进行自然的对话:"
|
|
||||||
else:
|
|
||||||
return f"{base_prompt}\n\n用户输入:{user_input}\n\n请回复:"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 多模型对比
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def compare_models(prompt: str):
|
|
||||||
"""使用多个模型生成内容并对比"""
|
|
||||||
models = llm_api.get_available_models()
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for model_name, model_config in models.items():
|
|
||||||
success, response, reasoning, actual_model = await llm_api.generate_with_model(
|
|
||||||
prompt=prompt,
|
|
||||||
model_config=model_config,
|
|
||||||
request_type="comparison.test"
|
|
||||||
)
|
|
||||||
|
|
||||||
results[model_name] = {
|
|
||||||
"success": success,
|
|
||||||
"response": response,
|
|
||||||
"model": actual_model,
|
|
||||||
"reasoning": reasoning
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 智能对话插件
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ChatbotPlugin(BasePlugin):
|
|
||||||
async def handle_action(self, action_data, chat_stream):
|
|
||||||
user_message = action_data.get("message", "")
|
|
||||||
|
|
||||||
# 获取历史对话上下文
|
|
||||||
context = self.get_conversation_context(chat_stream)
|
|
||||||
|
|
||||||
# 构建对话提示词
|
|
||||||
prompt = self.build_conversation_prompt(user_message, context)
|
|
||||||
|
|
||||||
# 获取模型配置
|
|
||||||
models = llm_api.get_available_models()
|
|
||||||
chat_model = models.get("chat", models.get("default"))
|
|
||||||
|
|
||||||
if not chat_model:
|
|
||||||
return {"success": False, "message": "聊天模型未配置"}
|
|
||||||
|
|
||||||
# 生成回复
|
|
||||||
success, response, reasoning, model_name = await llm_api.generate_with_model(
|
|
||||||
prompt=prompt,
|
|
||||||
model_config=chat_model,
|
|
||||||
request_type="chat.conversation",
|
|
||||||
temperature=0.8,
|
|
||||||
max_tokens=500
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
# 保存对话历史
|
|
||||||
self.save_conversation(chat_stream, user_message, response)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"reply": response,
|
|
||||||
"model": model_name
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"success": False, "message": "回复生成失败"}
|
|
||||||
|
|
||||||
def build_conversation_prompt(self, user_message: str, context: list) -> str:
|
|
||||||
"""构建对话提示词"""
|
|
||||||
prompt = "你是一个有趣、友善的聊天机器人。请自然地回复用户的消息。\n\n"
|
|
||||||
|
|
||||||
# 添加历史对话
|
|
||||||
if context:
|
|
||||||
prompt += "对话历史:\n"
|
|
||||||
for msg in context[-5:]: # 只保留最近5条
|
|
||||||
prompt += f"用户: {msg['user']}\n机器人: {msg['bot']}\n"
|
|
||||||
prompt += "\n"
|
|
||||||
|
|
||||||
prompt += f"用户: {user_message}\n机器人: "
|
|
||||||
return prompt
|
|
||||||
```
|
|
||||||
|
|
||||||
## 模型配置说明
|
|
||||||
|
|
||||||
### 常用模型类型
|
|
||||||
- `default`:默认模型
|
|
||||||
- `chat`:聊天专用模型
|
|
||||||
- `creative`:创意生成模型
|
|
||||||
- `code`:代码生成模型
|
|
||||||
|
|
||||||
### 配置参数
|
|
||||||
LLM模型支持的常用参数:
|
|
||||||
- `temperature`:控制输出随机性(0.0-1.0)
|
|
||||||
- `max_tokens`:最大生成长度
|
|
||||||
- `top_p`:核采样参数
|
|
||||||
- `frequency_penalty`:频率惩罚
|
|
||||||
- `presence_penalty`:存在惩罚
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **异步操作**:LLM生成是异步的,必须使用`await`
|
|
||||||
2. **错误处理**:生成失败时返回False和错误信息
|
|
||||||
3. **配置依赖**:需要正确配置模型才能使用
|
|
||||||
4. **请求类型**:建议为不同用途设置不同的request_type
|
|
||||||
5. **性能考虑**:LLM调用可能较慢,考虑超时和缓存
|
|
||||||
6. **成本控制**:注意控制max_tokens以控制成本
|
|
||||||
29
docs/plugins/api/logging-api.md
Normal file
29
docs/plugins/api/logging-api.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Logging API
|
||||||
|
|
||||||
|
Logging API模块提供了获取本体logger的功能,允许插件记录日志信息。
|
||||||
|
|
||||||
|
## 导入方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis import get_logger
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要功能
|
||||||
|
### 1. 获取本体logger
|
||||||
|
```python
|
||||||
|
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||||
|
```
|
||||||
|
获取本体logger实例。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `name` (str): 日志记录器的名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- 一个logger实例,有以下方法:
|
||||||
|
- `debug`
|
||||||
|
- `info`
|
||||||
|
- `warning`
|
||||||
|
- `error`
|
||||||
|
- `critical`
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
# 消息API
|
# 消息API
|
||||||
|
|
||||||
> 消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。
|
消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。
|
||||||
|
|
||||||
## 导入方式
|
## 导入方式
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import message_api
|
from src.plugin_system.apis import message_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import message_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 功能概述
|
## 功能概述
|
||||||
@@ -15,297 +17,356 @@ from src.plugin_system.apis import message_api
|
|||||||
- **消息计数** - 统计新消息数量
|
- **消息计数** - 统计新消息数量
|
||||||
- **消息格式化** - 将消息转换为可读格式
|
- **消息格式化** - 将消息转换为可读格式
|
||||||
|
|
||||||
---
|
## 主要功能
|
||||||
|
|
||||||
## 消息查询API
|
### 1. 按照事件查询消息
|
||||||
|
```python
|
||||||
|
def get_messages_by_time(
|
||||||
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定时间范围内的消息。
|
||||||
|
|
||||||
### 按时间查询消息
|
**Args:**
|
||||||
|
|
||||||
#### `get_messages_by_time(start_time, end_time, limit=0, limit_mode="latest")`
|
|
||||||
|
|
||||||
获取指定时间范围内的消息
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `start_time` (float): 开始时间戳
|
- `start_time` (float): 开始时间戳
|
||||||
- `end_time` (float): 结束时间戳
|
- `end_time` (float): 结束时间戳
|
||||||
- `limit` (int): 限制返回消息数量,0为不限制
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
**返回:** `List[Dict[str, Any]]` - 消息列表
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
**示例:**
|
消息列表中包含的键与`Messages`类的属性一致。(位于`src.common.database.database_model`)
|
||||||
|
|
||||||
|
### 2. 获取指定聊天中指定时间范围内的信息
|
||||||
```python
|
```python
|
||||||
import time
|
def get_messages_by_time_in_chat(
|
||||||
|
chat_id: str,
|
||||||
# 获取最近24小时的消息
|
start_time: float,
|
||||||
now = time.time()
|
end_time: float,
|
||||||
yesterday = now - 24 * 3600
|
limit: int = 0,
|
||||||
messages = message_api.get_messages_by_time(yesterday, now, limit=50)
|
limit_mode: str = "latest",
|
||||||
|
filter_mai: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
```
|
```
|
||||||
|
获取指定聊天中指定时间范围内的消息。
|
||||||
|
|
||||||
### 按聊天查询消息
|
**Args:**
|
||||||
|
|
||||||
#### `get_messages_by_time_in_chat(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
|
|
||||||
|
|
||||||
获取指定聊天中指定时间范围内的消息
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `chat_id` (str): 聊天ID
|
|
||||||
- 其他参数同上
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
# 获取某个群聊最近的100条消息
|
|
||||||
messages = message_api.get_messages_by_time_in_chat(
|
|
||||||
chat_id="123456789",
|
|
||||||
start_time=yesterday,
|
|
||||||
end_time=now,
|
|
||||||
limit=100
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `get_messages_by_time_in_chat_inclusive(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
|
|
||||||
|
|
||||||
获取指定聊天中指定时间范围内的消息(包含边界时间点)
|
|
||||||
|
|
||||||
与 `get_messages_by_time_in_chat` 类似,但包含边界时间戳的消息。
|
|
||||||
|
|
||||||
#### `get_recent_messages(chat_id, hours=24.0, limit=100, limit_mode="latest")`
|
|
||||||
|
|
||||||
获取指定聊天中最近一段时间的消息(便捷方法)
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `chat_id` (str): 聊天ID
|
|
||||||
- `hours` (float): 最近多少小时,默认24小时
|
|
||||||
- `limit` (int): 限制返回消息数量,默认100条
|
|
||||||
- `limit_mode` (str): 限制模式
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
# 获取最近6小时的消息
|
|
||||||
recent_messages = message_api.get_recent_messages(
|
|
||||||
chat_id="123456789",
|
|
||||||
hours=6.0,
|
|
||||||
limit=50
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 按用户查询消息
|
|
||||||
|
|
||||||
#### `get_messages_by_time_in_chat_for_users(chat_id, start_time, end_time, person_ids, limit=0, limit_mode="latest")`
|
|
||||||
|
|
||||||
获取指定聊天中指定用户在指定时间范围内的消息
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `chat_id` (str): 聊天ID
|
- `chat_id` (str): 聊天ID
|
||||||
- `start_time` (float): 开始时间戳
|
- `start_time` (float): 开始时间戳
|
||||||
- `end_time` (float): 结束时间戳
|
- `end_time` (float): 结束时间戳
|
||||||
- `person_ids` (list): 用户ID列表
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
- `limit` (int): 限制返回消息数量
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
- `limit_mode` (str): 限制模式
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
**示例:**
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
|
|
||||||
|
### 3. 获取指定聊天中指定时间范围内的信息(包含边界)
|
||||||
```python
|
```python
|
||||||
# 获取特定用户的消息
|
def get_messages_by_time_in_chat_inclusive(
|
||||||
user_messages = message_api.get_messages_by_time_in_chat_for_users(
|
chat_id: str,
|
||||||
chat_id="123456789",
|
start_time: float,
|
||||||
start_time=yesterday,
|
end_time: float,
|
||||||
end_time=now,
|
limit: int = 0,
|
||||||
person_ids=["user1", "user2"]
|
limit_mode: str = "latest",
|
||||||
)
|
filter_mai: bool = False,
|
||||||
|
filter_command: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
```
|
```
|
||||||
|
获取指定聊天中指定时间范围内的消息(包含边界)。
|
||||||
|
|
||||||
#### `get_messages_by_time_for_users(start_time, end_time, person_ids, limit=0, limit_mode="latest")`
|
**Args:**
|
||||||
|
- `chat_id` (str): 聊天ID
|
||||||
|
- `start_time` (float): 开始时间戳(包含)
|
||||||
|
- `end_time` (float): 结束时间戳(包含)
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
- `filter_command` (bool): 是否过滤命令消息,默认False
|
||||||
|
|
||||||
获取指定用户在所有聊天中指定时间范围内的消息
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
### 其他查询方法
|
|
||||||
|
|
||||||
#### `get_random_chat_messages(start_time, end_time, limit=0, limit_mode="latest")`
|
### 4. 获取指定聊天中指定用户在指定时间范围内的消息
|
||||||
|
```python
|
||||||
|
def get_messages_by_time_in_chat_for_users(
|
||||||
|
chat_id: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
person_ids: List[str],
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定聊天中指定用户在指定时间范围内的消息。
|
||||||
|
|
||||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
**Args:**
|
||||||
|
|
||||||
#### `get_messages_before_time(timestamp, limit=0)`
|
|
||||||
|
|
||||||
获取指定时间戳之前的消息
|
|
||||||
|
|
||||||
#### `get_messages_before_time_in_chat(chat_id, timestamp, limit=0)`
|
|
||||||
|
|
||||||
获取指定聊天中指定时间戳之前的消息
|
|
||||||
|
|
||||||
#### `get_messages_before_time_for_users(timestamp, person_ids, limit=0)`
|
|
||||||
|
|
||||||
获取指定用户在指定时间戳之前的消息
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 消息计数API
|
|
||||||
|
|
||||||
### `count_new_messages(chat_id, start_time=0.0, end_time=None)`
|
|
||||||
|
|
||||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `chat_id` (str): 聊天ID
|
- `chat_id` (str): 聊天ID
|
||||||
- `start_time` (float): 开始时间戳
|
- `start_time` (float): 开始时间戳
|
||||||
- `end_time` (float): 结束时间戳,如果为None则使用当前时间
|
- `end_time` (float): 结束时间戳
|
||||||
|
- `person_ids` (List[str]): 用户ID列表
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
|
||||||
**返回:** `int` - 新消息数量
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
### 5. 随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||||
```python
|
```python
|
||||||
# 计算最近1小时的新消息数
|
def get_random_chat_messages(
|
||||||
import time
|
start_time: float,
|
||||||
now = time.time()
|
end_time: float,
|
||||||
hour_ago = now - 3600
|
limit: int = 0,
|
||||||
new_count = message_api.count_new_messages("123456789", hour_ago, now)
|
limit_mode: str = "latest",
|
||||||
print(f"最近1小时有{new_count}条新消息")
|
filter_mai: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
```
|
```
|
||||||
|
随机选择一个聊天,返回该聊天在指定时间范围内的消息。
|
||||||
|
|
||||||
### `count_new_messages_for_users(chat_id, start_time, end_time, person_ids)`
|
**Args:**
|
||||||
|
- `start_time` (float): 开始时间戳
|
||||||
|
- `end_time` (float): 结束时间戳
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 消息格式化API
|
### 6. 获取指定用户在所有聊天中指定时间范围内的消息
|
||||||
|
```python
|
||||||
|
def get_messages_by_time_for_users(
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
person_ids: List[str],
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定用户在所有聊天中指定时间范围内的消息。
|
||||||
|
|
||||||
### `build_readable_messages_to_str(messages, **options)`
|
**Args:**
|
||||||
|
- `start_time` (float): 开始时间戳
|
||||||
|
- `end_time` (float): 结束时间戳
|
||||||
|
- `person_ids` (List[str]): 用户ID列表
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
|
||||||
将消息列表构建成可读的字符串
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
**参数:**
|
|
||||||
|
### 7. 获取指定时间戳之前的消息
|
||||||
|
```python
|
||||||
|
def get_messages_before_time(
|
||||||
|
timestamp: float,
|
||||||
|
limit: int = 0,
|
||||||
|
filter_mai: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定时间戳之前的消息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `timestamp` (float): 时间戳
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
|
|
||||||
|
### 8. 获取指定聊天中指定时间戳之前的消息
|
||||||
|
```python
|
||||||
|
def get_messages_before_time_in_chat(
|
||||||
|
chat_id: str,
|
||||||
|
timestamp: float,
|
||||||
|
limit: int = 0,
|
||||||
|
filter_mai: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定聊天中指定时间戳之前的消息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_id` (str): 聊天ID
|
||||||
|
- `timestamp` (float): 时间戳
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
|
|
||||||
|
### 9. 获取指定用户在指定时间戳之前的消息
|
||||||
|
```python
|
||||||
|
def get_messages_before_time_for_users(
|
||||||
|
timestamp: float,
|
||||||
|
person_ids: List[str],
|
||||||
|
limit: int = 0,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定用户在指定时间戳之前的消息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `timestamp` (float): 时间戳
|
||||||
|
- `person_ids` (List[str]): 用户ID列表
|
||||||
|
- `limit` (int): 限制返回消息数量,0为不限制
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
|
|
||||||
|
### 10. 获取指定聊天中最近一段时间的消息
|
||||||
|
```python
|
||||||
|
def get_recent_messages(
|
||||||
|
chat_id: str,
|
||||||
|
hours: float = 24.0,
|
||||||
|
limit: int = 100,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
filter_mai: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
获取指定聊天中最近一段时间的消息。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_id` (str): 聊天ID
|
||||||
|
- `hours` (float): 最近多少小时,默认24小时
|
||||||
|
- `limit` (int): 限制返回消息数量,默认100条
|
||||||
|
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
|
||||||
|
- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 消息列表
|
||||||
|
|
||||||
|
|
||||||
|
### 11. 计算指定聊天中从开始时间到结束时间的新消息数量
|
||||||
|
```python
|
||||||
|
def count_new_messages(
|
||||||
|
chat_id: str,
|
||||||
|
start_time: float = 0.0,
|
||||||
|
end_time: Optional[float] = None,
|
||||||
|
) -> int:
|
||||||
|
```
|
||||||
|
计算指定聊天中从开始时间到结束时间的新消息数量。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_id` (str): 聊天ID
|
||||||
|
- `start_time` (float): 开始时间戳
|
||||||
|
- `end_time` (Optional[float]): 结束时间戳,如果为None则使用当前时间
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `int` - 新消息数量
|
||||||
|
|
||||||
|
|
||||||
|
### 12. 计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||||
|
```python
|
||||||
|
def count_new_messages_for_users(
|
||||||
|
chat_id: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
person_ids: List[str],
|
||||||
|
) -> int:
|
||||||
|
```
|
||||||
|
计算指定聊天中指定用户从开始时间到结束时间的新消息数量。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `chat_id` (str): 聊天ID
|
||||||
|
- `start_time` (float): 开始时间戳
|
||||||
|
- `end_time` (float): 结束时间戳
|
||||||
|
- `person_ids` (List[str]): 用户ID列表
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `int` - 新消息数量
|
||||||
|
|
||||||
|
|
||||||
|
### 13. 将消息列表构建成可读的字符串
|
||||||
|
```python
|
||||||
|
def build_readable_messages_to_str(
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
replace_bot_name: bool = True,
|
||||||
|
merge_messages: bool = False,
|
||||||
|
timestamp_mode: str = "relative",
|
||||||
|
read_mark: float = 0.0,
|
||||||
|
truncate: bool = False,
|
||||||
|
show_actions: bool = False,
|
||||||
|
) -> str:
|
||||||
|
```
|
||||||
|
将消息列表构建成可读的字符串。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
- `messages` (List[Dict[str, Any]]): 消息列表
|
- `messages` (List[Dict[str, Any]]): 消息列表
|
||||||
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你",默认True
|
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
|
||||||
- `merge_messages` (bool): 是否合并连续消息,默认False
|
- `merge_messages` (bool): 是否合并连续消息
|
||||||
- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`,默认`"relative"`
|
- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
|
||||||
- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息,默认0.0
|
- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息
|
||||||
- `truncate` (bool): 是否截断长消息,默认False
|
- `truncate` (bool): 是否截断长消息
|
||||||
- `show_actions` (bool): 是否显示动作记录,默认False
|
- `show_actions` (bool): 是否显示动作记录
|
||||||
|
|
||||||
**返回:** `str` - 格式化后的可读字符串
|
**Returns:**
|
||||||
|
- `str` - 格式化后的可读字符串
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
### 14. 将消息列表构建成可读的字符串,并返回详细信息
|
||||||
```python
|
```python
|
||||||
# 获取消息并格式化为可读文本
|
async def build_readable_messages_with_details(
|
||||||
messages = message_api.get_recent_messages("123456789", hours=2)
|
messages: List[Dict[str, Any]],
|
||||||
readable_text = message_api.build_readable_messages_to_str(
|
replace_bot_name: bool = True,
|
||||||
messages,
|
merge_messages: bool = False,
|
||||||
replace_bot_name=True,
|
timestamp_mode: str = "relative",
|
||||||
merge_messages=True,
|
truncate: bool = False,
|
||||||
timestamp_mode="relative"
|
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||||
)
|
|
||||||
print(readable_text)
|
|
||||||
```
|
```
|
||||||
|
将消息列表构建成可读的字符串,并返回详细信息。
|
||||||
|
|
||||||
### `build_readable_messages_with_details(messages, **options)` 异步
|
**Args:**
|
||||||
|
- `messages` (List[Dict[str, Any]]): 消息列表
|
||||||
|
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
|
||||||
|
- `merge_messages` (bool): 是否合并连续消息
|
||||||
|
- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
|
||||||
|
- `truncate` (bool): 是否截断长消息
|
||||||
|
|
||||||
将消息列表构建成可读的字符串,并返回详细信息
|
**Returns:**
|
||||||
|
- `Tuple[str, List[Tuple[float, str, str]]]` - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
||||||
|
|
||||||
**参数:** 与 `build_readable_messages_to_str` 类似,但不包含 `read_mark` 和 `show_actions`
|
|
||||||
|
|
||||||
**返回:** `Tuple[str, List[Tuple[float, str, str]]]` - 格式化字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
### 15. 从消息列表中提取不重复的用户ID列表
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
```python
|
||||||
# 异步获取详细格式化信息
|
async def get_person_ids_from_messages(
|
||||||
readable_text, details = await message_api.build_readable_messages_with_details(
|
messages: List[Dict[str, Any]],
|
||||||
messages,
|
) -> List[str]:
|
||||||
timestamp_mode="absolute"
|
|
||||||
)
|
|
||||||
|
|
||||||
for timestamp, nickname, content in details:
|
|
||||||
print(f"{timestamp}: {nickname} 说: {content}")
|
|
||||||
```
|
```
|
||||||
|
从消息列表中提取不重复的用户ID列表。
|
||||||
|
|
||||||
### `get_person_ids_from_messages(messages)` 异步
|
**Args:**
|
||||||
|
|
||||||
从消息列表中提取不重复的用户ID列表
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `messages` (List[Dict[str, Any]]): 消息列表
|
- `messages` (List[Dict[str, Any]]): 消息列表
|
||||||
|
|
||||||
**返回:** `List[str]` - 用户ID列表
|
**Returns:**
|
||||||
|
- `List[str]` - 用户ID列表
|
||||||
|
|
||||||
**示例:**
|
|
||||||
|
### 16. 从消息列表中移除机器人的消息
|
||||||
```python
|
```python
|
||||||
# 获取参与对话的所有用户ID
|
def filter_mai_messages(
|
||||||
messages = message_api.get_recent_messages("123456789")
|
messages: List[Dict[str, Any]],
|
||||||
person_ids = await message_api.get_person_ids_from_messages(messages)
|
) -> List[Dict[str, Any]]:
|
||||||
print(f"参与对话的用户: {person_ids}")
|
|
||||||
```
|
```
|
||||||
|
从消息列表中移除机器人的消息。
|
||||||
|
|
||||||
---
|
**Args:**
|
||||||
|
- `messages` (List[Dict[str, Any]]): 消息列表,每个元素是消息字典
|
||||||
|
|
||||||
## 完整使用示例
|
**Returns:**
|
||||||
|
- `List[Dict[str, Any]]` - 过滤后的消息列表
|
||||||
### 场景1:统计活跃度
|
|
||||||
|
|
||||||
```python
|
|
||||||
import time
|
|
||||||
from src.plugin_system.apis import message_api
|
|
||||||
|
|
||||||
async def analyze_chat_activity(chat_id: str):
|
|
||||||
"""分析聊天活跃度"""
|
|
||||||
now = time.time()
|
|
||||||
day_ago = now - 24 * 3600
|
|
||||||
|
|
||||||
# 获取最近24小时的消息
|
|
||||||
messages = message_api.get_recent_messages(chat_id, hours=24)
|
|
||||||
|
|
||||||
# 统计消息数量
|
|
||||||
total_count = len(messages)
|
|
||||||
|
|
||||||
# 获取参与用户
|
|
||||||
person_ids = await message_api.get_person_ids_from_messages(messages)
|
|
||||||
|
|
||||||
# 格式化消息内容
|
|
||||||
readable_text = message_api.build_readable_messages_to_str(
|
|
||||||
messages[-10:], # 最后10条消息
|
|
||||||
merge_messages=True,
|
|
||||||
timestamp_mode="relative"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_messages": total_count,
|
|
||||||
"active_users": len(person_ids),
|
|
||||||
"recent_chat": readable_text
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 场景2:查看特定用户的历史消息
|
|
||||||
|
|
||||||
```python
|
|
||||||
def get_user_history(chat_id: str, user_id: str, days: int = 7):
|
|
||||||
"""获取用户最近N天的消息历史"""
|
|
||||||
now = time.time()
|
|
||||||
start_time = now - days * 24 * 3600
|
|
||||||
|
|
||||||
# 获取特定用户的消息
|
|
||||||
user_messages = message_api.get_messages_by_time_in_chat_for_users(
|
|
||||||
chat_id=chat_id,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=now,
|
|
||||||
person_ids=[user_id],
|
|
||||||
limit=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# 格式化为可读文本
|
|
||||||
readable_history = message_api.build_readable_messages_to_str(
|
|
||||||
user_messages,
|
|
||||||
replace_bot_name=False,
|
|
||||||
timestamp_mode="absolute"
|
|
||||||
)
|
|
||||||
|
|
||||||
return readable_history
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. **时间戳格式**:所有时间参数都使用Unix时间戳(float类型)
|
1. **时间戳格式**:所有时间参数都使用Unix时间戳(float类型)
|
||||||
2. **异步函数**:`build_readable_messages_with_details` 和 `get_person_ids_from_messages` 是异步函数,需要使用 `await`
|
2. **异步函数**:部分函数是异步函数,需要使用 `await`
|
||||||
3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数
|
3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数
|
||||||
4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息
|
4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息
|
||||||
5. **用户ID**:`person_ids` 参数接受字符串列表,用于筛选特定用户的消息
|
5. **用户ID**:`person_ids` 参数接受字符串列表,用于筛选特定用户的消息
|
||||||
@@ -6,59 +6,65 @@
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import person_api
|
from src.plugin_system.apis import person_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import person_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. Person ID管理
|
### 1. Person ID 获取
|
||||||
|
```python
|
||||||
#### `get_person_id(platform: str, user_id: int) -> str`
|
def get_person_id(platform: str, user_id: int) -> str:
|
||||||
|
```
|
||||||
根据平台和用户ID获取person_id
|
根据平台和用户ID获取person_id
|
||||||
|
|
||||||
**参数:**
|
**Args:**
|
||||||
- `platform`:平台名称,如 "qq", "telegram" 等
|
- `platform`:平台名称,如 "qq", "telegram" 等
|
||||||
- `user_id`:用户ID
|
- `user_id`:用户ID
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `str`:唯一的person_id(MD5哈希值)
|
- `str`:唯一的person_id(MD5哈希值)
|
||||||
|
|
||||||
**示例:**
|
#### 示例
|
||||||
```python
|
```python
|
||||||
person_id = person_api.get_person_id("qq", 123456)
|
person_id = person_api.get_person_id("qq", 123456)
|
||||||
print(f"Person ID: {person_id}")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 用户信息查询
|
### 2. 用户信息查询
|
||||||
|
```python
|
||||||
|
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
|
||||||
|
```
|
||||||
|
查询单个用户信息字段值
|
||||||
|
|
||||||
#### `get_person_value(person_id: str, field_name: str, default: Any = None) -> Any`
|
**Args:**
|
||||||
根据person_id和字段名获取某个值
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `person_id`:用户的唯一标识ID
|
- `person_id`:用户的唯一标识ID
|
||||||
- `field_name`:要获取的字段名,如 "nickname", "impression" 等
|
- `field_name`:要获取的字段名
|
||||||
- `default`:当字段不存在或获取失败时返回的默认值
|
- `default`:字段值不存在时的默认值
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `Any`:字段值或默认值
|
- `Any`:字段值或默认值
|
||||||
|
|
||||||
**示例:**
|
#### 示例
|
||||||
```python
|
```python
|
||||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||||
impression = await person_api.get_person_value(person_id, "impression")
|
impression = await person_api.get_person_value(person_id, "impression")
|
||||||
```
|
```
|
||||||
|
|
||||||
#### `get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict`
|
### 3. 批量用户信息查询
|
||||||
|
```python
|
||||||
|
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
|
||||||
|
```
|
||||||
批量获取用户信息字段值
|
批量获取用户信息字段值
|
||||||
|
|
||||||
**参数:**
|
**Args:**
|
||||||
- `person_id`:用户的唯一标识ID
|
- `person_id`:用户的唯一标识ID
|
||||||
- `field_names`:要获取的字段名列表
|
- `field_names`:要获取的字段名列表
|
||||||
- `default_dict`:默认值字典,键为字段名,值为默认值
|
- `default_dict`:默认值字典,键为字段名,值为默认值
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `dict`:字段名到值的映射字典
|
- `dict`:字段名到值的映射字典
|
||||||
|
|
||||||
**示例:**
|
#### 示例
|
||||||
```python
|
```python
|
||||||
values = await person_api.get_person_values(
|
values = await person_api.get_person_values(
|
||||||
person_id,
|
person_id,
|
||||||
@@ -67,204 +73,31 @@ values = await person_api.get_person_values(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 用户状态查询
|
### 4. 判断用户是否已知
|
||||||
|
```python
|
||||||
#### `is_person_known(platform: str, user_id: int) -> bool`
|
async def is_person_known(platform: str, user_id: int) -> bool:
|
||||||
|
```
|
||||||
判断是否认识某个用户
|
判断是否认识某个用户
|
||||||
|
|
||||||
**参数:**
|
**Args:**
|
||||||
- `platform`:平台名称
|
- `platform`:平台名称
|
||||||
- `user_id`:用户ID
|
- `user_id`:用户ID
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `bool`:是否认识该用户
|
- `bool`:是否认识该用户
|
||||||
|
|
||||||
**示例:**
|
### 5. 根据用户名获取Person ID
|
||||||
```python
|
```python
|
||||||
known = await person_api.is_person_known("qq", 123456)
|
def get_person_id_by_name(person_name: str) -> str:
|
||||||
if known:
|
|
||||||
print("这个用户我认识")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. 用户名查询
|
|
||||||
|
|
||||||
#### `get_person_id_by_name(person_name: str) -> str`
|
|
||||||
根据用户名获取person_id
|
根据用户名获取person_id
|
||||||
|
|
||||||
**参数:**
|
**Args:**
|
||||||
- `person_name`:用户名
|
- `person_name`:用户名
|
||||||
|
|
||||||
**返回:**
|
**Returns:**
|
||||||
- `str`:person_id,如果未找到返回空字符串
|
- `str`:person_id,如果未找到返回空字符串
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
person_id = person_api.get_person_id_by_name("张三")
|
|
||||||
if person_id:
|
|
||||||
print(f"找到用户: {person_id}")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
### 1. 基础用户信息获取
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.apis import person_api
|
|
||||||
|
|
||||||
async def get_user_info(platform: str, user_id: int):
|
|
||||||
"""获取用户基本信息"""
|
|
||||||
|
|
||||||
# 获取person_id
|
|
||||||
person_id = person_api.get_person_id(platform, user_id)
|
|
||||||
|
|
||||||
# 获取用户信息
|
|
||||||
user_info = await person_api.get_person_values(
|
|
||||||
person_id,
|
|
||||||
["nickname", "impression", "know_times", "last_seen"],
|
|
||||||
{
|
|
||||||
"nickname": "未知用户",
|
|
||||||
"impression": "",
|
|
||||||
"know_times": 0,
|
|
||||||
"last_seen": 0
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"person_id": person_id,
|
|
||||||
"nickname": user_info["nickname"],
|
|
||||||
"impression": user_info["impression"],
|
|
||||||
"know_times": user_info["know_times"],
|
|
||||||
"last_seen": user_info["last_seen"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 在Action中使用用户信息
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.base import BaseAction
|
|
||||||
|
|
||||||
class PersonalizedAction(BaseAction):
|
|
||||||
async def execute(self, action_data, chat_stream):
|
|
||||||
# 获取发送者信息
|
|
||||||
user_id = chat_stream.user_info.user_id
|
|
||||||
platform = chat_stream.platform
|
|
||||||
|
|
||||||
# 获取person_id
|
|
||||||
person_id = person_api.get_person_id(platform, user_id)
|
|
||||||
|
|
||||||
# 获取用户昵称和印象
|
|
||||||
nickname = await person_api.get_person_value(person_id, "nickname", "朋友")
|
|
||||||
impression = await person_api.get_person_value(person_id, "impression", "")
|
|
||||||
|
|
||||||
# 根据用户信息个性化回复
|
|
||||||
if impression:
|
|
||||||
response = f"你好 {nickname}!根据我对你的了解:{impression}"
|
|
||||||
else:
|
|
||||||
response = f"你好 {nickname}!很高兴见到你。"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"response": response,
|
|
||||||
"user_info": {
|
|
||||||
"nickname": nickname,
|
|
||||||
"impression": impression
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 用户识别和欢迎
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def welcome_user(chat_stream):
|
|
||||||
"""欢迎用户,区分新老用户"""
|
|
||||||
|
|
||||||
user_id = chat_stream.user_info.user_id
|
|
||||||
platform = chat_stream.platform
|
|
||||||
|
|
||||||
# 检查是否认识这个用户
|
|
||||||
is_known = await person_api.is_person_known(platform, user_id)
|
|
||||||
|
|
||||||
if is_known:
|
|
||||||
# 老用户,获取详细信息
|
|
||||||
person_id = person_api.get_person_id(platform, user_id)
|
|
||||||
nickname = await person_api.get_person_value(person_id, "nickname", "老朋友")
|
|
||||||
know_times = await person_api.get_person_value(person_id, "know_times", 0)
|
|
||||||
|
|
||||||
welcome_msg = f"欢迎回来,{nickname}!我们已经聊过 {know_times} 次了。"
|
|
||||||
else:
|
|
||||||
# 新用户
|
|
||||||
welcome_msg = "你好!很高兴认识你,我是MaiBot。"
|
|
||||||
|
|
||||||
return welcome_msg
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 用户搜索功能
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def find_user_by_name(name: str):
|
|
||||||
"""根据名字查找用户"""
|
|
||||||
|
|
||||||
person_id = person_api.get_person_id_by_name(name)
|
|
||||||
|
|
||||||
if not person_id:
|
|
||||||
return {"found": False, "message": f"未找到名为 '{name}' 的用户"}
|
|
||||||
|
|
||||||
# 获取用户详细信息
|
|
||||||
user_info = await person_api.get_person_values(
|
|
||||||
person_id,
|
|
||||||
["nickname", "platform", "user_id", "impression", "know_times"],
|
|
||||||
{}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"found": True,
|
|
||||||
"person_id": person_id,
|
|
||||||
"info": user_info
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 用户印象分析
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def analyze_user_relationship(chat_stream):
|
|
||||||
"""分析用户关系"""
|
|
||||||
|
|
||||||
user_id = chat_stream.user_info.user_id
|
|
||||||
platform = chat_stream.platform
|
|
||||||
person_id = person_api.get_person_id(platform, user_id)
|
|
||||||
|
|
||||||
# 获取关系相关信息
|
|
||||||
relationship_info = await person_api.get_person_values(
|
|
||||||
person_id,
|
|
||||||
["nickname", "impression", "know_times", "relationship_level", "last_interaction"],
|
|
||||||
{
|
|
||||||
"nickname": "未知",
|
|
||||||
"impression": "",
|
|
||||||
"know_times": 0,
|
|
||||||
"relationship_level": "stranger",
|
|
||||||
"last_interaction": 0
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分析关系程度
|
|
||||||
know_times = relationship_info["know_times"]
|
|
||||||
if know_times == 0:
|
|
||||||
relationship = "陌生人"
|
|
||||||
elif know_times < 5:
|
|
||||||
relationship = "新朋友"
|
|
||||||
elif know_times < 20:
|
|
||||||
relationship = "熟人"
|
|
||||||
else:
|
|
||||||
relationship = "老朋友"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"nickname": relationship_info["nickname"],
|
|
||||||
"relationship": relationship,
|
|
||||||
"impression": relationship_info["impression"],
|
|
||||||
"interaction_count": know_times
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 常用字段说明
|
## 常用字段说明
|
||||||
|
|
||||||
### 基础信息字段
|
### 基础信息字段
|
||||||
@@ -274,69 +107,13 @@ async def analyze_user_relationship(chat_stream):
|
|||||||
|
|
||||||
### 关系信息字段
|
### 关系信息字段
|
||||||
- `impression`:对用户的印象
|
- `impression`:对用户的印象
|
||||||
- `know_times`:交互次数
|
- `points`: 用户特征点
|
||||||
- `relationship_level`:关系等级
|
|
||||||
- `last_seen`:最后见面时间
|
|
||||||
- `last_interaction`:最后交互时间
|
|
||||||
|
|
||||||
### 个性化字段
|
其他字段可以参考`PersonInfo`类的属性(位于`src.common.database.database_model`)
|
||||||
- `preferences`:用户偏好
|
|
||||||
- `interests`:兴趣爱好
|
|
||||||
- `mood_history`:情绪历史
|
|
||||||
- `topic_interests`:话题兴趣
|
|
||||||
|
|
||||||
## 最佳实践
|
|
||||||
|
|
||||||
### 1. 错误处理
|
|
||||||
```python
|
|
||||||
async def safe_get_user_info(person_id: str, field: str):
|
|
||||||
"""安全获取用户信息"""
|
|
||||||
try:
|
|
||||||
value = await person_api.get_person_value(person_id, field)
|
|
||||||
return value if value is not None else "未设置"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取用户信息失败: {e}")
|
|
||||||
return "获取失败"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 批量操作
|
|
||||||
```python
|
|
||||||
async def get_complete_user_profile(person_id: str):
|
|
||||||
"""获取完整用户档案"""
|
|
||||||
|
|
||||||
# 一次性获取所有需要的字段
|
|
||||||
fields = [
|
|
||||||
"nickname", "impression", "know_times",
|
|
||||||
"preferences", "interests", "relationship_level"
|
|
||||||
]
|
|
||||||
|
|
||||||
defaults = {
|
|
||||||
"nickname": "用户",
|
|
||||||
"impression": "",
|
|
||||||
"know_times": 0,
|
|
||||||
"preferences": "{}",
|
|
||||||
"interests": "[]",
|
|
||||||
"relationship_level": "stranger"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = await person_api.get_person_values(person_id, fields, defaults)
|
|
||||||
|
|
||||||
# 处理JSON字段
|
|
||||||
try:
|
|
||||||
profile["preferences"] = json.loads(profile["preferences"])
|
|
||||||
profile["interests"] = json.loads(profile["interests"])
|
|
||||||
except:
|
|
||||||
profile["preferences"] = {}
|
|
||||||
profile["interests"] = []
|
|
||||||
|
|
||||||
return profile
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. **异步操作**:大部分查询函数都是异步的,需要使用`await`
|
1. **异步操作**:部分查询函数都是异步的,需要使用`await`
|
||||||
2. **错误处理**:所有函数都有错误处理,失败时记录日志并返回默认值
|
2. **性能考虑**:批量查询优于单个查询
|
||||||
3. **数据类型**:返回的数据可能是字符串、数字或JSON,需要适当处理
|
3. **隐私保护**:确保用户信息的使用符合隐私政策
|
||||||
4. **性能考虑**:批量查询优于单个查询
|
4. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用
|
||||||
5. **隐私保护**:确保用户信息的使用符合隐私政策
|
|
||||||
6. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用
|
|
||||||
105
docs/plugins/api/plugin-manage-api.md
Normal file
105
docs/plugins/api/plugin-manage-api.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# 插件管理API
|
||||||
|
|
||||||
|
插件管理API模块提供了对插件的加载、卸载、重新加载以及目录管理功能。
|
||||||
|
|
||||||
|
## 导入方式
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis import plugin_manage_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import plugin_manage_api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 功能概述
|
||||||
|
|
||||||
|
插件管理API主要提供以下功能:
|
||||||
|
- **插件查询** - 列出当前加载的插件或已注册的插件。
|
||||||
|
- **插件管理** - 加载、卸载、重新加载插件。
|
||||||
|
- **插件目录管理** - 添加插件目录并重新扫描。
|
||||||
|
|
||||||
|
## 主要功能
|
||||||
|
|
||||||
|
### 1. 列出当前加载的插件
|
||||||
|
```python
|
||||||
|
def list_loaded_plugins() -> List[str]:
|
||||||
|
```
|
||||||
|
列出所有当前加载的插件。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[str]` - 当前加载的插件名称列表。
|
||||||
|
|
||||||
|
### 2. 列出所有已注册的插件
|
||||||
|
```python
|
||||||
|
def list_registered_plugins() -> List[str]:
|
||||||
|
```
|
||||||
|
列出所有已注册的插件。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `List[str]` - 已注册的插件名称列表。
|
||||||
|
|
||||||
|
### 3. 获取插件路径
|
||||||
|
```python
|
||||||
|
def get_plugin_path(plugin_name: str) -> str:
|
||||||
|
```
|
||||||
|
获取指定插件的路径。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_name` (str): 要查询的插件名称。
|
||||||
|
**Returns:**
|
||||||
|
- `str` - 插件的路径,如果插件不存在则 raise ValueError。
|
||||||
|
|
||||||
|
### 4. 卸载指定的插件
|
||||||
|
```python
|
||||||
|
async def remove_plugin(plugin_name: str) -> bool:
|
||||||
|
```
|
||||||
|
卸载指定的插件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_name` (str): 要卸载的插件名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 卸载是否成功。
|
||||||
|
|
||||||
|
### 5. 重新加载指定的插件
|
||||||
|
```python
|
||||||
|
async def reload_plugin(plugin_name: str) -> bool:
|
||||||
|
```
|
||||||
|
重新加载指定的插件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_name` (str): 要重新加载的插件名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 重新加载是否成功。
|
||||||
|
|
||||||
|
### 6. 加载指定的插件
|
||||||
|
```python
|
||||||
|
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||||
|
```
|
||||||
|
加载指定的插件。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_name` (str): 要加载的插件名称。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Tuple[bool, int]` - 加载是否成功,成功或失败的个数。
|
||||||
|
|
||||||
|
### 7. 添加插件目录
|
||||||
|
```python
|
||||||
|
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||||
|
```
|
||||||
|
添加插件目录。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `plugin_directory` (str): 要添加的插件目录路径。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `bool` - 添加是否成功。
|
||||||
|
|
||||||
|
### 8. 重新扫描插件目录
|
||||||
|
```python
|
||||||
|
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||||
|
```
|
||||||
|
重新扫描插件目录,加载新插件。
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- `Tuple[int, int]` - 成功加载的插件数量和失败的插件数量。
|
||||||
@@ -6,86 +6,108 @@
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import send_api
|
||||||
```
|
```
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
### 1. 文本消息发送
|
### 1. 发送文本消息
|
||||||
|
```python
|
||||||
|
async def text_to_stream(
|
||||||
|
text: str,
|
||||||
|
stream_id: str,
|
||||||
|
typing: bool = False,
|
||||||
|
reply_to: str = "",
|
||||||
|
storage_message: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
```
|
||||||
|
发送文本消息到指定的流
|
||||||
|
|
||||||
#### `text_to_group(text, group_id, platform="qq", typing=False, reply_to="", storage_message=True)`
|
**Args:**
|
||||||
向群聊发送文本消息
|
- `text` (str): 要发送的文本内容
|
||||||
|
- `stream_id` (str): 聊天流ID
|
||||||
|
- `typing` (bool): 是否显示正在输入
|
||||||
|
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
|
||||||
|
- `storage_message` (bool): 是否存储消息到数据库
|
||||||
|
|
||||||
**参数:**
|
**Returns:**
|
||||||
- `text`:要发送的文本内容
|
- `bool` - 是否发送成功
|
||||||
- `group_id`:群聊ID
|
|
||||||
- `platform`:平台,默认为"qq"
|
|
||||||
- `typing`:是否显示正在输入
|
|
||||||
- `reply_to`:回复消息的格式,如"发送者:消息内容"
|
|
||||||
- `storage_message`:是否存储到数据库
|
|
||||||
|
|
||||||
**返回:**
|
### 2. 发送表情包
|
||||||
- `bool`:是否发送成功
|
```python
|
||||||
|
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||||
|
```
|
||||||
|
向指定流发送表情包。
|
||||||
|
|
||||||
#### `text_to_user(text, user_id, platform="qq", typing=False, reply_to="", storage_message=True)`
|
**Args:**
|
||||||
向用户发送私聊文本消息
|
- `emoji_base64` (str): 表情包的base64编码
|
||||||
|
- `stream_id` (str): 聊天流ID
|
||||||
|
- `storage_message` (bool): 是否存储消息到数据库
|
||||||
|
|
||||||
**参数与返回值同上**
|
**Returns:**
|
||||||
|
- `bool` - 是否发送成功
|
||||||
|
|
||||||
### 2. 表情包发送
|
### 3. 发送图片
|
||||||
|
```python
|
||||||
|
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||||
|
```
|
||||||
|
向指定流发送图片。
|
||||||
|
|
||||||
#### `emoji_to_group(emoji_base64, group_id, platform="qq", storage_message=True)`
|
**Args:**
|
||||||
向群聊发送表情包
|
- `image_base64` (str): 图片的base64编码
|
||||||
|
- `stream_id` (str): 聊天流ID
|
||||||
|
- `storage_message` (bool): 是否存储消息到数据库
|
||||||
|
|
||||||
**参数:**
|
**Returns:**
|
||||||
- `emoji_base64`:表情包的base64编码
|
- `bool` - 是否发送成功
|
||||||
- `group_id`:群聊ID
|
|
||||||
- `platform`:平台,默认为"qq"
|
|
||||||
- `storage_message`:是否存储到数据库
|
|
||||||
|
|
||||||
#### `emoji_to_user(emoji_base64, user_id, platform="qq", storage_message=True)`
|
### 4. 发送命令
|
||||||
向用户发送表情包
|
```python
|
||||||
|
async def command_to_stream(command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "") -> bool:
|
||||||
|
```
|
||||||
|
向指定流发送命令。
|
||||||
|
|
||||||
### 3. 图片发送
|
**Args:**
|
||||||
|
- `command` (Union[str, dict]): 命令内容
|
||||||
|
- `stream_id` (str): 聊天流ID
|
||||||
|
- `storage_message` (bool): 是否存储消息到数据库
|
||||||
|
- `display_message` (str): 显示消息
|
||||||
|
|
||||||
#### `image_to_group(image_base64, group_id, platform="qq", storage_message=True)`
|
**Returns:**
|
||||||
向群聊发送图片
|
- `bool` - 是否发送成功
|
||||||
|
|
||||||
#### `image_to_user(image_base64, user_id, platform="qq", storage_message=True)`
|
### 5. 发送自定义类型消息
|
||||||
向用户发送图片
|
```python
|
||||||
|
async def custom_to_stream(
|
||||||
|
message_type: str,
|
||||||
|
content: str,
|
||||||
|
stream_id: str,
|
||||||
|
display_message: str = "",
|
||||||
|
typing: bool = False,
|
||||||
|
reply_to: str = "",
|
||||||
|
storage_message: bool = True,
|
||||||
|
show_log: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
```
|
||||||
|
向指定流发送自定义类型消息。
|
||||||
|
|
||||||
### 4. 命令发送
|
**Args:**
|
||||||
|
- `message_type` (str): 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||||
|
- `content` (str): 消息内容(通常是base64编码或文本)
|
||||||
|
- `stream_id` (str): 聊天流ID
|
||||||
|
- `display_message` (str): 显示消息
|
||||||
|
- `typing` (bool): 是否显示正在输入
|
||||||
|
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
|
||||||
|
- `storage_message` (bool): 是否存储消息到数据库
|
||||||
|
- `show_log` (bool): 是否显示日志
|
||||||
|
|
||||||
#### `command_to_group(command, group_id, platform="qq", storage_message=True)`
|
**Returns:**
|
||||||
向群聊发送命令
|
- `bool` - 是否发送成功
|
||||||
|
|
||||||
#### `command_to_user(command, user_id, platform="qq", storage_message=True)`
|
|
||||||
向用户发送命令
|
|
||||||
|
|
||||||
### 5. 自定义消息发送
|
|
||||||
|
|
||||||
#### `custom_to_group(message_type, content, group_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
|
|
||||||
向群聊发送自定义类型消息
|
|
||||||
|
|
||||||
#### `custom_to_user(message_type, content, user_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
|
|
||||||
向用户发送自定义类型消息
|
|
||||||
|
|
||||||
#### `custom_message(message_type, content, target_id, is_group=True, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
|
|
||||||
通用的自定义消息发送
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `message_type`:消息类型,如"text"、"image"、"emoji"等
|
|
||||||
- `content`:消息内容
|
|
||||||
- `target_id`:目标ID(群ID或用户ID)
|
|
||||||
- `is_group`:是否为群聊
|
|
||||||
- `platform`:平台
|
|
||||||
- `display_message`:显示消息
|
|
||||||
- `typing`:是否显示正在输入
|
|
||||||
- `reply_to`:回复消息
|
|
||||||
- `storage_message`:是否存储
|
|
||||||
|
|
||||||
## 使用示例
|
## 使用示例
|
||||||
|
|
||||||
### 1. 基础文本发送
|
### 1. 基础文本发送,并回复消息
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
@@ -93,57 +115,23 @@ from src.plugin_system.apis import send_api
|
|||||||
async def send_hello(chat_stream):
|
async def send_hello(chat_stream):
|
||||||
"""发送问候消息"""
|
"""发送问候消息"""
|
||||||
|
|
||||||
if chat_stream.group_info:
|
success = await send_api.text_to_stream(
|
||||||
# 群聊
|
text="Hello, world!",
|
||||||
success = await send_api.text_to_group(
|
stream_id=chat_stream.stream_id,
|
||||||
text="大家好!",
|
typing=True,
|
||||||
group_id=chat_stream.group_info.group_id,
|
reply_to="User:How are you?",
|
||||||
typing=True
|
storage_message=True
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# 私聊
|
|
||||||
success = await send_api.text_to_user(
|
|
||||||
text="你好!",
|
|
||||||
user_id=chat_stream.user_info.user_id,
|
|
||||||
typing=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return success
|
return success
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 回复特定消息
|
### 2. 发送表情包
|
||||||
|
|
||||||
```python
|
|
||||||
async def reply_to_message(chat_stream, reply_text, original_sender, original_message):
|
|
||||||
"""回复特定消息"""
|
|
||||||
|
|
||||||
# 构建回复格式
|
|
||||||
reply_to = f"{original_sender}:{original_message}"
|
|
||||||
|
|
||||||
if chat_stream.group_info:
|
|
||||||
success = await send_api.text_to_group(
|
|
||||||
text=reply_text,
|
|
||||||
group_id=chat_stream.group_info.group_id,
|
|
||||||
reply_to=reply_to
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
success = await send_api.text_to_user(
|
|
||||||
text=reply_text,
|
|
||||||
user_id=chat_stream.user_info.user_id,
|
|
||||||
reply_to=reply_to
|
|
||||||
)
|
|
||||||
|
|
||||||
return success
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 发送表情包
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from src.plugin_system.apis import emoji_api
|
||||||
async def send_emoji_reaction(chat_stream, emotion):
|
async def send_emoji_reaction(chat_stream, emotion):
|
||||||
"""根据情感发送表情包"""
|
"""根据情感发送表情包"""
|
||||||
|
|
||||||
from src.plugin_system.apis import emoji_api
|
|
||||||
|
|
||||||
# 获取表情包
|
# 获取表情包
|
||||||
emoji_result = await emoji_api.get_by_emotion(emotion)
|
emoji_result = await emoji_api.get_by_emotion(emotion)
|
||||||
if not emoji_result:
|
if not emoji_result:
|
||||||
@@ -152,107 +140,10 @@ async def send_emoji_reaction(chat_stream, emotion):
|
|||||||
emoji_base64, description, matched_emotion = emoji_result
|
emoji_base64, description, matched_emotion = emoji_result
|
||||||
|
|
||||||
# 发送表情包
|
# 发送表情包
|
||||||
if chat_stream.group_info:
|
success = await send_api.emoji_to_stream(
|
||||||
success = await send_api.emoji_to_group(
|
emoji_base64=emoji_base64,
|
||||||
emoji_base64=emoji_base64,
|
stream_id=chat_stream.stream_id,
|
||||||
group_id=chat_stream.group_info.group_id
|
storage_message=False # 不存储到数据库
|
||||||
)
|
|
||||||
else:
|
|
||||||
success = await send_api.emoji_to_user(
|
|
||||||
emoji_base64=emoji_base64,
|
|
||||||
user_id=chat_stream.user_info.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return success
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 在Action中发送消息
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.base import BaseAction
|
|
||||||
|
|
||||||
class MessageAction(BaseAction):
|
|
||||||
async def execute(self, action_data, chat_stream):
|
|
||||||
message_type = action_data.get("type", "text")
|
|
||||||
content = action_data.get("content", "")
|
|
||||||
|
|
||||||
if message_type == "text":
|
|
||||||
success = await self.send_text(chat_stream, content)
|
|
||||||
elif message_type == "emoji":
|
|
||||||
success = await self.send_emoji(chat_stream, content)
|
|
||||||
elif message_type == "image":
|
|
||||||
success = await self.send_image(chat_stream, content)
|
|
||||||
else:
|
|
||||||
success = False
|
|
||||||
|
|
||||||
return {"success": success}
|
|
||||||
|
|
||||||
async def send_text(self, chat_stream, text):
|
|
||||||
if chat_stream.group_info:
|
|
||||||
return await send_api.text_to_group(text, chat_stream.group_info.group_id)
|
|
||||||
else:
|
|
||||||
return await send_api.text_to_user(text, chat_stream.user_info.user_id)
|
|
||||||
|
|
||||||
async def send_emoji(self, chat_stream, emoji_base64):
|
|
||||||
if chat_stream.group_info:
|
|
||||||
return await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
|
|
||||||
else:
|
|
||||||
return await send_api.emoji_to_user(emoji_base64, chat_stream.user_info.user_id)
|
|
||||||
|
|
||||||
async def send_image(self, chat_stream, image_base64):
|
|
||||||
if chat_stream.group_info:
|
|
||||||
return await send_api.image_to_group(image_base64, chat_stream.group_info.group_id)
|
|
||||||
else:
|
|
||||||
return await send_api.image_to_user(image_base64, chat_stream.user_info.user_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 批量发送消息
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def broadcast_message(message: str, target_groups: list):
|
|
||||||
"""向多个群组广播消息"""
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for group_id in target_groups:
|
|
||||||
try:
|
|
||||||
success = await send_api.text_to_group(
|
|
||||||
text=message,
|
|
||||||
group_id=group_id,
|
|
||||||
typing=True
|
|
||||||
)
|
|
||||||
results[group_id] = success
|
|
||||||
except Exception as e:
|
|
||||||
results[group_id] = False
|
|
||||||
print(f"发送到群 {group_id} 失败: {e}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. 智能消息发送
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def smart_send(chat_stream, message_data):
|
|
||||||
"""智能发送不同类型的消息"""
|
|
||||||
|
|
||||||
message_type = message_data.get("type", "text")
|
|
||||||
content = message_data.get("content", "")
|
|
||||||
options = message_data.get("options", {})
|
|
||||||
|
|
||||||
# 根据聊天流类型选择发送方法
|
|
||||||
target_id = (chat_stream.group_info.group_id if chat_stream.group_info
|
|
||||||
else chat_stream.user_info.user_id)
|
|
||||||
is_group = chat_stream.group_info is not None
|
|
||||||
|
|
||||||
# 使用通用发送方法
|
|
||||||
success = await send_api.custom_message(
|
|
||||||
message_type=message_type,
|
|
||||||
content=content,
|
|
||||||
target_id=target_id,
|
|
||||||
is_group=is_group,
|
|
||||||
typing=options.get("typing", False),
|
|
||||||
reply_to=options.get("reply_to", ""),
|
|
||||||
display_message=options.get("display_message", "")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
@@ -273,90 +164,6 @@ async def smart_send(chat_stream, message_data):
|
|||||||
|
|
||||||
系统会自动查找匹配的原始消息并进行回复。
|
系统会自动查找匹配的原始消息并进行回复。
|
||||||
|
|
||||||
## 高级用法
|
|
||||||
|
|
||||||
### 1. 消息发送队列
|
|
||||||
|
|
||||||
```python
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
class MessageQueue:
|
|
||||||
def __init__(self):
|
|
||||||
self.queue = asyncio.Queue()
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
async def add_message(self, chat_stream, message_type, content, options=None):
|
|
||||||
"""添加消息到队列"""
|
|
||||||
message_item = {
|
|
||||||
"chat_stream": chat_stream,
|
|
||||||
"type": message_type,
|
|
||||||
"content": content,
|
|
||||||
"options": options or {}
|
|
||||||
}
|
|
||||||
await self.queue.put(message_item)
|
|
||||||
|
|
||||||
async def process_queue(self):
|
|
||||||
"""处理消息队列"""
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
message_item = await asyncio.wait_for(self.queue.get(), timeout=1.0)
|
|
||||||
|
|
||||||
# 发送消息
|
|
||||||
success = await smart_send(
|
|
||||||
message_item["chat_stream"],
|
|
||||||
{
|
|
||||||
"type": message_item["type"],
|
|
||||||
"content": message_item["content"],
|
|
||||||
"options": message_item["options"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 标记任务完成
|
|
||||||
self.queue.task_done()
|
|
||||||
|
|
||||||
# 发送间隔
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理消息队列出错: {e}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 消息模板系统
|
|
||||||
|
|
||||||
```python
|
|
||||||
class MessageTemplate:
|
|
||||||
def __init__(self):
|
|
||||||
self.templates = {
|
|
||||||
"welcome": "欢迎 {nickname} 加入群聊!",
|
|
||||||
"goodbye": "{nickname} 离开了群聊。",
|
|
||||||
"notification": "🔔 通知:{message}",
|
|
||||||
"error": "❌ 错误:{error_message}",
|
|
||||||
"success": "✅ 成功:{message}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def format_message(self, template_name: str, **kwargs) -> str:
|
|
||||||
"""格式化消息模板"""
|
|
||||||
template = self.templates.get(template_name, "{message}")
|
|
||||||
return template.format(**kwargs)
|
|
||||||
|
|
||||||
async def send_template(self, chat_stream, template_name: str, **kwargs):
|
|
||||||
"""发送模板消息"""
|
|
||||||
message = self.format_message(template_name, **kwargs)
|
|
||||||
|
|
||||||
if chat_stream.group_info:
|
|
||||||
return await send_api.text_to_group(message, chat_stream.group_info.group_id)
|
|
||||||
else:
|
|
||||||
return await send_api.text_to_user(message, chat_stream.user_info.user_id)
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
template_system = MessageTemplate()
|
|
||||||
await template_system.send_template(chat_stream, "welcome", nickname="张三")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. **异步操作**:所有发送函数都是异步的,必须使用`await`
|
1. **异步操作**:所有发送函数都是异步的,必须使用`await`
|
||||||
|
|||||||
55
docs/plugins/api/tool-api.md
Normal file
55
docs/plugins/api/tool-api.md
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# 工具API
|
||||||
|
|
||||||
|
工具API模块提供了获取和管理工具实例的功能,让插件能够访问系统中注册的工具。
|
||||||
|
|
||||||
|
## 导入方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.apis import tool_api
|
||||||
|
# 或者
|
||||||
|
from src.plugin_system import tool_api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要功能
|
||||||
|
|
||||||
|
### 1. 获取工具实例
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||||
|
```
|
||||||
|
|
||||||
|
获取指定名称的工具实例。
|
||||||
|
|
||||||
|
**Args**:
|
||||||
|
- `tool_name`: 工具名称字符串
|
||||||
|
|
||||||
|
**Returns**:
|
||||||
|
- `Optional[BaseTool]`: 工具实例,如果工具不存在则返回 None
|
||||||
|
|
||||||
|
### 2. 获取LLM可用的工具定义
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_llm_available_tool_definitions():
|
||||||
|
```
|
||||||
|
|
||||||
|
获取所有LLM可用的工具定义列表。
|
||||||
|
|
||||||
|
**Returns**:
|
||||||
|
- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
|
||||||
|
- 其具体定义请参照[tool-components.md](../tool-components.md#属性说明)中的工具定义格式。
|
||||||
|
#### 示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 获取所有LLM可用的工具定义
|
||||||
|
tools = tool_api.get_llm_available_tool_definitions()
|
||||||
|
for tool_name, tool_definition in tools:
|
||||||
|
print(f"工具: {tool_name}")
|
||||||
|
print(f"定义: {tool_definition}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **工具存在性检查**:使用前请检查工具实例是否为 None
|
||||||
|
2. **权限控制**:某些工具可能有使用权限限制
|
||||||
|
3. **异步调用**:大多数工具方法是异步的,需要使用 await
|
||||||
|
4. **错误处理**:调用工具时请做好异常处理
|
||||||
@@ -1,435 +0,0 @@
|
|||||||
# 工具API
|
|
||||||
|
|
||||||
工具API模块提供了各种辅助功能,包括文件操作、时间处理、唯一ID生成等常用工具函数。
|
|
||||||
|
|
||||||
## 导入方式
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.apis import utils_api
|
|
||||||
```
|
|
||||||
|
|
||||||
## 主要功能
|
|
||||||
|
|
||||||
### 1. 文件操作
|
|
||||||
|
|
||||||
#### `get_plugin_path(caller_frame=None) -> str`
|
|
||||||
获取调用者插件的路径
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `caller_frame`:调用者的栈帧,默认为None(自动获取)
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `str`:插件目录的绝对路径
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
plugin_path = utils_api.get_plugin_path()
|
|
||||||
print(f"插件路径: {plugin_path}")
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `read_json_file(file_path: str, default: Any = None) -> Any`
|
|
||||||
读取JSON文件
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `file_path`:文件路径,可以是相对于插件目录的路径
|
|
||||||
- `default`:如果文件不存在或读取失败时返回的默认值
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `Any`:JSON数据或默认值
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
# 读取插件配置文件
|
|
||||||
config = utils_api.read_json_file("config.json", {})
|
|
||||||
settings = utils_api.read_json_file("data/settings.json", {"enabled": True})
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `write_json_file(file_path: str, data: Any, indent: int = 2) -> bool`
|
|
||||||
写入JSON文件
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `file_path`:文件路径,可以是相对于插件目录的路径
|
|
||||||
- `data`:要写入的数据
|
|
||||||
- `indent`:JSON缩进
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `bool`:是否写入成功
|
|
||||||
|
|
||||||
**示例:**
|
|
||||||
```python
|
|
||||||
data = {"name": "test", "value": 123}
|
|
||||||
success = utils_api.write_json_file("output.json", data)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 时间相关
|
|
||||||
|
|
||||||
#### `get_timestamp() -> int`
|
|
||||||
获取当前时间戳
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `int`:当前时间戳(秒)
|
|
||||||
|
|
||||||
#### `format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str`
|
|
||||||
格式化时间
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `timestamp`:时间戳,如果为None则使用当前时间
|
|
||||||
- `format_str`:时间格式字符串
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `str`:格式化后的时间字符串
|
|
||||||
|
|
||||||
#### `parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int`
|
|
||||||
解析时间字符串为时间戳
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
- `time_str`:时间字符串
|
|
||||||
- `format_str`:时间格式字符串
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `int`:时间戳(秒)
|
|
||||||
|
|
||||||
### 3. 其他工具
|
|
||||||
|
|
||||||
#### `generate_unique_id() -> str`
|
|
||||||
生成唯一ID
|
|
||||||
|
|
||||||
**返回:**
|
|
||||||
- `str`:唯一ID
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
### 1. 插件数据管理
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.plugin_system.apis import utils_api
|
|
||||||
|
|
||||||
class DataPlugin(BasePlugin):
|
|
||||||
def __init__(self):
|
|
||||||
self.plugin_path = utils_api.get_plugin_path()
|
|
||||||
self.data_file = "plugin_data.json"
|
|
||||||
self.load_data()
|
|
||||||
|
|
||||||
def load_data(self):
|
|
||||||
"""加载插件数据"""
|
|
||||||
default_data = {
|
|
||||||
"users": {},
|
|
||||||
"settings": {"enabled": True},
|
|
||||||
"stats": {"message_count": 0}
|
|
||||||
}
|
|
||||||
self.data = utils_api.read_json_file(self.data_file, default_data)
|
|
||||||
|
|
||||||
def save_data(self):
|
|
||||||
"""保存插件数据"""
|
|
||||||
return utils_api.write_json_file(self.data_file, self.data)
|
|
||||||
|
|
||||||
async def handle_action(self, action_data, chat_stream):
|
|
||||||
# 更新统计信息
|
|
||||||
self.data["stats"]["message_count"] += 1
|
|
||||||
self.data["stats"]["last_update"] = utils_api.get_timestamp()
|
|
||||||
|
|
||||||
# 保存数据
|
|
||||||
if self.save_data():
|
|
||||||
return {"success": True, "message": "数据已保存"}
|
|
||||||
else:
|
|
||||||
return {"success": False, "message": "数据保存失败"}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 日志记录系统
|
|
||||||
|
|
||||||
```python
|
|
||||||
class PluginLogger:
|
|
||||||
def __init__(self, plugin_name: str):
|
|
||||||
self.plugin_name = plugin_name
|
|
||||||
self.log_file = f"{plugin_name}_log.json"
|
|
||||||
self.logs = utils_api.read_json_file(self.log_file, [])
|
|
||||||
|
|
||||||
def log_event(self, event_type: str, message: str, data: dict = None):
|
|
||||||
"""记录事件"""
|
|
||||||
log_entry = {
|
|
||||||
"id": utils_api.generate_unique_id(),
|
|
||||||
"timestamp": utils_api.get_timestamp(),
|
|
||||||
"formatted_time": utils_api.format_time(),
|
|
||||||
"event_type": event_type,
|
|
||||||
"message": message,
|
|
||||||
"data": data or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logs.append(log_entry)
|
|
||||||
|
|
||||||
# 保持最新的100条记录
|
|
||||||
if len(self.logs) > 100:
|
|
||||||
self.logs = self.logs[-100:]
|
|
||||||
|
|
||||||
# 保存到文件
|
|
||||||
utils_api.write_json_file(self.log_file, self.logs)
|
|
||||||
|
|
||||||
def get_logs_by_type(self, event_type: str) -> list:
|
|
||||||
"""获取指定类型的日志"""
|
|
||||||
return [log for log in self.logs if log["event_type"] == event_type]
|
|
||||||
|
|
||||||
def get_recent_logs(self, count: int = 10) -> list:
|
|
||||||
"""获取最近的日志"""
|
|
||||||
return self.logs[-count:]
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
logger = PluginLogger("my_plugin")
|
|
||||||
logger.log_event("user_action", "用户发送了消息", {"user_id": "123", "message": "hello"})
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 配置管理系统
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ConfigManager:
|
|
||||||
def __init__(self, config_file: str = "plugin_config.json"):
|
|
||||||
self.config_file = config_file
|
|
||||||
self.default_config = {
|
|
||||||
"enabled": True,
|
|
||||||
"debug": False,
|
|
||||||
"max_users": 100,
|
|
||||||
"response_delay": 1.0,
|
|
||||||
"features": {
|
|
||||||
"auto_reply": True,
|
|
||||||
"logging": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.config = self.load_config()
|
|
||||||
|
|
||||||
def load_config(self) -> dict:
|
|
||||||
"""加载配置"""
|
|
||||||
return utils_api.read_json_file(self.config_file, self.default_config)
|
|
||||||
|
|
||||||
def save_config(self) -> bool:
|
|
||||||
"""保存配置"""
|
|
||||||
return utils_api.write_json_file(self.config_file, self.config, indent=4)
|
|
||||||
|
|
||||||
def get(self, key: str, default=None):
|
|
||||||
"""获取配置值,支持嵌套访问"""
|
|
||||||
keys = key.split('.')
|
|
||||||
value = self.config
|
|
||||||
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(value, dict) and k in value:
|
|
||||||
value = value[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def set(self, key: str, value):
|
|
||||||
"""设置配置值,支持嵌套设置"""
|
|
||||||
keys = key.split('.')
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
for k in keys[:-1]:
|
|
||||||
if k not in config:
|
|
||||||
config[k] = {}
|
|
||||||
config = config[k]
|
|
||||||
|
|
||||||
config[keys[-1]] = value
|
|
||||||
|
|
||||||
def update_config(self, updates: dict):
|
|
||||||
"""批量更新配置"""
|
|
||||||
def deep_update(base, updates):
|
|
||||||
for key, value in updates.items():
|
|
||||||
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
|
|
||||||
deep_update(base[key], value)
|
|
||||||
else:
|
|
||||||
base[key] = value
|
|
||||||
|
|
||||||
deep_update(self.config, updates)
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
config = ConfigManager()
|
|
||||||
print(f"调试模式: {config.get('debug', False)}")
|
|
||||||
print(f"自动回复: {config.get('features.auto_reply', True)}")
|
|
||||||
|
|
||||||
config.set('features.new_feature', True)
|
|
||||||
config.save_config()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 缓存系统
|
|
||||||
|
|
||||||
```python
|
|
||||||
class PluginCache:
|
|
||||||
def __init__(self, cache_file: str = "plugin_cache.json", ttl: int = 3600):
|
|
||||||
self.cache_file = cache_file
|
|
||||||
self.ttl = ttl # 缓存过期时间(秒)
|
|
||||||
self.cache = self.load_cache()
|
|
||||||
|
|
||||||
def load_cache(self) -> dict:
|
|
||||||
"""加载缓存"""
|
|
||||||
return utils_api.read_json_file(self.cache_file, {})
|
|
||||||
|
|
||||||
def save_cache(self):
|
|
||||||
"""保存缓存"""
|
|
||||||
return utils_api.write_json_file(self.cache_file, self.cache)
|
|
||||||
|
|
||||||
def get(self, key: str):
|
|
||||||
"""获取缓存值"""
|
|
||||||
if key not in self.cache:
|
|
||||||
return None
|
|
||||||
|
|
||||||
item = self.cache[key]
|
|
||||||
current_time = utils_api.get_timestamp()
|
|
||||||
|
|
||||||
# 检查是否过期
|
|
||||||
if current_time - item["timestamp"] > self.ttl:
|
|
||||||
del self.cache[key]
|
|
||||||
return None
|
|
||||||
|
|
||||||
return item["value"]
|
|
||||||
|
|
||||||
def set(self, key: str, value):
|
|
||||||
"""设置缓存值"""
|
|
||||||
self.cache[key] = {
|
|
||||||
"value": value,
|
|
||||||
"timestamp": utils_api.get_timestamp()
|
|
||||||
}
|
|
||||||
self.save_cache()
|
|
||||||
|
|
||||||
def clear_expired(self):
|
|
||||||
"""清理过期缓存"""
|
|
||||||
current_time = utils_api.get_timestamp()
|
|
||||||
expired_keys = []
|
|
||||||
|
|
||||||
for key, item in self.cache.items():
|
|
||||||
if current_time - item["timestamp"] > self.ttl:
|
|
||||||
expired_keys.append(key)
|
|
||||||
|
|
||||||
for key in expired_keys:
|
|
||||||
del self.cache[key]
|
|
||||||
|
|
||||||
if expired_keys:
|
|
||||||
self.save_cache()
|
|
||||||
|
|
||||||
return len(expired_keys)
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
cache = PluginCache(ttl=1800) # 30分钟过期
|
|
||||||
cache.set("user_data_123", {"name": "张三", "score": 100})
|
|
||||||
user_data = cache.get("user_data_123")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 时间处理工具
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TimeHelper:
|
|
||||||
@staticmethod
|
|
||||||
def get_time_info():
|
|
||||||
"""获取当前时间的详细信息"""
|
|
||||||
timestamp = utils_api.get_timestamp()
|
|
||||||
return {
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"datetime": utils_api.format_time(timestamp),
|
|
||||||
"date": utils_api.format_time(timestamp, "%Y-%m-%d"),
|
|
||||||
"time": utils_api.format_time(timestamp, "%H:%M:%S"),
|
|
||||||
"year": utils_api.format_time(timestamp, "%Y"),
|
|
||||||
"month": utils_api.format_time(timestamp, "%m"),
|
|
||||||
"day": utils_api.format_time(timestamp, "%d"),
|
|
||||||
"weekday": utils_api.format_time(timestamp, "%A")
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def time_ago(timestamp: int) -> str:
|
|
||||||
"""计算时间差"""
|
|
||||||
current = utils_api.get_timestamp()
|
|
||||||
diff = current - timestamp
|
|
||||||
|
|
||||||
if diff < 60:
|
|
||||||
return f"{diff}秒前"
|
|
||||||
elif diff < 3600:
|
|
||||||
return f"{diff // 60}分钟前"
|
|
||||||
elif diff < 86400:
|
|
||||||
return f"{diff // 3600}小时前"
|
|
||||||
else:
|
|
||||||
return f"{diff // 86400}天前"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_duration(duration_str: str) -> int:
|
|
||||||
"""解析时间段字符串,返回秒数"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
pattern = r'(\d+)([smhd])'
|
|
||||||
matches = re.findall(pattern, duration_str.lower())
|
|
||||||
|
|
||||||
total_seconds = 0
|
|
||||||
for value, unit in matches:
|
|
||||||
value = int(value)
|
|
||||||
if unit == 's':
|
|
||||||
total_seconds += value
|
|
||||||
elif unit == 'm':
|
|
||||||
total_seconds += value * 60
|
|
||||||
elif unit == 'h':
|
|
||||||
total_seconds += value * 3600
|
|
||||||
elif unit == 'd':
|
|
||||||
total_seconds += value * 86400
|
|
||||||
|
|
||||||
return total_seconds
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
time_info = TimeHelper.get_time_info()
|
|
||||||
print(f"当前时间: {time_info['datetime']}")
|
|
||||||
|
|
||||||
last_seen = 1699000000
|
|
||||||
print(f"最后见面: {TimeHelper.time_ago(last_seen)}")
|
|
||||||
|
|
||||||
duration = TimeHelper.parse_duration("1h30m") # 1小时30分钟 = 5400秒
|
|
||||||
```
|
|
||||||
|
|
||||||
## 最佳实践
|
|
||||||
|
|
||||||
### 1. 错误处理
|
|
||||||
```python
|
|
||||||
def safe_file_operation(file_path: str, data: dict):
|
|
||||||
"""安全的文件操作"""
|
|
||||||
try:
|
|
||||||
success = utils_api.write_json_file(file_path, data)
|
|
||||||
if not success:
|
|
||||||
logger.warning(f"文件写入失败: {file_path}")
|
|
||||||
return success
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文件操作出错: {e}")
|
|
||||||
return False
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 路径处理
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
|
|
||||||
def get_data_path(filename: str) -> str:
|
|
||||||
"""获取数据文件的完整路径"""
|
|
||||||
plugin_path = utils_api.get_plugin_path()
|
|
||||||
data_dir = os.path.join(plugin_path, "data")
|
|
||||||
|
|
||||||
# 确保数据目录存在
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
return os.path.join(data_dir, filename)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 定期清理
|
|
||||||
```python
|
|
||||||
async def cleanup_old_files():
|
|
||||||
"""清理旧文件"""
|
|
||||||
plugin_path = utils_api.get_plugin_path()
|
|
||||||
current_time = utils_api.get_timestamp()
|
|
||||||
|
|
||||||
for filename in os.listdir(plugin_path):
|
|
||||||
if filename.endswith('.tmp'):
|
|
||||||
file_path = os.path.join(plugin_path, filename)
|
|
||||||
file_time = os.path.getmtime(file_path)
|
|
||||||
|
|
||||||
# 删除超过24小时的临时文件
|
|
||||||
if current_time - file_time > 86400:
|
|
||||||
os.remove(file_path)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **相对路径**:文件路径支持相对于插件目录的路径
|
|
||||||
2. **自动创建目录**:写入文件时会自动创建必要的目录
|
|
||||||
3. **错误处理**:所有函数都有错误处理,失败时返回默认值
|
|
||||||
4. **编码格式**:文件读写使用UTF-8编码
|
|
||||||
5. **时间格式**:时间戳使用秒为单位
|
|
||||||
6. **JSON格式**:JSON文件使用可读性好的缩进格式
|
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
|
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
|
||||||
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
|
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
|
||||||
|
- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
|
||||||
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
|
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
|
||||||
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
|
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
|
||||||
|
|
||||||
@@ -43,24 +44,24 @@ Command vs Action 选择指南
|
|||||||
- [LLM API](api/llm-api.md) - 大语言模型交互接口,可以使用内置LLM生成内容
|
- [LLM API](api/llm-api.md) - 大语言模型交互接口,可以使用内置LLM生成内容
|
||||||
- [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器
|
- [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器
|
||||||
|
|
||||||
### 表情包api
|
### 表情包API
|
||||||
- [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口
|
- [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口
|
||||||
|
|
||||||
### 关系系统api
|
### 关系系统API
|
||||||
- [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口
|
- [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口
|
||||||
|
|
||||||
### 数据与配置API
|
### 数据与配置API
|
||||||
- [🗄️ 数据库API](api/database-api.md) - 数据库操作接口
|
- [🗄️ 数据库API](api/database-api.md) - 数据库操作接口
|
||||||
- [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口
|
- [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口
|
||||||
|
|
||||||
|
### 插件和组件管理API
|
||||||
|
- [🔌 插件API](api/plugin-manage-api.md) - 插件加载和管理接口
|
||||||
|
- [🧩 组件API](api/component-manage-api.md) - 组件注册和管理接口
|
||||||
|
|
||||||
|
### 日志API
|
||||||
|
- [📜 日志API](api/logging-api.md) - logger实例获取接口
|
||||||
### 工具API
|
### 工具API
|
||||||
- [工具API](api/utils-api.md) - 文件操作、时间处理等工具函数
|
- [🔧 工具API](api/tool-api.md) - tool获取接口
|
||||||
|
|
||||||
|
|
||||||
## 实验性
|
|
||||||
|
|
||||||
这些功能将在未来重构或移除
|
|
||||||
- [🔧 工具系统详解](tool-system.md) - 工具系统的使用和开发
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# 🔧 工具系统详解
|
# 🔧 工具组件详解
|
||||||
|
|
||||||
## 📖 什么是工具系统
|
## 📖 什么是工具
|
||||||
|
|
||||||
工具系统是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
|
工具是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
|
||||||
|
|
||||||
### 🎯 工具系统的特点
|
### 🎯 工具的特点
|
||||||
|
|
||||||
- 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力
|
- 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力
|
||||||
- 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据
|
- 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据
|
||||||
@@ -20,14 +20,11 @@
|
|||||||
| **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 |
|
| **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 |
|
||||||
| **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 |
|
| **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 |
|
||||||
|
|
||||||
## 🏗️ 工具基本结构
|
## 🏗️ Tool组件的基本结构
|
||||||
|
|
||||||
### 必要组件
|
|
||||||
|
|
||||||
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
|
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
|
|
||||||
class MyTool(BaseTool):
|
class MyTool(BaseTool):
|
||||||
# 工具名称,必须唯一
|
# 工具名称,必须唯一
|
||||||
@@ -36,21 +33,29 @@ class MyTool(BaseTool):
|
|||||||
# 工具描述,告诉LLM这个工具的用途
|
# 工具描述,告诉LLM这个工具的用途
|
||||||
description = "这个工具用于获取特定类型的信息"
|
description = "这个工具用于获取特定类型的信息"
|
||||||
|
|
||||||
# 参数定义,遵循JSONSchema格式
|
# 参数定义,仅定义参数
|
||||||
parameters = {
|
# 比如想要定义一个类似下面的openai格式的参数表,则可以这么定义:
|
||||||
"type": "object",
|
# {
|
||||||
"properties": {
|
# "type": "object",
|
||||||
"query": {
|
# "properties": {
|
||||||
"type": "string",
|
# "query": {
|
||||||
"description": "查询参数"
|
# "type": "string",
|
||||||
},
|
# "description": "查询参数"
|
||||||
"limit": {
|
# },
|
||||||
"type": "integer",
|
# "limit": {
|
||||||
"description": "结果数量限制"
|
# "type": "integer",
|
||||||
}
|
# "description": "结果数量限制"
|
||||||
},
|
# "enum": [10, 20, 50] # 可选值
|
||||||
"required": ["query"]
|
# }
|
||||||
}
|
# },
|
||||||
|
# "required": ["query"]
|
||||||
|
# }
|
||||||
|
parameters = [
|
||||||
|
("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数
|
||||||
|
("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数
|
||||||
|
]
|
||||||
|
|
||||||
|
available_for_llm = True # 是否对LLM可用
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]):
|
async def execute(self, function_args: Dict[str, Any]):
|
||||||
"""执行工具逻辑"""
|
"""执行工具逻辑"""
|
||||||
@@ -69,7 +74,12 @@ class MyTool(BaseTool):
|
|||||||
|-----|------|------|
|
|-----|------|------|
|
||||||
| `name` | str | 工具的唯一标识名称 |
|
| `name` | str | 工具的唯一标识名称 |
|
||||||
| `description` | str | 工具功能描述,帮助LLM理解用途 |
|
| `description` | str | 工具功能描述,帮助LLM理解用途 |
|
||||||
| `parameters` | dict | JSONSchema格式的参数定义 |
|
| `parameters` | list[tuple] | 参数定义 |
|
||||||
|
|
||||||
|
其构造而成的工具定义为:
|
||||||
|
```python
|
||||||
|
definition: Dict[str, Any] = {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||||
|
```
|
||||||
|
|
||||||
### 方法说明
|
### 方法说明
|
||||||
|
|
||||||
@@ -77,15 +87,6 @@ class MyTool(BaseTool):
|
|||||||
|-----|------|--------|------|
|
|-----|------|--------|------|
|
||||||
| `execute` | `function_args` | `dict` | 执行工具核心逻辑 |
|
| `execute` | `function_args` | `dict` | 执行工具核心逻辑 |
|
||||||
|
|
||||||
## 🔄 自动注册机制
|
|
||||||
|
|
||||||
工具系统采用自动发现和注册机制:
|
|
||||||
|
|
||||||
1. **文件扫描**:系统自动遍历 `tool_can_use` 目录中的所有Python文件
|
|
||||||
2. **类识别**:寻找继承自 `BaseTool` 的工具类
|
|
||||||
3. **自动注册**:只需要实现对应的类并把文件放在正确文件夹中就可自动注册
|
|
||||||
4. **即用即加载**:工具在需要时被实例化和调用
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🎨 完整工具示例
|
## 🎨 完整工具示例
|
||||||
@@ -93,7 +94,7 @@ class MyTool(BaseTool):
|
|||||||
完成一个天气查询工具
|
完成一个天气查询工具
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
from src.plugin_system import BaseTool
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -102,23 +103,13 @@ class WeatherTool(BaseTool):
|
|||||||
|
|
||||||
name = "weather_query"
|
name = "weather_query"
|
||||||
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
|
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
|
||||||
|
available_for_llm = True # 允许LLM调用此工具
|
||||||
|
parameters = [
|
||||||
|
("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None),
|
||||||
|
("country", ToolParamType.STRING, "国家代码,如:CN、US,可选参数", False, None)
|
||||||
|
]
|
||||||
|
|
||||||
parameters = {
|
async def execute(self, function_args: dict):
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要查询天气的城市名称,如:北京、上海、纽约"
|
|
||||||
},
|
|
||||||
"country": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "国家代码,如:CN、US,可选参数"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args, message_txt=""):
|
|
||||||
"""执行天气查询"""
|
"""执行天气查询"""
|
||||||
try:
|
try:
|
||||||
city = function_args.get("city")
|
city = function_args.get("city")
|
||||||
@@ -177,55 +168,12 @@ class WeatherTool(BaseTool):
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 📊 工具开发步骤
|
|
||||||
|
|
||||||
### 1. 创建工具文件
|
|
||||||
|
|
||||||
在 `src/tools/tool_can_use/` 目录下创建新的Python文件:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 例如创建 my_new_tool.py
|
|
||||||
touch src/tools/tool_can_use/my_new_tool.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 实现工具类
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
|
||||||
|
|
||||||
class MyNewTool(BaseTool):
|
|
||||||
name = "my_new_tool"
|
|
||||||
description = "新工具的功能描述"
|
|
||||||
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
# 定义参数
|
|
||||||
},
|
|
||||||
"required": []
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args, message_txt=""):
|
|
||||||
# 实现工具逻辑
|
|
||||||
return {
|
|
||||||
"name": self.name,
|
|
||||||
"content": "执行结果"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 系统集成
|
|
||||||
|
|
||||||
工具创建完成后,系统会自动发现和注册,无需额外配置。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🚨 注意事项和限制
|
## 🚨 注意事项和限制
|
||||||
|
|
||||||
### 当前限制
|
### 当前限制
|
||||||
|
|
||||||
1. **独立开发**:需要单独编写,暂未完全融入插件系统
|
1. **适用范围**:主要适用于信息获取场景
|
||||||
2. **适用范围**:主要适用于信息获取场景
|
2. **配置要求**:必须开启工具处理器
|
||||||
3. **配置要求**:必须开启工具处理器
|
|
||||||
|
|
||||||
### 开发建议
|
### 开发建议
|
||||||
|
|
||||||
@@ -238,66 +186,49 @@ class MyNewTool(BaseTool):
|
|||||||
## 🎯 最佳实践
|
## 🎯 最佳实践
|
||||||
|
|
||||||
### 1. 工具命名规范
|
### 1. 工具命名规范
|
||||||
|
#### ✅ 好的命名
|
||||||
```python
|
```python
|
||||||
# ✅ 好的命名
|
|
||||||
name = "weather_query" # 清晰表达功能
|
name = "weather_query" # 清晰表达功能
|
||||||
name = "knowledge_search" # 描述性强
|
name = "knowledge_search" # 描述性强
|
||||||
name = "stock_price_check" # 功能明确
|
name = "stock_price_check" # 功能明确
|
||||||
|
```
|
||||||
# ❌ 避免的命名
|
#### ❌ 避免的命名
|
||||||
|
```python
|
||||||
name = "tool1" # 无意义
|
name = "tool1" # 无意义
|
||||||
name = "wq" # 过于简短
|
name = "wq" # 过于简短
|
||||||
name = "weather_and_news" # 功能过于复杂
|
name = "weather_and_news" # 功能过于复杂
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 描述规范
|
### 2. 描述规范
|
||||||
|
#### ✅ 良好的描述
|
||||||
```python
|
```python
|
||||||
# ✅ 好的描述
|
|
||||||
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况"
|
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况"
|
||||||
|
```
|
||||||
# ❌ 避免的描述
|
#### ❌ 避免的描述
|
||||||
|
```python
|
||||||
description = "天气" # 过于简单
|
description = "天气" # 过于简单
|
||||||
description = "获取信息" # 不够具体
|
description = "获取信息" # 不够具体
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 参数设计
|
### 3. 参数设计
|
||||||
|
|
||||||
|
#### ✅ 合理的参数设计
|
||||||
```python
|
```python
|
||||||
# ✅ 合理的参数设计
|
parameters = [
|
||||||
parameters = {
|
("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None),
|
||||||
"type": "object",
|
("unit", ToolParamType.STRING, "温度单位:celsius 或 fahrenheit", False, ["celsius", "fahrenheit"])
|
||||||
"properties": {
|
]
|
||||||
"city": {
|
```
|
||||||
"type": "string",
|
#### ❌ 避免的参数设计
|
||||||
"description": "城市名称,如:北京、上海"
|
```python
|
||||||
},
|
parameters = [
|
||||||
"unit": {
|
("data", "string", "数据", True) # 参数过于模糊
|
||||||
"type": "string",
|
]
|
||||||
"description": "温度单位:celsius(摄氏度) 或 fahrenheit(华氏度)",
|
|
||||||
"enum": ["celsius", "fahrenheit"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# ❌ 避免的参数设计
|
|
||||||
parameters = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "数据" # 描述不清晰
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. 结果格式化
|
### 4. 结果格式化
|
||||||
|
#### ✅ 良好的结果格式
|
||||||
```python
|
```python
|
||||||
# ✅ 良好的结果格式
|
|
||||||
def _format_result(self, data):
|
def _format_result(self, data):
|
||||||
return f"""
|
return f"""
|
||||||
🔍 查询结果
|
🔍 查询结果
|
||||||
@@ -307,12 +238,9 @@ def _format_result(self, data):
|
|||||||
📝 说明: {data['description']}
|
📝 说明: {data['description']}
|
||||||
━━━━━━━━━━━━
|
━━━━━━━━━━━━
|
||||||
""".strip()
|
""".strip()
|
||||||
|
```
|
||||||
# ❌ 避免的结果格式
|
#### ❌ 避免的结果格式
|
||||||
|
```python
|
||||||
def _format_result(self, data):
|
def _format_result(self, data):
|
||||||
return str(data) # 直接返回原始数据
|
return str(data) # 直接返回原始数据
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
🎉 **工具系统为麦麦提供了强大的信息获取能力!合理使用工具可以让麦麦变得更加智能和博学。**
|
|
||||||
@@ -1,18 +1,55 @@
|
|||||||
from typing import List, Tuple, Type
|
from typing import List, Tuple, Type, Any
|
||||||
from src.plugin_system import (
|
from src.plugin_system import (
|
||||||
BasePlugin,
|
BasePlugin,
|
||||||
register_plugin,
|
register_plugin,
|
||||||
BaseAction,
|
BaseAction,
|
||||||
BaseCommand,
|
BaseCommand,
|
||||||
|
BaseTool,
|
||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
ConfigField,
|
ConfigField,
|
||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
|
ToolParamType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompareNumbersTool(BaseTool):
|
||||||
|
"""比较两个数大小的工具"""
|
||||||
|
|
||||||
|
name = "compare_numbers"
|
||||||
|
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||||
|
parameters = [
|
||||||
|
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
|
||||||
|
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""执行比较两个数的大小
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_args: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 工具执行结果
|
||||||
|
"""
|
||||||
|
num1: int | float = function_args.get("num1") # type: ignore
|
||||||
|
num2: int | float = function_args.get("num2") # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
if num1 > num2:
|
||||||
|
result = f"{num1} 大于 {num2}"
|
||||||
|
elif num1 < num2:
|
||||||
|
result = f"{num1} 小于 {num2}"
|
||||||
|
else:
|
||||||
|
result = f"{num1} 等于 {num2}"
|
||||||
|
|
||||||
|
return {"name": self.name, "content": result}
|
||||||
|
except Exception as e:
|
||||||
|
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
# ===== Action组件 =====
|
# ===== Action组件 =====
|
||||||
class HelloAction(BaseAction):
|
class HelloAction(BaseAction):
|
||||||
"""问候Action - 简单的问候动作"""
|
"""问候Action - 简单的问候动作"""
|
||||||
@@ -132,7 +169,9 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||||
},
|
},
|
||||||
"greeting": {
|
"greeting": {
|
||||||
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"),
|
"message": ConfigField(
|
||||||
|
type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息"
|
||||||
|
),
|
||||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
||||||
},
|
},
|
||||||
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
||||||
@@ -142,6 +181,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||||
return [
|
return [
|
||||||
(HelloAction.get_action_info(), HelloAction),
|
(HelloAction.get_action_info(), HelloAction),
|
||||||
|
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
|
||||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||||
(TimeCommand.get_command_info(), TimeCommand),
|
(TimeCommand.get_command_info(), TimeCommand),
|
||||||
(PrintMessage.get_handler_info(), PrintMessage),
|
(PrintMessage.get_handler_info(), PrintMessage),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ matplotlib
|
|||||||
networkx
|
networkx
|
||||||
numpy
|
numpy
|
||||||
openai
|
openai
|
||||||
|
google-genai
|
||||||
pandas
|
pandas
|
||||||
peewee
|
peewee
|
||||||
pyarrow
|
pyarrow
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
@echo off
|
|
||||||
start "Voice Adapter" cmd /k "call conda activate maicore && cd /d C:\GitHub\maimbot_tts_adapter && echo Running Napcat Adapter... && python maimbot_pipeline.py"
|
|
||||||
@@ -14,8 +14,6 @@ from src.chat.knowledge.open_ie import OpenIE
|
|||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.utils.hash import get_sha256
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
@@ -24,46 +22,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
|||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
|
||||||
load_dotenv(".env", override=True)
|
|
||||||
print("成功加载环境变量配置")
|
|
||||||
else:
|
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
|
||||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
|
||||||
|
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
|
||||||
def scan_provider(env_config: dict):
|
|
||||||
provider = {}
|
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
|
||||||
# 提取 provider 名称
|
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
|
||||||
provider[provider_name]["url"] = env_config[key]
|
|
||||||
elif key.endswith("_KEY"):
|
|
||||||
provider[provider_name]["key"] = env_config[key]
|
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
|
||||||
if config["url"] is None or config["key"] is None:
|
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
|
||||||
|
|
||||||
def ensure_openie_dir():
|
def ensure_openie_dir():
|
||||||
"""确保OpenIE数据目录存在"""
|
"""确保OpenIE数据目录存在"""
|
||||||
if not os.path.exists(OPENIE_DIR):
|
if not os.path.exists(OPENIE_DIR):
|
||||||
@@ -101,7 +59,9 @@ def hash_deduplicate(
|
|||||||
):
|
):
|
||||||
# 段落hash
|
# 段落hash
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
paragraph_hash = get_sha256(raw_paragraph)
|
||||||
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||||
|
paragraph_key = f"paragraph-{paragraph_hash}"
|
||||||
|
if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||||
continue
|
continue
|
||||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||||
new_triple_list_data[paragraph_hash] = triple_list
|
new_triple_list_data[paragraph_hash] = triple_list
|
||||||
@@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
|
|
||||||
def main(): # sourcery skip: dict-comprehension
|
def main(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
|
||||||
scan_provider(env_config)
|
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||||
@@ -264,7 +222,8 @@ def main(): # sourcery skip: dict-comprehension
|
|||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = f"{local_storage['pg_namespace']}-{pg_hash}"
|
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||||
|
key = f"paragraph-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,8 @@ from rich.progress import (
|
|||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
@@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
|
||||||
load_dotenv(".env", override=True)
|
|
||||||
print("成功加载环境变量配置")
|
|
||||||
else:
|
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
|
||||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
|
||||||
|
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
|
||||||
def scan_provider(env_config: dict):
|
|
||||||
provider = {}
|
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
|
||||||
# 提取 provider 名称
|
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
|
||||||
provider[provider_name]["url"] = env_config[key]
|
|
||||||
elif key.endswith("_KEY"):
|
|
||||||
provider[provider_name]["key"] = env_config[key]
|
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
|
||||||
if config["url"] is None or config["key"] is None:
|
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -96,11 +56,11 @@ open_ie_doc_lock = Lock()
|
|||||||
shutdown_event = Event()
|
shutdown_event = Event()
|
||||||
|
|
||||||
lpmm_entity_extract_llm = LLMRequest(
|
lpmm_entity_extract_llm = LLMRequest(
|
||||||
model=global_config.model.lpmm_entity_extract,
|
model_set=model_config.model_task_config.lpmm_entity_extract,
|
||||||
request_type="lpmm.entity_extract"
|
request_type="lpmm.entity_extract"
|
||||||
)
|
)
|
||||||
lpmm_rdf_build_llm = LLMRequest(
|
lpmm_rdf_build_llm = LLMRequest(
|
||||||
model=global_config.model.lpmm_rdf_build,
|
model_set=model_config.model_task_config.lpmm_rdf_build,
|
||||||
request_type="lpmm.rdf_build"
|
request_type="lpmm.rdf_build"
|
||||||
)
|
)
|
||||||
def process_single_text(pg_hash, raw_data):
|
def process_single_text(pg_hash, raw_data):
|
||||||
@@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
|
||||||
scan_provider(env_config)
|
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -8,15 +8,15 @@ import traceback
|
|||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
from typing import Optional, Tuple, List, Any
|
from typing import Optional, Tuple, List, Any
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
|
||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from src.common.database.database import db as peewee_db
|
from src.common.database.database import db as peewee_db
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
@@ -379,9 +379,9 @@ class EmojiManager:
|
|||||||
|
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
|
|
||||||
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
|
||||||
self.llm_emotion_judge = LLMRequest(
|
self.llm_emotion_judge = LLMRequest(
|
||||||
model=global_config.model.utils, max_tokens=600, request_type="emoji"
|
model_set=model_config.model_task_config.utils, request_type="emoji"
|
||||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
self.emoji_num = 0
|
self.emoji_num = 0
|
||||||
@@ -492,6 +492,7 @@ class EmojiManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||||
|
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
|
||||||
"""计算两个字符串的编辑距离
|
"""计算两个字符串的编辑距离
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -629,11 +630,11 @@ class EmojiManager:
|
|||||||
if success:
|
if success:
|
||||||
# 注册成功则跳出循环
|
# 注册成功则跳出循环
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
# 注册失败则删除对应文件
|
# 注册失败则删除对应文件
|
||||||
file_path = os.path.join(EMOJI_DIR, filename)
|
file_path = os.path.join(EMOJI_DIR, filename)
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
|
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
||||||
|
|
||||||
@@ -694,6 +695,7 @@ class EmojiManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
||||||
|
# sourcery skip: use-next
|
||||||
"""从内存中的 emoji_objects 列表获取表情包
|
"""从内存中的 emoji_objects 列表获取表情包
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -706,13 +708,45 @@ class EmojiManager:
|
|||||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||||
return emoji
|
return emoji
|
||||||
return None # 如果循环结束还没找到,则返回 None
|
return None # 如果循环结束还没找到,则返回 None
|
||||||
|
|
||||||
|
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||||
|
"""根据哈希值获取已注册表情包的描述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emoji_hash: 表情包的哈希值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 表情包描述,如果未找到则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先从内存中查找
|
||||||
|
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||||
|
if emoji and emoji.emotion:
|
||||||
|
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
|
||||||
|
return ",".join(emoji.emotion)
|
||||||
|
|
||||||
|
# 如果内存中没有,从数据库查找
|
||||||
|
self._ensure_db()
|
||||||
|
try:
|
||||||
|
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||||
|
if emoji_record and emoji_record.emotion:
|
||||||
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
|
||||||
|
return emoji_record.emotion
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||||
"""根据哈希值获取已注册表情包的描述
|
"""根据哈希值获取已注册表情包的描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
emoji_hash: 表情包的哈希值
|
emoji_hash: 表情包的哈希值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 表情包描述,如果未找到则返回None
|
Optional[str]: 表情包描述,如果未找到则返回None
|
||||||
"""
|
"""
|
||||||
@@ -722,7 +756,7 @@ class EmojiManager:
|
|||||||
if emoji and emoji.description:
|
if emoji and emoji.description:
|
||||||
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
|
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
|
||||||
return emoji.description
|
return emoji.description
|
||||||
|
|
||||||
# 如果内存中没有,从数据库查找
|
# 如果内存中没有,从数据库查找
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
try:
|
try:
|
||||||
@@ -732,9 +766,9 @@ class EmojiManager:
|
|||||||
return emoji_record.description
|
return emoji_record.description
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||||
return None
|
return None
|
||||||
@@ -779,6 +813,7 @@ class EmojiManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
|
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
|
||||||
|
# sourcery skip: use-getitem-for-re-match-groups
|
||||||
"""替换一个表情包
|
"""替换一个表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -820,7 +855,7 @@ class EmojiManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用大模型进行决策
|
# 调用大模型进行决策
|
||||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
|
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
|
||||||
logger.info(f"[决策] 结果: {decision}")
|
logger.info(f"[决策] 结果: {decision}")
|
||||||
|
|
||||||
# 解析决策结果
|
# 解析决策结果
|
||||||
@@ -828,9 +863,7 @@ class EmojiManager:
|
|||||||
logger.info("[决策] 不删除任何表情包")
|
logger.info("[决策] 不删除任何表情包")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 尝试从决策中提取表情包编号
|
if match := re.search(r"删除编号(\d+)", decision):
|
||||||
match = re.search(r"删除编号(\d+)", decision)
|
|
||||||
if match:
|
|
||||||
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
|
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
|
||||||
|
|
||||||
# 检查索引是否有效
|
# 检查索引是否有效
|
||||||
@@ -889,6 +922,7 @@ class EmojiManager:
|
|||||||
existing_description = None
|
existing_description = None
|
||||||
try:
|
try:
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
|
|
||||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||||
if existing_image and existing_image.description:
|
if existing_image and existing_image.description:
|
||||||
existing_description = existing_image.description
|
existing_description = existing_image.description
|
||||||
@@ -902,15 +936,21 @@ class EmojiManager:
|
|||||||
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
||||||
else:
|
else:
|
||||||
logger.info("[VLM分析] 生成新的详细描述")
|
logger.info("[VLM分析] 生成新的详细描述")
|
||||||
if image_format == "gif" or image_format == "GIF":
|
if image_format in ["gif", "GIF"]:
|
||||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||||
if not image_base64:
|
if not image_base64:
|
||||||
raise RuntimeError("GIF表情包转换失败")
|
raise RuntimeError("GIF表情包转换失败")
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
prompt = (
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
|
)
|
||||||
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||||
|
)
|
||||||
|
|
||||||
# 审核表情包
|
# 审核表情包
|
||||||
if global_config.emoji.content_filtration:
|
if global_config.emoji.content_filtration:
|
||||||
@@ -922,7 +962,9 @@ class EmojiManager:
|
|||||||
4. 不要出现5个以上文字
|
4. 不要出现5个以上文字
|
||||||
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
||||||
'''
|
'''
|
||||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
content, _ = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||||
|
)
|
||||||
if content == "否":
|
if content == "否":
|
||||||
return "", []
|
return "", []
|
||||||
|
|
||||||
@@ -933,7 +975,9 @@ class EmojiManager:
|
|||||||
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||||
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||||
"""
|
"""
|
||||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
|
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
|
||||||
|
emotion_prompt, temperature=0.7, max_tokens=600
|
||||||
|
)
|
||||||
|
|
||||||
# 处理情感列表
|
# 处理情感列表
|
||||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ from datetime import datetime
|
|||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database_model import Expression
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.common.database.database_model import Expression
|
|
||||||
|
|
||||||
|
|
||||||
MAX_EXPRESSION_COUNT = 300
|
MAX_EXPRESSION_COUNT = 300
|
||||||
@@ -38,10 +38,9 @@ def init_prompt() -> None:
|
|||||||
|
|
||||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
||||||
1. 只考虑文字,不要考虑表情包和图片
|
1. 只考虑文字,不要考虑表情包和图片
|
||||||
2. 不要涉及具体的人名,只考虑语言风格
|
2. 不要涉及具体的人名,但是可以涉及具体名词
|
||||||
3. 语言风格包含特殊内容和情感
|
3. 思考有没有特殊的梗,一并总结成语言风格
|
||||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
|
||||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||||
|
|
||||||
@@ -51,302 +50,150 @@ def init_prompt() -> None:
|
|||||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
||||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||||
|
|
||||||
请注意:不要总结你自己(SELF)的发言
|
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||||
现在请你概括
|
现在请你概括
|
||||||
"""
|
"""
|
||||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||||
|
|
||||||
learn_grammar_prompt = """
|
|
||||||
{chat_str}
|
|
||||||
|
|
||||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语法和句法特点,只考虑纯文字,不要考虑表情包和图片
|
|
||||||
1.不要总结【图片】,【动画表情】,[图片],[动画表情],不总结 表情符号 at @ 回复 和[回复]
|
|
||||||
2.不要涉及具体的人名,只考虑语法和句法特点,
|
|
||||||
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
|
|
||||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
|
||||||
总结成如下格式的规律,总结的内容要简洁,不浮夸:
|
|
||||||
当"xxx"时,可以"xxx"
|
|
||||||
|
|
||||||
例如:
|
|
||||||
当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
|
|
||||||
当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
|
|
||||||
当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
|
|
||||||
|
|
||||||
注意不要总结你自己(SELF)的发言
|
|
||||||
现在请你概括
|
|
||||||
"""
|
|
||||||
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionLearner:
|
class ExpressionLearner:
|
||||||
def __init__(self) -> None:
|
def __init__(self, chat_id: str) -> None:
|
||||||
# TODO: API-Adapter修改标记
|
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
model=global_config.model.replyer_1,
|
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
|
||||||
temperature=0.3,
|
|
||||||
request_type="expressor.learner",
|
|
||||||
)
|
)
|
||||||
self.llm_model = None
|
self.chat_id = chat_id
|
||||||
self._ensure_expression_directories()
|
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||||
self._auto_migrate_json_to_db()
|
|
||||||
self._migrate_old_data_create_date()
|
|
||||||
|
|
||||||
def _ensure_expression_directories(self):
|
|
||||||
"""
|
|
||||||
确保表达方式相关的目录结构存在
|
|
||||||
"""
|
|
||||||
base_dir = os.path.join("data", "expression")
|
|
||||||
directories_to_create = [
|
|
||||||
base_dir,
|
|
||||||
os.path.join(base_dir, "learnt_style"),
|
|
||||||
os.path.join(base_dir, "learnt_grammar"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for directory in directories_to_create:
|
# 维护每个chat的上次学习时间
|
||||||
try:
|
self.last_learning_time: float = time.time()
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
logger.debug(f"确保目录存在: {directory}")
|
# 学习参数
|
||||||
except Exception as e:
|
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||||
logger.error(f"创建目录失败 {directory}: {e}")
|
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||||
|
|
||||||
def _auto_migrate_json_to_db(self):
|
|
||||||
|
|
||||||
|
|
||||||
|
def can_learn_for_chat(self) -> bool:
|
||||||
"""
|
"""
|
||||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
检查指定聊天流是否允许学习表达
|
||||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
|
||||||
"""
|
|
||||||
base_dir = os.path.join("data", "expression")
|
|
||||||
done_flag = os.path.join(base_dir, "done.done")
|
|
||||||
|
|
||||||
# 确保基础目录存在
|
Args:
|
||||||
try:
|
chat_id: 聊天流ID
|
||||||
os.makedirs(base_dir, exist_ok=True)
|
|
||||||
logger.debug(f"确保目录存在: {base_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建表达方式目录失败: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if os.path.exists(done_flag):
|
|
||||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("开始迁移表达方式JSON到数据库...")
|
Returns:
|
||||||
migrated_count = 0
|
bool: 是否允许学习
|
||||||
|
|
||||||
for type in ["learnt_style", "learnt_grammar"]:
|
|
||||||
type_str = "style" if type == "learnt_style" else "grammar"
|
|
||||||
type_dir = os.path.join(base_dir, type)
|
|
||||||
if not os.path.exists(type_dir):
|
|
||||||
logger.debug(f"目录不存在,跳过: {type_dir}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat_ids = os.listdir(type_dir)
|
|
||||||
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取目录失败 {type_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
|
||||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
|
||||||
if not os.path.exists(expr_file):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
with open(expr_file, "r", encoding="utf-8") as f:
|
|
||||||
expressions = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(expressions, list):
|
|
||||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for expr in expressions:
|
|
||||||
if not isinstance(expr, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
situation = expr.get("situation")
|
|
||||||
style_val = expr.get("style")
|
|
||||||
count = expr.get("count", 1)
|
|
||||||
last_active_time = expr.get("last_active_time", time.time())
|
|
||||||
|
|
||||||
if not situation or not style_val:
|
|
||||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 查重:同chat_id+type+situation+style
|
|
||||||
from src.common.database.database_model import Expression
|
|
||||||
|
|
||||||
query = Expression.select().where(
|
|
||||||
(Expression.chat_id == chat_id)
|
|
||||||
& (Expression.type == type_str)
|
|
||||||
& (Expression.situation == situation)
|
|
||||||
& (Expression.style == style_val)
|
|
||||||
)
|
|
||||||
if query.exists():
|
|
||||||
expr_obj = query.get()
|
|
||||||
expr_obj.count = max(expr_obj.count, count)
|
|
||||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
|
||||||
expr_obj.save()
|
|
||||||
else:
|
|
||||||
Expression.create(
|
|
||||||
situation=situation,
|
|
||||||
style=style_val,
|
|
||||||
count=count,
|
|
||||||
last_active_time=last_active_time,
|
|
||||||
chat_id=chat_id,
|
|
||||||
type=type_str,
|
|
||||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
|
||||||
)
|
|
||||||
migrated_count += 1
|
|
||||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
|
||||||
|
|
||||||
# 标记迁移完成
|
|
||||||
try:
|
|
||||||
# 确保done.done文件的父目录存在
|
|
||||||
done_parent_dir = os.path.dirname(done_flag)
|
|
||||||
if not os.path.exists(done_parent_dir):
|
|
||||||
os.makedirs(done_parent_dir, exist_ok=True)
|
|
||||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
|
||||||
|
|
||||||
with open(done_flag, "w", encoding="utf-8") as f:
|
|
||||||
f.write("done\n")
|
|
||||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
|
||||||
except OSError as e:
|
|
||||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"写入done.done标记文件失败: {e}")
|
|
||||||
|
|
||||||
def _migrate_old_data_create_date(self):
|
|
||||||
"""
|
|
||||||
为没有create_date的老数据设置创建日期
|
|
||||||
使用last_active_time作为create_date的默认值
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 查找所有create_date为空的表达方式
|
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
return enable_learning
|
||||||
updated_count = 0
|
|
||||||
|
|
||||||
for expr in old_expressions:
|
|
||||||
# 使用last_active_time作为create_date
|
|
||||||
expr.create_date = expr.last_active_time
|
|
||||||
expr.save()
|
|
||||||
updated_count += 1
|
|
||||||
|
|
||||||
if updated_count > 0:
|
|
||||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
logger.error(f"检查学习权限失败: {e}")
|
||||||
|
|
||||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
|
||||||
"""
|
|
||||||
获取指定chat_id的style和grammar表达方式
|
|
||||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
|
||||||
"""
|
|
||||||
learnt_style_expressions = []
|
|
||||||
learnt_grammar_expressions = []
|
|
||||||
|
|
||||||
# 直接从数据库查询
|
|
||||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
|
||||||
for expr in style_query:
|
|
||||||
# 确保create_date存在,如果不存在则使用last_active_time
|
|
||||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
|
||||||
learnt_style_expressions.append(
|
|
||||||
{
|
|
||||||
"situation": expr.situation,
|
|
||||||
"style": expr.style,
|
|
||||||
"count": expr.count,
|
|
||||||
"last_active_time": expr.last_active_time,
|
|
||||||
"source_id": chat_id,
|
|
||||||
"type": "style",
|
|
||||||
"create_date": create_date,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
|
||||||
for expr in grammar_query:
|
|
||||||
# 确保create_date存在,如果不存在则使用last_active_time
|
|
||||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
|
||||||
learnt_grammar_expressions.append(
|
|
||||||
{
|
|
||||||
"situation": expr.situation,
|
|
||||||
"style": expr.style,
|
|
||||||
"count": expr.count,
|
|
||||||
"last_active_time": expr.last_active_time,
|
|
||||||
"source_id": chat_id,
|
|
||||||
"type": "grammar",
|
|
||||||
"create_date": create_date,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return learnt_style_expressions, learnt_grammar_expressions
|
|
||||||
|
|
||||||
def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
获取指定chat_id的表达方式创建信息,按创建日期排序
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
expressions = (Expression.select()
|
|
||||||
.where(Expression.chat_id == chat_id)
|
|
||||||
.order_by(Expression.create_date.desc())
|
|
||||||
.limit(limit))
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for expr in expressions:
|
|
||||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
|
||||||
result.append({
|
|
||||||
"situation": expr.situation,
|
|
||||||
"style": expr.style,
|
|
||||||
"type": expr.type,
|
|
||||||
"count": expr.count,
|
|
||||||
"create_date": create_date,
|
|
||||||
"create_date_formatted": format_create_date(create_date),
|
|
||||||
"last_active_time": expr.last_active_time,
|
|
||||||
"last_active_formatted": format_create_date(expr.last_active_time),
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取表达方式创建信息失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def is_similar(self, s1: str, s2: str) -> bool:
|
|
||||||
"""
|
|
||||||
判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串)
|
|
||||||
"""
|
|
||||||
if not s1 or not s2:
|
|
||||||
return False
|
return False
|
||||||
min_len = min(len(s1), len(s2))
|
|
||||||
if min_len < 5:
|
|
||||||
return False
|
|
||||||
same = sum(a == b for a, b in zip(s1, s2, strict=False))
|
|
||||||
return same / min_len > 0.8
|
|
||||||
|
|
||||||
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
|
def should_trigger_learning(self) -> bool:
|
||||||
"""
|
"""
|
||||||
学习并存储表达方式,分别学习语言风格和句法特点
|
检查是否应该触发学习
|
||||||
同时对所有已存储的表达方式进行全局衰减
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否应该触发学习
|
||||||
"""
|
"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 全局衰减所有已存储的表达方式(直接操作数据库)
|
# 获取该聊天流的学习强度
|
||||||
self._apply_global_decay_to_database(current_time)
|
try:
|
||||||
|
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否允许学习
|
||||||
|
if not enable_learning:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 根据学习强度计算最短学习时间间隔
|
||||||
|
min_interval = self.min_learning_interval / learning_intensity
|
||||||
|
|
||||||
|
# 检查时间间隔
|
||||||
|
time_diff = current_time - self.last_learning_time
|
||||||
|
if time_diff < min_interval:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查消息数量(只检查指定聊天流的消息)
|
||||||
|
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
timestamp_start=self.last_learning_time,
|
||||||
|
timestamp_end=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
return False
|
||||||
# 学习新的表达方式(这里会进行局部衰减)
|
|
||||||
for _ in range(3):
|
return True
|
||||||
learnt_style = await self.learn_and_store(type="style", num=25)
|
|
||||||
if not learnt_style:
|
async def trigger_learning_for_chat(self) -> bool:
|
||||||
return [], []
|
"""
|
||||||
|
为指定聊天流触发学习
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功触发学习
|
||||||
|
"""
|
||||||
|
if not self.should_trigger_learning():
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||||
|
|
||||||
|
# 学习语言风格
|
||||||
|
learnt_style = await self.learn_and_store(num=25)
|
||||||
|
|
||||||
|
# 更新学习时间
|
||||||
|
self.last_learning_time = time.time()
|
||||||
|
|
||||||
|
if learnt_style:
|
||||||
|
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||||
|
# """
|
||||||
|
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
|
||||||
|
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||||
|
# """
|
||||||
|
# learnt_style_expressions = []
|
||||||
|
|
||||||
|
# # 直接从数据库查询
|
||||||
|
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||||
|
# for expr in style_query:
|
||||||
|
# # 确保create_date存在,如果不存在则使用last_active_time
|
||||||
|
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||||
|
# learnt_style_expressions.append(
|
||||||
|
# {
|
||||||
|
# "situation": expr.situation,
|
||||||
|
# "style": expr.style,
|
||||||
|
# "count": expr.count,
|
||||||
|
# "last_active_time": expr.last_active_time,
|
||||||
|
# "source_id": self.chat_id,
|
||||||
|
# "type": "style",
|
||||||
|
# "create_date": create_date,
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# return learnt_style_expressions
|
||||||
|
|
||||||
for _ in range(1):
|
|
||||||
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
|
||||||
if not learnt_grammar:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
return learnt_style, learnt_grammar
|
|
||||||
|
|
||||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -355,19 +202,19 @@ class ExpressionLearner:
|
|||||||
try:
|
try:
|
||||||
# 获取所有表达方式
|
# 获取所有表达方式
|
||||||
all_expressions = Expression.select()
|
all_expressions = Expression.select()
|
||||||
|
|
||||||
updated_count = 0
|
updated_count = 0
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
|
|
||||||
for expr in all_expressions:
|
for expr in all_expressions:
|
||||||
# 计算时间差
|
# 计算时间差
|
||||||
last_active = expr.last_active_time
|
last_active = expr.last_active_time
|
||||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||||
|
|
||||||
# 计算衰减值
|
# 计算衰减值
|
||||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||||
new_count = max(0.01, expr.count - decay_value)
|
new_count = max(0.01, expr.count - decay_value)
|
||||||
|
|
||||||
if new_count <= 0.01:
|
if new_count <= 0.01:
|
||||||
# 如果count太小,删除这个表达方式
|
# 如果count太小,删除这个表达方式
|
||||||
expr.delete_instance()
|
expr.delete_instance()
|
||||||
@@ -377,10 +224,10 @@ class ExpressionLearner:
|
|||||||
expr.count = new_count
|
expr.count = new_count
|
||||||
expr.save()
|
expr.save()
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
|
|
||||||
if updated_count > 0 or deleted_count > 0:
|
if updated_count > 0 or deleted_count > 0:
|
||||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库全局衰减失败: {e}")
|
logger.error(f"数据库全局衰减失败: {e}")
|
||||||
|
|
||||||
@@ -406,20 +253,16 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
return min(0.01, decay)
|
return min(0.01, decay)
|
||||||
|
|
||||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||||
# sourcery skip: use-join
|
|
||||||
"""
|
"""
|
||||||
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
学习并存储表达方式
|
||||||
type: "style" or "grammar"
|
|
||||||
"""
|
"""
|
||||||
if type == "style":
|
# 检查是否允许在此聊天流中学习(在函数最前面检查)
|
||||||
type_str = "语言风格"
|
if not self.can_learn_for_chat():
|
||||||
elif type == "grammar":
|
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
|
||||||
type_str = "句法特点"
|
return []
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid type: {type}")
|
|
||||||
|
|
||||||
res = await self.learn_expression(type, num)
|
res = await self.learn_expression(num)
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
return []
|
return []
|
||||||
@@ -435,10 +278,10 @@ class ExpressionLearner:
|
|||||||
learnt_expressions_str = ""
|
learnt_expressions_str = ""
|
||||||
for _chat_id, situation, style in learnt_expressions:
|
for _chat_id, situation, style in learnt_expressions:
|
||||||
learnt_expressions_str += f"{situation}->{style}\n"
|
learnt_expressions_str += f"{situation}->{style}\n"
|
||||||
logger.info(f"在 {group_name} 学习到{type_str}:\n{learnt_expressions_str}")
|
logger.info(f"在 {group_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||||
|
|
||||||
if not learnt_expressions:
|
if not learnt_expressions:
|
||||||
logger.info(f"没有学习到{type_str}")
|
logger.info("没有学习到表达风格")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 按chat_id分组
|
# 按chat_id分组
|
||||||
@@ -456,7 +299,7 @@ class ExpressionLearner:
|
|||||||
# 查找是否已存在相似表达方式
|
# 查找是否已存在相似表达方式
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == chat_id)
|
(Expression.chat_id == chat_id)
|
||||||
& (Expression.type == type)
|
& (Expression.type == "style")
|
||||||
& (Expression.situation == new_expr["situation"])
|
& (Expression.situation == new_expr["situation"])
|
||||||
& (Expression.style == new_expr["style"])
|
& (Expression.style == new_expr["style"])
|
||||||
)
|
)
|
||||||
@@ -476,13 +319,13 @@ class ExpressionLearner:
|
|||||||
count=1,
|
count=1,
|
||||||
last_active_time=current_time,
|
last_active_time=current_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
type=type,
|
type="style",
|
||||||
create_date=current_time, # 手动设置创建日期
|
create_date=current_time, # 手动设置创建日期
|
||||||
)
|
)
|
||||||
# 限制最大数量
|
# 限制最大数量
|
||||||
exprs = list(
|
exprs = list(
|
||||||
Expression.select()
|
Expression.select()
|
||||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||||
.order_by(Expression.count.asc())
|
.order_by(Expression.count.asc())
|
||||||
)
|
)
|
||||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||||
@@ -491,25 +334,25 @@ class ExpressionLearner:
|
|||||||
expr.delete_instance()
|
expr.delete_instance()
|
||||||
return learnt_expressions
|
return learnt_expressions
|
||||||
|
|
||||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||||
"""选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
"""从指定聊天流学习表达方式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type: "style" or "grammar"
|
num: 学习数量
|
||||||
"""
|
"""
|
||||||
if type == "style":
|
type_str = "语言风格"
|
||||||
type_str = "语言风格"
|
prompt = "learn_style_prompt"
|
||||||
prompt = "learn_style_prompt"
|
|
||||||
elif type == "grammar":
|
|
||||||
type_str = "句法特点"
|
|
||||||
prompt = "learn_grammar_prompt"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid type: {type}")
|
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_random(
|
|
||||||
current_time - 3600 * 24, current_time, limit=num
|
# 获取上次学习时间
|
||||||
|
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
timestamp_start=self.last_learning_time,
|
||||||
|
timestamp_end=current_time,
|
||||||
|
limit=num,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(random_msg)
|
# print(random_msg)
|
||||||
if not random_msg or random_msg == []:
|
if not random_msg or random_msg == []:
|
||||||
return None
|
return None
|
||||||
@@ -527,7 +370,7 @@ class ExpressionLearner:
|
|||||||
logger.debug(f"学习{type_str}的prompt: {prompt}")
|
logger.debug(f"学习{type_str}的prompt: {prompt}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self.express_learn_model.generate_response_async(prompt)
|
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"学习{type_str}失败: {e}")
|
logger.error(f"学习{type_str}失败: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -571,12 +414,221 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
class ExpressionLearnerManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.expression_learners = {}
|
||||||
|
|
||||||
|
self._ensure_expression_directories()
|
||||||
|
self._auto_migrate_json_to_db()
|
||||||
|
self._migrate_old_data_create_date()
|
||||||
|
|
||||||
|
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||||
|
if chat_id not in self.expression_learners:
|
||||||
|
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||||
|
return self.expression_learners[chat_id]
|
||||||
|
|
||||||
|
def _ensure_expression_directories(self):
|
||||||
|
"""
|
||||||
|
确保表达方式相关的目录结构存在
|
||||||
|
"""
|
||||||
|
base_dir = os.path.join("data", "expression")
|
||||||
|
directories_to_create = [
|
||||||
|
base_dir,
|
||||||
|
os.path.join(base_dir, "learnt_style"),
|
||||||
|
os.path.join(base_dir, "learnt_grammar"),
|
||||||
|
]
|
||||||
|
|
||||||
expression_learner = None
|
for directory in directories_to_create:
|
||||||
|
try:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
logger.debug(f"确保目录存在: {directory}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建目录失败 {directory}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_expression_learner():
|
def _auto_migrate_json_to_db(self):
|
||||||
global expression_learner
|
"""
|
||||||
if expression_learner is None:
|
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||||
expression_learner = ExpressionLearner()
|
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||||
return expression_learner
|
然后检查done.done2,如果没有就删除所有grammar表达并创建该标记文件。
|
||||||
|
"""
|
||||||
|
base_dir = os.path.join("data", "expression")
|
||||||
|
done_flag = os.path.join(base_dir, "done.done")
|
||||||
|
done_flag2 = os.path.join(base_dir, "done.done2")
|
||||||
|
|
||||||
|
# 确保基础目录存在
|
||||||
|
try:
|
||||||
|
os.makedirs(base_dir, exist_ok=True)
|
||||||
|
logger.debug(f"确保目录存在: {base_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建表达方式目录失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.exists(done_flag):
|
||||||
|
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||||
|
else:
|
||||||
|
logger.info("开始迁移表达方式JSON到数据库...")
|
||||||
|
migrated_count = 0
|
||||||
|
|
||||||
|
for type in ["learnt_style", "learnt_grammar"]:
|
||||||
|
type_str = "style" if type == "learnt_style" else "grammar"
|
||||||
|
type_dir = os.path.join(base_dir, type)
|
||||||
|
if not os.path.exists(type_dir):
|
||||||
|
logger.debug(f"目录不存在,跳过: {type_dir}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_ids = os.listdir(type_dir)
|
||||||
|
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取目录失败 {type_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for chat_id in chat_ids:
|
||||||
|
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||||
|
if not os.path.exists(expr_file):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
with open(expr_file, "r", encoding="utf-8") as f:
|
||||||
|
expressions = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(expressions, list):
|
||||||
|
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for expr in expressions:
|
||||||
|
if not isinstance(expr, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
situation = expr.get("situation")
|
||||||
|
style_val = expr.get("style")
|
||||||
|
count = expr.get("count", 1)
|
||||||
|
last_active_time = expr.get("last_active_time", time.time())
|
||||||
|
|
||||||
|
if not situation or not style_val:
|
||||||
|
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 查重:同chat_id+type+situation+style
|
||||||
|
from src.common.database.database_model import Expression
|
||||||
|
|
||||||
|
query = Expression.select().where(
|
||||||
|
(Expression.chat_id == chat_id)
|
||||||
|
& (Expression.type == type_str)
|
||||||
|
& (Expression.situation == situation)
|
||||||
|
& (Expression.style == style_val)
|
||||||
|
)
|
||||||
|
if query.exists():
|
||||||
|
expr_obj = query.get()
|
||||||
|
expr_obj.count = max(expr_obj.count, count)
|
||||||
|
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||||
|
expr_obj.save()
|
||||||
|
else:
|
||||||
|
Expression.create(
|
||||||
|
situation=situation,
|
||||||
|
style=style_val,
|
||||||
|
count=count,
|
||||||
|
last_active_time=last_active_time,
|
||||||
|
chat_id=chat_id,
|
||||||
|
type=type_str,
|
||||||
|
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||||
|
)
|
||||||
|
migrated_count += 1
|
||||||
|
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||||
|
|
||||||
|
# 标记迁移完成
|
||||||
|
try:
|
||||||
|
# 确保done.done文件的父目录存在
|
||||||
|
done_parent_dir = os.path.dirname(done_flag)
|
||||||
|
if not os.path.exists(done_parent_dir):
|
||||||
|
os.makedirs(done_parent_dir, exist_ok=True)
|
||||||
|
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||||
|
|
||||||
|
with open(done_flag, "w", encoding="utf-8") as f:
|
||||||
|
f.write("done\n")
|
||||||
|
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||||
|
except OSError as e:
|
||||||
|
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入done.done标记文件失败: {e}")
|
||||||
|
|
||||||
|
# 检查并处理grammar表达删除
|
||||||
|
if not os.path.exists(done_flag2):
|
||||||
|
logger.info("开始删除所有grammar类型的表达...")
|
||||||
|
try:
|
||||||
|
deleted_count = self.delete_all_grammar_expressions()
|
||||||
|
logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达")
|
||||||
|
|
||||||
|
# 创建done.done2标记文件
|
||||||
|
with open(done_flag2, "w", encoding="utf-8") as f:
|
||||||
|
f.write("done\n")
|
||||||
|
logger.info("已创建done.done2标记文件,grammar表达删除标记完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("grammar表达已删除,跳过重复删除")
|
||||||
|
|
||||||
|
def _migrate_old_data_create_date(self):
|
||||||
|
"""
|
||||||
|
为没有create_date的老数据设置创建日期
|
||||||
|
使用last_active_time作为create_date的默认值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找所有create_date为空的表达方式
|
||||||
|
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
for expr in old_expressions:
|
||||||
|
# 使用last_active_time作为create_date
|
||||||
|
expr.create_date = expr.last_active_time
|
||||||
|
expr.save()
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
if updated_count > 0:
|
||||||
|
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||||
|
|
||||||
|
def delete_all_grammar_expressions(self) -> int:
|
||||||
|
"""
|
||||||
|
检查expression库中所有type为"grammar"的表达并全部删除
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的grammar表达数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查询所有type为"grammar"的表达
|
||||||
|
grammar_expressions = Expression.select().where(Expression.type == "grammar")
|
||||||
|
grammar_count = grammar_expressions.count()
|
||||||
|
|
||||||
|
if grammar_count == 0:
|
||||||
|
logger.info("expression库中没有找到grammar类型的表达")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...")
|
||||||
|
|
||||||
|
# 删除所有grammar类型的表达
|
||||||
|
deleted_count = 0
|
||||||
|
for expr in grammar_expressions:
|
||||||
|
try:
|
||||||
|
expr.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除grammar表达失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除grammar表达过程中发生错误: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
expression_learner_manager = ExpressionLearnerManager()
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from typing import List, Dict, Tuple, Optional, Any
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
from .expression_learner import get_expression_learner
|
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ def init_prompt():
|
|||||||
以下是可选的表达情境:
|
以下是可选的表达情境:
|
||||||
{all_situations}
|
{all_situations}
|
||||||
|
|
||||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。
|
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||||
考虑因素包括:
|
考虑因素包括:
|
||||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||||
2. 话题类型(日常、技术、游戏、情感等)
|
2. 话题类型(日常、技术、游戏、情感等)
|
||||||
@@ -35,11 +35,7 @@ def init_prompt():
|
|||||||
请以JSON格式输出,只需要输出选中的情境编号:
|
请以JSON格式输出,只需要输出选中的情境编号:
|
||||||
例如:
|
例如:
|
||||||
{{
|
{{
|
||||||
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48 , 64]
|
"selected_situations": [2, 3, 5, 7, 19]
|
||||||
}}
|
|
||||||
例如:
|
|
||||||
{{
|
|
||||||
"selected_situations": [1, 4, 7, 9, 23, 38, 44]
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
请严格按照JSON格式输出,不要包含其他内容:
|
请严格按照JSON格式输出,不要包含其他内容:
|
||||||
@@ -74,13 +70,27 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
|
|||||||
|
|
||||||
class ExpressionSelector:
|
class ExpressionSelector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.expression_learner = get_expression_learner()
|
|
||||||
# TODO: API-Adapter修改标记
|
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
model=global_config.model.utils_small,
|
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||||
request_type="expression.selector",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查指定聊天流是否允许使用表达
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否允许使用表达
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||||
|
return use_expression
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查表达使用权限失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||||
@@ -92,7 +102,6 @@ class ExpressionSelector:
|
|||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
is_group = stream_type == "group"
|
is_group = stream_type == "group"
|
||||||
import hashlib
|
|
||||||
if is_group:
|
if is_group:
|
||||||
components = [platform, str(id_str)]
|
components = [platform, str(id_str)]
|
||||||
else:
|
else:
|
||||||
@@ -108,29 +117,27 @@ class ExpressionSelector:
|
|||||||
for group in groups:
|
for group in groups:
|
||||||
group_chat_ids = []
|
group_chat_ids = []
|
||||||
for stream_config_str in group:
|
for stream_config_str in group:
|
||||||
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
|
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||||
if chat_id_candidate:
|
|
||||||
group_chat_ids.append(chat_id_candidate)
|
group_chat_ids.append(chat_id_candidate)
|
||||||
if chat_id in group_chat_ids:
|
if chat_id in group_chat_ids:
|
||||||
return group_chat_ids
|
return group_chat_ids
|
||||||
return [chat_id]
|
return [chat_id]
|
||||||
|
|
||||||
def get_random_expressions(
|
def get_random_expressions(
|
||||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
self, chat_id: str, total_num: int
|
||||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
) -> List[Dict[str, Any]]:
|
||||||
|
# sourcery skip: extract-duplicate-method, move-assign
|
||||||
# 支持多chat_id合并抽选
|
# 支持多chat_id合并抽选
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式
|
# 优化:一次性查询所有相关chat_id的表达方式
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where(
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||||
)
|
)
|
||||||
grammar_query = Expression.select().where(
|
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
|
||||||
)
|
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
{
|
{
|
||||||
|
"id": expr.id,
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
"style": expr.style,
|
"style": expr.style,
|
||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
@@ -138,35 +145,17 @@ class ExpressionSelector:
|
|||||||
"source_id": expr.chat_id,
|
"source_id": expr.chat_id,
|
||||||
"type": "style",
|
"type": "style",
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
} for expr in style_query
|
}
|
||||||
|
for expr in style_query
|
||||||
]
|
]
|
||||||
|
|
||||||
grammar_exprs = [
|
|
||||||
{
|
|
||||||
"situation": expr.situation,
|
|
||||||
"style": expr.style,
|
|
||||||
"count": expr.count,
|
|
||||||
"last_active_time": expr.last_active_time,
|
|
||||||
"source_id": expr.chat_id,
|
|
||||||
"type": "grammar",
|
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
|
||||||
} for expr in grammar_query
|
|
||||||
]
|
|
||||||
|
|
||||||
style_num = int(total_num * style_percentage)
|
|
||||||
grammar_num = int(total_num * grammar_percentage)
|
|
||||||
# 按权重抽样(使用count作为权重)
|
# 按权重抽样(使用count作为权重)
|
||||||
if style_exprs:
|
if style_exprs:
|
||||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||||
selected_style = weighted_sample(style_exprs, style_weights, style_num)
|
selected_style = weighted_sample(style_exprs, style_weights, total_num)
|
||||||
else:
|
else:
|
||||||
selected_style = []
|
selected_style = []
|
||||||
if grammar_exprs:
|
return selected_style
|
||||||
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
|
|
||||||
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
|
|
||||||
else:
|
|
||||||
selected_grammar = []
|
|
||||||
return selected_style, selected_grammar
|
|
||||||
|
|
||||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||||
@@ -174,22 +163,22 @@ class ExpressionSelector:
|
|||||||
return
|
return
|
||||||
updates_by_key = {}
|
updates_by_key = {}
|
||||||
for expr in expressions_to_update:
|
for expr in expressions_to_update:
|
||||||
source_id = expr.get("source_id")
|
source_id: str = expr.get("source_id") # type: ignore
|
||||||
expr_type = expr.get("type", "style")
|
expr_type: str = expr.get("type", "style")
|
||||||
situation = expr.get("situation")
|
situation: str = expr.get("situation") # type: ignore
|
||||||
style = expr.get("style")
|
style: str = expr.get("style") # type: ignore
|
||||||
if not source_id or not situation or not style:
|
if not source_id or not situation or not style:
|
||||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||||
continue
|
continue
|
||||||
key = (source_id, expr_type, situation, style)
|
key = (source_id, expr_type, situation, style)
|
||||||
if key not in updates_by_key:
|
if key not in updates_by_key:
|
||||||
updates_by_key[key] = expr
|
updates_by_key[key] = expr
|
||||||
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
|
for chat_id, expr_type, situation, style in updates_by_key:
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == chat_id) &
|
(Expression.chat_id == chat_id)
|
||||||
(Expression.type == expr_type) &
|
& (Expression.type == expr_type)
|
||||||
(Expression.situation == situation) &
|
& (Expression.situation == situation)
|
||||||
(Expression.style == style)
|
& (Expression.style == style)
|
||||||
)
|
)
|
||||||
if query.exists():
|
if query.exists():
|
||||||
expr_obj = query.get()
|
expr_obj = query.get()
|
||||||
@@ -207,38 +196,36 @@ class ExpressionSelector:
|
|||||||
chat_id: str,
|
chat_id: str,
|
||||||
chat_info: str,
|
chat_info: str,
|
||||||
max_num: int = 10,
|
max_num: int = 10,
|
||||||
min_num: int = 5,
|
|
||||||
target_message: Optional[str] = None,
|
target_message: Optional[str] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||||
# sourcery skip: inline-variable, list-comprehension
|
# sourcery skip: inline-variable, list-comprehension
|
||||||
"""使用LLM选择适合的表达方式"""
|
"""使用LLM选择适合的表达方式"""
|
||||||
|
|
||||||
|
# 检查是否允许在此聊天流中使用表达
|
||||||
|
if not self.can_use_expression_for_chat(chat_id):
|
||||||
|
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||||
|
return [], []
|
||||||
|
|
||||||
# 1. 获取35个随机表达方式(现在按权重抽取)
|
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||||
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 50, 0.5, 0.5)
|
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||||
|
|
||||||
|
if len(style_exprs) < 10:
|
||||||
|
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||||
|
return [], []
|
||||||
|
|
||||||
# 2. 构建所有表达方式的索引和情境列表
|
# 2. 构建所有表达方式的索引和情境列表
|
||||||
all_expressions = []
|
all_expressions: List[Dict[str, Any]] = []
|
||||||
all_situations = []
|
all_situations: List[str] = []
|
||||||
|
|
||||||
# 添加style表达方式
|
# 添加style表达方式
|
||||||
for expr in style_exprs:
|
for expr in style_exprs:
|
||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
expr = expr.copy()
|
||||||
expr_with_type = expr.copy()
|
all_expressions.append(expr)
|
||||||
expr_with_type["type"] = "style"
|
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||||
all_expressions.append(expr_with_type)
|
|
||||||
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
|
|
||||||
|
|
||||||
# 添加grammar表达方式
|
|
||||||
for expr in grammar_exprs:
|
|
||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
|
||||||
expr_with_type = expr.copy()
|
|
||||||
expr_with_type["type"] = "grammar"
|
|
||||||
all_expressions.append(expr_with_type)
|
|
||||||
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
|
|
||||||
|
|
||||||
if not all_expressions:
|
if not all_expressions:
|
||||||
logger.warning("没有找到可用的表达方式")
|
logger.warning("没有找到可用的表达方式")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
all_situations_str = "\n".join(all_situations)
|
all_situations_str = "\n".join(all_situations)
|
||||||
|
|
||||||
@@ -254,23 +241,28 @@ class ExpressionSelector:
|
|||||||
bot_name=global_config.bot.nickname,
|
bot_name=global_config.bot.nickname,
|
||||||
chat_observe_info=chat_info,
|
chat_observe_info=chat_info,
|
||||||
all_situations=all_situations_str,
|
all_situations=all_situations_str,
|
||||||
min_num=min_num,
|
|
||||||
max_num=max_num,
|
max_num=max_num,
|
||||||
target_message=target_message_str,
|
target_message=target_message_str,
|
||||||
target_message_extra_block=target_message_extra_block,
|
target_message_extra_block=target_message_extra_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(prompt)
|
|
||||||
|
|
||||||
# 4. 调用LLM
|
# 4. 调用LLM
|
||||||
try:
|
try:
|
||||||
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
|
||||||
|
# start_time = time.time()
|
||||||
|
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||||
|
|
||||||
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")
|
# logger.info(f"模型名称: {model_name}")
|
||||||
|
# logger.info(f"LLM返回结果: {content}")
|
||||||
|
# if reasoning_content:
|
||||||
|
# logger.info(f"LLM推理: {reasoning_content}")
|
||||||
|
# else:
|
||||||
|
# logger.info(f"LLM推理: 无")
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning("LLM返回空结果")
|
logger.warning("LLM返回空结果")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
# 5. 解析结果
|
# 5. 解析结果
|
||||||
result = repair_json(content)
|
result = repair_json(content)
|
||||||
@@ -280,15 +272,17 @@ class ExpressionSelector:
|
|||||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||||
logger.error("LLM返回格式错误")
|
logger.error("LLM返回格式错误")
|
||||||
logger.info(f"LLM返回结果: \n{content}")
|
logger.info(f"LLM返回结果: \n{content}")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
selected_indices = result["selected_situations"]
|
selected_indices = result["selected_situations"]
|
||||||
|
|
||||||
# 根据索引获取完整的表达方式
|
# 根据索引获取完整的表达方式
|
||||||
valid_expressions = []
|
valid_expressions: List[Dict[str, Any]] = []
|
||||||
|
selected_ids = []
|
||||||
for idx in selected_indices:
|
for idx in selected_indices:
|
||||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||||
expression = all_expressions[idx - 1] # 索引从1开始
|
expression = all_expressions[idx - 1] # 索引从1开始
|
||||||
|
selected_ids.append(expression["id"])
|
||||||
valid_expressions.append(expression)
|
valid_expressions.append(expression)
|
||||||
|
|
||||||
# 对选中的所有表达方式,一次性更新count数
|
# 对选中的所有表达方式,一次性更新count数
|
||||||
@@ -296,11 +290,12 @@ class ExpressionSelector:
|
|||||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||||
|
|
||||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||||
return valid_expressions
|
return valid_expressions, selected_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
143
src/chat/frequency_control/focus_value_control.py
Normal file
143
src/chat/frequency_control/focus_value_control.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||||
|
|
||||||
|
|
||||||
|
class FocusValueControl:
|
||||||
|
def __init__(self,chat_id:str):
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.focus_value_adjust = 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_focus_value(self) -> float:
|
||||||
|
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
|
||||||
|
|
||||||
|
|
||||||
|
class FocusValueControlManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.focus_value_controls = {}
|
||||||
|
|
||||||
|
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
|
||||||
|
if chat_id not in self.focus_value_controls:
|
||||||
|
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
|
||||||
|
return self.focus_value_controls[chat_id]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
||||||
|
"""
|
||||||
|
根据当前时间和聊天流获取对应的 focus_value
|
||||||
|
"""
|
||||||
|
if not global_config.chat.focus_value_adjust:
|
||||||
|
return global_config.chat.focus_value
|
||||||
|
|
||||||
|
if chat_id:
|
||||||
|
stream_focus_value = get_stream_specific_focus_value(chat_id)
|
||||||
|
if stream_focus_value is not None:
|
||||||
|
return stream_focus_value
|
||||||
|
|
||||||
|
global_focus_value = get_global_focus_value()
|
||||||
|
if global_focus_value is not None:
|
||||||
|
return global_focus_value
|
||||||
|
|
||||||
|
return global_config.chat.focus_value
|
||||||
|
|
||||||
|
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取特定聊天流在当前时间的专注度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream_id: 聊天流ID(哈希值)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 专注度值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
# 查找匹配的聊天流配置
|
||||||
|
for config_item in global_config.chat.focus_value_adjust:
|
||||||
|
if not config_item or len(config_item) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||||
|
|
||||||
|
# 解析配置字符串并生成对应的 chat_id
|
||||||
|
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
|
||||||
|
if config_chat_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 比较生成的 chat_id
|
||||||
|
if config_chat_id != chat_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 使用通用的时间专注度解析方法
|
||||||
|
return get_time_based_focus_value(config_item[1:])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_based_focus_value(time_focus_list: list[str]) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
根据时间配置列表获取当前时段的专注度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 专注度值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
current_time = datetime.now().strftime("%H:%M")
|
||||||
|
current_hour, current_minute = map(int, current_time.split(":"))
|
||||||
|
current_minutes = current_hour * 60 + current_minute
|
||||||
|
|
||||||
|
# 解析时间专注度配置
|
||||||
|
time_focus_pairs = []
|
||||||
|
for time_focus_str in time_focus_list:
|
||||||
|
try:
|
||||||
|
time_str, focus_str = time_focus_str.split(",")
|
||||||
|
hour, minute = map(int, time_str.split(":"))
|
||||||
|
focus_value = float(focus_str)
|
||||||
|
minutes = hour * 60 + minute
|
||||||
|
time_focus_pairs.append((minutes, focus_value))
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not time_focus_pairs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
time_focus_pairs.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 查找当前时间对应的专注度
|
||||||
|
current_focus_value = None
|
||||||
|
for minutes, focus_value in time_focus_pairs:
|
||||||
|
if current_minutes >= minutes:
|
||||||
|
current_focus_value = focus_value
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
|
||||||
|
if current_focus_value is None and time_focus_pairs:
|
||||||
|
current_focus_value = time_focus_pairs[-1][1]
|
||||||
|
|
||||||
|
return current_focus_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_focus_value() -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取全局默认专注度配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 专注度值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
for config_item in global_config.chat.focus_value_adjust:
|
||||||
|
if not config_item or len(config_item) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||||
|
if config_item[0] == "":
|
||||||
|
return get_time_based_focus_value(config_item[1:])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
focus_value_control = FocusValueControlManager()
|
||||||
144
src/chat/frequency_control/talk_frequency_control.py
Normal file
144
src/chat/frequency_control/talk_frequency_control.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||||
|
|
||||||
|
class TalkFrequencyControl:
|
||||||
|
def __init__(self,chat_id:str):
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.talk_frequency_adjust = 1
|
||||||
|
|
||||||
|
def get_current_talk_frequency(self) -> float:
|
||||||
|
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
|
||||||
|
|
||||||
|
|
||||||
|
class TalkFrequencyControlManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.talk_frequency_controls = {}
|
||||||
|
|
||||||
|
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
|
||||||
|
if chat_id not in self.talk_frequency_controls:
|
||||||
|
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
|
||||||
|
return self.talk_frequency_controls[chat_id]
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
|
||||||
|
"""
|
||||||
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream_id: 聊天流ID,格式为 "platform:chat_id:type"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 对应的频率值
|
||||||
|
"""
|
||||||
|
if not global_config.chat.talk_frequency_adjust:
|
||||||
|
return global_config.chat.talk_frequency
|
||||||
|
|
||||||
|
# 优先检查聊天流特定的配置
|
||||||
|
if chat_id:
|
||||||
|
stream_frequency = get_stream_specific_frequency(chat_id)
|
||||||
|
if stream_frequency is not None:
|
||||||
|
return stream_frequency
|
||||||
|
|
||||||
|
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||||
|
global_frequency = get_global_frequency()
|
||||||
|
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
||||||
|
|
||||||
|
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
根据时间配置列表获取当前时段的频率
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 频率值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
current_time = datetime.now().strftime("%H:%M")
|
||||||
|
current_hour, current_minute = map(int, current_time.split(":"))
|
||||||
|
current_minutes = current_hour * 60 + current_minute
|
||||||
|
|
||||||
|
# 解析时间频率配置
|
||||||
|
time_freq_pairs = []
|
||||||
|
for time_freq_str in time_freq_list:
|
||||||
|
try:
|
||||||
|
time_str, freq_str = time_freq_str.split(",")
|
||||||
|
hour, minute = map(int, time_str.split(":"))
|
||||||
|
frequency = float(freq_str)
|
||||||
|
minutes = hour * 60 + minute
|
||||||
|
time_freq_pairs.append((minutes, frequency))
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not time_freq_pairs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
time_freq_pairs.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 查找当前时间对应的频率
|
||||||
|
current_frequency = None
|
||||||
|
for minutes, frequency in time_freq_pairs:
|
||||||
|
if current_minutes >= minutes:
|
||||||
|
current_frequency = frequency
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
|
||||||
|
if current_frequency is None and time_freq_pairs:
|
||||||
|
current_frequency = time_freq_pairs[-1][1]
|
||||||
|
|
||||||
|
return current_frequency
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_specific_frequency(chat_stream_id: str):
|
||||||
|
"""
|
||||||
|
获取特定聊天流在当前时间的频率
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream_id: 聊天流ID(哈希值)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 频率值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
# 查找匹配的聊天流配置
|
||||||
|
for config_item in global_config.chat.talk_frequency_adjust:
|
||||||
|
if not config_item or len(config_item) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||||
|
|
||||||
|
# 解析配置字符串并生成对应的 chat_id
|
||||||
|
config_chat_id = parse_stream_config_to_chat_id(stream_config_str)
|
||||||
|
if config_chat_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 比较生成的 chat_id
|
||||||
|
if config_chat_id != chat_stream_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 使用通用的时间频率解析方法
|
||||||
|
return get_time_based_frequency(config_item[1:])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_global_frequency() -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取全局默认频率配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 频率值,如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
for config_item in global_config.chat.talk_frequency_adjust:
|
||||||
|
if not config_item or len(config_item) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||||
|
if config_item[0] == "":
|
||||||
|
return get_time_based_frequency(config_item[1:])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
talk_frequency_control = TalkFrequencyControlManager()
|
||||||
37
src/chat/frequency_control/utils.py
Normal file
37
src/chat/frequency_control/utils.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
解析流配置字符串并生成对应的 chat_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_config_str: 格式为 "platform:id:type" 的字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 生成的 chat_id,如果解析失败则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parts = stream_config_str.split(":")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
platform = parts[0]
|
||||||
|
id_str = parts[1]
|
||||||
|
stream_type = parts[2]
|
||||||
|
|
||||||
|
# 判断是否为群聊
|
||||||
|
is_group = stream_type == "group"
|
||||||
|
|
||||||
|
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||||
|
|
||||||
|
if is_group:
|
||||||
|
components = [platform, str(id_str)]
|
||||||
|
else:
|
||||||
|
components = [platform, str(id_str), "private"]
|
||||||
|
key = "_".join(components)
|
||||||
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return None
|
||||||
@@ -3,7 +3,6 @@ from typing import Any, Optional, Dict
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
|
|
||||||
@@ -27,8 +26,6 @@ class Heartflow:
|
|||||||
|
|
||||||
# 注册子心流
|
# 注册子心流
|
||||||
self.subheartflows[subheartflow_id] = new_subflow
|
self.subheartflows[subheartflow_id] = new_subflow
|
||||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
|
||||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
|
||||||
|
|
||||||
return new_subflow
|
return new_subflow
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -14,53 +14,35 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
|
|||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
from src.person_info.person_info import Person
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]:
|
||||||
async def _process_relationship(message: MessageRecv) -> None:
|
|
||||||
"""处理用户关系逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 消息对象,包含用户信息
|
|
||||||
"""
|
|
||||||
platform = message.message_info.platform
|
|
||||||
user_id = message.message_info.user_info.user_id # type: ignore
|
|
||||||
nickname = message.message_info.user_info.user_nickname # type: ignore
|
|
||||||
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
|
||||||
|
|
||||||
relationship_manager = get_relationship_manager()
|
|
||||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
|
||||||
|
|
||||||
if not is_known:
|
|
||||||
logger.info(f"首次认识用户: {nickname}")
|
|
||||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
|
||||||
"""计算消息的兴趣度
|
"""计算消息的兴趣度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 待处理的消息对象
|
message: 待处理的消息对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
|
||||||
"""
|
"""
|
||||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||||
interested_rate = 0.0
|
interested_rate = 0.0
|
||||||
|
|
||||||
with Timer("记忆激活"):
|
with Timer("记忆激活"):
|
||||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||||
message.processed_plain_text,
|
message.processed_plain_text,
|
||||||
max_depth= 5,
|
max_depth= 4,
|
||||||
fast_retrieval=False,
|
fast_retrieval=False,
|
||||||
)
|
)
|
||||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
message.key_words = keywords
|
||||||
|
message.key_words_lite = keywords_lite
|
||||||
|
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
|
||||||
|
|
||||||
text_len = len(message.processed_plain_text)
|
text_len = len(message.processed_plain_text)
|
||||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||||
@@ -99,7 +81,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
|||||||
interest_increase_on_mention = 1
|
interest_increase_on_mention = 1
|
||||||
interested_rate += interest_increase_on_mention
|
interested_rate += interest_increase_on_mention
|
||||||
|
|
||||||
return interested_rate, is_mentioned
|
return interested_rate, is_mentioned, keywords
|
||||||
|
|
||||||
|
|
||||||
class HeartFCMessageReceiver:
|
class HeartFCMessageReceiver:
|
||||||
@@ -128,7 +110,7 @@ class HeartFCMessageReceiver:
|
|||||||
chat = message.chat_stream
|
chat = message.chat_stream
|
||||||
|
|
||||||
# 2. 兴趣度计算与更新
|
# 2. 兴趣度计算与更新
|
||||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
interested_rate, is_mentioned, keywords = await _calculate_interest(message)
|
||||||
message.interest_value = interested_rate
|
message.interest_value = interested_rate
|
||||||
message.is_mentioned = is_mentioned
|
message.is_mentioned = is_mentioned
|
||||||
|
|
||||||
@@ -143,8 +125,6 @@ class HeartFCMessageReceiver:
|
|||||||
|
|
||||||
# 3. 日志记录
|
# 3. 日志记录
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
|
||||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
|
||||||
|
|
||||||
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
|
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
|
||||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||||
@@ -157,13 +137,12 @@ class HeartFCMessageReceiver:
|
|||||||
replace_bot_name=True
|
replace_bot_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore
|
if keywords:
|
||||||
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]") # type: ignore
|
||||||
|
else:
|
||||||
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore
|
||||||
|
|
||||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
|
||||||
|
|
||||||
# 4. 关系处理
|
|
||||||
if global_config.relationship.enable_relationship:
|
|
||||||
await _process_relationship(message)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"消息处理失败: {e}")
|
logger.error(f"消息处理失败: {e}")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,8 +12,6 @@ import pandas as pd
|
|||||||
# import tqdm
|
# import tqdm
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
# from .llm_client import LLMClient
|
|
||||||
# from .lpmmconfig import global_config
|
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -26,12 +25,20 @@ from rich.progress import (
|
|||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
# 多线程embedding配置常量
|
||||||
|
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||||
|
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||||
|
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||||
|
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||||
|
MIN_WORKERS = 1 # 最小线程数
|
||||||
|
MAX_WORKERS = 20 # 最大线程数
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||||
@@ -87,13 +94,23 @@ class EmbeddingStoreItem:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingStore:
|
class EmbeddingStore:
|
||||||
def __init__(self, namespace: str, dir_path: str):
|
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.dir = dir_path
|
self.dir = dir_path
|
||||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||||
|
|
||||||
|
# 多线程配置参数验证和设置
|
||||||
|
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||||
|
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||||
|
|
||||||
|
# 如果配置值被调整,记录日志
|
||||||
|
if self.max_workers != max_workers:
|
||||||
|
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||||
|
if self.chunk_size != chunk_size:
|
||||||
|
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||||
|
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
@@ -125,16 +142,134 @@ class EmbeddingStore:
|
|||||||
return []
|
return []
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||||
|
"""使用多线程批量获取嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strs: 要获取嵌入的字符串列表
|
||||||
|
chunk_size: 每个线程处理的数据块大小
|
||||||
|
max_workers: 最大线程数
|
||||||
|
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||||
|
"""
|
||||||
|
if not strs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 分块
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(strs), chunk_size):
|
||||||
|
chunk = strs[i:i + chunk_size]
|
||||||
|
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||||
|
|
||||||
|
# 结果存储,使用字典按索引存储以保证顺序
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def process_chunk(chunk_data):
|
||||||
|
"""处理单个数据块的函数"""
|
||||||
|
start_idx, chunk_strs = chunk_data
|
||||||
|
chunk_results = []
|
||||||
|
|
||||||
|
# 为每个线程创建独立的LLMRequest实例
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建线程专用的LLM实例
|
||||||
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||||
|
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
try:
|
||||||
|
# 直接使用异步函数
|
||||||
|
embedding = asyncio.run(llm.get_embedding(s))
|
||||||
|
if embedding and len(embedding) > 0:
|
||||||
|
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||||
|
else:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
|
||||||
|
# 每完成一个嵌入立即更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
|
||||||
|
# 即使失败也要更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建LLM实例失败: {e}")
|
||||||
|
# 如果创建LLM实例失败,返回空结果
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
# 即使失败也要更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
return chunk_results
|
||||||
|
|
||||||
|
# 使用线程池处理
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||||
|
|
||||||
|
# 收集结果(进度已在process_chunk中实时更新)
|
||||||
|
for future in as_completed(future_to_chunk):
|
||||||
|
try:
|
||||||
|
chunk_results = future.result()
|
||||||
|
for idx, s, embedding in chunk_results:
|
||||||
|
results[idx] = (s, embedding)
|
||||||
|
except Exception as e:
|
||||||
|
chunk = future_to_chunk[future]
|
||||||
|
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
|
||||||
|
# 为失败的块添加空结果
|
||||||
|
start_idx, chunk_strs = chunk
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
results[start_idx + i] = (s, [])
|
||||||
|
|
||||||
|
# 按原始顺序返回结果
|
||||||
|
ordered_results = []
|
||||||
|
for i in range(len(strs)):
|
||||||
|
if i in results:
|
||||||
|
ordered_results.append(results[i])
|
||||||
|
else:
|
||||||
|
# 防止遗漏
|
||||||
|
ordered_results.append((strs[i], []))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|
||||||
def save_embedding_test_vectors(self):
|
def save_embedding_test_vectors(self):
|
||||||
"""保存测试字符串的嵌入到本地"""
|
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||||
|
logger.info("开始保存测试字符串的嵌入向量...")
|
||||||
|
|
||||||
|
# 使用多线程批量获取测试字符串的嵌入
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建测试向量字典
|
||||||
test_vectors = {}
|
test_vectors = {}
|
||||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
for idx, (s, embedding) in enumerate(embedding_results):
|
||||||
test_vectors[str(idx)] = self._get_embedding(s)
|
if embedding:
|
||||||
|
test_vectors[str(idx)] = embedding
|
||||||
|
else:
|
||||||
|
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||||
|
# 使用原始单线程方法作为后备
|
||||||
|
test_vectors[str(idx)] = self._get_embedding(s)
|
||||||
|
|
||||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info("测试字符串嵌入向量保存完成")
|
||||||
|
|
||||||
def load_embedding_test_vectors(self):
|
def load_embedding_test_vectors(self):
|
||||||
"""加载本地保存的测试字符串嵌入"""
|
"""加载本地保存的测试字符串嵌入"""
|
||||||
@@ -145,29 +280,64 @@ class EmbeddingStore:
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def check_embedding_model_consistency(self):
|
def check_embedding_model_consistency(self):
|
||||||
"""校验当前模型与本地嵌入模型是否一致"""
|
"""校验当前模型与本地嵌入模型是否一致(使用多线程优化)"""
|
||||||
local_vectors = self.load_embedding_test_vectors()
|
local_vectors = self.load_embedding_test_vectors()
|
||||||
if local_vectors is None:
|
if local_vectors is None:
|
||||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||||
self.save_embedding_test_vectors()
|
self.save_embedding_test_vectors()
|
||||||
return True
|
return True
|
||||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
|
||||||
local_emb = local_vectors.get(str(idx))
|
# 检查本地向量完整性
|
||||||
if local_emb is None:
|
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||||
|
if local_vectors.get(str(idx)) is None:
|
||||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||||
self.save_embedding_test_vectors()
|
self.save_embedding_test_vectors()
|
||||||
return True
|
return True
|
||||||
new_emb = self._get_embedding(s)
|
|
||||||
|
logger.info("开始检验嵌入模型一致性...")
|
||||||
|
|
||||||
|
# 使用多线程批量获取当前模型的嵌入
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查一致性
|
||||||
|
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||||
|
local_emb = local_vectors.get(str(idx))
|
||||||
|
if not new_emb:
|
||||||
|
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||||
|
return False
|
||||||
|
|
||||||
sim = cosine_similarity(local_emb, new_emb)
|
sim = cosine_similarity(local_emb, new_emb)
|
||||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||||
logger.error("嵌入模型一致性校验失败")
|
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info("嵌入模型一致性校验通过。")
|
logger.info("嵌入模型一致性校验通过。")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||||
"""向库中存入字符串"""
|
"""向库中存入字符串(使用多线程优化)"""
|
||||||
|
if not strs:
|
||||||
|
return
|
||||||
|
|
||||||
total = len(strs)
|
total = len(strs)
|
||||||
|
|
||||||
|
# 过滤已存在的字符串
|
||||||
|
new_strs = []
|
||||||
|
for s in strs:
|
||||||
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
|
if item_hash not in self.store:
|
||||||
|
new_strs.append(s)
|
||||||
|
|
||||||
|
if not new_strs:
|
||||||
|
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -181,19 +351,38 @@ class EmbeddingStore:
|
|||||||
transient=False,
|
transient=False,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||||
for s in strs:
|
|
||||||
# 计算hash去重
|
# 首先更新已存在项的进度
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
already_processed = total - len(new_strs)
|
||||||
if item_hash in self.store:
|
if already_processed > 0:
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=already_processed)
|
||||||
continue
|
|
||||||
|
if new_strs:
|
||||||
# 获取embedding
|
# 使用实例配置的参数,智能调整分块和线程数
|
||||||
embedding = self._get_embedding(s)
|
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||||
|
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||||
# 存入
|
|
||||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||||
progress.update(task, advance=1)
|
|
||||||
|
# 定义进度更新回调函数
|
||||||
|
def update_progress(count):
|
||||||
|
progress.update(task, advance=count)
|
||||||
|
|
||||||
|
# 批量获取嵌入,并实时更新进度
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
new_strs,
|
||||||
|
chunk_size=optimal_chunk_size,
|
||||||
|
max_workers=optimal_max_workers,
|
||||||
|
progress_callback=update_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||||
|
for s, embedding in embedding_results:
|
||||||
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
|
if embedding: # 只有成功获取到嵌入才存入
|
||||||
|
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||||
|
else:
|
||||||
|
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||||
|
|
||||||
def save_to_file(self) -> None:
|
def save_to_file(self) -> None:
|
||||||
"""保存到文件"""
|
"""保存到文件"""
|
||||||
@@ -316,31 +505,37 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self):
|
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||||
|
"""
|
||||||
|
初始化EmbeddingManager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: 最大线程数
|
||||||
|
chunk_size: 每个线程处理的数据块大小
|
||||||
|
"""
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"paragraph", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
self.entities_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"entity", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
self.relation_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"relation", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|
||||||
def check_all_embedding_model_consistency(self):
|
def check_all_embedding_model_consistency(self):
|
||||||
"""对所有嵌入库做模型一致性校验"""
|
"""对所有嵌入库做模型一致性校验"""
|
||||||
for store in [
|
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||||
self.paragraphs_embedding_store,
|
|
||||||
self.entities_embedding_store,
|
|
||||||
self.relation_embedding_store,
|
|
||||||
]:
|
|
||||||
if not store.check_embedding_model_consistency():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||||
"""将段落编码存入Embedding库"""
|
"""将段落编码存入Embedding库"""
|
||||||
|
|||||||
@@ -8,12 +8,15 @@ from . import prompt_template
|
|||||||
from .knowledge_lib import INVALID_ENTITY
|
from .knowledge_lib import INVALID_ENTITY
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
|
||||||
def _extract_json_from_text(text: str):
|
def _extract_json_from_text(text: str):
|
||||||
|
# sourcery skip: assign-if-exp, extract-method
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
if text is None:
|
if text is None:
|
||||||
logger.error("输入文本为None")
|
logger.error("输入文本为None")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fixed_json = repair_json(text)
|
fixed_json = repair_json(text)
|
||||||
if isinstance(fixed_json, str):
|
if isinstance(fixed_json, str):
|
||||||
@@ -24,7 +27,7 @@ def _extract_json_from_text(text: str):
|
|||||||
# 如果是列表,直接返回
|
# 如果是列表,直接返回
|
||||||
if isinstance(parsed_json, list):
|
if isinstance(parsed_json, list):
|
||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
# 如果是字典且只有一个项目,可能包装了列表
|
# 如果是字典且只有一个项目,可能包装了列表
|
||||||
if isinstance(parsed_json, dict):
|
if isinstance(parsed_json, dict):
|
||||||
# 如果字典只有一个键,并且值是列表,返回那个列表
|
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||||
@@ -33,7 +36,7 @@ def _extract_json_from_text(text: str):
|
|||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return value
|
return value
|
||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
# 其他情况,尝试转换为列表
|
# 其他情况,尝试转换为列表
|
||||||
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
||||||
return []
|
return []
|
||||||
@@ -42,44 +45,40 @@ def _extract_json_from_text(text: str):
|
|||||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
|
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
|
|
||||||
# 使用 asyncio.run 来运行异步方法
|
# 使用 asyncio.run 来运行异步方法
|
||||||
try:
|
try:
|
||||||
# 如果当前已有事件循环在运行,使用它
|
# 如果当前已有事件循环在运行,使用它
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
|
||||||
llm_req.generate_response_async(entity_extract_context), loop
|
response, _ = future.result()
|
||||||
)
|
|
||||||
response, (reasoning_content, model_name) = future.result()
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
response, (reasoning_content, model_name) = asyncio.run(
|
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
|
||||||
llm_req.generate_response_async(entity_extract_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加调试日志
|
# 添加调试日志
|
||||||
logger.debug(f"LLM返回的原始响应: {response}")
|
logger.debug(f"LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
|
|
||||||
# 检查返回的是否为有效的实体列表
|
# 检查返回的是否为有效的实体列表
|
||||||
if not isinstance(entity_extract_result, list):
|
if not isinstance(entity_extract_result, list):
|
||||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
if not isinstance(entity_extract_result, dict):
|
||||||
if isinstance(entity_extract_result, dict):
|
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
# 尝试常见的键名
|
|
||||||
for key in ['entities', 'result', 'data', 'items']:
|
# 尝试常见的键名
|
||||||
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
for key in ["entities", "result", "data", "items"]:
|
||||||
entity_extract_result = entity_extract_result[key]
|
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||||
break
|
entity_extract_result = entity_extract_result[key]
|
||||||
else:
|
break
|
||||||
# 如果找不到合适的列表,抛出异常
|
|
||||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
# 过滤无效实体
|
# 过滤无效实体
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
@@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
|||||||
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
|
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(entity_extract_result) == 0:
|
if not entity_extract_result:
|
||||||
raise Exception("实体提取结果为空")
|
raise ValueError("实体提取结果为空")
|
||||||
|
|
||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
||||||
@@ -98,45 +97,44 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 asyncio.run 来运行异步方法
|
# 使用 asyncio.run 来运行异步方法
|
||||||
try:
|
try:
|
||||||
# 如果当前已有事件循环在运行,使用它
|
# 如果当前已有事件循环在运行,使用它
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
|
||||||
llm_req.generate_response_async(rdf_extract_context), loop
|
response, _ = future.result()
|
||||||
)
|
|
||||||
response, (reasoning_content, model_name) = future.result()
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
response, (reasoning_content, model_name) = asyncio.run(
|
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
|
||||||
llm_req.generate_response_async(rdf_extract_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加调试日志
|
# 添加调试日志
|
||||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
rdf_triple_result = _extract_json_from_text(response)
|
rdf_triple_result = _extract_json_from_text(response)
|
||||||
|
|
||||||
# 检查返回的是否为有效的三元组列表
|
# 检查返回的是否为有效的三元组列表
|
||||||
if not isinstance(rdf_triple_result, list):
|
if not isinstance(rdf_triple_result, list):
|
||||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
if not isinstance(rdf_triple_result, dict):
|
||||||
if isinstance(rdf_triple_result, dict):
|
raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
# 尝试常见的键名
|
|
||||||
for key in ['triples', 'result', 'data', 'items']:
|
# 尝试常见的键名
|
||||||
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
for key in ["triples", "result", "data", "items"]:
|
||||||
rdf_triple_result = rdf_triple_result[key]
|
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||||
break
|
rdf_triple_result = rdf_triple_result[key]
|
||||||
else:
|
break
|
||||||
# 如果找不到合适的列表,抛出异常
|
|
||||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
# 验证三元组格式
|
# 验证三元组格式
|
||||||
for triple in rdf_triple_result:
|
for triple in rdf_triple_result:
|
||||||
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
if (
|
||||||
raise Exception("RDF提取结果格式错误")
|
not isinstance(triple, list)
|
||||||
|
or len(triple) != 3
|
||||||
|
or (triple[0] is None or triple[1] is None or triple[2] is None)
|
||||||
|
or "" in triple
|
||||||
|
):
|
||||||
|
raise ValueError("RDF提取结果格式错误")
|
||||||
|
|
||||||
return rdf_triple_result
|
return rdf_triple_result
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ from quick_algo import di_graph, pagerank
|
|||||||
|
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||||
from .lpmmconfig import global_config
|
from src.config.config import global_config
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
@@ -30,19 +29,9 @@ def _get_kg_dir():
|
|||||||
"""
|
"""
|
||||||
安全地获取KG数据目录路径
|
安全地获取KG数据目录路径
|
||||||
"""
|
"""
|
||||||
root_path: str = local_storage["root_path"]
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
if root_path is None:
|
root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
kg_dir = os.path.join(root_path, "data/rag")
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
|
||||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
|
||||||
|
|
||||||
# 获取RAG数据目录
|
|
||||||
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
|
|
||||||
if rag_data_dir is None:
|
|
||||||
kg_dir = os.path.join(root_path, "data/rag")
|
|
||||||
else:
|
|
||||||
kg_dir = os.path.join(root_path, rag_data_dir)
|
|
||||||
|
|
||||||
return str(kg_dir).replace("\\", "/")
|
return str(kg_dir).replace("\\", "/")
|
||||||
|
|
||||||
@@ -65,9 +54,9 @@ class KGManager:
|
|||||||
|
|
||||||
# 持久化相关 - 使用延迟初始化的路径
|
# 持久化相关 - 使用延迟初始化的路径
|
||||||
self.dir_path = get_kg_dir_str()
|
self.dir_path = get_kg_dir_str()
|
||||||
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
|
self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
|
||||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
|
self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
|
||||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
|
self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
"""将KG数据保存到文件"""
|
"""将KG数据保存到文件"""
|
||||||
@@ -122,8 +111,8 @@ class KGManager:
|
|||||||
# 避免自连接
|
# 避免自连接
|
||||||
continue
|
continue
|
||||||
# 一个triple就是一条边(同时构建双向联系)
|
# 一个triple就是一条边(同时构建双向联系)
|
||||||
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
hash_key1 = "entity" + "-" + get_sha256(triple[0])
|
||||||
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
|
hash_key2 = "entity" + "-" + get_sha256(triple[2])
|
||||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||||
entity_set.add(hash_key1)
|
entity_set.add(hash_key1)
|
||||||
@@ -141,8 +130,8 @@ class KGManager:
|
|||||||
"""构建实体节点与文段节点之间的关系"""
|
"""构建实体节点与文段节点之间的关系"""
|
||||||
for idx in triple_list_data:
|
for idx in triple_list_data:
|
||||||
for triple in triple_list_data[idx]:
|
for triple in triple_list_data[idx]:
|
||||||
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
ent_hash_key = "entity" + "-" + get_sha256(triple[0])
|
||||||
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
|
pg_hash_key = "paragraph" + "-" + str(idx)
|
||||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -157,12 +146,12 @@ class KGManager:
|
|||||||
ent_hash_list = set()
|
ent_hash_list = set()
|
||||||
for triple_list in triple_list_data.values():
|
for triple_list in triple_list_data.values():
|
||||||
for triple in triple_list:
|
for triple in triple_list:
|
||||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
|
ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
|
||||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
|
ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
|
||||||
ent_hash_list = list(ent_hash_list)
|
ent_hash_list = list(ent_hash_list)
|
||||||
|
|
||||||
synonym_hash_set = set()
|
synonym_hash_set = set()
|
||||||
synonym_result = dict()
|
synonym_result = {}
|
||||||
|
|
||||||
# rich 进度条
|
# rich 进度条
|
||||||
total = len(ent_hash_list)
|
total = len(ent_hash_list)
|
||||||
@@ -190,14 +179,14 @@ class KGManager:
|
|||||||
assert isinstance(ent, EmbeddingStoreItem)
|
assert isinstance(ent, EmbeddingStoreItem)
|
||||||
# 查询相似实体
|
# 查询相似实体
|
||||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k
|
||||||
)
|
)
|
||||||
res_ent = [] # Debug
|
res_ent = [] # Debug
|
||||||
for res_ent_hash, similarity in similar_ents:
|
for res_ent_hash, similarity in similar_ents:
|
||||||
if res_ent_hash == ent_hash:
|
if res_ent_hash == ent_hash:
|
||||||
# 避免自连接
|
# 避免自连接
|
||||||
continue
|
continue
|
||||||
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
|
if similarity < global_config.lpmm_knowledge.rag_synonym_threshold:
|
||||||
# 相似度阈值
|
# 相似度阈值
|
||||||
continue
|
continue
|
||||||
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
||||||
@@ -263,7 +252,7 @@ class KGManager:
|
|||||||
for src_tgt in node_to_node.keys():
|
for src_tgt in node_to_node.keys():
|
||||||
for node_hash in src_tgt:
|
for node_hash in src_tgt:
|
||||||
if node_hash not in existed_nodes:
|
if node_hash not in existed_nodes:
|
||||||
if node_hash.startswith(local_storage["ent_namespace"]):
|
if node_hash.startswith("entity"):
|
||||||
# 新增实体节点
|
# 新增实体节点
|
||||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||||
if node is None:
|
if node is None:
|
||||||
@@ -275,7 +264,7 @@ class KGManager:
|
|||||||
node_item["type"] = "ent"
|
node_item["type"] = "ent"
|
||||||
node_item["create_time"] = now_time
|
node_item["create_time"] = now_time
|
||||||
self.graph.update_node(node_item)
|
self.graph.update_node(node_item)
|
||||||
elif node_hash.startswith(local_storage["pg_namespace"]):
|
elif node_hash.startswith("paragraph"):
|
||||||
# 新增文段节点
|
# 新增文段节点
|
||||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||||
if node is None:
|
if node is None:
|
||||||
@@ -359,7 +348,7 @@ class KGManager:
|
|||||||
# 关系三元组
|
# 关系三元组
|
||||||
triple = relation[2:-2].split("', '")
|
triple = relation[2:-2].split("', '")
|
||||||
for ent in [(triple[0]), (triple[2])]:
|
for ent in [(triple[0]), (triple[2])]:
|
||||||
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
|
ent_hash = "entity" + "-" + get_sha256(ent)
|
||||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||||
ent_sim_scores[ent_hash] = []
|
ent_sim_scores[ent_hash] = []
|
||||||
@@ -380,7 +369,7 @@ class KGManager:
|
|||||||
for ent_hash in ent_weights.keys():
|
for ent_hash in ent_weights.keys():
|
||||||
ent_weights[ent_hash] = 1.0
|
ent_weights[ent_hash] = 1.0
|
||||||
else:
|
else:
|
||||||
down_edge = global_config["qa"]["params"]["paragraph_node_weight"]
|
down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight
|
||||||
# 缩放取值区间至[down_edge, 1]
|
# 缩放取值区间至[down_edge, 1]
|
||||||
for ent_hash, score in ent_weights.items():
|
for ent_hash, score in ent_weights.items():
|
||||||
# 缩放相似度
|
# 缩放相似度
|
||||||
@@ -389,7 +378,7 @@ class KGManager:
|
|||||||
) + down_edge
|
) + down_edge
|
||||||
|
|
||||||
# 取平均相似度的top_k实体
|
# 取平均相似度的top_k实体
|
||||||
top_k = global_config["qa"]["params"]["ent_filter_top_k"]
|
top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
|
||||||
if len(ent_mean_scores) > top_k:
|
if len(ent_mean_scores) > top_k:
|
||||||
# 从大到小排序,取后len - k个
|
# 从大到小排序,取后len - k个
|
||||||
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
|
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
|
||||||
@@ -418,7 +407,7 @@ class KGManager:
|
|||||||
|
|
||||||
for pg_hash, score in pg_sim_scores.items():
|
for pg_hash, score in pg_sim_scores.items():
|
||||||
pg_weights[pg_hash] = (
|
pg_weights[pg_hash] = (
|
||||||
score * global_config["qa"]["params"]["paragraph_node_weight"]
|
score * global_config.lpmm_knowledge.qa_paragraph_node_weight
|
||||||
) # 文段权重 = 归一化相似度 * 文段节点权重参数
|
) # 文段权重 = 归一化相似度 * 文段节点权重参数
|
||||||
del pg_sim_scores
|
del pg_sim_scores
|
||||||
|
|
||||||
@@ -431,7 +420,7 @@ class KGManager:
|
|||||||
self.graph,
|
self.graph,
|
||||||
personalization=ppr_node_weights,
|
personalization=ppr_node_weights,
|
||||||
max_iter=100,
|
max_iter=100,
|
||||||
alpha=global_config["qa"]["params"]["ppr_damping"],
|
alpha=global_config.lpmm_knowledge.qa_ppr_damping,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取最终结果
|
# 获取最终结果
|
||||||
@@ -439,7 +428,7 @@ class KGManager:
|
|||||||
passage_node_res = [
|
passage_node_res = [
|
||||||
(node_key, score)
|
(node_key, score)
|
||||||
for node_key, score in ppr_res.items()
|
for node_key, score in ppr_res.items()
|
||||||
if node_key.startswith(local_storage["pg_namespace"])
|
if node_key.startswith("paragraph")
|
||||||
]
|
]
|
||||||
del ppr_res
|
del ppr_res
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
from src.chat.knowledge.lpmmconfig import global_config
|
|
||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||||
from src.chat.knowledge.llm_client import LLMClient
|
|
||||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
|
||||||
from src.chat.knowledge.qa_manager import QAManager
|
from src.chat.knowledge.qa_manager import QAManager
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
from src.chat.knowledge.global_logger import logger
|
from src.chat.knowledge.global_logger import logger
|
||||||
from src.config.config import global_config as bot_global_config
|
from src.config.config import global_config
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
INVALID_ENTITY = [
|
INVALID_ENTITY = [
|
||||||
@@ -21,9 +17,6 @@ INVALID_ENTITY = [
|
|||||||
"她们",
|
"她们",
|
||||||
"它们",
|
"它们",
|
||||||
]
|
]
|
||||||
PG_NAMESPACE = "paragraph"
|
|
||||||
ENT_NAMESPACE = "entity"
|
|
||||||
REL_NAMESPACE = "relation"
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||||
@@ -34,67 +27,13 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..",
|
|||||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
|
|
||||||
def _initialize_knowledge_local_storage():
|
|
||||||
"""
|
|
||||||
初始化知识库相关的本地存储配置
|
|
||||||
使用字典批量设置,避免重复的if判断
|
|
||||||
"""
|
|
||||||
# 定义所有需要初始化的配置项
|
|
||||||
default_configs = {
|
|
||||||
# 路径配置
|
|
||||||
"root_path": ROOT_PATH,
|
|
||||||
"data_path": f"{ROOT_PATH}/data",
|
|
||||||
# 实体和命名空间配置
|
|
||||||
"lpmm_invalid_entity": INVALID_ENTITY,
|
|
||||||
"pg_namespace": PG_NAMESPACE,
|
|
||||||
"ent_namespace": ENT_NAMESPACE,
|
|
||||||
"rel_namespace": REL_NAMESPACE,
|
|
||||||
# RAG相关命名空间配置
|
|
||||||
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
|
|
||||||
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
|
|
||||||
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 日志级别映射:重要配置用info,其他用debug
|
|
||||||
important_configs = {"root_path", "data_path"}
|
|
||||||
|
|
||||||
# 批量设置配置项
|
|
||||||
initialized_count = 0
|
|
||||||
for key, default_value in default_configs.items():
|
|
||||||
if local_storage[key] is None:
|
|
||||||
local_storage[key] = default_value
|
|
||||||
|
|
||||||
# 根据重要性选择日志级别
|
|
||||||
if key in important_configs:
|
|
||||||
logger.info(f"设置{key}: {default_value}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"设置{key}: {default_value}")
|
|
||||||
|
|
||||||
initialized_count += 1
|
|
||||||
|
|
||||||
if initialized_count > 0:
|
|
||||||
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
|
||||||
else:
|
|
||||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
|
||||||
|
|
||||||
|
|
||||||
# 初始化本地存储路径
|
|
||||||
# sourcery skip: dict-comprehension
|
|
||||||
_initialize_knowledge_local_storage()
|
|
||||||
|
|
||||||
qa_manager = None
|
qa_manager = None
|
||||||
inspire_manager = None
|
inspire_manager = None
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
# 检查LPMM知识库是否启用
|
||||||
if bot_global_config.lpmm_knowledge.enable:
|
if global_config.lpmm_knowledge.enable:
|
||||||
logger.info("正在初始化Mai-LPMM")
|
logger.info("正在初始化Mai-LPMM")
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
llm_client_list = {}
|
|
||||||
for key in global_config["llm_providers"]:
|
|
||||||
llm_client_list[key] = LLMClient(
|
|
||||||
global_config["llm_providers"][key]["base_url"], # type: ignore
|
|
||||||
global_config["llm_providers"][key]["api_key"], # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = EmbeddingManager()
|
embed_manager = EmbeddingManager()
|
||||||
@@ -120,7 +59,8 @@ if bot_global_config.lpmm_knowledge.enable:
|
|||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
# 使用与EmbeddingStore中一致的命名空间格式
|
||||||
|
key = f"paragraph-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
@@ -130,11 +70,11 @@ if bot_global_config.lpmm_knowledge.enable:
|
|||||||
kg_manager,
|
kg_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记忆激活(用于记忆库)
|
# # 记忆激活(用于记忆库)
|
||||||
inspire_manager = MemoryActiveManager(
|
# inspire_manager = MemoryActiveManager(
|
||||||
embed_manager,
|
# embed_manager,
|
||||||
llm_client_list[global_config["embedding"]["provider"]],
|
# llm_client_list[global_config["embedding"]["provider"]],
|
||||||
)
|
# )
|
||||||
else:
|
else:
|
||||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||||
# 创建空的占位符对象,避免导入错误
|
# 创建空的占位符对象,避免导入错误
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage:
|
|
||||||
def __init__(self, role, content):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return {"role": self.role, "content": self.content}
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
|
||||||
"""LLM客户端,对应一个API服务商"""
|
|
||||||
|
|
||||||
def __init__(self, url, api_key):
|
|
||||||
self.client = OpenAI(
|
|
||||||
base_url=url,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_chat_request(self, model, messages):
|
|
||||||
"""发送对话请求,等待返回结果"""
|
|
||||||
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content"):
|
|
||||||
# 有单独的推理内容块
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
else:
|
|
||||||
# 无单独的推理内容块
|
|
||||||
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
|
|
||||||
# 如果有推理内容,则分割推理内容和内容
|
|
||||||
if len(response) == 2:
|
|
||||||
reasoning_content = response[0]
|
|
||||||
content = response[1]
|
|
||||||
else:
|
|
||||||
reasoning_content = None
|
|
||||||
content = response[0]
|
|
||||||
|
|
||||||
return reasoning_content, content
|
|
||||||
|
|
||||||
def send_embedding_request(self, model, text):
|
|
||||||
"""发送嵌入请求,等待返回结果"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import os
|
|
||||||
import toml
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# import argparse
|
|
||||||
from .global_logger import logger
|
|
||||||
|
|
||||||
PG_NAMESPACE = "paragraph"
|
|
||||||
ENT_NAMESPACE = "entity"
|
|
||||||
REL_NAMESPACE = "relation"
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
|
||||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
|
||||||
|
|
||||||
# 无效实体
|
|
||||||
INVALID_ENTITY = [
|
|
||||||
"",
|
|
||||||
"你",
|
|
||||||
"他",
|
|
||||||
"她",
|
|
||||||
"它",
|
|
||||||
"我们",
|
|
||||||
"你们",
|
|
||||||
"他们",
|
|
||||||
"她们",
|
|
||||||
"它们",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _load_config(config, config_file_path):
|
|
||||||
"""读取TOML格式的配置文件"""
|
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
return
|
|
||||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
|
||||||
file_config = toml.load(f)
|
|
||||||
|
|
||||||
# Check if all top-level keys from default config exist in the file config
|
|
||||||
for key in config.keys():
|
|
||||||
if key not in file_config:
|
|
||||||
logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
|
|
||||||
logger.critical("请通过template/lpmm_config_template.toml文件进行更新")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if "llm_providers" in file_config:
|
|
||||||
for provider in file_config["llm_providers"]:
|
|
||||||
if provider["name"] not in config["llm_providers"]:
|
|
||||||
config["llm_providers"][provider["name"]] = {}
|
|
||||||
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
|
|
||||||
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
|
|
||||||
|
|
||||||
if "entity_extract" in file_config:
|
|
||||||
config["entity_extract"] = file_config["entity_extract"]
|
|
||||||
|
|
||||||
if "rdf_build" in file_config:
|
|
||||||
config["rdf_build"] = file_config["rdf_build"]
|
|
||||||
|
|
||||||
if "embedding" in file_config:
|
|
||||||
config["embedding"] = file_config["embedding"]
|
|
||||||
|
|
||||||
if "rag" in file_config:
|
|
||||||
config["rag"] = file_config["rag"]
|
|
||||||
|
|
||||||
if "qa" in file_config:
|
|
||||||
config["qa"] = file_config["qa"]
|
|
||||||
|
|
||||||
if "persistence" in file_config:
|
|
||||||
config["persistence"] = file_config["persistence"]
|
|
||||||
# print(config)
|
|
||||||
logger.info(f"从文件中读取配置: {config_file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
global_config = dict(
|
|
||||||
{
|
|
||||||
"lpmm": {
|
|
||||||
"version": "0.1.0",
|
|
||||||
},
|
|
||||||
"llm_providers": {
|
|
||||||
"localhost": {
|
|
||||||
"base_url": "https://api.siliconflow.cn/v1",
|
|
||||||
"api_key": "sk-ospynxadyorf",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"entity_extract": {
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"rdf_build": {
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"embedding": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "Pro/BAAI/bge-m3",
|
|
||||||
"dimension": 1024,
|
|
||||||
},
|
|
||||||
"rag": {
|
|
||||||
"params": {
|
|
||||||
"synonym_search_top_k": 10,
|
|
||||||
"synonym_threshold": 0.75,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"qa": {
|
|
||||||
"params": {
|
|
||||||
"relation_search_top_k": 10,
|
|
||||||
"relation_threshold": 0.75,
|
|
||||||
"paragraph_search_top_k": 10,
|
|
||||||
"paragraph_node_weight": 0.05,
|
|
||||||
"ent_filter_top_k": 10,
|
|
||||||
"ppr_damping": 0.8,
|
|
||||||
"res_top_k": 10,
|
|
||||||
},
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "qa",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"persistence": {
|
|
||||||
"data_root_path": "data",
|
|
||||||
"raw_data_path": "data/raw.json",
|
|
||||||
"openie_data_path": "data/openie.json",
|
|
||||||
"embedding_data_dir": "data/embedding",
|
|
||||||
"rag_data_dir": "data/rag",
|
|
||||||
},
|
|
||||||
"info_extraction": {
|
|
||||||
"workers": 10,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
|
|
||||||
_load_config(global_config, config_path)
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||||
from .lpmmconfig import global_config
|
from .lpmmconfig import global_config
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
|
|||||||
@@ -2,16 +2,14 @@ import time
|
|||||||
from typing import Tuple, List, Dict, Optional
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
# from . import prompt_template
|
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
# from .llm_client import LLMClient
|
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
|
|
||||||
# from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
from .utils.dyn_topk import dyn_select_top_k
|
from .utils.dyn_topk import dyn_select_top_k
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||||
|
|
||||||
@@ -21,17 +19,14 @@ class QAManager:
|
|||||||
self,
|
self,
|
||||||
embed_manager: EmbeddingManager,
|
embed_manager: EmbeddingManager,
|
||||||
kg_manager: KGManager,
|
kg_manager: KGManager,
|
||||||
|
|
||||||
):
|
):
|
||||||
self.embed_manager = embed_manager
|
self.embed_manager = embed_manager
|
||||||
self.kg_manager = kg_manager
|
self.kg_manager = kg_manager
|
||||||
# TODO: API-Adapter修改标记
|
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
||||||
self.qa_model = LLMRequest(
|
|
||||||
model=global_config.model.lpmm_qa,
|
|
||||||
request_type="lpmm.qa"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
async def process_query(
|
||||||
|
self, question: str
|
||||||
|
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
@@ -49,66 +44,71 @@ class QAManager:
|
|||||||
question_embedding,
|
question_embedding,
|
||||||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||||
)
|
)
|
||||||
if relation_search_res is not None:
|
if relation_search_res is None:
|
||||||
# 过滤阈值
|
return None
|
||||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
# 过滤阈值
|
||||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||||
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||||
# 未找到相关关系
|
if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||||
logger.debug("未找到相关关系,跳过关系检索")
|
# 未找到相关关系
|
||||||
relation_search_res = []
|
logger.debug("未找到相关关系,跳过关系检索")
|
||||||
|
relation_search_res = []
|
||||||
|
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
for res in relation_search_res:
|
for res in relation_search_res:
|
||||||
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
|
if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
|
||||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
rel_str = store_item.str
|
||||||
|
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||||
|
|
||||||
# TODO: 使用LLM过滤三元组结果
|
# TODO: 使用LLM过滤三元组结果
|
||||||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||||||
# part_start_time = time.time()
|
# part_start_time = time.time()
|
||||||
|
|
||||||
# 根据问题Embedding查询Paragraph Embedding库
|
# 根据问题Embedding查询Paragraph Embedding库
|
||||||
|
part_start_time = time.perf_counter()
|
||||||
|
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||||
|
question_embedding,
|
||||||
|
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||||
|
)
|
||||||
|
part_end_time = time.perf_counter()
|
||||||
|
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
|
if len(relation_search_res) != 0:
|
||||||
|
logger.info("找到相关关系,将使用RAG进行检索")
|
||||||
|
# 使用KG检索
|
||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
result, ppr_node_weights = self.kg_manager.kg_search(
|
||||||
question_embedding,
|
relation_search_res, paragraph_search_res, self.embed_manager
|
||||||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
|
||||||
)
|
)
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
if len(relation_search_res) != 0:
|
|
||||||
logger.info("找到相关关系,将使用RAG进行检索")
|
|
||||||
# 使用KG检索
|
|
||||||
part_start_time = time.perf_counter()
|
|
||||||
result, ppr_node_weights = self.kg_manager.kg_search(
|
|
||||||
relation_search_res, paragraph_search_res, self.embed_manager
|
|
||||||
)
|
|
||||||
part_end_time = time.perf_counter()
|
|
||||||
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
|
||||||
else:
|
|
||||||
logger.info("未找到相关关系,将使用文段检索结果")
|
|
||||||
result = paragraph_search_res
|
|
||||||
ppr_node_weights = None
|
|
||||||
|
|
||||||
# 过滤阈值
|
|
||||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
|
||||||
|
|
||||||
for res in result:
|
|
||||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
|
||||||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
|
||||||
|
|
||||||
return result, ppr_node_weights
|
|
||||||
else:
|
else:
|
||||||
return None
|
logger.info("未找到相关关系,将使用文段检索结果")
|
||||||
|
result = paragraph_search_res
|
||||||
|
ppr_node_weights = None
|
||||||
|
|
||||||
async def get_knowledge(self, question: str) -> str:
|
# 过滤阈值
|
||||||
|
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||||
|
|
||||||
|
for res in result:
|
||||||
|
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||||
|
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||||
|
|
||||||
|
return result, ppr_node_weights
|
||||||
|
|
||||||
|
async def get_knowledge(self, question: str) -> Optional[str]:
|
||||||
"""获取知识"""
|
"""获取知识"""
|
||||||
# 处理查询
|
# 处理查询
|
||||||
processed_result = await self.process_query(question)
|
processed_result = await self.process_query(question)
|
||||||
if processed_result is not None:
|
if processed_result is not None:
|
||||||
query_res = processed_result[0]
|
query_res = processed_result[0]
|
||||||
|
# 检查查询结果是否为空
|
||||||
|
if not query_res:
|
||||||
|
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
|
||||||
|
return None
|
||||||
|
|
||||||
knowledge = [
|
knowledge = [
|
||||||
(
|
(
|
||||||
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .global_logger import logger
|
|
||||||
from .lpmmconfig import global_config
|
|
||||||
from src.chat.knowledge.utils.hash import get_sha256
|
|
||||||
|
|
||||||
|
|
||||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
|
||||||
"""加载原始数据文件
|
|
||||||
|
|
||||||
读取原始数据文件,将原始数据加载到内存中
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 可选,指定要读取的json文件绝对路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- raw_data: 原始数据列表
|
|
||||||
- sha256_list: 原始数据的SHA256集合
|
|
||||||
"""
|
|
||||||
# 读取指定路径或默认路径的json文件
|
|
||||||
json_path = path if path else global_config["persistence"]["raw_data_path"]
|
|
||||||
if os.path.exists(json_path):
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
import_json = json.loads(f.read())
|
|
||||||
else:
|
|
||||||
raise Exception(f"原始数据文件读取失败: {json_path}")
|
|
||||||
"""
|
|
||||||
import_json 内容示例:
|
|
||||||
import_json = ["The capital of China is Beijing. The capital of France is Paris.",]
|
|
||||||
"""
|
|
||||||
raw_data = []
|
|
||||||
sha256_list = []
|
|
||||||
sha256_set = set()
|
|
||||||
for item in import_json:
|
|
||||||
if not isinstance(item, str):
|
|
||||||
logger.warning("数据类型错误:{}".format(item))
|
|
||||||
continue
|
|
||||||
pg_hash = get_sha256(item)
|
|
||||||
if pg_hash in sha256_set:
|
|
||||||
logger.warning("重复数据:{}".format(item))
|
|
||||||
continue
|
|
||||||
sha256_set.add(pg_hash)
|
|
||||||
sha256_list.append(pg_hash)
|
|
||||||
raw_data.append(item)
|
|
||||||
logger.info("共读取到{}条数据".format(len(raw_data)))
|
|
||||||
|
|
||||||
return sha256_list, raw_data
|
|
||||||
@@ -5,6 +5,10 @@ def dyn_select_top_k(
|
|||||||
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||||
) -> List[Tuple[Any, float, float]]:
|
) -> List[Tuple[Any, float, float]]:
|
||||||
"""动态TopK选择"""
|
"""动态TopK选择"""
|
||||||
|
# 检查输入列表是否为空
|
||||||
|
if not score:
|
||||||
|
return []
|
||||||
|
|
||||||
# 按照分数排序(降序)
|
# 按照分数排序(降序)
|
||||||
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
import networkx as nx
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def draw_graph_and_show(graph):
|
|
||||||
"""绘制图并显示,画布大小1280*1280"""
|
|
||||||
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
|
|
||||||
nx.draw_networkx(
|
|
||||||
graph,
|
|
||||||
node_size=100,
|
|
||||||
width=0.5,
|
|
||||||
with_labels=True,
|
|
||||||
labels=nx.get_node_attributes(graph, "content"),
|
|
||||||
font_family="Sarasa Mono SC",
|
|
||||||
font_size=8,
|
|
||||||
)
|
|
||||||
fig.show()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -3,13 +3,16 @@ import time
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import ast
|
import ast
|
||||||
from json_repair import repair_json
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from src.config.config import global_config
|
from json_repair import repair_json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Memory # Peewee Models导入
|
from src.common.database.database_model import Memory # Peewee Models导入
|
||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -35,8 +38,7 @@ class InstantMemory:
|
|||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.last_view_time = time.time()
|
self.last_view_time = time.time()
|
||||||
self.summary_model = LLMRequest(
|
self.summary_model = LLMRequest(
|
||||||
model=global_config.model.memory,
|
model_set=model_config.model_task_config.utils,
|
||||||
temperature=0.5,
|
|
||||||
request_type="memory.summary",
|
request_type="memory.summary",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,14 +50,11 @@ class InstantMemory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
if "1" in response:
|
return "1" in response
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return False
|
return False
|
||||||
@@ -71,9 +70,9 @@ class InstantMemory:
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
print(prompt)
|
# print(prompt)
|
||||||
print(response)
|
# print(response)
|
||||||
if not response:
|
if not response:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
@@ -142,7 +141,7 @@ class InstantMemory:
|
|||||||
请只输出json格式,不要输出其他多余内容
|
请只输出json格式,不要输出其他多余内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
if not response:
|
if not response:
|
||||||
@@ -177,7 +176,7 @@ class InstantMemory:
|
|||||||
|
|
||||||
for mem in query:
|
for mem in query:
|
||||||
# 对每条记忆
|
# 对每条记忆
|
||||||
mem_keywords = mem.keywords or []
|
mem_keywords = mem.keywords or ""
|
||||||
parsed = ast.literal_eval(mem_keywords)
|
parsed = ast.literal_eval(mem_keywords)
|
||||||
if isinstance(parsed, list):
|
if isinstance(parsed, list):
|
||||||
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
|
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
|
||||||
@@ -201,6 +200,7 @@ class InstantMemory:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_time_range(self, time_str):
|
def _parse_time_range(self, time_str):
|
||||||
|
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
|
||||||
"""
|
"""
|
||||||
支持解析如下格式:
|
支持解析如下格式:
|
||||||
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
||||||
@@ -208,8 +208,6 @@ class InstantMemory:
|
|||||||
- 相对时间:今天,昨天,前天,N天前,N个月前
|
- 相对时间:今天,昨天,前天,N天前,N个月前
|
||||||
- 空字符串:返回(None, None)
|
- 空字符串:返回(None, None)
|
||||||
"""
|
"""
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
if not time_str:
|
if not time_str:
|
||||||
return 0, now
|
return 0, now
|
||||||
@@ -239,14 +237,12 @@ class InstantMemory:
|
|||||||
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
end = start + timedelta(days=1)
|
end = start + timedelta(days=1)
|
||||||
return start, end
|
return start, end
|
||||||
m = re.match(r"(\d+)天前", time_str)
|
if m := re.match(r"(\d+)天前", time_str):
|
||||||
if m:
|
|
||||||
days = int(m.group(1))
|
days = int(m.group(1))
|
||||||
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
end = start + timedelta(days=1)
|
end = start + timedelta(days=1)
|
||||||
return start, end
|
return start, end
|
||||||
m = re.match(r"(\d+)个月前", time_str)
|
if m := re.match(r"(\d+)个月前", time_str):
|
||||||
if m:
|
|
||||||
months = int(m.group(1))
|
months = int(m.group(1))
|
||||||
# 近似每月30天
|
# 近似每月30天
|
||||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from datetime import datetime
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
from typing import List, Dict
|
from src.chat.utils.utils import parse_keywords_string
|
||||||
import difflib
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
import json
|
import random
|
||||||
from json_repair import repair_json
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
@@ -38,20 +42,20 @@ def get_keywords_from_json(json_str) -> List:
|
|||||||
def init_prompt():
|
def init_prompt():
|
||||||
# --- Group Chat Prompt ---
|
# --- Group Chat Prompt ---
|
||||||
memory_activator_prompt = """
|
memory_activator_prompt = """
|
||||||
你是一个记忆分析器,你需要根据以下信息来进行回忆
|
你需要根据以下信息来挑选合适的记忆编号
|
||||||
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
|
以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
|
||||||
|
|
||||||
聊天记录:
|
聊天记录:
|
||||||
{obs_info_text}
|
{obs_info_text}
|
||||||
你想要回复的消息:
|
你想要回复的消息:
|
||||||
{target_message}
|
{target_message}
|
||||||
|
|
||||||
历史关键词(请避免重复提取这些关键词):
|
记忆:
|
||||||
{cached_keywords}
|
{memory_info}
|
||||||
|
|
||||||
请输出一个json格式,包含以下字段:
|
请输出一个json格式,包含以下字段:
|
||||||
{{
|
{{
|
||||||
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
|
||||||
}}
|
}}
|
||||||
不要输出其他多余内容,只输出json格式就好
|
不要输出其他多余内容,只输出json格式就好
|
||||||
"""
|
"""
|
||||||
@@ -61,83 +65,197 @@ def init_prompt():
|
|||||||
|
|
||||||
class MemoryActivator:
|
class MemoryActivator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# TODO: API-Adapter修改标记
|
|
||||||
|
|
||||||
self.key_words_model = LLMRequest(
|
self.key_words_model = LLMRequest(
|
||||||
model=global_config.model.utils_small,
|
model_set=model_config.model_task_config.utils_small,
|
||||||
temperature=0.5,
|
|
||||||
request_type="memory.activator",
|
request_type="memory.activator",
|
||||||
)
|
)
|
||||||
|
# 用于记忆选择的 LLM 模型
|
||||||
|
self.memory_selection_model = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils_small,
|
||||||
|
request_type="memory.selection",
|
||||||
|
)
|
||||||
|
|
||||||
self.running_memory = []
|
|
||||||
self.cached_keywords = set() # 用于缓存历史关键词
|
|
||||||
|
|
||||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
激活记忆
|
激活记忆
|
||||||
"""
|
"""
|
||||||
# 如果记忆系统被禁用,直接返回空列表
|
# 如果记忆系统被禁用,直接返回空列表
|
||||||
if not global_config.memory.enable_memory:
|
if not global_config.memory.enable_memory:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 将缓存的关键词转换为字符串,用于prompt
|
keywords_list = set()
|
||||||
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
|
|
||||||
|
for msg in chat_history_prompt:
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
keywords = parse_keywords_string(msg.get("key_words", ""))
|
||||||
"memory_activator_prompt",
|
if keywords:
|
||||||
obs_info_text=chat_history_prompt,
|
if len(keywords_list) < 30:
|
||||||
target_message=target_message,
|
# 最多容纳30个关键词
|
||||||
cached_keywords=cached_keywords_str,
|
keywords_list.update(keywords)
|
||||||
)
|
logger.debug(f"提取关键词: {keywords_list}")
|
||||||
|
else:
|
||||||
# logger.debug(f"prompt: {prompt}")
|
break
|
||||||
|
|
||||||
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt)
|
if not keywords_list:
|
||||||
|
logger.debug("没有提取到关键词,返回空记忆列表")
|
||||||
keywords = list(get_keywords_from_json(response))
|
return []
|
||||||
|
|
||||||
# 更新关键词缓存
|
# 从海马体获取相关记忆
|
||||||
if keywords:
|
|
||||||
# 限制缓存大小,最多保留10个关键词
|
|
||||||
if len(self.cached_keywords) > 10:
|
|
||||||
# 转换为列表,移除最早的关键词
|
|
||||||
cached_list = list(self.cached_keywords)
|
|
||||||
self.cached_keywords = set(cached_list[-8:])
|
|
||||||
|
|
||||||
# 添加新的关键词到缓存
|
|
||||||
self.cached_keywords.update(keywords)
|
|
||||||
|
|
||||||
# 调用记忆系统获取相关记忆
|
|
||||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||||
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"当前记忆关键词: {self.cached_keywords} ")
|
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||||
logger.debug(f"获取到的记忆: {related_memory}")
|
logger.debug(f"获取到的记忆: {related_memory}")
|
||||||
|
|
||||||
|
if not related_memory:
|
||||||
|
logger.debug("海马体没有返回相关记忆")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 激活时,所有已有记忆的duration+1,达到3则移除
|
used_ids = set()
|
||||||
for m in self.running_memory[:]:
|
candidate_memories = []
|
||||||
m["duration"] = m.get("duration", 1) + 1
|
|
||||||
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
|
|
||||||
|
|
||||||
if related_memory:
|
# 为每个记忆分配随机ID并过滤相关记忆
|
||||||
for topic, memory in related_memory:
|
for memory in related_memory:
|
||||||
# 检查是否已存在相同topic或相似内容(相似度>=0.7)的记忆
|
keyword, content = memory
|
||||||
exists = any(
|
found = False
|
||||||
m["topic"] == topic or difflib.SequenceMatcher(None, m["content"], memory).ratio() >= 0.7
|
for kw in keywords_list:
|
||||||
for m in self.running_memory
|
if kw in content:
|
||||||
)
|
found = True
|
||||||
if not exists:
|
break
|
||||||
self.running_memory.append(
|
|
||||||
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat(), "duration": 1}
|
if found:
|
||||||
)
|
# 随机分配一个不重复的2位数id
|
||||||
logger.debug(f"添加新记忆: {topic} - {memory}")
|
while True:
|
||||||
|
random_id = "{:02d}".format(random.randint(0, 99))
|
||||||
|
if random_id not in used_ids:
|
||||||
|
used_ids.add(random_id)
|
||||||
|
break
|
||||||
|
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
|
||||||
|
|
||||||
# 限制同时加载的记忆条数,最多保留最后3条
|
if not candidate_memories:
|
||||||
if len(self.running_memory) > 3:
|
logger.info("没有找到相关的候选记忆")
|
||||||
self.running_memory = self.running_memory[-3:]
|
return []
|
||||||
|
|
||||||
|
# 如果只有少量记忆,直接返回
|
||||||
|
if len(candidate_memories) <= 2:
|
||||||
|
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||||
|
# 转换为 (keyword, content) 格式
|
||||||
|
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||||
|
|
||||||
|
# 使用 LLM 选择合适的记忆
|
||||||
|
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
|
||||||
|
|
||||||
|
return selected_memories
|
||||||
|
|
||||||
|
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
使用 LLM 选择合适的记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_message: 目标消息
|
||||||
|
chat_history_prompt: 聊天历史
|
||||||
|
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建聊天历史字符串
|
||||||
|
obs_info_text = build_readable_messages(
|
||||||
|
chat_history_prompt,
|
||||||
|
replace_bot_name=True,
|
||||||
|
merge_messages=False,
|
||||||
|
timestamp_mode="relative",
|
||||||
|
read_mark=0.0,
|
||||||
|
show_actions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 构建记忆信息字符串
|
||||||
|
memory_lines = []
|
||||||
|
for memory in candidate_memories:
|
||||||
|
memory_id = memory["memory_id"]
|
||||||
|
keyword = memory["keyword"]
|
||||||
|
content = memory["content"]
|
||||||
|
|
||||||
|
# 将 content 列表转换为字符串
|
||||||
|
if isinstance(content, list):
|
||||||
|
content_str = " | ".join(str(item) for item in content)
|
||||||
|
else:
|
||||||
|
content_str = str(content)
|
||||||
|
|
||||||
|
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
|
||||||
|
|
||||||
|
memory_info = "\n".join(memory_lines)
|
||||||
|
|
||||||
|
# 获取并格式化 prompt
|
||||||
|
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||||
|
formatted_prompt = prompt_template.format(
|
||||||
|
obs_info_text=obs_info_text,
|
||||||
|
target_message=target_message,
|
||||||
|
memory_info=memory_info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 调用 LLM
|
||||||
|
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||||
|
formatted_prompt,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=150
|
||||||
|
)
|
||||||
|
|
||||||
|
if global_config.debug.show_prompt:
|
||||||
|
logger.info(f"记忆选择 prompt: {formatted_prompt}")
|
||||||
|
logger.info(f"LLM 记忆选择响应: {response}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
|
||||||
|
logger.debug(f"LLM 记忆选择响应: {response}")
|
||||||
|
|
||||||
|
# 解析响应获取选择的记忆编号
|
||||||
|
try:
|
||||||
|
fixed_json = repair_json(response)
|
||||||
|
|
||||||
|
# 解析为 Python 对象
|
||||||
|
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||||
|
|
||||||
|
# 提取 memory_ids 字段
|
||||||
|
memory_ids_str = result.get("memory_ids", "")
|
||||||
|
|
||||||
|
# 解析逗号分隔的编号
|
||||||
|
if memory_ids_str:
|
||||||
|
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||||
|
# 过滤掉空字符串和无效编号
|
||||||
|
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||||
|
selected_memory_ids = valid_memory_ids
|
||||||
|
else:
|
||||||
|
selected_memory_ids = []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
|
||||||
|
selected_memory_ids = []
|
||||||
|
|
||||||
|
# 根据编号筛选记忆
|
||||||
|
selected_memories = []
|
||||||
|
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||||
|
|
||||||
|
for memory_id in selected_memory_ids:
|
||||||
|
if memory_id in memory_id_to_memory:
|
||||||
|
selected_memories.append(memory_id_to_memory[memory_id])
|
||||||
|
|
||||||
|
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||||
|
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||||
|
|
||||||
|
# 转换为 (keyword, content) 格式
|
||||||
|
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
|
||||||
|
# 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式
|
||||||
|
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||||
|
|
||||||
return self.running_memory
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -1,126 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBuildScheduler:
|
|
||||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
|
||||||
"""
|
|
||||||
初始化记忆构建调度器
|
|
||||||
|
|
||||||
参数:
|
|
||||||
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
|
||||||
std_hours1 (float): 第一个分布的标准差(小时)
|
|
||||||
weight1 (float): 第一个分布的权重
|
|
||||||
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
|
|
||||||
std_hours2 (float): 第二个分布的标准差(小时)
|
|
||||||
weight2 (float): 第二个分布的权重
|
|
||||||
total_samples (int): 要生成的总时间点数量
|
|
||||||
"""
|
|
||||||
# 验证参数
|
|
||||||
if total_samples <= 0:
|
|
||||||
raise ValueError("total_samples 必须大于0")
|
|
||||||
if weight1 < 0 or weight2 < 0:
|
|
||||||
raise ValueError("权重必须为非负数")
|
|
||||||
if std_hours1 < 0 or std_hours2 < 0:
|
|
||||||
raise ValueError("标准差必须为非负数")
|
|
||||||
|
|
||||||
# 归一化权重
|
|
||||||
total_weight = weight1 + weight2
|
|
||||||
if total_weight == 0:
|
|
||||||
raise ValueError("权重总和不能为0")
|
|
||||||
self.weight1 = weight1 / total_weight
|
|
||||||
self.weight2 = weight2 / total_weight
|
|
||||||
|
|
||||||
self.n_hours1 = n_hours1
|
|
||||||
self.std_hours1 = std_hours1
|
|
||||||
self.n_hours2 = n_hours2
|
|
||||||
self.std_hours2 = std_hours2
|
|
||||||
self.total_samples = total_samples
|
|
||||||
self.base_time = datetime.now()
|
|
||||||
|
|
||||||
def generate_time_samples(self):
|
|
||||||
"""生成混合分布的时间采样点"""
|
|
||||||
# 根据权重计算每个分布的样本数
|
|
||||||
samples1 = max(1, int(self.total_samples * self.weight1))
|
|
||||||
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
|
|
||||||
|
|
||||||
# 生成两个正态分布的小时偏移
|
|
||||||
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
|
|
||||||
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
|
|
||||||
|
|
||||||
# 合并两个分布的偏移
|
|
||||||
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
|
||||||
|
|
||||||
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
|
||||||
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
|
||||||
|
|
||||||
# 按时间排序(从最早到最近)
|
|
||||||
return sorted(timestamps)
|
|
||||||
|
|
||||||
def get_timestamp_array(self):
|
|
||||||
"""返回时间戳数组"""
|
|
||||||
timestamps = self.generate_time_samples()
|
|
||||||
return [int(t.timestamp()) for t in timestamps]
|
|
||||||
|
|
||||||
|
|
||||||
# def print_time_samples(timestamps, show_distribution=True):
|
|
||||||
# """打印时间样本和分布信息"""
|
|
||||||
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
|
||||||
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
|
||||||
# print("-" * 50)
|
|
||||||
|
|
||||||
# now = datetime.now()
|
|
||||||
# time_diffs = []
|
|
||||||
|
|
||||||
# for i, timestamp in enumerate(timestamps, 1):
|
|
||||||
# hours_diff = (now - timestamp).total_seconds() / 3600
|
|
||||||
# time_diffs.append(hours_diff)
|
|
||||||
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
|
||||||
|
|
||||||
# # 打印统计信息
|
|
||||||
# print("\n统计信息:")
|
|
||||||
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
|
||||||
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
|
||||||
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
|
||||||
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
|
||||||
|
|
||||||
# if show_distribution:
|
|
||||||
# # 计算时间分布的直方图
|
|
||||||
# hist, bins = np.histogram(time_diffs, bins=40)
|
|
||||||
# print("\n时间分布(每个*代表一个时间点):")
|
|
||||||
# for i in range(len(hist)):
|
|
||||||
# if hist[i] > 0:
|
|
||||||
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
|
||||||
|
|
||||||
|
|
||||||
# # 使用示例
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# # 创建一个双峰分布的记忆调度器
|
|
||||||
# scheduler = MemoryBuildScheduler(
|
|
||||||
# n_hours1=12, # 第一个分布均值(12小时前)
|
|
||||||
# std_hours1=8, # 第一个分布标准差
|
|
||||||
# weight1=0.7, # 第一个分布权重 70%
|
|
||||||
# n_hours2=36, # 第二个分布均值(36小时前)
|
|
||||||
# std_hours2=24, # 第二个分布标准差
|
|
||||||
# weight2=0.3, # 第二个分布权重 30%
|
|
||||||
# total_samples=50, # 总共生成50个时间点
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 生成时间分布
|
|
||||||
# timestamps = scheduler.generate_time_samples()
|
|
||||||
|
|
||||||
# # 打印结果,包含分布可视化
|
|
||||||
# print_time_samples(timestamps, show_distribution=True)
|
|
||||||
|
|
||||||
# # 打印时间戳数组
|
|
||||||
# timestamp_array = scheduler.get_timestamp_array()
|
|
||||||
# print("\n时间戳数组(Unix时间戳):")
|
|
||||||
# print("[", end="")
|
|
||||||
# for i, ts in enumerate(timestamp_array):
|
|
||||||
# if i > 0:
|
|
||||||
# print(", ", end="")
|
|
||||||
# print(ts, end="")
|
|
||||||
# print("]")
|
|
||||||
@@ -16,6 +16,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|||||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
|
from src.person_info.person_info import Person
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
@@ -146,7 +147,10 @@ class ChatBot:
|
|||||||
|
|
||||||
async def hanle_notice_message(self, message: MessageRecv):
|
async def hanle_notice_message(self, message: MessageRecv):
|
||||||
if message.message_info.message_id == "notice":
|
if message.message_info.message_id == "notice":
|
||||||
logger.info("收到notice消息,暂时不支持处理")
|
message.is_notify = True
|
||||||
|
logger.info("notice消息")
|
||||||
|
# print(message)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||||
@@ -165,6 +169,8 @@ class ChatBot:
|
|||||||
|
|
||||||
# 处理消息内容
|
# 处理消息内容
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
|
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
|
||||||
|
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
|
|
||||||
@@ -207,7 +213,8 @@ class ChatBot:
|
|||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
|
|
||||||
if await self.hanle_notice_message(message):
|
if await self.hanle_notice_message(message):
|
||||||
return
|
# return
|
||||||
|
pass
|
||||||
|
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
|||||||
@@ -217,7 +217,8 @@ class ChatManager:
|
|||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
stream.update_active_time()
|
stream.update_active_time()
|
||||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||||
stream.user_info = user_info
|
if user_info and user_info.platform and user_info.user_id:
|
||||||
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import urllib3
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, List
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -29,7 +29,6 @@ class Message(MessageBase):
|
|||||||
chat_stream: "ChatStream" = None # type: ignore
|
chat_stream: "ChatStream" = None # type: ignore
|
||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
memorized_times: int = 0
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -109,12 +108,16 @@ class MessageRecv(Message):
|
|||||||
self.has_picid = False
|
self.has_picid = False
|
||||||
self.is_voice = False
|
self.is_voice = False
|
||||||
self.is_mentioned = None
|
self.is_mentioned = None
|
||||||
|
self.is_notify = False
|
||||||
|
|
||||||
self.is_command = False
|
self.is_command = False
|
||||||
|
|
||||||
self.priority_mode = "interest"
|
self.priority_mode = "interest"
|
||||||
self.priority_info = None
|
self.priority_info = None
|
||||||
self.interest_value: float = None # type: ignore
|
self.interest_value: float = None # type: ignore
|
||||||
|
|
||||||
|
self.key_words = []
|
||||||
|
self.key_words_lite = []
|
||||||
|
|
||||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
@@ -203,7 +206,7 @@ class MessageRecvS4U(MessageRecv):
|
|||||||
self.is_superchat = False
|
self.is_superchat = False
|
||||||
self.gift_info = None
|
self.gift_info = None
|
||||||
self.gift_name = None
|
self.gift_name = None
|
||||||
self.gift_count = None
|
self.gift_count: Optional[str] = None
|
||||||
self.superchat_info = None
|
self.superchat_info = None
|
||||||
self.superchat_price = None
|
self.superchat_price = None
|
||||||
self.superchat_message_text = None
|
self.superchat_message_text = None
|
||||||
@@ -369,7 +372,7 @@ class MessageProcessBase(Message):
|
|||||||
return "[图片,网卡了加载不出来]"
|
return "[图片,网卡了加载不出来]"
|
||||||
elif seg.type == "emoji":
|
elif seg.type == "emoji":
|
||||||
if isinstance(seg.data, str):
|
if isinstance(seg.data, str):
|
||||||
return await get_image_manager().get_emoji_description(seg.data)
|
return await get_image_manager().get_emoji_tag(seg.data)
|
||||||
return "[表情,网卡了加载不出来]"
|
return "[表情,网卡了加载不出来]"
|
||||||
elif seg.type == "voice":
|
elif seg.type == "voice":
|
||||||
if isinstance(seg.data, str):
|
if isinstance(seg.data, str):
|
||||||
@@ -399,34 +402,6 @@ class MessageProcessBase(Message):
|
|||||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageThinking(MessageProcessBase):
|
|
||||||
"""思考状态的消息类"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_id: str,
|
|
||||||
chat_stream: "ChatStream",
|
|
||||||
bot_user_info: UserInfo,
|
|
||||||
reply: Optional["MessageRecv"] = None,
|
|
||||||
thinking_start_time: float = 0,
|
|
||||||
timestamp: Optional[float] = None,
|
|
||||||
):
|
|
||||||
# 调用父类初始化,传递时间戳
|
|
||||||
super().__init__(
|
|
||||||
message_id=message_id,
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
bot_user_info=bot_user_info,
|
|
||||||
message_segment=None, # 思考状态不需要消息段
|
|
||||||
reply=reply,
|
|
||||||
thinking_start_time=thinking_start_time,
|
|
||||||
timestamp=timestamp,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 思考状态特有属性
|
|
||||||
self.interrupt = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSending(MessageProcessBase):
|
class MessageSending(MessageProcessBase):
|
||||||
"""发送状态的消息类"""
|
"""发送状态的消息类"""
|
||||||
@@ -444,7 +419,8 @@ class MessageSending(MessageProcessBase):
|
|||||||
is_emoji: bool = False,
|
is_emoji: bool = False,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
reply_to: str = None, # type: ignore
|
reply_to: Optional[str] = None,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -469,6 +445,8 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.display_message = display_message
|
self.display_message = display_message
|
||||||
|
|
||||||
self.interest_value = 0.0
|
self.interest_value = 0.0
|
||||||
|
|
||||||
|
self.selected_expressions = selected_expressions
|
||||||
|
|
||||||
def build_reply(self):
|
def build_reply(self):
|
||||||
"""设置回复消息"""
|
"""设置回复消息"""
|
||||||
@@ -487,26 +465,6 @@ class MessageSending(MessageProcessBase):
|
|||||||
if self.message_segment:
|
if self.message_segment:
|
||||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# def from_thinking(
|
|
||||||
# cls,
|
|
||||||
# thinking: MessageThinking,
|
|
||||||
# message_segment: Seg,
|
|
||||||
# is_head: bool = False,
|
|
||||||
# is_emoji: bool = False,
|
|
||||||
# ) -> "MessageSending":
|
|
||||||
# """从思考状态消息创建发送状态消息"""
|
|
||||||
# return cls(
|
|
||||||
# message_id=thinking.message_info.message_id,
|
|
||||||
# chat_stream=thinking.chat_stream,
|
|
||||||
# message_segment=message_segment,
|
|
||||||
# bot_user_info=thinking.message_info.user_info,
|
|
||||||
# reply=thinking.reply,
|
|
||||||
# is_head=is_head,
|
|
||||||
# is_emoji=is_emoji,
|
|
||||||
# sender_info=None,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
ret = super().to_dict()
|
ret = super().to_dict()
|
||||||
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
|
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -11,6 +12,23 @@ logger = get_logger("message_storage")
|
|||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_keywords(keywords) -> str:
|
||||||
|
"""将关键词列表序列化为JSON字符串"""
|
||||||
|
if isinstance(keywords, list):
|
||||||
|
return json.dumps(keywords, ensure_ascii=False)
|
||||||
|
return "[]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deserialize_keywords(keywords_str: str) -> list:
|
||||||
|
"""将JSON字符串反序列化为关键词列表"""
|
||||||
|
if not keywords_str:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return json.loads(keywords_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
@@ -43,7 +61,11 @@ class MessageStorage:
|
|||||||
priority_info = {}
|
priority_info = {}
|
||||||
is_emoji = False
|
is_emoji = False
|
||||||
is_picid = False
|
is_picid = False
|
||||||
|
is_notify = False
|
||||||
is_command = False
|
is_command = False
|
||||||
|
key_words = ""
|
||||||
|
key_words_lite = ""
|
||||||
|
selected_expressions = message.selected_expressions
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
interest_value = message.interest_value
|
interest_value = message.interest_value
|
||||||
@@ -53,8 +75,13 @@ class MessageStorage:
|
|||||||
priority_info = message.priority_info
|
priority_info = message.priority_info
|
||||||
is_emoji = message.is_emoji
|
is_emoji = message.is_emoji
|
||||||
is_picid = message.is_picid
|
is_picid = message.is_picid
|
||||||
|
is_notify = message.is_notify
|
||||||
is_command = message.is_command
|
is_command = message.is_command
|
||||||
|
# 序列化关键词列表为JSON字符串
|
||||||
|
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||||
|
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||||
|
selected_expressions = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||||
|
|
||||||
@@ -92,13 +119,16 @@ class MessageStorage:
|
|||||||
# Text content
|
# Text content
|
||||||
processed_plain_text=filtered_processed_plain_text,
|
processed_plain_text=filtered_processed_plain_text,
|
||||||
display_message=filtered_display_message,
|
display_message=filtered_display_message,
|
||||||
memorized_times=message.memorized_times,
|
|
||||||
interest_value=interest_value,
|
interest_value=interest_value,
|
||||||
priority_mode=priority_mode,
|
priority_mode=priority_mode,
|
||||||
priority_info=priority_info,
|
priority_info=priority_info,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
is_picid=is_picid,
|
is_picid=is_picid,
|
||||||
|
is_notify=is_notify,
|
||||||
is_command=is_command,
|
is_command=is_command,
|
||||||
|
key_words=key_words,
|
||||||
|
key_words_lite=key_words_lite,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from typing import Dict, Optional, Type
|
from typing import Dict, Optional, Type
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||||
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import time
|
|||||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
@@ -36,10 +36,7 @@ class ActionModifier:
|
|||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
|
|
||||||
# 用于LLM判定的小模型
|
# 用于LLM判定的小模型
|
||||||
self.llm_judge = LLMRequest(
|
self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
|
||||||
model=global_config.model.utils_small,
|
|
||||||
request_type="action.judge",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 缓存相关属性
|
# 缓存相关属性
|
||||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||||
@@ -130,8 +127,10 @@ class ActionModifier:
|
|||||||
if all_removals:
|
if all_removals:
|
||||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||||
|
|
||||||
|
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||||
|
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}"
|
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||||
@@ -438,4 +437,4 @@ class ActionModifier:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Any, Optional, Tuple
|
from typing import Dict, Any, Optional, Tuple, List
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
@@ -36,17 +36,28 @@ def init_prompt():
|
|||||||
{chat_context_description},以下是具体的聊天内容
|
{chat_context_description},以下是具体的聊天内容
|
||||||
{chat_content_block}
|
{chat_content_block}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
|
|
||||||
现在请你根据{by_what}选择合适的action和触发action的消息:
|
现在请你根据聊天内容和用户的最新消息选择合适的action和触发action的消息:
|
||||||
{actions_before_now_block}
|
{actions_before_now_block}
|
||||||
|
|
||||||
{no_action_block}
|
{no_action_block}
|
||||||
|
|
||||||
|
动作:reply
|
||||||
|
动作描述:参与聊天回复,发送文本进行表达
|
||||||
|
- 你想要闲聊或者随便附和
|
||||||
|
- 有人提到了你,但是你还没有回应
|
||||||
|
- {mentioned_bonus}
|
||||||
|
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
||||||
|
{{
|
||||||
|
"action": "reply",
|
||||||
|
"target_message_id":"想要回复的消息id",
|
||||||
|
"reason":"回复的原因"
|
||||||
|
}}
|
||||||
|
|
||||||
{action_options_text}
|
{action_options_text}
|
||||||
|
|
||||||
你必须从上面列出的可用action中选择一个,并说明触发action的消息id(不是消息原文)和选择该action的原因。
|
你必须从上面列出的可用action中选择一个,并说明触发action的消息id(不是消息原文)和选择该action的原因。消息id格式:m+数字
|
||||||
|
|
||||||
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||||
""",
|
""",
|
||||||
@@ -59,7 +70,8 @@ def init_prompt():
|
|||||||
动作描述:{action_description}
|
动作描述:{action_description}
|
||||||
{action_require}
|
{action_require}
|
||||||
{{
|
{{
|
||||||
"action": "{action_name}",{action_parameters}{target_prompt}
|
"action": "{action_name}",{action_parameters},
|
||||||
|
"target_message_id":"触发action的消息id",
|
||||||
"reason":"触发action的原因"
|
"reason":"触发action的原因"
|
||||||
}}
|
}}
|
||||||
""",
|
""",
|
||||||
@@ -74,14 +86,15 @@ class ActionPlanner:
|
|||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
# LLM规划器配置
|
# LLM规划器配置
|
||||||
self.planner_llm = LLMRequest(
|
self.planner_llm = LLMRequest(
|
||||||
model=global_config.model.planner,
|
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||||
request_type="planner", # 用于动作规划
|
) # 用于动作规划
|
||||||
)
|
|
||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
# 添加重试计数器
|
||||||
|
self.plan_retry_count = 0
|
||||||
|
self.max_plan_retries = 3
|
||||||
|
|
||||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
# sourcery skip: use-next
|
|
||||||
"""
|
"""
|
||||||
根据message_id从message_id_list中查找对应的原始消息
|
根据message_id从message_id_list中查找对应的原始消息
|
||||||
|
|
||||||
@@ -97,37 +110,41 @@ class ActionPlanner:
|
|||||||
return item.get("message")
|
return item.get("message")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取消息列表中的最新消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最新的消息字典,如果列表为空则返回None
|
||||||
|
"""
|
||||||
|
if not message_id_list:
|
||||||
|
return None
|
||||||
|
# 假设消息列表是按时间顺序排列的,最后一个是最新的
|
||||||
|
return message_id_list[-1].get("message")
|
||||||
|
|
||||||
async def plan(
|
async def plan(
|
||||||
self, mode: ChatMode = ChatMode.FOCUS
|
self,
|
||||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
|
loop_start_time:float = 0.0,
|
||||||
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
action = "no_reply" # 默认动作
|
action = "no_action" # 默认动作
|
||||||
reasoning = "规划器初始化默认"
|
reasoning = "规划器初始化默认"
|
||||||
action_data = {}
|
action_data = {}
|
||||||
current_available_actions: Dict[str, ActionInfo] = {}
|
current_available_actions: Dict[str, ActionInfo] = {}
|
||||||
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
|
message_id_list: list = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_group_chat = True
|
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
|
||||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
|
||||||
|
|
||||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
|
||||||
|
|
||||||
# 获取完整的动作信息
|
|
||||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
|
||||||
ComponentType.ACTION
|
|
||||||
)
|
|
||||||
current_available_actions = {}
|
|
||||||
for action_name in current_available_actions_dict:
|
|
||||||
if action_name in all_registered_actions:
|
|
||||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
|
||||||
|
|
||||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||||
prompt, message_id_list = await self.build_planner_prompt(
|
prompt, message_id_list = await self.build_planner_prompt(
|
||||||
@@ -135,12 +152,13 @@ class ActionPlanner:
|
|||||||
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
|
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
|
||||||
current_available_actions=current_available_actions, # <-- Pass determined actions
|
current_available_actions=current_available_actions, # <-- Pass determined actions
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
refresh_time=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 调用 LLM (普通文本生成) ---
|
# --- 调用 LLM (普通文本生成) ---
|
||||||
llm_content = None
|
llm_content = None
|
||||||
try:
|
try:
|
||||||
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||||
@@ -156,7 +174,7 @@ class ActionPlanner:
|
|||||||
except Exception as req_e:
|
except Exception as req_e:
|
||||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||||
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||||
action = "no_reply"
|
action = "no_action"
|
||||||
|
|
||||||
if llm_content:
|
if llm_content:
|
||||||
try:
|
try:
|
||||||
@@ -173,68 +191,94 @@ class ActionPlanner:
|
|||||||
logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}")
|
logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}")
|
||||||
parsed_json = {}
|
parsed_json = {}
|
||||||
|
|
||||||
action = parsed_json.get("action", "no_reply")
|
action = parsed_json.get("action", "no_action")
|
||||||
reasoning = parsed_json.get("reasoning", "未提供原因")
|
reasoning = parsed_json.get("reason", "未提供原因")
|
||||||
|
|
||||||
# 将所有其他属性添加到action_data
|
# 将所有其他属性添加到action_data
|
||||||
for key, value in parsed_json.items():
|
for key, value in parsed_json.items():
|
||||||
if key not in ["action", "reasoning"]:
|
if key not in ["action", "reasoning"]:
|
||||||
action_data[key] = value
|
action_data[key] = value
|
||||||
|
|
||||||
# 在FOCUS模式下,非no_reply动作需要target_message_id
|
# 非no_action动作需要target_message_id
|
||||||
if mode == ChatMode.FOCUS and action != "no_reply":
|
if action != "no_action":
|
||||||
if target_message_id := parsed_json.get("target_message_id"):
|
if target_message_id := parsed_json.get("target_message_id"):
|
||||||
# 根据target_message_id查找原始消息
|
# 根据target_message_id查找原始消息
|
||||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||||
|
# 如果获取的target_message为None,输出warning并重新plan
|
||||||
|
if target_message is None:
|
||||||
|
self.plan_retry_count += 1
|
||||||
|
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
||||||
|
|
||||||
|
# 如果连续三次plan均为None,输出error并选取最新消息
|
||||||
|
if self.plan_retry_count >= self.max_plan_retries:
|
||||||
|
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
||||||
|
target_message = self.get_latest_message(message_id_list)
|
||||||
|
self.plan_retry_count = 0 # 重置计数器
|
||||||
|
else:
|
||||||
|
# 递归重新plan
|
||||||
|
return await self.plan(mode, loop_start_time, available_actions)
|
||||||
|
else:
|
||||||
|
# 成功获取到target_message,重置计数器
|
||||||
|
self.plan_retry_count = 0
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix}FOCUS模式下动作'{action}'缺少target_message_id")
|
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if action == "no_action":
|
if action != "no_action" and action != "reply" and action not in current_available_actions:
|
||||||
reasoning = "normal决定不使用额外动作"
|
|
||||||
elif action != "no_reply" and action != "reply" and action not in current_available_actions:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
|
||||||
)
|
)
|
||||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
||||||
action = "no_reply"
|
action = "no_action"
|
||||||
|
|
||||||
except Exception as json_e:
|
except Exception as json_e:
|
||||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_reply'."
|
reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_action'."
|
||||||
action = "no_reply"
|
action = "no_action"
|
||||||
|
|
||||||
except Exception as outer_e:
|
except Exception as outer_e:
|
||||||
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}")
|
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_action: {outer_e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
action = "no_reply"
|
action = "no_action"
|
||||||
reasoning = f"Planner 内部处理错误: {outer_e}"
|
reasoning = f"Planner 内部处理错误: {outer_e}"
|
||||||
|
|
||||||
is_parallel = False
|
is_parallel = False
|
||||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||||
is_parallel = current_available_actions[action].parallel_action
|
is_parallel = current_available_actions[action].parallel_action
|
||||||
|
|
||||||
action_result = {
|
|
||||||
|
action_data["loop_start_time"] = loop_start_time
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
# 1. 添加Planner取得的动作
|
||||||
|
actions.append({
|
||||||
"action_type": action,
|
"action_type": action,
|
||||||
"action_data": action_data,
|
|
||||||
"reasoning": reasoning,
|
"reasoning": reasoning,
|
||||||
"timestamp": time.time(),
|
"action_data": action_data,
|
||||||
"is_parallel": is_parallel,
|
"action_message": target_message,
|
||||||
}
|
"available_actions": available_actions # 添加这个字段
|
||||||
|
})
|
||||||
return (
|
|
||||||
{
|
if action != "reply" and is_parallel:
|
||||||
"action_result": action_result,
|
actions.append({
|
||||||
"action_prompt": prompt,
|
"action_type": "reply",
|
||||||
},
|
"action_message": target_message,
|
||||||
target_message,
|
"available_actions": available_actions
|
||||||
)
|
})
|
||||||
|
|
||||||
|
return actions,target_message
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
is_group_chat: bool, # Now passed as argument
|
is_group_chat: bool, # Now passed as argument
|
||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
current_available_actions: Dict[str, ActionInfo],
|
current_available_actions: Dict[str, ActionInfo],
|
||||||
|
refresh_time :bool = False,
|
||||||
mode: ChatMode = ChatMode.FOCUS,
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
) -> tuple[str, list]: # sourcery skip: use-join
|
) -> tuple[str, list]: # sourcery skip: use-join
|
||||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||||
@@ -265,43 +309,36 @@ class ActionPlanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||||
|
if refresh_time:
|
||||||
self.last_obs_time_mark = time.time()
|
self.last_obs_time_mark = time.time()
|
||||||
|
|
||||||
|
mentioned_bonus = ""
|
||||||
|
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你"
|
||||||
|
if global_config.chat.at_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||||
|
|
||||||
|
|
||||||
if mode == ChatMode.FOCUS:
|
if mode == ChatMode.FOCUS:
|
||||||
mentioned_bonus = ""
|
no_action_block = """
|
||||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
动作:no_action
|
||||||
mentioned_bonus = "\n- 有人提到你"
|
动作描述:不进行动作,等待合适的时机
|
||||||
if global_config.chat.at_bot_inevitable_reply:
|
- 当你刚刚发送了消息,没有人回复时,选择no_action
|
||||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
- 如果有别的动作(非回复)满足条件,可以不用no_action
|
||||||
|
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_action
|
||||||
by_what = "聊天内容"
|
{
|
||||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
"action": "no_action",
|
||||||
no_action_block = f"""重要说明:
|
"reason":"不动作的原因"
|
||||||
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
}
|
||||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
|
||||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
|
||||||
|
|
||||||
动作:reply
|
|
||||||
动作描述:参与聊天回复,发送文本进行表达
|
|
||||||
- 你想要闲聊或者随便附和{mentioned_bonus}
|
|
||||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
|
||||||
{{
|
|
||||||
"action": "reply",
|
|
||||||
"target_message_id":"触发action的消息id",
|
|
||||||
"reason":"回复的原因"
|
|
||||||
}}
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
by_what = "聊天内容和用户的最新消息"
|
|
||||||
target_prompt = ""
|
|
||||||
no_action_block = """重要说明:
|
no_action_block = """重要说明:
|
||||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||||
- 其他action表示在普通回复的基础上,执行相应的额外动作"""
|
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
||||||
|
"""
|
||||||
|
|
||||||
chat_context_description = "你现在正在一个群聊中"
|
chat_context_description = "你现在正在一个群聊中"
|
||||||
chat_target_name = None # Only relevant for private
|
chat_target_name = None
|
||||||
if not is_group_chat and chat_target_info:
|
if not is_group_chat and chat_target_info:
|
||||||
chat_target_name = (
|
chat_target_name = (
|
||||||
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
||||||
@@ -330,7 +367,6 @@ class ActionPlanner:
|
|||||||
action_description=using_actions_info.description,
|
action_description=using_actions_info.description,
|
||||||
action_parameters=param_text,
|
action_parameters=param_text,
|
||||||
action_require=require_text,
|
action_require=require_text,
|
||||||
target_prompt=target_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
action_options_block += using_action_prompt
|
action_options_block += using_action_prompt
|
||||||
@@ -350,11 +386,11 @@ class ActionPlanner:
|
|||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
prompt = planner_prompt_template.format(
|
prompt = planner_prompt_template.format(
|
||||||
time_block=time_block,
|
time_block=time_block,
|
||||||
by_what=by_what,
|
|
||||||
chat_context_description=chat_context_description,
|
chat_context_description=chat_context_description,
|
||||||
chat_content_block=chat_content_block,
|
chat_content_block=chat_content_block,
|
||||||
actions_before_now_block=actions_before_now_block,
|
actions_before_now_block=actions_before_now_block,
|
||||||
no_action_block=no_action_block,
|
no_action_block=no_action_block,
|
||||||
|
mentioned_bonus=mentioned_bonus,
|
||||||
action_options_text=action_options_block,
|
action_options_text=action_options_block,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
identity_block=identity_block,
|
identity_block=identity_block,
|
||||||
@@ -365,5 +401,28 @@ class ActionPlanner:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return "构建 Planner Prompt 时出错", []
|
return "构建 Planner Prompt 时出错", []
|
||||||
|
|
||||||
|
def get_necessary_info(self) -> Tuple[bool, Optional[dict], Dict[str, ActionInfo]]:
|
||||||
|
"""
|
||||||
|
获取 Planner 需要的必要信息
|
||||||
|
"""
|
||||||
|
is_group_chat = True
|
||||||
|
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
|
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||||
|
|
||||||
|
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
|
# 获取完整的动作信息
|
||||||
|
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||||
|
ComponentType.ACTION
|
||||||
|
)
|
||||||
|
current_available_actions = {}
|
||||||
|
for action_name in current_available_actions_dict:
|
||||||
|
if action_name in all_registered_actions:
|
||||||
|
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||||
|
else:
|
||||||
|
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||||
|
|
||||||
|
return is_group_chat, chat_target_info, current_available_actions
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
@@ -15,7 +15,6 @@ class ReplyerManager:
|
|||||||
self,
|
self,
|
||||||
chat_stream: Optional[ChatStream] = None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
) -> Optional[DefaultReplyer]:
|
) -> Optional[DefaultReplyer]:
|
||||||
"""
|
"""
|
||||||
@@ -49,7 +48,6 @@ class ReplyerManager:
|
|||||||
# model_configs 只在此时(初始化时)生效
|
# model_configs 只在此时(初始化时)生效
|
||||||
replyer = DefaultReplyer(
|
replyer = DefaultReplyer(
|
||||||
chat_stream=target_stream,
|
chat_stream=target_stream,
|
||||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
)
|
)
|
||||||
self._repliers[stream_id] = replyer
|
self._repliers[stream_id] = replyer
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from src.config.config import global_config
|
|||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import Person,get_person_id
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -35,14 +35,12 @@ def replace_user_references_sync(
|
|||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
|
|
||||||
def default_resolver(platform: str, user_id: str) -> str:
|
def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
return person.person_name or user_id # type: ignore
|
||||||
|
|
||||||
name_resolver = default_resolver
|
name_resolver = default_resolver
|
||||||
|
|
||||||
@@ -110,14 +108,12 @@ async def replace_user_references_async(
|
|||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
|
|
||||||
async def default_resolver(platform: str, user_id: str) -> str:
|
async def default_resolver(platform: str, user_id: str) -> str:
|
||||||
# 检查是否是机器人自己
|
# 检查是否是机器人自己
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
return person.person_name or user_id # type: ignore
|
||||||
|
|
||||||
name_resolver = default_resolver
|
name_resolver = default_resolver
|
||||||
|
|
||||||
@@ -506,14 +502,13 @@ def _build_readable_messages_internal(
|
|||||||
if not all([platform, user_id, timestamp is not None]):
|
if not all([platform, user_id, timestamp is not None]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
person_name: str
|
person_name: str
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
person_name = person.person_name or user_id # type: ignore
|
||||||
|
|
||||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||||
if not person_name:
|
if not person_name:
|
||||||
@@ -740,7 +735,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
|||||||
for action in actions:
|
for action in actions:
|
||||||
action_time = action.get("time", current_time)
|
action_time = action.get("time", current_time)
|
||||||
action_name = action.get("action_name", "未知动作")
|
action_name = action.get("action_name", "未知动作")
|
||||||
if action_name in ["no_action", "no_reply"]:
|
if action_name in ["no_action", "no_action"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
||||||
@@ -1009,7 +1004,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
# print("SELF11111111111111")
|
# print("SELF11111111111111")
|
||||||
return "SELF"
|
return "SELF"
|
||||||
try:
|
try:
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
person_id = None
|
person_id = None
|
||||||
if not person_id:
|
if not person_id:
|
||||||
@@ -1098,7 +1093,11 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if person_id := PersonInfoManager.get_person_id(platform, user_id):
|
# 添加空值检查,防止 platform 为 None 时出错
|
||||||
|
if platform is None:
|
||||||
|
platform = "unknown"
|
||||||
|
|
||||||
|
if person_id := get_person_id(platform, user_id):
|
||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set) # 将集合转换为列表返回
|
||||||
|
|||||||
@@ -1,223 +0,0 @@
|
|||||||
import ast
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
|
|
||||||
|
|
||||||
# 定义类型变量用于泛型类型提示
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
# 获取logger
|
|
||||||
logger = logging.getLogger("json_utils")
|
|
||||||
|
|
||||||
|
|
||||||
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
|
||||||
"""
|
|
||||||
安全地解析JSON字符串,出错时返回默认值
|
|
||||||
现在尝试处理单引号和标准JSON
|
|
||||||
|
|
||||||
参数:
|
|
||||||
json_str: 要解析的JSON字符串
|
|
||||||
default_value: 解析失败时返回的默认值
|
|
||||||
|
|
||||||
返回:
|
|
||||||
解析后的Python对象,或在解析失败时返回default_value
|
|
||||||
"""
|
|
||||||
if not json_str or not isinstance(json_str, str):
|
|
||||||
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
|
|
||||||
return default_value
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试标准的 JSON 解析
|
|
||||||
return json.loads(json_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# 如果标准解析失败,尝试用 ast.literal_eval 解析
|
|
||||||
try:
|
|
||||||
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
|
||||||
result = ast.literal_eval(json_str)
|
|
||||||
if isinstance(result, dict):
|
|
||||||
return result
|
|
||||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
|
||||||
return default_value
|
|
||||||
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
|
||||||
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
|
|
||||||
return default_value
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
|
||||||
return default_value
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
|
||||||
return default_value
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_call_arguments(
|
|
||||||
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
从LLM工具调用对象中提取参数
|
|
||||||
|
|
||||||
参数:
|
|
||||||
tool_call: 工具调用对象字典
|
|
||||||
default_value: 解析失败时返回的默认值
|
|
||||||
|
|
||||||
返回:
|
|
||||||
解析后的参数字典,或在解析失败时返回default_value
|
|
||||||
"""
|
|
||||||
default_result = default_value or {}
|
|
||||||
|
|
||||||
if not tool_call or not isinstance(tool_call, dict):
|
|
||||||
logger.error(f"无效的工具调用对象: {tool_call}")
|
|
||||||
return default_result
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 提取function参数
|
|
||||||
function_data = tool_call.get("function", {})
|
|
||||||
if not function_data or not isinstance(function_data, dict):
|
|
||||||
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
|
||||||
return default_result
|
|
||||||
|
|
||||||
if arguments_str := function_data.get("arguments", "{}"):
|
|
||||||
# 解析JSON
|
|
||||||
return safe_json_loads(arguments_str, default_result)
|
|
||||||
else:
|
|
||||||
return default_result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"提取工具调用参数时出错: {e}")
|
|
||||||
return default_result
|
|
||||||
|
|
||||||
|
|
||||||
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
|
|
||||||
"""
|
|
||||||
安全地将Python对象序列化为JSON字符串
|
|
||||||
|
|
||||||
参数:
|
|
||||||
obj: 要序列化的Python对象
|
|
||||||
default_value: 序列化失败时返回的默认值
|
|
||||||
ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符)
|
|
||||||
pretty: 是否美化输出JSON
|
|
||||||
|
|
||||||
返回:
|
|
||||||
序列化后的JSON字符串,或在序列化失败时返回default_value
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
indent = 2 if pretty else None
|
|
||||||
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
|
|
||||||
except TypeError as e:
|
|
||||||
logger.error(f"JSON序列化失败(类型错误): {e}")
|
|
||||||
return default_value
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"JSON序列化过程中发生意外错误: {e}")
|
|
||||||
return default_value
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
|
|
||||||
"""
|
|
||||||
标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式
|
|
||||||
|
|
||||||
参数:
|
|
||||||
response: 原始LLM响应
|
|
||||||
log_prefix: 日志前缀
|
|
||||||
|
|
||||||
返回:
|
|
||||||
元组 (成功标志, 标准化后的响应列表, 错误消息)
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
|
|
||||||
|
|
||||||
# 检查是否为None
|
|
||||||
if response is None:
|
|
||||||
return False, [], "LLM响应为None"
|
|
||||||
|
|
||||||
# 记录原始类型
|
|
||||||
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
|
|
||||||
|
|
||||||
# 将元组转换为列表
|
|
||||||
if isinstance(response, tuple):
|
|
||||||
logger.debug(f"{log_prefix}将元组响应转换为列表")
|
|
||||||
response = list(response)
|
|
||||||
|
|
||||||
# 确保是列表类型
|
|
||||||
if not isinstance(response, list):
|
|
||||||
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
|
|
||||||
|
|
||||||
# 处理工具调用部分(如果存在)
|
|
||||||
if len(response) == 3:
|
|
||||||
content, reasoning, tool_calls = response
|
|
||||||
|
|
||||||
# 将工具调用部分转换为列表(如果是元组)
|
|
||||||
if isinstance(tool_calls, tuple):
|
|
||||||
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
|
|
||||||
tool_calls = list(tool_calls)
|
|
||||||
response[2] = tool_calls
|
|
||||||
|
|
||||||
return True, response, ""
|
|
||||||
|
|
||||||
|
|
||||||
def process_llm_tool_calls(
|
|
||||||
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
|
|
||||||
) -> Tuple[bool, List[Dict[str, Any]], str]:
|
|
||||||
"""
|
|
||||||
处理并验证LLM响应中的工具调用列表
|
|
||||||
|
|
||||||
参数:
|
|
||||||
tool_calls: 从LLM响应中直接获取的工具调用列表
|
|
||||||
log_prefix: 日志前缀
|
|
||||||
|
|
||||||
返回:
|
|
||||||
元组 (成功标志, 验证后的工具调用列表, 错误消息)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 如果列表为空,表示没有工具调用,这不是错误
|
|
||||||
if not tool_calls:
|
|
||||||
return True, [], "工具调用列表为空"
|
|
||||||
|
|
||||||
# 验证每个工具调用的格式
|
|
||||||
valid_tool_calls = []
|
|
||||||
for i, tool_call in enumerate(tool_calls):
|
|
||||||
if not isinstance(tool_call, dict):
|
|
||||||
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查基本结构
|
|
||||||
if tool_call.get("type") != "function":
|
|
||||||
logger.warning(
|
|
||||||
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
|
|
||||||
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
func_details = tool_call["function"]
|
|
||||||
if "name" not in func_details or not isinstance(func_details.get("name"), str):
|
|
||||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 验证参数 'arguments'
|
|
||||||
args_value = func_details.get("arguments")
|
|
||||||
|
|
||||||
# 1. 检查 arguments 是否存在且是字符串
|
|
||||||
if args_value is None or not isinstance(args_value, str):
|
|
||||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 2. 尝试安全地解析 arguments 字符串
|
|
||||||
parsed_args = safe_json_loads(args_value, None)
|
|
||||||
|
|
||||||
# 3. 检查解析结果是否为字典
|
|
||||||
if parsed_args is None or not isinstance(parsed_args, dict):
|
|
||||||
logger.warning(
|
|
||||||
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
|
|
||||||
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 如果检查通过,将原始的 tool_call 加入有效列表
|
|
||||||
valid_tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
|
|
||||||
return False, [], "所有工具调用格式均无效"
|
|
||||||
|
|
||||||
return True, valid_tool_calls, ""
|
|
||||||
@@ -36,6 +36,18 @@ COST_BY_TYPE = "costs_by_type"
|
|||||||
COST_BY_USER = "costs_by_user"
|
COST_BY_USER = "costs_by_user"
|
||||||
COST_BY_MODEL = "costs_by_model"
|
COST_BY_MODEL = "costs_by_model"
|
||||||
COST_BY_MODULE = "costs_by_module"
|
COST_BY_MODULE = "costs_by_module"
|
||||||
|
TIME_COST_BY_TYPE = "time_costs_by_type"
|
||||||
|
TIME_COST_BY_USER = "time_costs_by_user"
|
||||||
|
TIME_COST_BY_MODEL = "time_costs_by_model"
|
||||||
|
TIME_COST_BY_MODULE = "time_costs_by_module"
|
||||||
|
AVG_TIME_COST_BY_TYPE = "avg_time_costs_by_type"
|
||||||
|
AVG_TIME_COST_BY_USER = "avg_time_costs_by_user"
|
||||||
|
AVG_TIME_COST_BY_MODEL = "avg_time_costs_by_model"
|
||||||
|
AVG_TIME_COST_BY_MODULE = "avg_time_costs_by_module"
|
||||||
|
STD_TIME_COST_BY_TYPE = "std_time_costs_by_type"
|
||||||
|
STD_TIME_COST_BY_USER = "std_time_costs_by_user"
|
||||||
|
STD_TIME_COST_BY_MODEL = "std_time_costs_by_model"
|
||||||
|
STD_TIME_COST_BY_MODULE = "std_time_costs_by_module"
|
||||||
ONLINE_TIME = "online_time"
|
ONLINE_TIME = "online_time"
|
||||||
TOTAL_MSG_CNT = "total_messages"
|
TOTAL_MSG_CNT = "total_messages"
|
||||||
MSG_CNT_BY_CHAT = "messages_by_chat"
|
MSG_CNT_BY_CHAT = "messages_by_chat"
|
||||||
@@ -293,6 +305,18 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
COST_BY_USER: defaultdict(float),
|
COST_BY_USER: defaultdict(float),
|
||||||
COST_BY_MODEL: defaultdict(float),
|
COST_BY_MODEL: defaultdict(float),
|
||||||
COST_BY_MODULE: defaultdict(float),
|
COST_BY_MODULE: defaultdict(float),
|
||||||
|
TIME_COST_BY_TYPE: defaultdict(list),
|
||||||
|
TIME_COST_BY_USER: defaultdict(list),
|
||||||
|
TIME_COST_BY_MODEL: defaultdict(list),
|
||||||
|
TIME_COST_BY_MODULE: defaultdict(list),
|
||||||
|
AVG_TIME_COST_BY_TYPE: defaultdict(float),
|
||||||
|
AVG_TIME_COST_BY_USER: defaultdict(float),
|
||||||
|
AVG_TIME_COST_BY_MODEL: defaultdict(float),
|
||||||
|
AVG_TIME_COST_BY_MODULE: defaultdict(float),
|
||||||
|
STD_TIME_COST_BY_TYPE: defaultdict(float),
|
||||||
|
STD_TIME_COST_BY_USER: defaultdict(float),
|
||||||
|
STD_TIME_COST_BY_MODEL: defaultdict(float),
|
||||||
|
STD_TIME_COST_BY_MODULE: defaultdict(float),
|
||||||
}
|
}
|
||||||
for period_key, _ in collect_period
|
for period_key, _ in collect_period
|
||||||
}
|
}
|
||||||
@@ -344,7 +368,41 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
stats[period_key][COST_BY_USER][user_id] += cost
|
stats[period_key][COST_BY_USER][user_id] += cost
|
||||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||||
|
|
||||||
|
# 收集time_cost数据
|
||||||
|
time_cost = record.time_cost or 0.0
|
||||||
|
if time_cost > 0: # 只记录有效的time_cost
|
||||||
|
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||||
|
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 计算平均耗时和标准差
|
||||||
|
for period_key in stats:
|
||||||
|
for category in [REQ_CNT_BY_TYPE, REQ_CNT_BY_USER, REQ_CNT_BY_MODEL, REQ_CNT_BY_MODULE]:
|
||||||
|
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||||
|
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||||
|
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||||
|
|
||||||
|
for item_name in stats[period_key][category]:
|
||||||
|
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||||
|
if time_costs:
|
||||||
|
# 计算平均耗时
|
||||||
|
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||||
|
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||||
|
|
||||||
|
# 计算标准差
|
||||||
|
if len(time_costs) > 1:
|
||||||
|
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||||
|
std_time_cost = variance ** 0.5
|
||||||
|
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||||
|
else:
|
||||||
|
stats[period_key][std_key][item_name] = 0.0
|
||||||
|
else:
|
||||||
|
stats[period_key][avg_key][item_name] = 0.0
|
||||||
|
stats[period_key][std_key][item_name] = 0.0
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -566,11 +624,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
"""
|
"""
|
||||||
if stats[TOTAL_REQ_CNT] <= 0:
|
if stats[TOTAL_REQ_CNT] <= 0:
|
||||||
return ""
|
return ""
|
||||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
|
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥ {:>10} {:>10}"
|
||||||
|
|
||||||
output = [
|
output = [
|
||||||
"按模型分类统计:",
|
"按模型分类统计:",
|
||||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
|
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒)",
|
||||||
]
|
]
|
||||||
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
||||||
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
|
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
|
||||||
@@ -578,7 +636,9 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
|
out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
|
||||||
tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
|
tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
|
||||||
cost = stats[COST_BY_MODEL][model_name]
|
cost = stats[COST_BY_MODEL][model_name]
|
||||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
|
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||||
|
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||||
|
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||||
|
|
||||||
output.append("")
|
output.append("")
|
||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
@@ -663,6 +723,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
|
f"<td>{stat_data[OUT_TOK_BY_MODEL][model_name]}</td>"
|
||||||
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
|
f"<td>{stat_data[TOTAL_TOK_BY_MODEL][model_name]}</td>"
|
||||||
f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>"
|
f"<td>{stat_data[COST_BY_MODEL][model_name]:.4f} ¥</td>"
|
||||||
|
f"<td>{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.3f} 秒</td>"
|
||||||
|
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.3f} 秒</td>"
|
||||||
f"</tr>"
|
f"</tr>"
|
||||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||||
]
|
]
|
||||||
@@ -677,6 +739,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
|
f"<td>{stat_data[OUT_TOK_BY_TYPE][req_type]}</td>"
|
||||||
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
|
f"<td>{stat_data[TOTAL_TOK_BY_TYPE][req_type]}</td>"
|
||||||
f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>"
|
f"<td>{stat_data[COST_BY_TYPE][req_type]:.4f} ¥</td>"
|
||||||
|
f"<td>{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.3f} 秒</td>"
|
||||||
|
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.3f} 秒</td>"
|
||||||
f"</tr>"
|
f"</tr>"
|
||||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||||
]
|
]
|
||||||
@@ -691,6 +755,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
f"<td>{stat_data[OUT_TOK_BY_MODULE][module_name]}</td>"
|
f"<td>{stat_data[OUT_TOK_BY_MODULE][module_name]}</td>"
|
||||||
f"<td>{stat_data[TOTAL_TOK_BY_MODULE][module_name]}</td>"
|
f"<td>{stat_data[TOTAL_TOK_BY_MODULE][module_name]}</td>"
|
||||||
f"<td>{stat_data[COST_BY_MODULE][module_name]:.4f} ¥</td>"
|
f"<td>{stat_data[COST_BY_MODULE][module_name]:.4f} ¥</td>"
|
||||||
|
f"<td>{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.3f} 秒</td>"
|
||||||
|
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.3f} 秒</td>"
|
||||||
f"</tr>"
|
f"</tr>"
|
||||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||||
]
|
]
|
||||||
@@ -717,7 +783,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
<h2>按模型分类统计</h2>
|
<h2>按模型分类统计</h2>
|
||||||
<table>
|
<table>
|
||||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr></thead>
|
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{model_rows}
|
{model_rows}
|
||||||
</tbody>
|
</tbody>
|
||||||
@@ -726,7 +792,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
<h2>按模块分类统计</h2>
|
<h2>按模块分类统计</h2>
|
||||||
<table>
|
<table>
|
||||||
<thead>
|
<thead>
|
||||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
|
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{module_rows}
|
{module_rows}
|
||||||
@@ -736,7 +802,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
<h2>按请求类型分类统计</h2>
|
<h2>按请求类型分类统计</h2>
|
||||||
<table>
|
<table>
|
||||||
<thead>
|
<thead>
|
||||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th></tr>
|
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{type_rows}
|
{type_rows}
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ from typing import Optional, Tuple, Dict, List, Any
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import Person
|
||||||
from .typo_generator import ChineseTypoGenerator
|
from .typo_generator import ChineseTypoGenerator
|
||||||
|
|
||||||
logger = get_logger("chat_utils")
|
logger = get_logger("chat_utils")
|
||||||
@@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(text, request_type="embedding"):
|
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
# TODO: API-Adapter修改标记
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||||
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
|
||||||
# return llm.get_embedding_sync(text)
|
|
||||||
try:
|
try:
|
||||||
embedding = await llm.get_embedding(text)
|
embedding, _ = await llm.get_embedding(text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取embedding失败: {str(e)}")
|
logger.error(f"获取embedding失败: {str(e)}")
|
||||||
embedding = None
|
embedding = None
|
||||||
@@ -641,12 +639,16 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
# Try to fetch person info
|
# Try to fetch person info
|
||||||
try:
|
try:
|
||||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
|
if not person.is_known:
|
||||||
|
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
||||||
|
# 如果用户尚未认识,则返回False和None
|
||||||
|
return False, None
|
||||||
|
person_id = person.person_id
|
||||||
person_name = None
|
person_name = None
|
||||||
if person_id:
|
if person_id:
|
||||||
# get_value is async, so await it directly
|
# get_value is async, so await it directly
|
||||||
person_info_manager = get_person_info_manager()
|
person_name = person.person_name
|
||||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
|
||||||
|
|
||||||
target_info["person_id"] = person_id
|
target_info["person_id"] = person_id
|
||||||
target_info["person_name"] = person_name
|
target_info["person_name"] = person_name
|
||||||
@@ -767,3 +769,68 @@ def assign_message_ids_flexible(
|
|||||||
# # 增强版本 - 使用时间戳
|
# # 增强版本 - 使用时间戳
|
||||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||||
|
|
||||||
|
def parse_keywords_string(keywords_input) -> list[str]:
|
||||||
|
"""
|
||||||
|
统一的关键词解析函数,支持多种格式的关键词字符串解析
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]'
|
||||||
|
2. 斜杠分隔格式:'utils.py/修改/代码/动作'
|
||||||
|
3. 逗号分隔格式:'utils.py,修改,代码,动作'
|
||||||
|
4. 空格分隔格式:'utils.py 修改 代码 动作'
|
||||||
|
5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"]
|
||||||
|
6. JSON格式字符串:'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keywords_input: 关键词输入,可以是字符串或列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: 解析后的关键词列表,去除空白项
|
||||||
|
"""
|
||||||
|
if not keywords_input:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 如果已经是列表,直接处理
|
||||||
|
if isinstance(keywords_input, list):
|
||||||
|
return [str(k).strip() for k in keywords_input if str(k).strip()]
|
||||||
|
|
||||||
|
# 转换为字符串处理
|
||||||
|
keywords_str = str(keywords_input).strip()
|
||||||
|
if not keywords_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式)
|
||||||
|
import json
|
||||||
|
json_data = json.loads(keywords_str)
|
||||||
|
if isinstance(json_data, dict) and "keywords" in json_data:
|
||||||
|
keywords_list = json_data["keywords"]
|
||||||
|
if isinstance(keywords_list, list):
|
||||||
|
return [str(k).strip() for k in keywords_list if str(k).strip()]
|
||||||
|
elif isinstance(json_data, list):
|
||||||
|
# 直接是JSON数组格式
|
||||||
|
return [str(k).strip() for k in json_data if str(k).strip()]
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试使用 ast.literal_eval 解析(支持Python字面量格式)
|
||||||
|
import ast
|
||||||
|
parsed = ast.literal_eval(keywords_str)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return [str(k).strip() for k in parsed if str(k).strip()]
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试不同的分隔符
|
||||||
|
separators = ['/', ',', ' ', '|', ';']
|
||||||
|
|
||||||
|
for separator in separators:
|
||||||
|
if separator in keywords_str:
|
||||||
|
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
|
||||||
|
if len(keywords_list) > 1: # 确保分割有效
|
||||||
|
return keywords_list
|
||||||
|
|
||||||
|
# 如果没有分隔符,返回单个关键词
|
||||||
|
return [keywords_str] if keywords_str else []
|
||||||
@@ -14,7 +14,7 @@ from rich.traceback import install
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import db
|
||||||
from src.common.database.database_model import Images, ImageDescriptions
|
from src.common.database.database_model import Images, ImageDescriptions
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -37,7 +37,7 @@ class ImageManager:
|
|||||||
self._ensure_image_dir()
|
self._ensure_image_dir()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.connect(reuse_if_open=True)
|
db.connect(reuse_if_open=True)
|
||||||
@@ -92,6 +92,20 @@ class ImageManager:
|
|||||||
desc_obj.save()
|
desc_obj.save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||||
|
|
||||||
|
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||||
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
|
emoji_manager = get_emoji_manager()
|
||||||
|
if isinstance(image_base64, str):
|
||||||
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
|
image_bytes = base64.b64decode(image_base64)
|
||||||
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||||
|
if not emoji:
|
||||||
|
return "[表情包:未知]"
|
||||||
|
emotion_list = emoji.emotion
|
||||||
|
tag_str = ",".join(emotion_list)
|
||||||
|
return f"[表情包:{tag_str}]"
|
||||||
|
|
||||||
async def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""获取表情包描述,优先使用Emoji表中的缓存数据"""
|
"""获取表情包描述,优先使用Emoji表中的缓存数据"""
|
||||||
@@ -108,21 +122,21 @@ class ImageManager:
|
|||||||
try:
|
try:
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = get_emoji_manager()
|
||||||
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
|
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||||
if cached_emoji_description:
|
if tags:
|
||||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {cached_emoji_description[:50]}...")
|
tag_str = ",".join(tags)
|
||||||
return cached_emoji_description
|
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
|
||||||
|
return f"[表情包:{tag_str}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||||
|
|
||||||
# 查询ImageDescriptions表的缓存描述
|
# 查询ImageDescriptions表的缓存描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||||
if cached_description:
|
|
||||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||||
return f"[表情包:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
# === 二步走识别流程 ===
|
# === 二步走识别流程 ===
|
||||||
|
|
||||||
# 第一步:VLM视觉分析 - 生成详细描述
|
# 第一步:VLM视觉分析 - 生成详细描述
|
||||||
if image_format in ["gif", "GIF"]:
|
if image_format in ["gif", "GIF"]:
|
||||||
image_base64_processed = self.transform_gif(image_base64)
|
image_base64_processed = self.transform_gif(image_base64)
|
||||||
@@ -130,10 +144,16 @@ class ImageManager:
|
|||||||
logger.warning("GIF转换失败,无法获取描述")
|
logger.warning("GIF转换失败,无法获取描述")
|
||||||
return "[表情包(GIF处理失败)]"
|
return "[表情包(GIF处理失败)]"
|
||||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
vlm_prompt = (
|
||||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
|
)
|
||||||
|
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||||
|
)
|
||||||
|
|
||||||
if detailed_description is None:
|
if detailed_description is None:
|
||||||
logger.warning("VLM未能生成表情包详细描述")
|
logger.warning("VLM未能生成表情包详细描述")
|
||||||
@@ -150,31 +170,32 @@ class ImageManager:
|
|||||||
3. 输出简短精准,不要解释
|
3. 输出简短精准,不要解释
|
||||||
4. 如果有多个词用逗号分隔
|
4. 如果有多个词用逗号分隔
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 使用较低温度确保输出稳定
|
# 使用较低温度确保输出稳定
|
||||||
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
|
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||||
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
|
emotion_result, _ = await emotion_llm.generate_response_async(
|
||||||
|
emotion_prompt, temperature=0.3, max_tokens=50
|
||||||
|
)
|
||||||
|
|
||||||
if emotion_result is None:
|
if emotion_result is None:
|
||||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||||
# 降级处理:从详细描述中提取关键词
|
# 降级处理:从详细描述中提取关键词
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
words = list(jieba.cut(detailed_description))
|
words = list(jieba.cut(detailed_description))
|
||||||
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
||||||
|
|
||||||
# 处理情感结果,取前1-2个最重要的标签
|
# 处理情感结果,取前1-2个最重要的标签
|
||||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||||
final_emotion = emotions[0] if emotions else "表情"
|
final_emotion = emotions[0] if emotions else "表情"
|
||||||
|
|
||||||
# 如果有第二个情感且不重复,也包含进来
|
# 如果有第二个情感且不重复,也包含进来
|
||||||
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||||
final_emotion = f"{emotions[0]},{emotions[1]}"
|
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||||
|
|
||||||
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||||
|
|
||||||
# 再次检查缓存,防止并发写入时重复生成
|
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
|
||||||
if cached_description:
|
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||||
return f"[表情包:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
@@ -242,9 +263,7 @@ class ImageManager:
|
|||||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||||
return f"[图片:{existing_image.description}]"
|
return f"[图片:{existing_image.description}]"
|
||||||
|
|
||||||
# 查询ImageDescriptions表的缓存描述
|
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||||
cached_description = self._get_description_from_db(image_hash, "image")
|
|
||||||
if cached_description:
|
|
||||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
@@ -252,7 +271,9 @@ class ImageManager:
|
|||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
prompt = global_config.custom_prompt.image_prompt
|
prompt = global_config.custom_prompt.image_prompt
|
||||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||||
|
)
|
||||||
|
|
||||||
if description is None:
|
if description is None:
|
||||||
logger.warning("AI未能生成图片描述")
|
logger.warning("AI未能生成图片描述")
|
||||||
@@ -445,10 +466,7 @@ class ImageManager:
|
|||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 检查图片是否已存在
|
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
|
||||||
|
|
||||||
if existing_image:
|
|
||||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||||
if (
|
if (
|
||||||
not hasattr(existing_image, "image_id")
|
not hasattr(existing_image, "image_id")
|
||||||
@@ -524,9 +542,7 @@ class ImageManager:
|
|||||||
|
|
||||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||||
existing_with_description = Images.get_or_none(
|
existing_with_description = Images.get_or_none(
|
||||||
(Images.emoji_hash == image_hash) &
|
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||||
(Images.description.is_null(False)) &
|
|
||||||
(Images.description != "")
|
|
||||||
)
|
)
|
||||||
if existing_with_description and existing_with_description.id != image.id:
|
if existing_with_description and existing_with_description.id != image.id:
|
||||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||||
@@ -538,8 +554,7 @@ class ImageManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 检查ImageDescriptions表的缓存描述
|
# 检查ImageDescriptions表的缓存描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "image")
|
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||||
if cached_description:
|
|
||||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||||
image.description = cached_description
|
image.description = cached_description
|
||||||
image.vlm_processed = True
|
image.vlm_processed = True
|
||||||
@@ -554,15 +569,15 @@ class ImageManager:
|
|||||||
|
|
||||||
# 获取VLM描述
|
# 获取VLM描述
|
||||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||||
|
)
|
||||||
|
|
||||||
if description is None:
|
if description is None:
|
||||||
logger.warning("VLM未能生成图片描述")
|
logger.warning("VLM未能生成图片描述")
|
||||||
description = "无法生成描述"
|
description = "无法生成描述"
|
||||||
|
|
||||||
# 再次检查缓存,防止并发写入时重复生成
|
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||||
cached_description = self._get_description_from_db(image_hash, "image")
|
|
||||||
if cached_description:
|
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||||
description = cached_description
|
description = cached_description
|
||||||
|
|
||||||
@@ -606,7 +621,7 @@ def image_path_to_base64(image_path: str) -> str:
|
|||||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||||
|
|
||||||
with open(image_path, "rb") as f:
|
with open(image_path, "rb") as f:
|
||||||
image_data = f.read()
|
if image_data := f.read():
|
||||||
if not image_data:
|
return base64.b64encode(image_data).decode("utf-8")
|
||||||
|
else:
|
||||||
raise IOError(f"读取图片文件失败: {image_path}")
|
raise IOError(f"读取图片文件失败: {image_path}")
|
||||||
return base64.b64encode(image_data).decode("utf-8")
|
|
||||||
|
|||||||
@@ -1,35 +1,29 @@
|
|||||||
import base64
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("chat_voice")
|
logger = get_logger("chat_voice")
|
||||||
|
|
||||||
|
|
||||||
async def get_voice_text(voice_base64: str) -> str:
|
async def get_voice_text(voice_base64: str) -> str:
|
||||||
"""获取音频文件描述"""
|
"""获取音频文件转录文本"""
|
||||||
if not global_config.voice.enable_asr:
|
if not global_config.voice.enable_asr:
|
||||||
logger.warning("语音识别未启用,无法处理语音消息")
|
logger.warning("语音识别未启用,无法处理语音消息")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
try:
|
try:
|
||||||
# 解码base64音频数据
|
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
|
||||||
# 确保base64字符串只包含ASCII字符
|
text = await _llm.generate_response_for_voice(voice_base64)
|
||||||
if isinstance(voice_base64, str):
|
|
||||||
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
|
|
||||||
voice_bytes = base64.b64decode(voice_base64)
|
|
||||||
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
|
|
||||||
text = await _llm.generate_response_for_voice(voice_bytes)
|
|
||||||
if text is None:
|
if text is None:
|
||||||
logger.warning("未能生成语音文本")
|
logger.warning("未能生成语音文本")
|
||||||
return "[语音(文本生成失败)]"
|
return "[语音(文本生成失败)]"
|
||||||
|
|
||||||
logger.debug(f"描述是{text}")
|
logger.debug(f"描述是{text}")
|
||||||
|
|
||||||
return f"[语音:{text}]"
|
return f"[语音:{text}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"语音转文字失败: {str(e)}")
|
logger.error(f"语音转文字失败: {str(e)}")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
|
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from src.config.config import global_config
|
|
||||||
from .willing_manager import BaseWillingManager
|
|
||||||
|
|
||||||
|
|
||||||
class ClassicalWillingManager(BaseWillingManager):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self._decay_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
async def _decay_reply_willing(self):
|
|
||||||
"""定期衰减回复意愿"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
for chat_id in self.chat_reply_willing:
|
|
||||||
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
|
|
||||||
|
|
||||||
async def async_task_starter(self):
|
|
||||||
if self._decay_task is None:
|
|
||||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
|
||||||
|
|
||||||
async def get_reply_probability(self, message_id):
|
|
||||||
# sourcery skip: inline-immediately-returned-variable
|
|
||||||
willing_info = self.ongoing_messages[message_id]
|
|
||||||
chat_id = willing_info.chat_id
|
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
|
||||||
|
|
||||||
# print(f"[{chat_id}] 回复意愿: {current_willing}")
|
|
||||||
|
|
||||||
interested_rate = willing_info.interested_rate
|
|
||||||
|
|
||||||
# print(f"[{chat_id}] 兴趣值: {interested_rate}")
|
|
||||||
|
|
||||||
if interested_rate > 0.2:
|
|
||||||
current_willing += interested_rate - 0.2
|
|
||||||
|
|
||||||
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
|
||||||
current_willing += 1 if current_willing < 1.0 else 0.2
|
|
||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
|
||||||
|
|
||||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1.5)
|
|
||||||
|
|
||||||
# print(f"[{chat_id}] 回复概率: {reply_probability}")
|
|
||||||
|
|
||||||
return reply_probability
|
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def after_generate_reply_handle(self, message_id):
|
|
||||||
if message_id not in self.ongoing_messages:
|
|
||||||
return
|
|
||||||
|
|
||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
|
||||||
if current_willing < 1:
|
|
||||||
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.3)
|
|
||||||
|
|
||||||
async def not_reply_handle(self, message_id):
|
|
||||||
return await super().not_reply_handle(message_id)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from .willing_manager import BaseWillingManager
|
|
||||||
|
|
||||||
NOT_IMPLEMENTED_MESSAGE = "\ncustom模式你实现了吗?没自行实现不要选custom。给你退了快点给你麦爹配置\n注:以上内容由gemini生成,如有不满请投诉gemini"
|
|
||||||
|
|
||||||
class CustomWillingManager(BaseWillingManager):
|
|
||||||
async def async_task_starter(self) -> None:
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id: str):
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
|
|
||||||
async def after_generate_reply_handle(self, message_id: str):
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
|
|
||||||
async def not_reply_handle(self, message_id: str):
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
|
|
||||||
async def get_reply_probability(self, message_id: str):
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE)
|
|
||||||
@@ -1,296 +0,0 @@
|
|||||||
"""
|
|
||||||
Mxp 模式:梦溪畔独家赞助
|
|
||||||
此模式的一些参数不会在配置文件中显示,要修改请在可变参数下修改
|
|
||||||
同时一些全局设置对此模式无效
|
|
||||||
此模式的可变参数暂时比较草率,需要调参仙人的大手
|
|
||||||
此模式的特点:
|
|
||||||
1.每个聊天流的每个用户的意愿是独立的
|
|
||||||
2.接入关系系统,关系会影响意愿值(已移除,因为关系系统重构)
|
|
||||||
3.会根据群聊的热度来调整基础意愿值
|
|
||||||
4.限制同时思考的消息数量,防止喷射
|
|
||||||
5.拥有单聊增益,无论在群里还是私聊,只要bot一直和你聊,就会增加意愿值
|
|
||||||
6.意愿分为衰减意愿+临时意愿
|
|
||||||
7.疲劳机制
|
|
||||||
|
|
||||||
如果你发现本模式出现了bug
|
|
||||||
上上策是询问智慧的小草神()
|
|
||||||
上策是询问万能的千石可乐
|
|
||||||
中策是发issue
|
|
||||||
下下策是询问一个菜鸟(@梦溪畔)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .willing_manager import BaseWillingManager
|
|
||||||
from typing import Dict
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
|
|
||||||
|
|
||||||
class MxpWillingManager(BaseWillingManager):
|
|
||||||
"""Mxp意愿管理器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
|
|
||||||
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
|
|
||||||
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
|
|
||||||
self.temporary_willing: float = 0 # 临时意愿值
|
|
||||||
self.chat_bot_message_time: Dict[str, list[float]] = {} # 聊天流ID: bot已回复消息时间
|
|
||||||
self.chat_fatigue_punishment_list: Dict[
|
|
||||||
str, list[tuple[float, float]]
|
|
||||||
] = {} # 聊天流疲劳惩罚列, 聊天流ID: 惩罚时间列(开始时间,持续时间)
|
|
||||||
self.chat_fatigue_willing_attenuation: Dict[str, float] = {} # 聊天流疲劳意愿衰减值
|
|
||||||
|
|
||||||
# 可变参数
|
|
||||||
self.intention_decay_rate = 0.93 # 意愿衰减率
|
|
||||||
|
|
||||||
self.number_of_message_storage = 12 # 消息存储数量
|
|
||||||
self.expected_replies_per_min = 3 # 每分钟预期回复数
|
|
||||||
self.basic_maximum_willing = 0.5 # 基础最大意愿值
|
|
||||||
|
|
||||||
self.mention_willing_gain = 0.6 # 提及意愿增益
|
|
||||||
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
|
||||||
self.single_chat_gain = 0.12 # 单聊增益
|
|
||||||
|
|
||||||
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
|
|
||||||
self.fatigue_coefficient = 1.0 # 疲劳系数
|
|
||||||
|
|
||||||
self.is_debug = False # 是否开启调试模式
|
|
||||||
|
|
||||||
async def async_task_starter(self) -> None:
|
|
||||||
"""异步任务启动器"""
|
|
||||||
asyncio.create_task(self._return_to_basic_willing())
|
|
||||||
asyncio.create_task(self._chat_new_message_to_change_basic_willing())
|
|
||||||
asyncio.create_task(self._fatigue_attenuation())
|
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id: str):
|
|
||||||
"""回复前处理"""
|
|
||||||
current_time = time.time()
|
|
||||||
async with self.lock:
|
|
||||||
w_info = self.ongoing_messages[message_id]
|
|
||||||
if w_info.chat_id not in self.chat_bot_message_time:
|
|
||||||
self.chat_bot_message_time[w_info.chat_id] = []
|
|
||||||
self.chat_bot_message_time[w_info.chat_id] = [
|
|
||||||
t for t in self.chat_bot_message_time[w_info.chat_id] if current_time - t < 60
|
|
||||||
]
|
|
||||||
self.chat_bot_message_time[w_info.chat_id].append(current_time)
|
|
||||||
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
|
|
||||||
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
|
|
||||||
self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2))
|
|
||||||
|
|
||||||
async def after_generate_reply_handle(self, message_id: str):
|
|
||||||
"""回复后处理"""
|
|
||||||
async with self.lock:
|
|
||||||
w_info = self.ongoing_messages[message_id]
|
|
||||||
# 移除关系值相关代码
|
|
||||||
# rel_value = await w_info.person_info_manager.get_value(w_info.person_id, "relationship_value")
|
|
||||||
# rel_level = self._get_relationship_level_num(rel_value)
|
|
||||||
# self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
|
|
||||||
|
|
||||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0))
|
|
||||||
if now_chat_new_person[0] == w_info.person_id:
|
|
||||||
if now_chat_new_person[1] < 3:
|
|
||||||
tmp_list = list(now_chat_new_person)
|
|
||||||
tmp_list[1] += 1 # type: ignore
|
|
||||||
self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore
|
|
||||||
else:
|
|
||||||
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
|
|
||||||
|
|
||||||
async def not_reply_handle(self, message_id: str):
|
|
||||||
"""不回复处理"""
|
|
||||||
async with self.lock:
|
|
||||||
w_info = self.ongoing_messages[message_id]
|
|
||||||
if w_info.is_mentioned_bot:
|
|
||||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.mention_willing_gain / 2.5
|
|
||||||
if (
|
|
||||||
w_info.chat_id in self.last_response_person
|
|
||||||
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
|
||||||
and self.last_response_person[w_info.chat_id][1]
|
|
||||||
):
|
|
||||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
|
|
||||||
2 * self.last_response_person[w_info.chat_id][1] - 1
|
|
||||||
)
|
|
||||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0))
|
|
||||||
if now_chat_new_person[0] != w_info.person_id:
|
|
||||||
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
|
|
||||||
|
|
||||||
async def get_reply_probability(self, message_id: str):
|
|
||||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
|
||||||
"""获取回复概率"""
|
|
||||||
async with self.lock:
|
|
||||||
w_info = self.ongoing_messages[message_id]
|
|
||||||
current_willing = self.chat_person_reply_willing[w_info.chat_id][w_info.person_id]
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(f"基础意愿值:{current_willing}")
|
|
||||||
|
|
||||||
if w_info.is_mentioned_bot:
|
|
||||||
willing_gain = self.mention_willing_gain / (int(current_willing) + 1)
|
|
||||||
current_willing += willing_gain
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(f"提及增益:{willing_gain}")
|
|
||||||
|
|
||||||
if w_info.interested_rate > 0:
|
|
||||||
willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
|
|
||||||
current_willing += willing_gain
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(f"兴趣增益:{willing_gain}")
|
|
||||||
|
|
||||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
|
|
||||||
|
|
||||||
# 添加单聊增益
|
|
||||||
if (
|
|
||||||
w_info.chat_id in self.last_response_person
|
|
||||||
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
|
||||||
and self.last_response_person[w_info.chat_id][1]
|
|
||||||
):
|
|
||||||
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(
|
|
||||||
f"单聊增益:{self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_willing += self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
|
|
||||||
|
|
||||||
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
|
|
||||||
chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
|
|
||||||
if len(chat_person_ongoing_messages) >= 2:
|
|
||||||
current_willing = 0
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug("进行中消息惩罚:归0")
|
|
||||||
elif len(chat_ongoing_messages) == 2:
|
|
||||||
current_willing -= 0.5
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug("进行中消息惩罚:-0.5")
|
|
||||||
elif len(chat_ongoing_messages) == 3:
|
|
||||||
current_willing -= 1.5
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug("进行中消息惩罚:-1.5")
|
|
||||||
elif len(chat_ongoing_messages) >= 4:
|
|
||||||
current_willing = 0
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug("进行中消息惩罚:归0")
|
|
||||||
|
|
||||||
probability = self._willing_to_probability(current_willing)
|
|
||||||
|
|
||||||
self.temporary_willing = current_willing
|
|
||||||
|
|
||||||
return probability
|
|
||||||
|
|
||||||
async def _return_to_basic_willing(self):
|
|
||||||
"""使每个人的意愿恢复到chat基础意愿"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
async with self.lock:
|
|
||||||
for chat_id, person_willing in self.chat_person_reply_willing.items():
|
|
||||||
for person_id, willing in person_willing.items():
|
|
||||||
if chat_id not in self.chat_reply_willing:
|
|
||||||
self.logger.debug(f"聊天流{chat_id}不存在,错误")
|
|
||||||
continue
|
|
||||||
basic_willing = self.chat_reply_willing[chat_id]
|
|
||||||
person_willing[person_id] = (
|
|
||||||
basic_willing + (willing - basic_willing) * self.intention_decay_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup(self, message: dict, chat_stream: ChatStream):
|
|
||||||
super().setup(message, chat_stream)
|
|
||||||
stream_id = chat_stream.stream_id
|
|
||||||
self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing)
|
|
||||||
self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {})
|
|
||||||
self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = (
|
|
||||||
self.chat_person_reply_willing[stream_id].get(
|
|
||||||
self.ongoing_messages[message.get("message_id", "")].person_id,
|
|
||||||
self.chat_reply_willing[stream_id],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
if stream_id not in self.chat_new_message_time:
|
|
||||||
self.chat_new_message_time[stream_id] = []
|
|
||||||
self.chat_new_message_time[stream_id].append(current_time)
|
|
||||||
if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage:
|
|
||||||
self.chat_new_message_time[stream_id].pop(0)
|
|
||||||
|
|
||||||
if stream_id not in self.chat_fatigue_punishment_list:
|
|
||||||
self.chat_fatigue_punishment_list[stream_id] = [
|
|
||||||
(
|
|
||||||
current_time,
|
|
||||||
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
self.chat_fatigue_willing_attenuation[stream_id] = (
|
|
||||||
-2 * self.basic_maximum_willing * self.fatigue_coefficient
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _willing_to_probability(willing: float) -> float:
|
|
||||||
"""意愿值转化为概率"""
|
|
||||||
willing = max(0, willing)
|
|
||||||
if willing < 2:
|
|
||||||
return math.atan(willing * 2) / math.pi * 2
|
|
||||||
elif willing < 2.5:
|
|
||||||
return math.atan(willing * 4) / math.pi * 2
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
async def _chat_new_message_to_change_basic_willing(self):
|
|
||||||
"""聊天流新消息改变基础意愿"""
|
|
||||||
update_time = 20
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(update_time)
|
|
||||||
async with self.lock:
|
|
||||||
for chat_id, message_times in self.chat_new_message_time.items():
|
|
||||||
# 清理过期消息
|
|
||||||
current_time = time.time()
|
|
||||||
message_times = [
|
|
||||||
msg_time
|
|
||||||
for msg_time in message_times
|
|
||||||
if current_time - msg_time
|
|
||||||
< self.number_of_message_storage
|
|
||||||
* self.basic_maximum_willing
|
|
||||||
/ self.expected_replies_per_min
|
|
||||||
* 60
|
|
||||||
]
|
|
||||||
self.chat_new_message_time[chat_id] = message_times
|
|
||||||
|
|
||||||
if len(message_times) < self.number_of_message_storage:
|
|
||||||
self.chat_reply_willing[chat_id] = self.basic_maximum_willing
|
|
||||||
update_time = 20
|
|
||||||
elif len(message_times) == self.number_of_message_storage:
|
|
||||||
time_interval = current_time - message_times[0]
|
|
||||||
basic_willing = self._basic_willing_calculate(time_interval)
|
|
||||||
self.chat_reply_willing[chat_id] = basic_willing
|
|
||||||
update_time = 17 * basic_willing / self.basic_maximum_willing + 3
|
|
||||||
else:
|
|
||||||
self.logger.debug(f"聊天流{chat_id}消息时间数量异常,数量:{len(message_times)}")
|
|
||||||
self.chat_reply_willing[chat_id] = 0
|
|
||||||
if self.is_debug:
|
|
||||||
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
|
|
||||||
|
|
||||||
def _basic_willing_calculate(self, t: float) -> float:
|
|
||||||
"""基础意愿值计算"""
|
|
||||||
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2
|
|
||||||
|
|
||||||
async def _fatigue_attenuation(self):
|
|
||||||
"""疲劳衰减"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
current_time = time.time()
|
|
||||||
async with self.lock:
|
|
||||||
for chat_id, fatigue_list in self.chat_fatigue_punishment_list.items():
|
|
||||||
fatigue_list = [z for z in fatigue_list if current_time - z[0] < z[1]]
|
|
||||||
self.chat_fatigue_willing_attenuation[chat_id] = 0
|
|
||||||
for start_time, duration in fatigue_list:
|
|
||||||
self.chat_fatigue_willing_attenuation[chat_id] += (
|
|
||||||
self.chat_reply_willing[chat_id]
|
|
||||||
* 2
|
|
||||||
/ math.pi
|
|
||||||
* math.asin(2 * (current_time - start_time) / duration - 1)
|
|
||||||
- self.chat_reply_willing[chat_id]
|
|
||||||
) * self.fatigue_coefficient
|
|
||||||
|
|
||||||
async def get_willing(self, chat_id):
|
|
||||||
return self.temporary_willing
|
|
||||||
@@ -1,180 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Optional, Any
|
|
||||||
from rich.traceback import install
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
"""
|
|
||||||
基类方法概览:
|
|
||||||
以下8个方法是你必须在子类重写的(哪怕什么都不干):
|
|
||||||
async_task_starter 在程序启动时执行,在其中用asyncio.create_task启动你想要执行的异步任务
|
|
||||||
before_generate_reply_handle 确定要回复后,在生成回复前的处理
|
|
||||||
after_generate_reply_handle 确定要回复后,在生成回复后的处理
|
|
||||||
not_reply_handle 确定不回复后的处理
|
|
||||||
get_reply_probability 获取回复概率
|
|
||||||
get_variable_parameters 暂不确定
|
|
||||||
set_variable_parameters 暂不确定
|
|
||||||
以下2个方法根据你的实现可以做调整:
|
|
||||||
get_willing 获取某聊天流意愿
|
|
||||||
set_willing 设置某聊天流意愿
|
|
||||||
规范说明:
|
|
||||||
模块文件命名: `mode_{manager_type}.py`
|
|
||||||
示例: 若 `manager_type="aggressive"`,则模块文件应为 `mode_aggressive.py`
|
|
||||||
类命名: `{manager_type}WillingManager` (首字母大写)
|
|
||||||
示例: 在 `mode_aggressive.py` 中,类名应为 `AggressiveWillingManager`
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("willing")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WillingInfo:
|
|
||||||
"""此类保存意愿模块常用的参数
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
message (MessageRecv): 原始消息对象
|
|
||||||
chat (ChatStream): 聊天流对象
|
|
||||||
person_info_manager (PersonInfoManager): 用户信息管理对象
|
|
||||||
chat_id (str): 当前聊天流的标识符
|
|
||||||
person_id (str): 发送者的个人信息的标识符
|
|
||||||
group_id (str): 群组ID(如果是私聊则为空)
|
|
||||||
is_mentioned_bot (bool): 是否提及了bot
|
|
||||||
is_emoji (bool): 是否为表情包
|
|
||||||
interested_rate (float): 兴趣度
|
|
||||||
"""
|
|
||||||
|
|
||||||
message: Dict[str, Any] # 原始消息数据
|
|
||||||
chat: ChatStream
|
|
||||||
person_info_manager: PersonInfoManager
|
|
||||||
chat_id: str
|
|
||||||
person_id: str
|
|
||||||
group_info: Optional[GroupInfo]
|
|
||||||
is_mentioned_bot: bool
|
|
||||||
is_emoji: bool
|
|
||||||
is_picid: bool
|
|
||||||
interested_rate: float
|
|
||||||
# current_mood: float 当前心情?
|
|
||||||
|
|
||||||
|
|
||||||
class BaseWillingManager(ABC):
|
|
||||||
"""回复意愿管理基类"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, manager_type: str) -> "BaseWillingManager":
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(f".mode_{manager_type}", __package__)
|
|
||||||
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
|
|
||||||
if not issubclass(manager_class, cls):
|
|
||||||
raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}")
|
|
||||||
else:
|
|
||||||
logger.info(f"普通回复模式:{manager_type}")
|
|
||||||
return manager_class()
|
|
||||||
except (ImportError, AttributeError, TypeError) as e:
|
|
||||||
module = importlib.import_module(".mode_classical", __package__)
|
|
||||||
manager_class = module.ClassicalWillingManager
|
|
||||||
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
|
|
||||||
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}。")
|
|
||||||
return manager_class()
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
|
||||||
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
|
||||||
self.lock = asyncio.Lock()
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
def setup(self, message: dict, chat: ChatStream):
|
|
||||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
|
|
||||||
self.ongoing_messages[message.get("message_id", "")] = WillingInfo(
|
|
||||||
message=message,
|
|
||||||
chat=chat,
|
|
||||||
person_info_manager=get_person_info_manager(),
|
|
||||||
chat_id=chat.stream_id,
|
|
||||||
person_id=person_id,
|
|
||||||
group_info=chat.group_info,
|
|
||||||
is_mentioned_bot=message.get("is_mentioned", False),
|
|
||||||
is_emoji=message.get("is_emoji", False),
|
|
||||||
is_picid=message.get("is_picid", False),
|
|
||||||
interested_rate = message.get("interest_value") or 0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, message_id: str):
|
|
||||||
del_message = self.ongoing_messages.pop(message_id, None)
|
|
||||||
if not del_message:
|
|
||||||
logger.debug(f"尝试删除不存在的消息 ID: {message_id},可能已被其他流程处理,喵~")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def async_task_starter(self) -> None:
|
|
||||||
"""抽象方法:异步任务启动器"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def before_generate_reply_handle(self, message_id: str):
|
|
||||||
"""抽象方法:回复前处理"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def after_generate_reply_handle(self, message_id: str):
|
|
||||||
"""抽象方法:回复后处理"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def not_reply_handle(self, message_id: str):
|
|
||||||
"""抽象方法:不回复处理"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_reply_probability(self, message_id: str):
|
|
||||||
"""抽象方法:获取回复概率"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_willing(self, chat_id: str):
|
|
||||||
"""获取指定聊天流的回复意愿"""
|
|
||||||
async with self.lock:
|
|
||||||
return self.chat_reply_willing.get(chat_id, 0)
|
|
||||||
|
|
||||||
async def set_willing(self, chat_id: str, willing: float):
|
|
||||||
"""设置指定聊天流的回复意愿"""
|
|
||||||
async with self.lock:
|
|
||||||
self.chat_reply_willing[chat_id] = willing
|
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# async def get_variable_parameters(self) -> Dict[str, str]:
|
|
||||||
# """抽象方法:获取可变参数"""
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# async def set_variable_parameters(self, parameters: Dict[str, any]):
|
|
||||||
# """抽象方法:设置可变参数"""
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
def init_willing_manager() -> BaseWillingManager:
|
|
||||||
"""
|
|
||||||
根据配置初始化并返回对应的WillingManager实例
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应mode的WillingManager实例
|
|
||||||
"""
|
|
||||||
mode = global_config.normal_chat.willing_mode.lower()
|
|
||||||
return BaseWillingManager.create(mode)
|
|
||||||
|
|
||||||
|
|
||||||
# 全局willing_manager对象
|
|
||||||
willing_manager = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_willing_manager():
|
|
||||||
global willing_manager
|
|
||||||
if willing_manager is None:
|
|
||||||
willing_manager = init_willing_manager()
|
|
||||||
return willing_manager
|
|
||||||
@@ -79,6 +79,8 @@ class LLMUsage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_name = TextField(index=True) # 添加索引
|
model_name = TextField(index=True) # 添加索引
|
||||||
|
model_assign_name = TextField(null=True) # 添加索引
|
||||||
|
model_api_provider = TextField(null=True) # 添加索引
|
||||||
user_id = TextField(index=True) # 添加索引
|
user_id = TextField(index=True) # 添加索引
|
||||||
request_type = TextField(index=True) # 添加索引
|
request_type = TextField(index=True) # 添加索引
|
||||||
endpoint = TextField()
|
endpoint = TextField()
|
||||||
@@ -86,6 +88,7 @@ class LLMUsage(BaseModel):
|
|||||||
completion_tokens = IntegerField()
|
completion_tokens = IntegerField()
|
||||||
total_tokens = IntegerField()
|
total_tokens = IntegerField()
|
||||||
cost = DoubleField()
|
cost = DoubleField()
|
||||||
|
time_cost = DoubleField(null=True)
|
||||||
status = TextField()
|
status = TextField()
|
||||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||||
|
|
||||||
@@ -130,6 +133,9 @@ class Messages(BaseModel):
|
|||||||
reply_to = TextField(null=True)
|
reply_to = TextField(null=True)
|
||||||
|
|
||||||
interest_value = DoubleField(null=True)
|
interest_value = DoubleField(null=True)
|
||||||
|
key_words = TextField(null=True)
|
||||||
|
key_words_lite = TextField(null=True)
|
||||||
|
|
||||||
is_mentioned = BooleanField(null=True)
|
is_mentioned = BooleanField(null=True)
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
# 从 chat_info 扁平化而来的字段
|
||||||
@@ -146,14 +152,13 @@ class Messages(BaseModel):
|
|||||||
chat_info_last_active_time = DoubleField()
|
chat_info_last_active_time = DoubleField()
|
||||||
|
|
||||||
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
||||||
user_platform = TextField()
|
user_platform = TextField(null=True)
|
||||||
user_id = TextField()
|
user_id = TextField(null=True)
|
||||||
user_nickname = TextField()
|
user_nickname = TextField(null=True)
|
||||||
user_cardname = TextField(null=True)
|
user_cardname = TextField(null=True)
|
||||||
|
|
||||||
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
||||||
display_message = TextField(null=True) # 显示的消息
|
display_message = TextField(null=True) # 显示的消息
|
||||||
memorized_times = IntegerField(default=0) # 被记忆的次数
|
|
||||||
|
|
||||||
priority_mode = TextField(null=True)
|
priority_mode = TextField(null=True)
|
||||||
priority_info = TextField(null=True)
|
priority_info = TextField(null=True)
|
||||||
@@ -162,6 +167,9 @@ class Messages(BaseModel):
|
|||||||
is_emoji = BooleanField(default=False)
|
is_emoji = BooleanField(default=False)
|
||||||
is_picid = BooleanField(default=False)
|
is_picid = BooleanField(default=False)
|
||||||
is_command = BooleanField(default=False)
|
is_command = BooleanField(default=False)
|
||||||
|
is_notify = BooleanField(default=False)
|
||||||
|
|
||||||
|
selected_expressions = TextField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
@@ -247,28 +255,60 @@ class PersonInfo(BaseModel):
|
|||||||
用于存储个人信息数据的模型。
|
用于存储个人信息数据的模型。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
is_known = BooleanField(default=False) # 是否已认识
|
||||||
person_id = TextField(unique=True, index=True) # 个人唯一ID
|
person_id = TextField(unique=True, index=True) # 个人唯一ID
|
||||||
person_name = TextField(null=True) # 个人名称 (允许为空)
|
person_name = TextField(null=True) # 个人名称 (允许为空)
|
||||||
name_reason = TextField(null=True) # 名称设定的原因
|
name_reason = TextField(null=True) # 名称设定的原因
|
||||||
platform = TextField() # 平台
|
platform = TextField() # 平台
|
||||||
user_id = TextField(index=True) # 用户ID
|
user_id = TextField(index=True) # 用户ID
|
||||||
nickname = TextField() # 用户昵称
|
nickname = TextField(null=True) # 用户昵称
|
||||||
impression = TextField(null=True) # 个人印象
|
memory_points = TextField(null=True) # 个人印象的点
|
||||||
short_impression = TextField(null=True) # 个人印象的简短描述
|
|
||||||
points = TextField(null=True) # 个人印象的点
|
|
||||||
forgotten_points = TextField(null=True) # 被遗忘的点
|
|
||||||
info_list = TextField(null=True) # 与Bot的互动
|
|
||||||
|
|
||||||
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
||||||
know_since = FloatField(null=True) # 首次印象总结时间
|
know_since = FloatField(null=True) # 首次印象总结时间
|
||||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||||
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
|
||||||
|
|
||||||
|
attitude_to_me = TextField(null=True) # 对bot的态度
|
||||||
|
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
|
||||||
|
friendly_value = FloatField(null=True) # 对bot的友好程度
|
||||||
|
friendly_value_confidence = FloatField(null=True) # 对bot的友好程度置信度
|
||||||
|
rudeness = TextField(null=True) # 对bot的冒犯程度
|
||||||
|
rudeness_confidence = FloatField(null=True) # 对bot的冒犯程度置信度
|
||||||
|
neuroticism = TextField(null=True) # 对bot的神经质程度
|
||||||
|
neuroticism_confidence = FloatField(null=True) # 对bot的神经质程度置信度
|
||||||
|
conscientiousness = TextField(null=True) # 对bot的尽责程度
|
||||||
|
conscientiousness_confidence = FloatField(null=True) # 对bot的尽责程度置信度
|
||||||
|
likeness = TextField(null=True) # 对bot的相似程度
|
||||||
|
likeness_confidence = FloatField(null=True) # 对bot的相似程度置信度
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "person_info"
|
table_name = "person_info"
|
||||||
|
|
||||||
|
|
||||||
|
class GroupInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储群组信息数据的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
group_id = TextField(unique=True, index=True) # 群组唯一ID
|
||||||
|
group_name = TextField(null=True) # 群组名称 (允许为空)
|
||||||
|
platform = TextField() # 平台
|
||||||
|
group_impression = TextField(null=True) # 群组印象
|
||||||
|
member_list = TextField(null=True) # 群成员列表 (JSON格式)
|
||||||
|
topic = TextField(null=True) # 群组基本信息
|
||||||
|
|
||||||
|
create_time = FloatField(null=True) # 创建时间 (时间戳)
|
||||||
|
last_active = FloatField(null=True) # 最后活跃时间
|
||||||
|
member_count = IntegerField(null=True, default=0) # 成员数量
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "group_info"
|
||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel):
|
class Memory(BaseModel):
|
||||||
memory_id = TextField(index=True)
|
memory_id = TextField(index=True)
|
||||||
chat_id = TextField(null=True)
|
chat_id = TextField(null=True)
|
||||||
@@ -281,20 +321,6 @@ class Memory(BaseModel):
|
|||||||
table_name = "memory"
|
table_name = "memory"
|
||||||
|
|
||||||
|
|
||||||
class Knowledges(BaseModel):
|
|
||||||
"""
|
|
||||||
用于存储知识库条目的模型。
|
|
||||||
"""
|
|
||||||
|
|
||||||
content = TextField() # 知识内容的文本
|
|
||||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
|
||||||
# 可以添加其他元数据字段,如 source, create_time 等
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
# database = db # 继承自 BaseModel
|
|
||||||
table_name = "knowledges"
|
|
||||||
|
|
||||||
|
|
||||||
class Expression(BaseModel):
|
class Expression(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储表达风格的模型。
|
用于存储表达风格的模型。
|
||||||
@@ -311,31 +337,6 @@ class Expression(BaseModel):
|
|||||||
class Meta:
|
class Meta:
|
||||||
table_name = "expression"
|
table_name = "expression"
|
||||||
|
|
||||||
|
|
||||||
class ThinkingLog(BaseModel):
|
|
||||||
chat_id = TextField(index=True)
|
|
||||||
trigger_text = TextField(null=True)
|
|
||||||
response_text = TextField(null=True)
|
|
||||||
|
|
||||||
# Store complex dicts/lists as JSON strings
|
|
||||||
trigger_info_json = TextField(null=True)
|
|
||||||
response_info_json = TextField(null=True)
|
|
||||||
timing_results_json = TextField(null=True)
|
|
||||||
chat_history_json = TextField(null=True)
|
|
||||||
chat_history_in_thinking_json = TextField(null=True)
|
|
||||||
chat_history_after_response_json = TextField(null=True)
|
|
||||||
heartflow_data_json = TextField(null=True)
|
|
||||||
reasoning_data_json = TextField(null=True)
|
|
||||||
|
|
||||||
# Add a timestamp for the log entry itself
|
|
||||||
# Ensure you have: from peewee import DateTimeField
|
|
||||||
# And: import datetime
|
|
||||||
created_at = DateTimeField(default=datetime.datetime.now)
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
table_name = "thinking_logs"
|
|
||||||
|
|
||||||
|
|
||||||
class GraphNodes(BaseModel):
|
class GraphNodes(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储记忆图节点的模型
|
用于存储记忆图节点的模型
|
||||||
@@ -343,6 +344,7 @@ class GraphNodes(BaseModel):
|
|||||||
|
|
||||||
concept = TextField(unique=True, index=True) # 节点概念
|
concept = TextField(unique=True, index=True) # 节点概念
|
||||||
memory_items = TextField() # JSON格式存储的记忆列表
|
memory_items = TextField() # JSON格式存储的记忆列表
|
||||||
|
weight = FloatField(default=0.0) # 节点权重
|
||||||
hash = TextField() # 节点哈希值
|
hash = TextField() # 节点哈希值
|
||||||
created_time = FloatField() # 创建时间戳
|
created_time = FloatField() # 创建时间戳
|
||||||
last_modified = FloatField() # 最后修改时间戳
|
last_modified = FloatField() # 最后修改时间戳
|
||||||
@@ -382,9 +384,7 @@ def create_tables():
|
|||||||
ImageDescriptions,
|
ImageDescriptions,
|
||||||
OnlineTime,
|
OnlineTime,
|
||||||
PersonInfo,
|
PersonInfo,
|
||||||
Knowledges,
|
|
||||||
Expression,
|
Expression,
|
||||||
ThinkingLog,
|
|
||||||
GraphNodes, # 添加图节点表
|
GraphNodes, # 添加图节点表
|
||||||
GraphEdges, # 添加图边表
|
GraphEdges, # 添加图边表
|
||||||
Memory,
|
Memory,
|
||||||
@@ -393,10 +393,14 @@ def create_tables():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_database():
|
def initialize_database(sync_constraints=False):
|
||||||
"""
|
"""
|
||||||
检查所有定义的表是否存在,如果不存在则创建它们。
|
检查所有定义的表是否存在,如果不存在则创建它们。
|
||||||
检查所有表的所有字段是否存在,如果缺失则自动添加。
|
检查所有表的所有字段是否存在,如果缺失则自动添加。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_constraints (bool): 是否同步字段约束。默认为 False。
|
||||||
|
如果为 True,会检查并修复字段的 NULL 约束不一致问题。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
@@ -408,10 +412,8 @@ def initialize_database():
|
|||||||
ImageDescriptions,
|
ImageDescriptions,
|
||||||
OnlineTime,
|
OnlineTime,
|
||||||
PersonInfo,
|
PersonInfo,
|
||||||
Knowledges,
|
|
||||||
Expression,
|
Expression,
|
||||||
Memory,
|
Memory,
|
||||||
ThinkingLog,
|
|
||||||
GraphNodes,
|
GraphNodes,
|
||||||
GraphEdges,
|
GraphEdges,
|
||||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||||
@@ -478,6 +480,13 @@ def initialize_database():
|
|||||||
logger.info(f"字段 '{field_name}' 删除成功")
|
logger.info(f"字段 '{field_name}' 删除成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
||||||
|
|
||||||
|
# 如果启用了约束同步,执行约束检查和修复
|
||||||
|
if sync_constraints:
|
||||||
|
logger.debug("开始同步数据库字段约束...")
|
||||||
|
sync_field_constraints()
|
||||||
|
logger.debug("数据库字段约束同步完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
||||||
# 如果检查失败(例如数据库不可用),则退出
|
# 如果检查失败(例如数据库不可用),则退出
|
||||||
@@ -486,5 +495,261 @@ def initialize_database():
|
|||||||
logger.info("数据库初始化完成")
|
logger.info("数据库初始化完成")
|
||||||
|
|
||||||
|
|
||||||
|
def sync_field_constraints():
|
||||||
|
"""
|
||||||
|
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
|
||||||
|
如果发现不一致,会自动修复字段约束。
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = [
|
||||||
|
ChatStreams,
|
||||||
|
LLMUsage,
|
||||||
|
Emoji,
|
||||||
|
Messages,
|
||||||
|
Images,
|
||||||
|
ImageDescriptions,
|
||||||
|
OnlineTime,
|
||||||
|
PersonInfo,
|
||||||
|
Expression,
|
||||||
|
Memory,
|
||||||
|
GraphNodes,
|
||||||
|
GraphEdges,
|
||||||
|
ActionRecords,
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db:
|
||||||
|
for model in models:
|
||||||
|
table_name = model._meta.table_name
|
||||||
|
if not db.table_exists(model):
|
||||||
|
logger.warning(f"表 '{table_name}' 不存在,跳过约束检查")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"检查表 '{table_name}' 的字段约束...")
|
||||||
|
|
||||||
|
# 获取当前表结构信息
|
||||||
|
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||||
|
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||||
|
for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# 检查每个模型字段的约束
|
||||||
|
constraints_to_fix = []
|
||||||
|
for field_name, field_obj in model._meta.fields.items():
|
||||||
|
if field_name not in current_schema:
|
||||||
|
continue # 字段不存在,跳过
|
||||||
|
|
||||||
|
current_notnull = current_schema[field_name]['notnull']
|
||||||
|
model_allows_null = field_obj.null
|
||||||
|
|
||||||
|
# 如果模型允许 null 但数据库字段不允许 null,需要修复
|
||||||
|
if model_allows_null and current_notnull:
|
||||||
|
constraints_to_fix.append({
|
||||||
|
'field_name': field_name,
|
||||||
|
'field_obj': field_obj,
|
||||||
|
'action': 'allow_null',
|
||||||
|
'current_constraint': 'NOT NULL',
|
||||||
|
'target_constraint': 'NULL'
|
||||||
|
})
|
||||||
|
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL,但数据库为NOT NULL")
|
||||||
|
|
||||||
|
# 如果模型不允许 null 但数据库字段允许 null,也需要修复(但要小心)
|
||||||
|
elif not model_allows_null and not current_notnull:
|
||||||
|
constraints_to_fix.append({
|
||||||
|
'field_name': field_name,
|
||||||
|
'field_obj': field_obj,
|
||||||
|
'action': 'disallow_null',
|
||||||
|
'current_constraint': 'NULL',
|
||||||
|
'target_constraint': 'NOT NULL'
|
||||||
|
})
|
||||||
|
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL,但数据库允许NULL")
|
||||||
|
|
||||||
|
# 修复约束不一致的字段
|
||||||
|
if constraints_to_fix:
|
||||||
|
logger.info(f"表 '{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
|
||||||
|
_fix_table_constraints(table_name, model, constraints_to_fix)
|
||||||
|
else:
|
||||||
|
logger.debug(f"表 '{table_name}' 的字段约束已同步")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"同步字段约束时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||||
|
"""
|
||||||
|
修复表的字段约束。
|
||||||
|
对于 SQLite,由于不支持直接修改列约束,需要重建表。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 备份表名
|
||||||
|
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
|
||||||
|
|
||||||
|
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
|
||||||
|
|
||||||
|
# 1. 创建备份表
|
||||||
|
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
|
||||||
|
logger.info(f"已创建备份表 '{backup_table}'")
|
||||||
|
|
||||||
|
# 2. 删除原表
|
||||||
|
db.execute_sql(f"DROP TABLE {table_name}")
|
||||||
|
logger.info(f"已删除原表 '{table_name}'")
|
||||||
|
|
||||||
|
# 3. 重新创建表(使用当前模型定义)
|
||||||
|
db.create_tables([model])
|
||||||
|
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
|
||||||
|
|
||||||
|
# 4. 从备份表恢复数据
|
||||||
|
# 获取字段列表
|
||||||
|
fields = list(model._meta.fields.keys())
|
||||||
|
fields_str = ', '.join(fields)
|
||||||
|
|
||||||
|
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
|
||||||
|
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
|
||||||
|
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||||
|
|
||||||
|
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||||
|
null_to_notnull_fields = [
|
||||||
|
constraint['field_name'] for constraint in constraints_to_fix
|
||||||
|
if constraint['action'] == 'disallow_null'
|
||||||
|
]
|
||||||
|
|
||||||
|
if null_to_notnull_fields:
|
||||||
|
# 需要处理 NULL 值,为这些字段设置默认值
|
||||||
|
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL,需要处理现有的NULL值")
|
||||||
|
|
||||||
|
# 构建更复杂的 SELECT 语句来处理 NULL 值
|
||||||
|
select_fields = []
|
||||||
|
for field_name in fields:
|
||||||
|
if field_name in null_to_notnull_fields:
|
||||||
|
field_obj = model._meta.fields[field_name]
|
||||||
|
# 根据字段类型设置默认值
|
||||||
|
if isinstance(field_obj, (TextField,)):
|
||||||
|
default_value = "''"
|
||||||
|
elif isinstance(field_obj, (IntegerField, FloatField, DoubleField)):
|
||||||
|
default_value = "0"
|
||||||
|
elif isinstance(field_obj, BooleanField):
|
||||||
|
default_value = "0"
|
||||||
|
elif isinstance(field_obj, DateTimeField):
|
||||||
|
default_value = f"'{datetime.datetime.now()}'"
|
||||||
|
else:
|
||||||
|
default_value = "''"
|
||||||
|
|
||||||
|
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
|
||||||
|
else:
|
||||||
|
select_fields.append(field_name)
|
||||||
|
|
||||||
|
select_str = ', '.join(select_fields)
|
||||||
|
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
|
||||||
|
|
||||||
|
db.execute_sql(insert_sql)
|
||||||
|
logger.info(f"已从备份表恢复数据到 '{table_name}'")
|
||||||
|
|
||||||
|
# 5. 验证数据完整性
|
||||||
|
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||||
|
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
|
||||||
|
|
||||||
|
if original_count == new_count:
|
||||||
|
logger.info(f"数据完整性验证通过: {original_count} 行数据")
|
||||||
|
# 删除备份表
|
||||||
|
db.execute_sql(f"DROP TABLE {backup_table}")
|
||||||
|
logger.info(f"已删除备份表 '{backup_table}'")
|
||||||
|
else:
|
||||||
|
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count} 行")
|
||||||
|
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
|
||||||
|
|
||||||
|
# 记录修复的约束
|
||||||
|
for constraint in constraints_to_fix:
|
||||||
|
logger.info(f"已修复字段 '{constraint['field_name']}': "
|
||||||
|
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
|
||||||
|
# 尝试恢复
|
||||||
|
try:
|
||||||
|
if db.table_exists(backup_table):
|
||||||
|
logger.info(f"尝试从备份表 '{backup_table}' 恢复...")
|
||||||
|
db.execute_sql(f"DROP TABLE IF EXISTS {table_name}")
|
||||||
|
db.execute_sql(f"ALTER TABLE {backup_table} RENAME TO {table_name}")
|
||||||
|
logger.info(f"已从备份恢复表 '{table_name}'")
|
||||||
|
except Exception as restore_error:
|
||||||
|
logger.exception(f"恢复表失败: {restore_error}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_field_constraints():
|
||||||
|
"""
|
||||||
|
检查但不修复字段约束,返回不一致的字段信息。
|
||||||
|
用于在修复前预览需要修复的内容。
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = [
|
||||||
|
ChatStreams,
|
||||||
|
LLMUsage,
|
||||||
|
Emoji,
|
||||||
|
Messages,
|
||||||
|
Images,
|
||||||
|
ImageDescriptions,
|
||||||
|
OnlineTime,
|
||||||
|
PersonInfo,
|
||||||
|
Expression,
|
||||||
|
Memory,
|
||||||
|
GraphNodes,
|
||||||
|
GraphEdges,
|
||||||
|
ActionRecords,
|
||||||
|
]
|
||||||
|
|
||||||
|
inconsistencies = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db:
|
||||||
|
for model in models:
|
||||||
|
table_name = model._meta.table_name
|
||||||
|
if not db.table_exists(model):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取当前表结构信息
|
||||||
|
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||||
|
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||||
|
for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
table_inconsistencies = []
|
||||||
|
|
||||||
|
# 检查每个模型字段的约束
|
||||||
|
for field_name, field_obj in model._meta.fields.items():
|
||||||
|
if field_name not in current_schema:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_notnull = current_schema[field_name]['notnull']
|
||||||
|
model_allows_null = field_obj.null
|
||||||
|
|
||||||
|
if model_allows_null and current_notnull:
|
||||||
|
table_inconsistencies.append({
|
||||||
|
'field_name': field_name,
|
||||||
|
'issue': 'model_allows_null_but_db_not_null',
|
||||||
|
'model_constraint': 'NULL',
|
||||||
|
'db_constraint': 'NOT NULL',
|
||||||
|
'recommended_action': 'allow_null'
|
||||||
|
})
|
||||||
|
elif not model_allows_null and not current_notnull:
|
||||||
|
table_inconsistencies.append({
|
||||||
|
'field_name': field_name,
|
||||||
|
'issue': 'model_not_null_but_db_allows_null',
|
||||||
|
'model_constraint': 'NOT NULL',
|
||||||
|
'db_constraint': 'NULL',
|
||||||
|
'recommended_action': 'disallow_null'
|
||||||
|
})
|
||||||
|
|
||||||
|
if table_inconsistencies:
|
||||||
|
inconsistencies[table_name] = table_inconsistencies
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"检查字段约束时出错: {e}")
|
||||||
|
|
||||||
|
return inconsistencies
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 模块加载时调用初始化函数
|
# 模块加载时调用初始化函数
|
||||||
initialize_database()
|
initialize_database(sync_constraints=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import json
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import structlog
|
import structlog
|
||||||
import toml
|
import tomlkit
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
@@ -188,24 +188,35 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
|||||||
"""从配置文件加载日志设置"""
|
"""从配置文件加载日志设置"""
|
||||||
config_path = Path("config/bot_config.toml")
|
config_path = Path("config/bot_config.toml")
|
||||||
default_config = {
|
default_config = {
|
||||||
"date_style": "Y-m-d H:i:s",
|
"date_style": "m-d H:i:s",
|
||||||
"log_level_style": "lite",
|
"log_level_style": "lite",
|
||||||
"color_text": "title",
|
"color_text": "full",
|
||||||
"log_level": "INFO", # 全局日志级别(向下兼容)
|
"log_level": "INFO", # 全局日志级别(向下兼容)
|
||||||
"console_log_level": "INFO", # 控制台日志级别
|
"console_log_level": "INFO", # 控制台日志级别
|
||||||
"file_log_level": "DEBUG", # 文件日志级别
|
"file_log_level": "DEBUG", # 文件日志级别
|
||||||
"suppress_libraries": [],
|
"suppress_libraries": [
|
||||||
"library_log_levels": {},
|
"faiss",
|
||||||
|
"httpx",
|
||||||
|
"urllib3",
|
||||||
|
"asyncio",
|
||||||
|
"websockets",
|
||||||
|
"httpcore",
|
||||||
|
"requests",
|
||||||
|
"peewee",
|
||||||
|
"openai",
|
||||||
|
"uvicorn",
|
||||||
|
"jieba",
|
||||||
|
],
|
||||||
|
"library_log_levels": {"aiohttp": "WARNING"},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
config = toml.load(f)
|
config = tomlkit.load(f)
|
||||||
return config.get("log", default_config)
|
return config.get("log", default_config)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
print(f"[日志系统] 加载日志配置失败: {e}")
|
||||||
|
|
||||||
return default_config
|
return default_config
|
||||||
|
|
||||||
|
|
||||||
@@ -334,7 +345,7 @@ MODULE_COLORS = {
|
|||||||
"llm_models": "\033[36m", # 青色
|
"llm_models": "\033[36m", # 青色
|
||||||
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||||
"planner": "\033[36m",
|
"planner": "\033[36m",
|
||||||
"memory": "\033[34m",
|
"memory": "\033[38;5;117m", # 天蓝色
|
||||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||||
# 关系系统
|
# 关系系统
|
||||||
@@ -352,7 +363,7 @@ MODULE_COLORS = {
|
|||||||
"expressor": "\033[38;5;166m", # 橙色
|
"expressor": "\033[38;5;166m", # 橙色
|
||||||
# 专注聊天模块
|
# 专注聊天模块
|
||||||
"replyer": "\033[38;5;166m", # 橙色
|
"replyer": "\033[38;5;166m", # 橙色
|
||||||
"memory_activator": "\033[34m", # 绿色
|
"memory_activator": "\033[38;5;117m", # 天蓝色
|
||||||
# 插件系统
|
# 插件系统
|
||||||
"plugins": "\033[31m", # 红色
|
"plugins": "\033[31m", # 红色
|
||||||
"plugin_api": "\033[33m", # 黄色
|
"plugin_api": "\033[33m", # 黄色
|
||||||
@@ -390,7 +401,7 @@ MODULE_COLORS = {
|
|||||||
"tts_action": "\033[38;5;58m", # 深黄色
|
"tts_action": "\033[38;5;58m", # 深黄色
|
||||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||||
# Action组件
|
# Action组件
|
||||||
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
"no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||||
"base_action": "\033[38;5;250m", # 浅灰色
|
"base_action": "\033[38;5;250m", # 浅灰色
|
||||||
# 数据库和消息
|
# 数据库和消息
|
||||||
@@ -403,8 +414,7 @@ MODULE_COLORS = {
|
|||||||
"model_utils": "\033[38;5;164m", # 紫红色
|
"model_utils": "\033[38;5;164m", # 紫红色
|
||||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||||
|
# s4u
|
||||||
#s4u
|
|
||||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||||
"S4U_chat": "\033[92m", # 深灰色
|
"S4U_chat": "\033[92m", # 深灰色
|
||||||
}
|
}
|
||||||
@@ -414,7 +424,7 @@ MODULE_ALIASES = {
|
|||||||
# 示例映射
|
# 示例映射
|
||||||
"individuality": "人格特质",
|
"individuality": "人格特质",
|
||||||
"emoji": "表情包",
|
"emoji": "表情包",
|
||||||
"no_reply_action": "摸鱼",
|
"no_action_action": "摸鱼",
|
||||||
"reply_action": "回复",
|
"reply_action": "回复",
|
||||||
"action_manager": "动作",
|
"action_manager": "动作",
|
||||||
"memory_activator": "记忆",
|
"memory_activator": "记忆",
|
||||||
@@ -440,6 +450,37 @@ MODULE_ALIASES = {
|
|||||||
RESET_COLOR = "\033[0m"
|
RESET_COLOR = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pathname_to_module(logger, method_name, event_dict):
|
||||||
|
# sourcery skip: extract-method, use-string-remove-affix
|
||||||
|
"""将 pathname 转换为模块风格的路径"""
|
||||||
|
if "pathname" in event_dict:
|
||||||
|
pathname = event_dict["pathname"]
|
||||||
|
try:
|
||||||
|
# 获取项目根目录 - 使用绝对路径确保准确性
|
||||||
|
logger_file = Path(__file__).resolve()
|
||||||
|
project_root = logger_file.parent.parent.parent
|
||||||
|
pathname_path = Path(pathname).resolve()
|
||||||
|
rel_path = pathname_path.relative_to(project_root)
|
||||||
|
|
||||||
|
# 转换为模块风格:移除 .py 扩展名,将路径分隔符替换为点
|
||||||
|
module_path = str(rel_path).replace("\\", ".").replace("/", ".")
|
||||||
|
if module_path.endswith(".py"):
|
||||||
|
module_path = module_path[:-3]
|
||||||
|
|
||||||
|
# 使用转换后的模块路径替换 module 字段
|
||||||
|
event_dict["module"] = module_path
|
||||||
|
# 移除原始的 pathname 字段
|
||||||
|
del event_dict["pathname"]
|
||||||
|
except Exception:
|
||||||
|
# 如果转换失败,删除 pathname 但保留原始的 module(如果有的话)
|
||||||
|
del event_dict["pathname"]
|
||||||
|
# 如果没有 module 字段,使用文件名作为备选
|
||||||
|
if "module" not in event_dict:
|
||||||
|
event_dict["module"] = Path(pathname).stem
|
||||||
|
|
||||||
|
return event_dict
|
||||||
|
|
||||||
|
|
||||||
class ModuleColoredConsoleRenderer:
|
class ModuleColoredConsoleRenderer:
|
||||||
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
||||||
|
|
||||||
@@ -451,7 +492,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
# 日志级别颜色
|
# 日志级别颜色
|
||||||
self._level_colors = {
|
self._level_colors = {
|
||||||
"debug": "\033[38;5;208m", # 橙色
|
"debug": "\033[38;5;208m", # 橙色
|
||||||
"info": "\033[34m", # 蓝色
|
"info": "\033[38;5;117m", # 天蓝色
|
||||||
"success": "\033[32m", # 绿色
|
"success": "\033[32m", # 绿色
|
||||||
"warning": "\033[33m", # 黄色
|
"warning": "\033[33m", # 黄色
|
||||||
"error": "\033[31m", # 红色
|
"error": "\033[31m", # 红色
|
||||||
@@ -529,7 +570,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
if logger_name:
|
if logger_name:
|
||||||
# 获取别名,如果没有别名则使用原名称
|
# 获取别名,如果没有别名则使用原名称
|
||||||
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
||||||
|
|
||||||
if self._colors and self._enable_module_colors:
|
if self._colors and self._enable_module_colors:
|
||||||
if module_color:
|
if module_color:
|
||||||
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
||||||
@@ -562,7 +603,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
# 处理其他字段
|
# 处理其他字段
|
||||||
extras = []
|
extras = []
|
||||||
for key, value in event_dict.items():
|
for key, value in event_dict.items():
|
||||||
if key not in ("timestamp", "level", "logger_name", "event"):
|
if key not in ("timestamp", "level", "logger_name", "event", "module", "lineno", "pathname"):
|
||||||
# 确保值也转换为字符串
|
# 确保值也转换为字符串
|
||||||
if isinstance(value, (dict, list)):
|
if isinstance(value, (dict, list)):
|
||||||
try:
|
try:
|
||||||
@@ -603,6 +644,13 @@ def configure_structlog():
|
|||||||
processors=[
|
processors=[
|
||||||
structlog.contextvars.merge_contextvars,
|
structlog.contextvars.merge_contextvars,
|
||||||
structlog.processors.add_log_level,
|
structlog.processors.add_log_level,
|
||||||
|
structlog.processors.CallsiteParameterAdder(
|
||||||
|
parameters=[
|
||||||
|
structlog.processors.CallsiteParameter.MODULE,
|
||||||
|
structlog.processors.CallsiteParameter.LINENO,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
convert_pathname_to_module,
|
||||||
structlog.processors.StackInfoRenderer(),
|
structlog.processors.StackInfoRenderer(),
|
||||||
structlog.dev.set_exc_info,
|
structlog.dev.set_exc_info,
|
||||||
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
|
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
|
||||||
@@ -627,6 +675,10 @@ file_formatter = structlog.stdlib.ProcessorFormatter(
|
|||||||
structlog.stdlib.add_log_level,
|
structlog.stdlib.add_log_level,
|
||||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||||
structlog.processors.TimeStamper(fmt="iso"),
|
structlog.processors.TimeStamper(fmt="iso"),
|
||||||
|
structlog.processors.CallsiteParameterAdder(
|
||||||
|
parameters=[structlog.processors.CallsiteParameter.MODULE, structlog.processors.CallsiteParameter.LINENO]
|
||||||
|
),
|
||||||
|
convert_pathname_to_module,
|
||||||
structlog.processors.StackInfoRenderer(),
|
structlog.processors.StackInfoRenderer(),
|
||||||
structlog.processors.format_exc_info,
|
structlog.processors.format_exc_info,
|
||||||
],
|
],
|
||||||
@@ -706,181 +758,6 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def configure_logging(
|
|
||||||
level: str = "INFO",
|
|
||||||
console_level: Optional[str] = None,
|
|
||||||
file_level: Optional[str] = None,
|
|
||||||
max_bytes: int = 5 * 1024 * 1024,
|
|
||||||
backup_count: int = 30,
|
|
||||||
log_dir: str = "logs",
|
|
||||||
):
|
|
||||||
"""动态配置日志参数"""
|
|
||||||
log_path = Path(log_dir)
|
|
||||||
log_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# 更新文件handler配置
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
if file_handler and isinstance(file_handler, TimestampedFileHandler):
|
|
||||||
file_handler.max_bytes = max_bytes
|
|
||||||
file_handler.backup_count = backup_count
|
|
||||||
file_handler.log_dir = Path(log_dir)
|
|
||||||
|
|
||||||
# 更新文件handler日志级别
|
|
||||||
if file_level:
|
|
||||||
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
# 更新控制台handler日志级别
|
|
||||||
console_handler = get_console_handler()
|
|
||||||
if console_handler and console_level:
|
|
||||||
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
# 设置根logger日志级别为最低级别
|
|
||||||
if console_level or file_level:
|
|
||||||
console_level_num = getattr(logging, (console_level or level).upper(), logging.INFO)
|
|
||||||
file_level_num = getattr(logging, (file_level or level).upper(), logging.INFO)
|
|
||||||
min_level = min(console_level_num, file_level_num)
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.setLevel(min_level)
|
|
||||||
else:
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.setLevel(getattr(logging, level.upper()))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reload_log_config():
|
|
||||||
"""重新加载日志配置"""
|
|
||||||
global LOG_CONFIG
|
|
||||||
LOG_CONFIG = load_log_config()
|
|
||||||
|
|
||||||
if file_handler := get_file_handler():
|
|
||||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
if console_handler := get_console_handler():
|
|
||||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
# 重新配置console渲染器
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
for handler in root_logger.handlers:
|
|
||||||
if isinstance(handler, logging.StreamHandler):
|
|
||||||
# 这是控制台处理器,更新其格式化器
|
|
||||||
handler.setFormatter(
|
|
||||||
structlog.stdlib.ProcessorFormatter(
|
|
||||||
processor=ModuleColoredConsoleRenderer(colors=True),
|
|
||||||
foreign_pre_chain=[
|
|
||||||
structlog.stdlib.add_logger_name,
|
|
||||||
structlog.stdlib.add_log_level,
|
|
||||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
|
||||||
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
|
|
||||||
structlog.processors.StackInfoRenderer(),
|
|
||||||
structlog.processors.format_exc_info,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 重新配置第三方库日志
|
|
||||||
configure_third_party_loggers()
|
|
||||||
|
|
||||||
# 重新配置所有已存在的logger
|
|
||||||
reconfigure_existing_loggers()
|
|
||||||
|
|
||||||
|
|
||||||
def get_log_config():
|
|
||||||
"""获取当前日志配置"""
|
|
||||||
return LOG_CONFIG.copy()
|
|
||||||
|
|
||||||
|
|
||||||
def set_console_log_level(level: str):
|
|
||||||
"""设置控制台日志级别
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
|
|
||||||
"""
|
|
||||||
global LOG_CONFIG
|
|
||||||
LOG_CONFIG["console_log_level"] = level.upper()
|
|
||||||
|
|
||||||
if console_handler := get_console_handler():
|
|
||||||
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
# 重新设置root logger级别
|
|
||||||
configure_third_party_loggers()
|
|
||||||
|
|
||||||
logger = get_logger("logger")
|
|
||||||
logger.info(f"控制台日志级别已设置为: {level.upper()}")
|
|
||||||
|
|
||||||
|
|
||||||
def set_file_log_level(level: str):
|
|
||||||
"""设置文件日志级别
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
|
|
||||||
"""
|
|
||||||
global LOG_CONFIG
|
|
||||||
LOG_CONFIG["file_log_level"] = level.upper()
|
|
||||||
|
|
||||||
if file_handler := get_file_handler():
|
|
||||||
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
|
||||||
|
|
||||||
# 重新设置root logger级别
|
|
||||||
configure_third_party_loggers()
|
|
||||||
|
|
||||||
logger = get_logger("logger")
|
|
||||||
logger.info(f"文件日志级别已设置为: {level.upper()}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_log_levels():
|
|
||||||
"""获取当前的日志级别设置"""
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
console_handler = get_console_handler()
|
|
||||||
|
|
||||||
file_level = logging.getLevelName(file_handler.level) if file_handler else "UNKNOWN"
|
|
||||||
console_level = logging.getLevelName(console_handler.level) if console_handler else "UNKNOWN"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"console_level": console_level,
|
|
||||||
"file_level": file_level,
|
|
||||||
"root_level": logging.getLevelName(logging.getLogger().level),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def force_reset_all_loggers():
|
|
||||||
"""强制重置所有logger,解决格式不一致问题"""
|
|
||||||
# 先关闭现有的handler
|
|
||||||
close_handlers()
|
|
||||||
|
|
||||||
# 清除所有现有的logger配置
|
|
||||||
logging.getLogger().manager.loggerDict.clear()
|
|
||||||
|
|
||||||
# 重新配置根logger
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.handlers.clear()
|
|
||||||
|
|
||||||
# 使用单例handler避免重复创建
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
console_handler = get_console_handler()
|
|
||||||
|
|
||||||
# 重新添加我们的handler
|
|
||||||
root_logger.addHandler(file_handler)
|
|
||||||
root_logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
# 设置格式化器
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
console_handler.setFormatter(console_formatter)
|
|
||||||
|
|
||||||
# 设置根logger级别为所有handler中最低的级别
|
|
||||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
|
|
||||||
console_level_num = getattr(logging, console_level.upper(), logging.INFO)
|
|
||||||
file_level_num = getattr(logging, file_level.upper(), logging.INFO)
|
|
||||||
min_level = min(console_level_num, file_level_num)
|
|
||||||
|
|
||||||
root_logger.setLevel(min_level)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_logging():
|
def initialize_logging():
|
||||||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||||
|
|
||||||
@@ -888,6 +765,7 @@ def initialize_logging():
|
|||||||
"""
|
"""
|
||||||
global LOG_CONFIG
|
global LOG_CONFIG
|
||||||
LOG_CONFIG = load_log_config()
|
LOG_CONFIG = load_log_config()
|
||||||
|
# print(LOG_CONFIG)
|
||||||
configure_third_party_loggers()
|
configure_third_party_loggers()
|
||||||
reconfigure_existing_loggers()
|
reconfigure_existing_loggers()
|
||||||
|
|
||||||
@@ -899,77 +777,10 @@ def initialize_logging():
|
|||||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||||
|
|
||||||
logger.info("日志系统已重新初始化:")
|
logger.info("日志系统已初始化:")
|
||||||
logger.info(f" - 控制台级别: {console_level}")
|
logger.info(f" - 控制台级别: {console_level}")
|
||||||
logger.info(f" - 文件级别: {file_level}")
|
logger.info(f" - 文件级别: {file_level}")
|
||||||
logger.info(" - 轮转份数: 30个文件")
|
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||||
logger.info(" - 自动清理: 30天前的日志")
|
|
||||||
|
|
||||||
|
|
||||||
def force_initialize_logging():
|
|
||||||
"""强制重新初始化整个日志系统,解决格式不一致问题"""
|
|
||||||
global LOG_CONFIG
|
|
||||||
LOG_CONFIG = load_log_config()
|
|
||||||
|
|
||||||
# 强制重置所有logger
|
|
||||||
force_reset_all_loggers()
|
|
||||||
|
|
||||||
# 重新配置structlog
|
|
||||||
configure_structlog()
|
|
||||||
|
|
||||||
# 配置第三方库
|
|
||||||
configure_third_party_loggers()
|
|
||||||
|
|
||||||
# 输出初始化信息
|
|
||||||
logger = get_logger("logger")
|
|
||||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
|
||||||
logger.info(
|
|
||||||
f"日志系统已强制重新初始化,控制台级别: {console_level},文件级别: {file_level},轮转份数: 30个文件,所有logger格式已统一"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def show_module_colors():
|
|
||||||
"""显示所有模块的颜色效果"""
|
|
||||||
get_logger("demo")
|
|
||||||
print("\n=== 模块颜色展示 ===")
|
|
||||||
|
|
||||||
for module_name, _color_code in MODULE_COLORS.items():
|
|
||||||
# 临时创建一个该模块的logger来展示颜色
|
|
||||||
demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name)
|
|
||||||
alias = MODULE_ALIASES.get(module_name, module_name)
|
|
||||||
if alias != module_name:
|
|
||||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})")
|
|
||||||
else:
|
|
||||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
|
|
||||||
|
|
||||||
print("=== 颜色展示结束 ===\n")
|
|
||||||
|
|
||||||
# 显示别名映射表
|
|
||||||
if MODULE_ALIASES:
|
|
||||||
print("=== 当前别名映射 ===")
|
|
||||||
for module_name, alias in MODULE_ALIASES.items():
|
|
||||||
print(f" {module_name} -> {alias}")
|
|
||||||
print("=== 别名映射结束 ===\n")
|
|
||||||
|
|
||||||
|
|
||||||
def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
|
||||||
"""将JSON数据格式化为可读字符串
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 要格式化的数据(字典、列表等)
|
|
||||||
indent: 缩进空格数
|
|
||||||
ensure_ascii: 是否确保ASCII编码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化后的JSON字符串
|
|
||||||
"""
|
|
||||||
if not isinstance(data, str):
|
|
||||||
# 如果是对象,直接格式化
|
|
||||||
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
|
||||||
# 如果是JSON字符串,先解析再格式化
|
|
||||||
parsed_data = json.loads(data)
|
|
||||||
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_old_logs():
|
def cleanup_old_logs():
|
||||||
@@ -1007,8 +818,8 @@ def start_log_cleanup_task():
|
|||||||
|
|
||||||
def cleanup_task():
|
def cleanup_task():
|
||||||
while True:
|
while True:
|
||||||
time.sleep(24 * 60 * 60) # 每24小时执行一次
|
|
||||||
cleanup_old_logs()
|
cleanup_old_logs()
|
||||||
|
time.sleep(24 * 60 * 60) # 每24小时执行一次
|
||||||
|
|
||||||
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
||||||
cleanup_thread.start()
|
cleanup_thread.start()
|
||||||
@@ -1017,35 +828,6 @@ def start_log_cleanup_task():
|
|||||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||||
|
|
||||||
|
|
||||||
def get_log_stats():
|
|
||||||
"""获取日志文件统计信息"""
|
|
||||||
stats = {"total_files": 0, "total_size": 0, "files": []}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not LOG_DIR.exists():
|
|
||||||
return stats
|
|
||||||
|
|
||||||
for log_file in LOG_DIR.glob("*.log*"):
|
|
||||||
file_info = {
|
|
||||||
"name": log_file.name,
|
|
||||||
"size": log_file.stat().st_size,
|
|
||||||
"modified": datetime.fromtimestamp(log_file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
|
|
||||||
}
|
|
||||||
|
|
||||||
stats["files"].append(file_info)
|
|
||||||
stats["total_files"] += 1
|
|
||||||
stats["total_size"] += file_info["size"]
|
|
||||||
|
|
||||||
# 按修改时间排序
|
|
||||||
stats["files"].sort(key=lambda x: x["modified"], reverse=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger = get_logger("logger")
|
|
||||||
logger.error(f"获取日志统计信息时出错: {e}")
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
def shutdown_logging():
|
def shutdown_logging():
|
||||||
"""优雅关闭日志系统,释放所有文件句柄"""
|
"""优雅关闭日志系统,释放所有文件句柄"""
|
||||||
logger = get_logger("logger")
|
logger = get_logger("logger")
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ def find_messages(
|
|||||||
if conditions:
|
if conditions:
|
||||||
query = query.where(*conditions)
|
query = query.where(*conditions)
|
||||||
|
|
||||||
|
# 排除 id 为 "notice" 的消息
|
||||||
|
query = query.where(Messages.message_id != "notice")
|
||||||
|
|
||||||
if filter_bot:
|
if filter_bot:
|
||||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||||
|
|
||||||
@@ -167,6 +170,9 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
if conditions:
|
if conditions:
|
||||||
query = query.where(*conditions)
|
query = query.where(*conditions)
|
||||||
|
|
||||||
|
# 排除 id 为 "notice" 的消息
|
||||||
|
query = query.where(Messages.message_id != "notice")
|
||||||
|
|
||||||
count = query.count()
|
count = query.count()
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
136
src/config/api_ada_configs.py
Normal file
136
src/config/api_ada_configs.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .config_base import ConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIProvider(ConfigBase):
|
||||||
|
"""API提供商配置类"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""API提供商名称"""
|
||||||
|
|
||||||
|
base_url: str
|
||||||
|
"""API基础URL"""
|
||||||
|
|
||||||
|
api_key: str = field(default_factory=str, repr=False)
|
||||||
|
"""API密钥列表"""
|
||||||
|
|
||||||
|
client_type: str = field(default="openai")
|
||||||
|
"""客户端类型(如openai/google等,默认为openai)"""
|
||||||
|
|
||||||
|
max_retry: int = 2
|
||||||
|
"""最大重试次数(单个模型API调用失败,最多重试的次数)"""
|
||||||
|
|
||||||
|
timeout: int = 10
|
||||||
|
"""API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)"""
|
||||||
|
|
||||||
|
retry_interval: int = 10
|
||||||
|
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""确保api_key在repr中不被显示"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||||
|
if not self.base_url and self.client_type != "gemini":
|
||||||
|
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||||
|
if not self.name:
|
||||||
|
raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo(ConfigBase):
|
||||||
|
"""单个模型信息配置类"""
|
||||||
|
|
||||||
|
model_identifier: str
|
||||||
|
"""模型标识符(用于URL调用)"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""模型名称(用于模块调用)"""
|
||||||
|
|
||||||
|
api_provider: str
|
||||||
|
"""API提供商(如OpenAI、Azure等)"""
|
||||||
|
|
||||||
|
price_in: float = field(default=0.0)
|
||||||
|
"""每M token输入价格"""
|
||||||
|
|
||||||
|
price_out: float = field(default=0.0)
|
||||||
|
"""每M token输出价格"""
|
||||||
|
|
||||||
|
force_stream_mode: bool = field(default=False)
|
||||||
|
"""是否强制使用流式输出模式"""
|
||||||
|
|
||||||
|
extra_params: dict = field(default_factory=dict)
|
||||||
|
"""额外参数(用于API调用时的额外配置)"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.model_identifier:
|
||||||
|
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
|
||||||
|
if not self.name:
|
||||||
|
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
|
||||||
|
if not self.api_provider:
|
||||||
|
raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskConfig(ConfigBase):
|
||||||
|
"""任务配置类"""
|
||||||
|
|
||||||
|
model_list: list[str] = field(default_factory=list)
|
||||||
|
"""任务使用的模型列表"""
|
||||||
|
|
||||||
|
max_tokens: int = 1024
|
||||||
|
"""任务最大输出token数"""
|
||||||
|
|
||||||
|
temperature: float = 0.3
|
||||||
|
"""模型温度"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelTaskConfig(ConfigBase):
|
||||||
|
"""模型配置类"""
|
||||||
|
|
||||||
|
utils: TaskConfig
|
||||||
|
"""组件模型配置"""
|
||||||
|
|
||||||
|
utils_small: TaskConfig
|
||||||
|
"""组件小模型配置"""
|
||||||
|
|
||||||
|
replyer: TaskConfig
|
||||||
|
"""normal_chat首要回复模型模型配置"""
|
||||||
|
|
||||||
|
emotion: TaskConfig
|
||||||
|
"""情绪模型配置"""
|
||||||
|
|
||||||
|
vlm: TaskConfig
|
||||||
|
"""视觉语言模型配置"""
|
||||||
|
|
||||||
|
voice: TaskConfig
|
||||||
|
"""语音识别模型配置"""
|
||||||
|
|
||||||
|
tool_use: TaskConfig
|
||||||
|
"""专注工具使用模型配置"""
|
||||||
|
|
||||||
|
planner: TaskConfig
|
||||||
|
"""规划模型配置"""
|
||||||
|
|
||||||
|
embedding: TaskConfig
|
||||||
|
"""嵌入模型配置"""
|
||||||
|
|
||||||
|
lpmm_entity_extract: TaskConfig
|
||||||
|
"""LPMM实体提取模型配置"""
|
||||||
|
|
||||||
|
lpmm_rdf_build: TaskConfig
|
||||||
|
"""LPMM RDF构建模型配置"""
|
||||||
|
|
||||||
|
lpmm_qa: TaskConfig
|
||||||
|
"""LPMM问答模型配置"""
|
||||||
|
|
||||||
|
def get_task(self, task_name: str) -> TaskConfig:
|
||||||
|
"""获取指定任务的配置"""
|
||||||
|
if hasattr(self, task_name):
|
||||||
|
return getattr(self, task_name)
|
||||||
|
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
import shutil
|
|
||||||
import tomlkit
|
|
||||||
from tomlkit.items import Table, KeyType
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
|
||||||
# 获取key的注释(如果有)
|
|
||||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
|
||||||
return toml_table.trivia.comment
|
|
||||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
|
||||||
item = toml_table.value.get(key)
|
|
||||||
if item is not None and hasattr(item, "trivia"):
|
|
||||||
return item.trivia.comment
|
|
||||||
if hasattr(toml_table, "keys"):
|
|
||||||
for k in toml_table.keys():
|
|
||||||
if isinstance(k, KeyType) and k.key == key:
|
|
||||||
return k.trivia.comment
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
|
|
||||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
|
||||||
if path is None:
|
|
||||||
path = []
|
|
||||||
if logs is None:
|
|
||||||
logs = []
|
|
||||||
if new_comments is None:
|
|
||||||
new_comments = {}
|
|
||||||
if old_comments is None:
|
|
||||||
old_comments = {}
|
|
||||||
# 新增项
|
|
||||||
for key in new:
|
|
||||||
if key == "version":
|
|
||||||
continue
|
|
||||||
if key not in old:
|
|
||||||
comment = get_key_comment(new, key)
|
|
||||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
|
||||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
|
||||||
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
|
|
||||||
# 删减项
|
|
||||||
for key in old:
|
|
||||||
if key == "version":
|
|
||||||
continue
|
|
||||||
if key not in new:
|
|
||||||
comment = get_key_comment(old, key)
|
|
||||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
|
||||||
return logs
|
|
||||||
|
|
||||||
|
|
||||||
def update_config():
|
|
||||||
print("开始更新配置文件...")
|
|
||||||
# 获取根目录路径
|
|
||||||
root_dir = Path(__file__).parent.parent.parent.parent
|
|
||||||
template_dir = root_dir / "template"
|
|
||||||
config_dir = root_dir / "config"
|
|
||||||
old_config_dir = config_dir / "old"
|
|
||||||
|
|
||||||
# 创建old目录(如果不存在)
|
|
||||||
old_config_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# 定义文件路径
|
|
||||||
template_path = template_dir / "bot_config_template.toml"
|
|
||||||
old_config_path = config_dir / "bot_config.toml"
|
|
||||||
new_config_path = config_dir / "bot_config.toml"
|
|
||||||
|
|
||||||
# 读取旧配置文件
|
|
||||||
old_config = {}
|
|
||||||
if old_config_path.exists():
|
|
||||||
print(f"发现旧配置文件: {old_config_path}")
|
|
||||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
|
||||||
old_config = tomlkit.load(f)
|
|
||||||
|
|
||||||
# 生成带时间戳的新文件名
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
|
||||||
|
|
||||||
# 移动旧配置文件到old目录
|
|
||||||
shutil.move(old_config_path, old_backup_path)
|
|
||||||
print(f"已备份旧配置文件到: {old_backup_path}")
|
|
||||||
|
|
||||||
# 复制模板文件到配置目录
|
|
||||||
print(f"从模板文件创建新配置: {template_path}")
|
|
||||||
shutil.copy2(template_path, new_config_path)
|
|
||||||
|
|
||||||
# 读取新配置文件
|
|
||||||
with open(new_config_path, "r", encoding="utf-8") as f:
|
|
||||||
new_config = tomlkit.load(f)
|
|
||||||
|
|
||||||
# 检查version是否相同
|
|
||||||
if old_config and "inner" in old_config and "inner" in new_config:
|
|
||||||
old_version = old_config["inner"].get("version") # type: ignore
|
|
||||||
new_version = new_config["inner"].get("version") # type: ignore
|
|
||||||
if old_version and new_version and old_version == new_version:
|
|
||||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
|
||||||
# 如果version相同,恢复旧配置文件并返回
|
|
||||||
shutil.move(old_backup_path, old_config_path) # type: ignore
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
|
||||||
|
|
||||||
# 输出新增和删减项及注释
|
|
||||||
if old_config:
|
|
||||||
print("配置项变动如下:")
|
|
||||||
logs = compare_dicts(new_config, old_config)
|
|
||||||
if logs:
|
|
||||||
for log in logs:
|
|
||||||
print(log)
|
|
||||||
else:
|
|
||||||
print("无新增或删减项")
|
|
||||||
|
|
||||||
# 递归更新配置
|
|
||||||
def update_dict(target, source):
|
|
||||||
for key, value in source.items():
|
|
||||||
# 跳过version字段的更新
|
|
||||||
if key == "version":
|
|
||||||
continue
|
|
||||||
if key in target:
|
|
||||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
|
||||||
update_dict(target[key], value)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# 对数组类型进行特殊处理
|
|
||||||
if isinstance(value, list):
|
|
||||||
# 如果是空数组,确保它保持为空数组
|
|
||||||
if not value:
|
|
||||||
target[key] = tomlkit.array()
|
|
||||||
else:
|
|
||||||
# 特殊处理正则表达式数组和包含正则表达式的结构
|
|
||||||
if key == "ban_msgs_regex":
|
|
||||||
# 直接使用原始值,不进行额外处理
|
|
||||||
target[key] = value
|
|
||||||
elif key == "regex_rules":
|
|
||||||
# 对于regex_rules,需要特殊处理其中的regex字段
|
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
# 检查是否包含正则表达式相关的字典项
|
|
||||||
contains_regex = False
|
|
||||||
if value and isinstance(value[0], dict) and "regex" in value[0]:
|
|
||||||
contains_regex = True
|
|
||||||
|
|
||||||
target[key] = value if contains_regex else tomlkit.array(str(value))
|
|
||||||
else:
|
|
||||||
# 其他类型使用item方法创建新值
|
|
||||||
target[key] = tomlkit.item(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
# 如果转换失败,直接赋值
|
|
||||||
target[key] = value
|
|
||||||
|
|
||||||
# 将旧配置的值更新到新配置中
|
|
||||||
print("开始合并新旧配置...")
|
|
||||||
update_dict(new_config, old_config)
|
|
||||||
|
|
||||||
# 保存更新后的配置(保留注释和格式)
|
|
||||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(tomlkit.dumps(new_config))
|
|
||||||
print("配置文件更新完成")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
update_config()
|
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import tomlkit
|
import tomlkit
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from tomlkit import TOMLDocument
|
from tomlkit import TOMLDocument
|
||||||
from tomlkit.items import Table, KeyType
|
from tomlkit.items import Table, KeyType
|
||||||
from dataclasses import field, dataclass
|
from dataclasses import field, dataclass
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
@@ -15,7 +17,6 @@ from src.config.official_configs import (
|
|||||||
PersonalityConfig,
|
PersonalityConfig,
|
||||||
ExpressionConfig,
|
ExpressionConfig,
|
||||||
ChatConfig,
|
ChatConfig,
|
||||||
NormalChatConfig,
|
|
||||||
EmojiConfig,
|
EmojiConfig,
|
||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
MoodConfig,
|
MoodConfig,
|
||||||
@@ -25,7 +26,6 @@ from src.config.official_configs import (
|
|||||||
ResponseSplitterConfig,
|
ResponseSplitterConfig,
|
||||||
TelemetryConfig,
|
TelemetryConfig,
|
||||||
ExperimentalConfig,
|
ExperimentalConfig,
|
||||||
ModelConfig,
|
|
||||||
MessageReceiveConfig,
|
MessageReceiveConfig,
|
||||||
MaimMessageConfig,
|
MaimMessageConfig,
|
||||||
LPMMKnowledgeConfig,
|
LPMMKnowledgeConfig,
|
||||||
@@ -36,6 +36,13 @@ from src.config.official_configs import (
|
|||||||
CustomPromptConfig,
|
CustomPromptConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .api_ada_configs import (
|
||||||
|
ModelTaskConfig,
|
||||||
|
ModelInfo,
|
||||||
|
APIProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||||||
|
|
||||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||||
MMC_VERSION = "0.9.1"
|
MMC_VERSION = "0.10.0"
|
||||||
|
|
||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
def get_key_comment(toml_table, key):
|
||||||
@@ -62,8 +69,8 @@ def get_key_comment(toml_table, key):
|
|||||||
return item.trivia.comment
|
return item.trivia.comment
|
||||||
if hasattr(toml_table, "keys"):
|
if hasattr(toml_table, "keys"):
|
||||||
for k in toml_table.keys():
|
for k in toml_table.keys():
|
||||||
if isinstance(k, KeyType) and k.key == key:
|
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
||||||
return k.trivia.comment
|
return k.trivia.comment # type: ignore
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +86,7 @@ def compare_dicts(new, old, path=None, logs=None):
|
|||||||
continue
|
continue
|
||||||
if key not in old:
|
if key not in old:
|
||||||
comment = get_key_comment(new, key)
|
comment = get_key_comment(new, key)
|
||||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||||
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||||
# 删减项
|
# 删减项
|
||||||
@@ -88,7 +95,7 @@ def compare_dicts(new, old, path=None, logs=None):
|
|||||||
continue
|
continue
|
||||||
if key not in new:
|
if key not in new:
|
||||||
comment = get_key_comment(old, key)
|
comment = get_key_comment(old, key)
|
||||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
|
||||||
@@ -102,11 +109,18 @@ def get_value_by_path(d, path):
|
|||||||
|
|
||||||
|
|
||||||
def set_value_by_path(d, path, value):
|
def set_value_by_path(d, path, value):
|
||||||
|
"""设置嵌套字典中指定路径的值"""
|
||||||
for k in path[:-1]:
|
for k in path[:-1]:
|
||||||
if k not in d or not isinstance(d[k], dict):
|
if k not in d or not isinstance(d[k], dict):
|
||||||
d[k] = {}
|
d[k] = {}
|
||||||
d = d[k]
|
d = d[k]
|
||||||
d[path[-1]] = value
|
|
||||||
|
# 使用 tomlkit.item 来保持 TOML 格式
|
||||||
|
try:
|
||||||
|
d[path[-1]] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# 如果转换失败,直接赋值
|
||||||
|
d[path[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||||
@@ -123,102 +137,140 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
|
|||||||
if key in old:
|
if key in old:
|
||||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||||
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||||
else:
|
elif new[key] != old[key]:
|
||||||
# 只要值发生变化就记录
|
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
||||||
if new[key] != old[key]:
|
changes.append((path + [str(key)], old[key], new[key]))
|
||||||
logs.append(
|
|
||||||
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
|
|
||||||
)
|
|
||||||
changes.append((path + [str(key)], old[key], new[key]))
|
|
||||||
return logs, changes
|
return logs, changes
|
||||||
|
|
||||||
|
|
||||||
def update_config():
|
def _get_version_from_toml(toml_path) -> Optional[str]:
|
||||||
|
"""从TOML文件中获取版本号"""
|
||||||
|
if not os.path.exists(toml_path):
|
||||||
|
return None
|
||||||
|
with open(toml_path, "r", encoding="utf-8") as f:
|
||||||
|
doc = tomlkit.load(f)
|
||||||
|
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||||
|
return doc["inner"]["version"] # type: ignore
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _version_tuple(v):
|
||||||
|
"""将版本字符串转换为元组以便比较"""
|
||||||
|
if v is None:
|
||||||
|
return (0,)
|
||||||
|
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||||
|
|
||||||
|
|
||||||
|
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||||
|
"""
|
||||||
|
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||||
|
"""
|
||||||
|
for key, value in source.items():
|
||||||
|
# 跳过version字段的更新
|
||||||
|
if key == "version":
|
||||||
|
continue
|
||||||
|
if key in target:
|
||||||
|
target_value = target[key]
|
||||||
|
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||||
|
_update_dict(target_value, value)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# 对数组类型进行特殊处理
|
||||||
|
if isinstance(value, list):
|
||||||
|
# 如果是空数组,确保它保持为空数组
|
||||||
|
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||||
|
else:
|
||||||
|
# 其他类型使用item方法创建新值
|
||||||
|
target[key] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# 如果转换失败,直接赋值
|
||||||
|
target[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _update_config_generic(config_name: str, template_name: str):
|
||||||
|
"""
|
||||||
|
通用的配置文件更新函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
|
||||||
|
template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
|
||||||
|
"""
|
||||||
# 获取根目录路径
|
# 获取根目录路径
|
||||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||||
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
||||||
|
|
||||||
# 定义文件路径
|
# 定义文件路径
|
||||||
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml")
|
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
|
||||||
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||||
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||||
compare_path = os.path.join(compare_dir, "bot_config_template.toml")
|
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
|
||||||
|
|
||||||
# 创建compare目录(如果不存在)
|
# 创建compare目录(如果不存在)
|
||||||
os.makedirs(compare_dir, exist_ok=True)
|
os.makedirs(compare_dir, exist_ok=True)
|
||||||
|
|
||||||
# 处理compare下的模板文件
|
template_version = _get_version_from_toml(template_path)
|
||||||
def get_version_from_toml(toml_path):
|
compare_version = _get_version_from_toml(compare_path)
|
||||||
if not os.path.exists(toml_path):
|
|
||||||
return None
|
|
||||||
with open(toml_path, "r", encoding="utf-8") as f:
|
|
||||||
doc = tomlkit.load(f)
|
|
||||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
|
||||||
return doc["inner"]["version"] # type: ignore
|
|
||||||
return None
|
|
||||||
|
|
||||||
template_version = get_version_from_toml(template_path)
|
# 检查配置文件是否存在
|
||||||
compare_version = get_version_from_toml(compare_path)
|
if not os.path.exists(old_config_path):
|
||||||
|
logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
|
||||||
|
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
|
||||||
|
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||||
|
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
|
||||||
|
# 新创建配置文件,退出
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
def version_tuple(v):
|
compare_config = None
|
||||||
if v is None:
|
new_config = None
|
||||||
return (0,)
|
old_config = None
|
||||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
|
||||||
|
|
||||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||||
if os.path.exists(compare_path):
|
if os.path.exists(compare_path):
|
||||||
with open(compare_path, "r", encoding="utf-8") as f:
|
with open(compare_path, "r", encoding="utf-8") as f:
|
||||||
compare_config = tomlkit.load(f)
|
compare_config = tomlkit.load(f)
|
||||||
else:
|
|
||||||
compare_config = None
|
|
||||||
|
|
||||||
# 读取当前模板
|
# 读取当前模板
|
||||||
with open(template_path, "r", encoding="utf-8") as f:
|
with open(template_path, "r", encoding="utf-8") as f:
|
||||||
new_config = tomlkit.load(f)
|
new_config = tomlkit.load(f)
|
||||||
|
|
||||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||||
if compare_config is not None:
|
if compare_config:
|
||||||
# 读取旧配置
|
# 读取旧配置
|
||||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||||
old_config = tomlkit.load(f)
|
old_config = tomlkit.load(f)
|
||||||
logs, changes = compare_default_values(new_config, compare_config)
|
logs, changes = compare_default_values(new_config, compare_config)
|
||||||
if logs:
|
if logs:
|
||||||
logger.info("检测到模板默认值变动如下:")
|
logger.info(f"检测到{config_name}模板默认值变动如下:")
|
||||||
for log in logs:
|
for log in logs:
|
||||||
logger.info(log)
|
logger.info(log)
|
||||||
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
||||||
|
config_updated = False
|
||||||
for path, old_default, new_default in changes:
|
for path, old_default, new_default in changes:
|
||||||
old_value = get_value_by_path(old_config, path)
|
old_value = get_value_by_path(old_config, path)
|
||||||
if old_value == old_default:
|
if old_value == old_default:
|
||||||
set_value_by_path(old_config, path, new_default)
|
set_value_by_path(old_config, path, new_default)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||||
)
|
)
|
||||||
|
config_updated = True
|
||||||
|
|
||||||
|
# 如果配置有更新,立即保存到文件
|
||||||
|
if config_updated:
|
||||||
|
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(tomlkit.dumps(old_config))
|
||||||
|
logger.info(f"已保存更新后的{config_name}配置文件")
|
||||||
else:
|
else:
|
||||||
logger.info("未检测到模板默认值变动")
|
logger.info(f"未检测到{config_name}模板默认值变动")
|
||||||
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
|
||||||
else:
|
|
||||||
old_config = None
|
|
||||||
|
|
||||||
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
||||||
if not os.path.exists(compare_path):
|
if not os.path.exists(compare_path):
|
||||||
shutil.copy2(template_path, compare_path)
|
shutil.copy2(template_path, compare_path)
|
||||||
logger.info(f"已将模板文件复制到: {compare_path}")
|
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
|
||||||
|
elif _version_tuple(template_version) > _version_tuple(compare_version):
|
||||||
|
shutil.copy2(template_path, compare_path)
|
||||||
|
logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}")
|
||||||
else:
|
else:
|
||||||
if version_tuple(template_version) > version_tuple(compare_version):
|
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
|
||||||
shutil.copy2(template_path, compare_path)
|
|
||||||
logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}")
|
|
||||||
|
|
||||||
# 检查配置文件是否存在
|
|
||||||
if not os.path.exists(old_config_path):
|
|
||||||
logger.info("配置文件不存在,从模板创建新配置")
|
|
||||||
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
|
|
||||||
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
|
||||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
|
||||||
# 如果是新创建的配置文件,直接返回
|
|
||||||
quit()
|
|
||||||
|
|
||||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||||
if old_config is None:
|
if old_config is None:
|
||||||
@@ -226,79 +278,60 @@ def update_config():
|
|||||||
old_config = tomlkit.load(f)
|
old_config = tomlkit.load(f)
|
||||||
# new_config 已经读取
|
# new_config 已经读取
|
||||||
|
|
||||||
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
|
|
||||||
|
|
||||||
# 检查version是否相同
|
# 检查version是否相同
|
||||||
if old_config and "inner" in old_config and "inner" in new_config:
|
if old_config and "inner" in old_config and "inner" in new_config:
|
||||||
old_version = old_config["inner"].get("version") # type: ignore
|
old_version = old_config["inner"].get("version") # type: ignore
|
||||||
new_version = new_config["inner"].get("version") # type: ignore
|
new_version = new_config["inner"].get("version") # type: ignore
|
||||||
if old_version and new_version and old_version == new_version:
|
if old_version and new_version and old_version == new_version:
|
||||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||||
|
|
||||||
# 创建old目录(如果不存在)
|
# 创建old目录(如果不存在)
|
||||||
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
|
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml")
|
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
|
||||||
|
|
||||||
# 移动旧配置文件到old目录
|
# 移动旧配置文件到old目录
|
||||||
shutil.move(old_config_path, old_backup_path)
|
shutil.move(old_config_path, old_backup_path)
|
||||||
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
|
||||||
|
|
||||||
# 复制模板文件到配置目录
|
# 复制模板文件到配置目录
|
||||||
shutil.copy2(template_path, new_config_path)
|
shutil.copy2(template_path, new_config_path)
|
||||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
|
||||||
|
|
||||||
# 输出新增和删减项及注释
|
# 输出新增和删减项及注释
|
||||||
if old_config:
|
if old_config:
|
||||||
logger.info("配置项变动如下:\n----------------------------------------")
|
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
|
||||||
logs = compare_dicts(new_config, old_config)
|
if logs := compare_dicts(new_config, old_config):
|
||||||
if logs:
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
logger.info(log)
|
logger.info(log)
|
||||||
else:
|
else:
|
||||||
logger.info("无新增或删减项")
|
logger.info("无新增或删减项")
|
||||||
|
|
||||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
|
||||||
"""
|
|
||||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
|
||||||
"""
|
|
||||||
for key, value in source.items():
|
|
||||||
# 跳过version字段的更新
|
|
||||||
if key == "version":
|
|
||||||
continue
|
|
||||||
if key in target:
|
|
||||||
target_value = target[key]
|
|
||||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
|
||||||
update_dict(target_value, value)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# 对数组类型进行特殊处理
|
|
||||||
if isinstance(value, list):
|
|
||||||
# 如果是空数组,确保它保持为空数组
|
|
||||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
|
||||||
else:
|
|
||||||
# 其他类型使用item方法创建新值
|
|
||||||
target[key] = tomlkit.item(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
# 如果转换失败,直接赋值
|
|
||||||
target[key] = value
|
|
||||||
|
|
||||||
# 将旧配置的值更新到新配置中
|
# 将旧配置的值更新到新配置中
|
||||||
logger.info("开始合并新旧配置...")
|
logger.info(f"开始合并{config_name}新旧配置...")
|
||||||
update_dict(new_config, old_config)
|
_update_dict(new_config, old_config)
|
||||||
|
|
||||||
# 保存更新后的配置(保留注释和格式)
|
# 保存更新后的配置(保留注释和格式)
|
||||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||||
f.write(tomlkit.dumps(new_config))
|
f.write(tomlkit.dumps(new_config))
|
||||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||||
quit()
|
|
||||||
|
|
||||||
|
def update_config():
|
||||||
|
"""更新bot_config.toml配置文件"""
|
||||||
|
_update_config_generic("bot_config", "bot_config_template")
|
||||||
|
|
||||||
|
|
||||||
|
def update_model_config():
|
||||||
|
"""更新model_config.toml配置文件"""
|
||||||
|
_update_config_generic("model_config", "model_config_template")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -312,7 +345,6 @@ class Config(ConfigBase):
|
|||||||
relationship: RelationshipConfig
|
relationship: RelationshipConfig
|
||||||
chat: ChatConfig
|
chat: ChatConfig
|
||||||
message_receive: MessageReceiveConfig
|
message_receive: MessageReceiveConfig
|
||||||
normal_chat: NormalChatConfig
|
|
||||||
emoji: EmojiConfig
|
emoji: EmojiConfig
|
||||||
expression: ExpressionConfig
|
expression: ExpressionConfig
|
||||||
memory: MemoryConfig
|
memory: MemoryConfig
|
||||||
@@ -323,7 +355,6 @@ class Config(ConfigBase):
|
|||||||
response_splitter: ResponseSplitterConfig
|
response_splitter: ResponseSplitterConfig
|
||||||
telemetry: TelemetryConfig
|
telemetry: TelemetryConfig
|
||||||
experimental: ExperimentalConfig
|
experimental: ExperimentalConfig
|
||||||
model: ModelConfig
|
|
||||||
maim_message: MaimMessageConfig
|
maim_message: MaimMessageConfig
|
||||||
lpmm_knowledge: LPMMKnowledgeConfig
|
lpmm_knowledge: LPMMKnowledgeConfig
|
||||||
tool: ToolConfig
|
tool: ToolConfig
|
||||||
@@ -331,11 +362,69 @@ class Config(ConfigBase):
|
|||||||
custom_prompt: CustomPromptConfig
|
custom_prompt: CustomPromptConfig
|
||||||
voice: VoiceConfig
|
voice: VoiceConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIAdapterConfig(ConfigBase):
|
||||||
|
"""API Adapter配置类"""
|
||||||
|
|
||||||
|
models: List[ModelInfo]
|
||||||
|
"""模型列表"""
|
||||||
|
|
||||||
|
model_task_config: ModelTaskConfig
|
||||||
|
"""模型任务配置"""
|
||||||
|
|
||||||
|
api_providers: List[APIProvider] = field(default_factory=list)
|
||||||
|
"""API提供商列表"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.models:
|
||||||
|
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||||
|
if not self.api_providers:
|
||||||
|
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||||||
|
|
||||||
|
# 检查API提供商名称是否重复
|
||||||
|
provider_names = [provider.name for provider in self.api_providers]
|
||||||
|
if len(provider_names) != len(set(provider_names)):
|
||||||
|
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||||
|
|
||||||
|
# 检查模型名称是否重复
|
||||||
|
model_names = [model.name for model in self.models]
|
||||||
|
if len(model_names) != len(set(model_names)):
|
||||||
|
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||||
|
|
||||||
|
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||||
|
self.models_dict = {model.name: model for model in self.models}
|
||||||
|
|
||||||
|
for model in self.models:
|
||||||
|
if not model.model_identifier:
|
||||||
|
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||||||
|
if not model.api_provider or model.api_provider not in self.api_providers_dict:
|
||||||
|
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
|
||||||
|
|
||||||
|
def get_model_info(self, model_name: str) -> ModelInfo:
|
||||||
|
"""根据模型名称获取模型信息"""
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError("模型名称不能为空")
|
||||||
|
if model_name not in self.models_dict:
|
||||||
|
raise KeyError(f"模型 '{model_name}' 不存在")
|
||||||
|
return self.models_dict[model_name]
|
||||||
|
|
||||||
|
def get_provider(self, provider_name: str) -> APIProvider:
|
||||||
|
"""根据提供商名称获取API提供商信息"""
|
||||||
|
if not provider_name:
|
||||||
|
raise ValueError("API提供商名称不能为空")
|
||||||
|
if provider_name not in self.api_providers_dict:
|
||||||
|
raise KeyError(f"API提供商 '{provider_name}' 不存在")
|
||||||
|
return self.api_providers_dict[provider_name]
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> Config:
|
def load_config(config_path: str) -> Config:
|
||||||
"""
|
"""
|
||||||
加载配置文件
|
加载配置文件
|
||||||
:param config_path: 配置文件路径
|
Args:
|
||||||
:return: Config对象
|
config_path: 配置文件路径
|
||||||
|
Returns:
|
||||||
|
Config对象
|
||||||
"""
|
"""
|
||||||
# 读取配置文件
|
# 读取配置文件
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
@@ -349,18 +438,32 @@ def load_config(config_path: str) -> Config:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_config_dir() -> str:
|
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||||
"""
|
"""
|
||||||
获取配置目录
|
加载API适配器配置文件
|
||||||
:return: 配置目录路径
|
Args:
|
||||||
|
config_path: 配置文件路径
|
||||||
|
Returns:
|
||||||
|
APIAdapterConfig对象
|
||||||
"""
|
"""
|
||||||
return CONFIG_DIR
|
# 读取配置文件
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_data = tomlkit.load(f)
|
||||||
|
|
||||||
|
# 创建APIAdapterConfig对象
|
||||||
|
try:
|
||||||
|
return APIAdapterConfig.from_dict(config_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical("API适配器配置文件解析失败")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||||
update_config()
|
update_config()
|
||||||
|
update_model_config()
|
||||||
|
|
||||||
logger.info("正在品鉴配置文件...")
|
logger.info("正在品鉴配置文件...")
|
||||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||||
|
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||||
logger.info("非常的新鲜,非常的美味!")
|
logger.info("非常的新鲜,非常的美味!")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ from src.config.config_base import ConfigBase
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BotConfig(ConfigBase):
|
class BotConfig(ConfigBase):
|
||||||
"""QQ机器人配置类"""
|
"""QQ机器人配置类"""
|
||||||
|
|
||||||
platform: str
|
platform: str
|
||||||
"""平台"""
|
"""平台"""
|
||||||
|
|
||||||
@@ -44,6 +44,9 @@ class PersonalityConfig(ConfigBase):
|
|||||||
identity: str = ""
|
identity: str = ""
|
||||||
"""身份特征"""
|
"""身份特征"""
|
||||||
|
|
||||||
|
reply_style: str = ""
|
||||||
|
"""表达风格"""
|
||||||
|
|
||||||
compress_personality: bool = True
|
compress_personality: bool = True
|
||||||
"""是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭"""
|
"""是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭"""
|
||||||
|
|
||||||
@@ -68,155 +71,90 @@ class ChatConfig(ConfigBase):
|
|||||||
|
|
||||||
max_context_size: int = 18
|
max_context_size: int = 18
|
||||||
"""上下文长度"""
|
"""上下文长度"""
|
||||||
|
|
||||||
willing_amplifier: float = 1.0
|
|
||||||
|
|
||||||
replyer_random_probability: float = 0.5
|
|
||||||
"""
|
|
||||||
发言时选择推理模型的概率(0-1之间)
|
|
||||||
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
|
||||||
"""
|
|
||||||
|
|
||||||
thinking_timeout: int = 40
|
|
||||||
"""麦麦最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)"""
|
|
||||||
|
|
||||||
talk_frequency: float = 1
|
|
||||||
"""回复频率阈值"""
|
|
||||||
|
|
||||||
mentioned_bot_inevitable_reply: bool = False
|
mentioned_bot_inevitable_reply: bool = False
|
||||||
"""提及 bot 必然回复"""
|
"""提及 bot 必然回复"""
|
||||||
|
|
||||||
at_bot_inevitable_reply: bool = False
|
at_bot_inevitable_reply: bool = False
|
||||||
"""@bot 必然回复"""
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
|
talk_frequency: float = 0.5
|
||||||
|
"""回复频率阈值"""
|
||||||
|
|
||||||
# 修改:基于时段的回复频率配置,改为数组格式
|
# 合并后的时段频率配置
|
||||||
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
|
|
||||||
"""
|
|
||||||
基于时段的回复频率配置(全局)
|
|
||||||
格式:["HH:MM,frequency", "HH:MM,frequency", ...]
|
|
||||||
示例:["8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]
|
|
||||||
表示从该时间开始使用该频率,直到下一个时间点
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 新增:基于聊天流的个性化时段频率配置
|
|
||||||
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||||
"""
|
|
||||||
基于聊天流的个性化时段频率配置
|
|
||||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
|
||||||
示例:[
|
|
||||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"],
|
|
||||||
["qq:729957033:group", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"]
|
|
||||||
]
|
|
||||||
每个子列表的第一个元素是聊天流标识符,后续元素是"时间,频率"格式
|
|
||||||
表示从该时间开始使用该频率,直到下一个时间点
|
|
||||||
"""
|
|
||||||
|
|
||||||
focus_value: float = 1.0
|
|
||||||
|
focus_value: float = 0.5
|
||||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||||
|
|
||||||
|
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
"""
|
||||||
|
统一的活跃度和专注度配置
|
||||||
|
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||||
|
|
||||||
|
全局配置示例:
|
||||||
|
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||||
|
|
||||||
|
特定聊天流配置示例:
|
||||||
|
[
|
||||||
|
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||||
|
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||||
|
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||||
|
]
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||||
|
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||||
|
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
|
||||||
|
- 优先级:特定聊天流配置 > 全局配置 > 默认值
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
|
||||||
|
- focus_value_adjust 控制专注思考能力,数值越低越容易专注,消耗token也越多
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
|
||||||
"""
|
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
|
||||||
|
|
||||||
Args:
|
@dataclass
|
||||||
chat_stream_id: 聊天流ID,格式为 "platform:chat_id:type"
|
class MessageReceiveConfig(ConfigBase):
|
||||||
|
"""消息接收配置类"""
|
||||||
|
|
||||||
Returns:
|
ban_words: set[str] = field(default_factory=lambda: set())
|
||||||
float: 对应的频率值
|
"""过滤词列表"""
|
||||||
"""
|
|
||||||
# 优先检查聊天流特定的配置
|
|
||||||
if chat_stream_id and self.talk_frequency_adjust:
|
|
||||||
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
|
|
||||||
if stream_frequency is not None:
|
|
||||||
return stream_frequency
|
|
||||||
|
|
||||||
# 如果没有聊天流特定配置,检查全局时段配置
|
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||||
if self.time_based_talk_frequency:
|
"""过滤正则表达式列表"""
|
||||||
global_frequency = self._get_time_based_frequency(self.time_based_talk_frequency)
|
|
||||||
if global_frequency is not None:
|
|
||||||
return global_frequency
|
|
||||||
|
|
||||||
# 如果都没有匹配,返回默认值
|
@dataclass
|
||||||
return self.talk_frequency
|
class ExpressionConfig(ConfigBase):
|
||||||
|
"""表达配置类"""
|
||||||
|
|
||||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
learning_list: list[list] = field(default_factory=lambda: [])
|
||||||
"""
|
"""
|
||||||
根据时间配置列表获取当前时段的频率
|
表达学习配置列表,支持按聊天流配置
|
||||||
|
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||||
|
|
||||||
|
示例:
|
||||||
|
[
|
||||||
|
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||||
|
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||||
|
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||||
|
]
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||||
|
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||||
|
- 第三位: 是否学习表达 ("enable"/"disable")
|
||||||
|
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
|
||||||
|
"""
|
||||||
|
|
||||||
Args:
|
expression_groups: list[list[str]] = field(default_factory=list)
|
||||||
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
|
"""
|
||||||
|
表达学习互通组
|
||||||
Returns:
|
格式: [["qq:12345:group", "qq:67890:private"]]
|
||||||
float: 频率值,如果没有配置则返回 None
|
"""
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
current_time = datetime.now().strftime("%H:%M")
|
|
||||||
current_hour, current_minute = map(int, current_time.split(":"))
|
|
||||||
current_minutes = current_hour * 60 + current_minute
|
|
||||||
|
|
||||||
# 解析时间频率配置
|
|
||||||
time_freq_pairs = []
|
|
||||||
for time_freq_str in time_freq_list:
|
|
||||||
try:
|
|
||||||
time_str, freq_str = time_freq_str.split(",")
|
|
||||||
hour, minute = map(int, time_str.split(":"))
|
|
||||||
frequency = float(freq_str)
|
|
||||||
minutes = hour * 60 + minute
|
|
||||||
time_freq_pairs.append((minutes, frequency))
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not time_freq_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 按时间排序
|
|
||||||
time_freq_pairs.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
# 查找当前时间对应的频率
|
|
||||||
current_frequency = None
|
|
||||||
for minutes, frequency in time_freq_pairs:
|
|
||||||
if current_minutes >= minutes:
|
|
||||||
current_frequency = frequency
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
|
|
||||||
if current_frequency is None and time_freq_pairs:
|
|
||||||
current_frequency = time_freq_pairs[-1][1]
|
|
||||||
|
|
||||||
return current_frequency
|
|
||||||
|
|
||||||
def _get_stream_specific_frequency(self, chat_stream_id: str):
|
|
||||||
"""
|
|
||||||
获取特定聊天流在当前时间的频率
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream_id: 聊天流ID(哈希值)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: 频率值,如果没有配置则返回 None
|
|
||||||
"""
|
|
||||||
# 查找匹配的聊天流配置
|
|
||||||
for config_item in self.talk_frequency_adjust:
|
|
||||||
if not config_item or len(config_item) < 2:
|
|
||||||
continue
|
|
||||||
|
|
||||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
|
||||||
|
|
||||||
# 解析配置字符串并生成对应的 chat_id
|
|
||||||
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
|
||||||
if config_chat_id is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 比较生成的 chat_id
|
|
||||||
if config_chat_id != chat_stream_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用通用的时间频率解析方法
|
|
||||||
return self._get_time_based_frequency(config_item[1:])
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -253,46 +191,96 @@ class ChatConfig(ConfigBase):
|
|||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
||||||
|
"""
|
||||||
|
根据聊天流ID获取表达配置
|
||||||
|
|
||||||
@dataclass
|
Args:
|
||||||
class MessageReceiveConfig(ConfigBase):
|
chat_stream_id: 聊天流ID,格式为哈希值
|
||||||
"""消息接收配置类"""
|
|
||||||
|
|
||||||
ban_words: set[str] = field(default_factory=lambda: set())
|
Returns:
|
||||||
"""过滤词列表"""
|
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||||
|
"""
|
||||||
|
if not self.learning_list:
|
||||||
|
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||||
|
return True, True, 300
|
||||||
|
|
||||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
# 优先检查聊天流特定的配置
|
||||||
"""过滤正则表达式列表"""
|
if chat_stream_id:
|
||||||
|
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
|
||||||
|
if specific_expression_config is not None:
|
||||||
|
return specific_expression_config
|
||||||
|
|
||||||
|
# 检查全局配置(第一个元素为空字符串的配置)
|
||||||
|
global_expression_config = self._get_global_config()
|
||||||
|
if global_expression_config is not None:
|
||||||
|
return global_expression_config
|
||||||
|
|
||||||
@dataclass
|
# 如果都没有匹配,返回默认值
|
||||||
class NormalChatConfig(ConfigBase):
|
return True, True, 300
|
||||||
"""普通聊天配置类"""
|
|
||||||
|
|
||||||
willing_mode: str = "classical"
|
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
||||||
"""意愿模式"""
|
"""
|
||||||
|
获取特定聊天流的表达配置
|
||||||
|
|
||||||
@dataclass
|
Args:
|
||||||
class ExpressionConfig(ConfigBase):
|
chat_stream_id: 聊天流ID(哈希值)
|
||||||
"""表达配置类"""
|
|
||||||
|
|
||||||
enable_expression: bool = True
|
Returns:
|
||||||
"""是否启用表达方式"""
|
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
for config_item in self.learning_list:
|
||||||
|
if not config_item or len(config_item) < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
expression_style: str = ""
|
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||||
"""表达风格"""
|
|
||||||
|
|
||||||
learning_interval: int = 300
|
# 如果是空字符串,跳过(这是全局配置)
|
||||||
"""学习间隔(秒)"""
|
if stream_config_str == "":
|
||||||
|
continue
|
||||||
|
|
||||||
enable_expression_learning: bool = True
|
# 解析配置字符串并生成对应的 chat_id
|
||||||
"""是否启用表达学习"""
|
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||||
|
if config_chat_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
expression_groups: list[list[str]] = field(default_factory=list)
|
# 比较生成的 chat_id
|
||||||
"""
|
if config_chat_id != chat_stream_id:
|
||||||
表达学习互通组
|
continue
|
||||||
格式: [["qq:12345:group", "qq:67890:private"]]
|
|
||||||
"""
|
# 解析配置
|
||||||
|
try:
|
||||||
|
use_expression: bool = config_item[1].lower() == "enable"
|
||||||
|
enable_learning: bool = config_item[2].lower() == "enable"
|
||||||
|
learning_intensity: float = float(config_item[3])
|
||||||
|
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
||||||
|
"""
|
||||||
|
获取全局表达配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||||
|
"""
|
||||||
|
for config_item in self.learning_list:
|
||||||
|
if not config_item or len(config_item) < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否为全局配置(第一个元素为空字符串)
|
||||||
|
if config_item[0] == "":
|
||||||
|
try:
|
||||||
|
use_expression: bool = config_item[1].lower() == "enable"
|
||||||
|
enable_learning: bool = config_item[2].lower() == "enable"
|
||||||
|
learning_intensity = float(config_item[3])
|
||||||
|
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -301,7 +289,8 @@ class ToolConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_tool: bool = False
|
enable_tool: bool = False
|
||||||
"""是否在聊天中启用工具"""
|
"""是否在聊天中启用工具"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceConfig(ConfigBase):
|
class VoiceConfig(ConfigBase):
|
||||||
"""语音识别配置类"""
|
"""语音识别配置类"""
|
||||||
@@ -317,9 +306,6 @@ class EmojiConfig(ConfigBase):
|
|||||||
emoji_chance: float = 0.6
|
emoji_chance: float = 0.6
|
||||||
"""发送表情包的基础概率"""
|
"""发送表情包的基础概率"""
|
||||||
|
|
||||||
emoji_activate_type: str = "random"
|
|
||||||
"""表情包激活类型,可选:random,llm,random下,表情包动作随机启用,llm下,表情包动作根据llm判断是否启用"""
|
|
||||||
|
|
||||||
max_reg_num: int = 200
|
max_reg_num: int = 200
|
||||||
"""表情包最大注册数量"""
|
"""表情包最大注册数量"""
|
||||||
|
|
||||||
@@ -344,25 +330,10 @@ class MemoryConfig(ConfigBase):
|
|||||||
"""记忆配置类"""
|
"""记忆配置类"""
|
||||||
|
|
||||||
enable_memory: bool = True
|
enable_memory: bool = True
|
||||||
|
"""是否启用记忆系统"""
|
||||||
memory_build_interval: int = 600
|
|
||||||
"""记忆构建间隔(秒)"""
|
memory_build_frequency: int = 1
|
||||||
|
"""记忆构建频率(秒)"""
|
||||||
memory_build_distribution: tuple[
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
|
|
||||||
"""记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重"""
|
|
||||||
|
|
||||||
memory_build_sample_num: int = 8
|
|
||||||
"""记忆构建采样数量"""
|
|
||||||
|
|
||||||
memory_build_sample_length: int = 40
|
|
||||||
"""记忆构建采样长度"""
|
|
||||||
|
|
||||||
memory_compress_rate: float = 0.1
|
memory_compress_rate: float = 0.1
|
||||||
"""记忆压缩率"""
|
"""记忆压缩率"""
|
||||||
@@ -376,18 +347,9 @@ class MemoryConfig(ConfigBase):
|
|||||||
memory_forget_percentage: float = 0.01
|
memory_forget_percentage: float = 0.01
|
||||||
"""记忆遗忘比例"""
|
"""记忆遗忘比例"""
|
||||||
|
|
||||||
consolidate_memory_interval: int = 1000
|
|
||||||
"""记忆整合间隔(秒)"""
|
|
||||||
|
|
||||||
consolidation_similarity_threshold: float = 0.7
|
|
||||||
"""整合相似度阈值"""
|
|
||||||
|
|
||||||
consolidate_memory_percentage: float = 0.01
|
|
||||||
"""整合检查节点比例"""
|
|
||||||
|
|
||||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||||
"""不允许记忆的词列表"""
|
"""不允许记忆的词列表"""
|
||||||
|
|
||||||
enable_instant_memory: bool = True
|
enable_instant_memory: bool = True
|
||||||
"""是否启用即时记忆"""
|
"""是否启用即时记忆"""
|
||||||
|
|
||||||
@@ -398,7 +360,7 @@ class MoodConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_mood: bool = False
|
enable_mood: bool = False
|
||||||
"""是否启用情绪系统"""
|
"""是否启用情绪系统"""
|
||||||
|
|
||||||
mood_update_threshold: float = 1.0
|
mood_update_threshold: float = 1.0
|
||||||
"""情绪更新阈值,越高,更新越慢"""
|
"""情绪更新阈值,越高,更新越慢"""
|
||||||
|
|
||||||
@@ -449,6 +411,7 @@ class KeywordReactionConfig(ConfigBase):
|
|||||||
if not isinstance(rule, KeywordRuleConfig):
|
if not isinstance(rule, KeywordRuleConfig):
|
||||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CustomPromptConfig(ConfigBase):
|
class CustomPromptConfig(ConfigBase):
|
||||||
"""自定义提示词配置类"""
|
"""自定义提示词配置类"""
|
||||||
@@ -597,52 +560,3 @@ class LPMMKnowledgeConfig(ConfigBase):
|
|||||||
|
|
||||||
embedding_dimension: int = 1024
|
embedding_dimension: int = 1024
|
||||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelConfig(ConfigBase):
|
|
||||||
"""模型配置类"""
|
|
||||||
|
|
||||||
model_max_output_length: int = 800 # 最大回复长度
|
|
||||||
|
|
||||||
utils: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""组件模型配置"""
|
|
||||||
|
|
||||||
utils_small: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""组件小模型配置"""
|
|
||||||
|
|
||||||
replyer_1: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""normal_chat首要回复模型模型配置"""
|
|
||||||
|
|
||||||
replyer_2: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""normal_chat次要回复模型配置"""
|
|
||||||
|
|
||||||
memory: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""记忆模型配置"""
|
|
||||||
|
|
||||||
emotion: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""情绪模型配置"""
|
|
||||||
|
|
||||||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""视觉语言模型配置"""
|
|
||||||
|
|
||||||
voice: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""语音识别模型配置"""
|
|
||||||
|
|
||||||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""专注工具使用模型配置"""
|
|
||||||
|
|
||||||
planner: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""规划模型配置"""
|
|
||||||
|
|
||||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""嵌入模型配置"""
|
|
||||||
|
|
||||||
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""LPMM实体提取模型配置"""
|
|
||||||
|
|
||||||
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""LPMM RDF构建模型配置"""
|
|
||||||
|
|
||||||
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
|
|
||||||
"""LPMM问答模型配置"""
|
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import hashlib
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -19,14 +18,10 @@ class Individuality:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.name = ""
|
self.name = ""
|
||||||
self.bot_person_id = ""
|
|
||||||
self.meta_info_file_path = "data/personality/meta.json"
|
self.meta_info_file_path = "data/personality/meta.json"
|
||||||
self.personality_data_file_path = "data/personality/personality_data.json"
|
self.personality_data_file_path = "data/personality/personality_data.json"
|
||||||
|
|
||||||
self.model = LLMRequest(
|
self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
|
||||||
model=global_config.model.utils,
|
|
||||||
request_type="individuality.compress",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化个体特征"""
|
"""初始化个体特征"""
|
||||||
@@ -35,9 +30,6 @@ class Individuality:
|
|||||||
personality_side = global_config.personality.personality_side
|
personality_side = global_config.personality.personality_side
|
||||||
identity = global_config.personality.identity
|
identity = global_config.personality.identity
|
||||||
|
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
|
||||||
self.name = bot_nickname
|
self.name = bot_nickname
|
||||||
|
|
||||||
# 检查配置变化,如果变化则清空
|
# 检查配置变化,如果变化则清空
|
||||||
@@ -68,16 +60,6 @@ class Individuality:
|
|||||||
else:
|
else:
|
||||||
logger.error("人设构建失败")
|
logger.error("人设构建失败")
|
||||||
|
|
||||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
|
||||||
if personality_changed or identity_changed:
|
|
||||||
logger.info("将清空数据库中原有的关键词缓存")
|
|
||||||
update_data = {
|
|
||||||
"platform": "system",
|
|
||||||
"user_id": "bot_id",
|
|
||||||
"person_name": self.name,
|
|
||||||
"nickname": self.name,
|
|
||||||
}
|
|
||||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
|
||||||
|
|
||||||
async def get_personality_block(self) -> str:
|
async def get_personality_block(self) -> str:
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
@@ -85,16 +67,16 @@ class Individuality:
|
|||||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||||
else:
|
else:
|
||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
|
|
||||||
# 从文件获取 short_impression
|
# 从文件获取 short_impression
|
||||||
personality, identity = self._get_personality_from_file()
|
personality, identity = self._get_personality_from_file()
|
||||||
|
|
||||||
# 确保short_impression是列表格式且有足够的元素
|
# 确保short_impression是列表格式且有足够的元素
|
||||||
if not personality or not identity:
|
if not personality or not identity:
|
||||||
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
|
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
|
||||||
personality = "友好活泼"
|
personality = "友好活泼"
|
||||||
identity = "人类"
|
identity = "人类"
|
||||||
|
|
||||||
prompt_personality = f"{personality}\n{identity}"
|
prompt_personality = f"{personality}\n{identity}"
|
||||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||||
|
|
||||||
@@ -134,7 +116,6 @@ class Individuality:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (personality_changed, identity_changed)
|
tuple: (personality_changed, identity_changed)
|
||||||
"""
|
"""
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
current_personality_hash, current_identity_hash = self._get_config_hash(
|
current_personality_hash, current_identity_hash = self._get_config_hash(
|
||||||
bot_nickname, personality_core, personality_side, identity
|
bot_nickname, personality_core, personality_side, identity
|
||||||
)
|
)
|
||||||
@@ -152,17 +133,6 @@ class Individuality:
|
|||||||
if identity_changed:
|
if identity_changed:
|
||||||
logger.info("检测到身份配置发生变化")
|
logger.info("检测到身份配置发生变化")
|
||||||
|
|
||||||
# 如果任何一个发生变化,都需要清空info_list(因为这影响整体人设)
|
|
||||||
if personality_changed or identity_changed:
|
|
||||||
logger.info("将清空原有的关键词缓存")
|
|
||||||
update_data = {
|
|
||||||
"platform": "system",
|
|
||||||
"user_id": "bot_id",
|
|
||||||
"person_name": self.name,
|
|
||||||
"nickname": self.name,
|
|
||||||
}
|
|
||||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
|
||||||
|
|
||||||
# 更新元信息文件
|
# 更新元信息文件
|
||||||
new_meta_info = {
|
new_meta_info = {
|
||||||
"personality_hash": current_personality_hash,
|
"personality_hash": current_personality_hash,
|
||||||
@@ -215,7 +185,7 @@ class Individuality:
|
|||||||
|
|
||||||
def _get_personality_from_file(self) -> tuple[str, str]:
|
def _get_personality_from_file(self) -> tuple[str, str]:
|
||||||
"""从文件获取personality数据
|
"""从文件获取personality数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (personality, identity)
|
tuple: (personality, identity)
|
||||||
"""
|
"""
|
||||||
@@ -226,7 +196,7 @@ class Individuality:
|
|||||||
|
|
||||||
def _save_personality_to_file(self, personality: str, identity: str):
|
def _save_personality_to_file(self, personality: str, identity: str):
|
||||||
"""保存personality数据到文件
|
"""保存personality数据到文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
personality: 压缩后的人格描述
|
personality: 压缩后的人格描述
|
||||||
identity: 压缩后的身份描述
|
identity: 压缩后的身份描述
|
||||||
@@ -235,7 +205,7 @@ class Individuality:
|
|||||||
"personality": personality,
|
"personality": personality,
|
||||||
"identity": identity,
|
"identity": identity,
|
||||||
"bot_nickname": self.name,
|
"bot_nickname": self.name,
|
||||||
"last_updated": int(time.time())
|
"last_updated": int(time.time()),
|
||||||
}
|
}
|
||||||
self._save_personality_data(personality_data)
|
self._save_personality_data(personality_data)
|
||||||
|
|
||||||
@@ -269,7 +239,7 @@ class Individuality:
|
|||||||
2. 尽量简洁,不超过30字
|
2. 尽量简洁,不超过30字
|
||||||
3. 直接输出压缩后的内容,不要解释"""
|
3. 直接输出压缩后的内容,不要解释"""
|
||||||
|
|
||||||
response, (_, _) = await self.model.generate_response_async(
|
response, _ = await self.model.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -281,7 +251,7 @@ class Individuality:
|
|||||||
# 压缩失败时使用原始内容
|
# 压缩失败时使用原始内容
|
||||||
if personality_side:
|
if personality_side:
|
||||||
personality_parts.append(personality_side)
|
personality_parts.append(personality_side)
|
||||||
|
|
||||||
if personality_parts:
|
if personality_parts:
|
||||||
personality_result = "。".join(personality_parts)
|
personality_result = "。".join(personality_parts)
|
||||||
else:
|
else:
|
||||||
@@ -308,7 +278,7 @@ class Individuality:
|
|||||||
2. 尽量简洁,不超过30字
|
2. 尽量简洁,不超过30字
|
||||||
3. 直接输出压缩后的内容,不要解释"""
|
3. 直接输出压缩后的内容,不要解释"""
|
||||||
|
|
||||||
response, (_, _) = await self.model.generate_response_async(
|
response, _ = await self.model.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
21
src/llm_models/LICENSE
Normal file
21
src/llm_models/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 Mai.To.The.Gate
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
0
src/llm_models/__init__.py
Normal file
0
src/llm_models/__init__.py
Normal file
98
src/llm_models/exceptions.py
Normal file
98
src/llm_models/exceptions.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# 常见Error Code Mapping (以OpenAI API为例)
|
||||||
|
error_code_mapping = {
|
||||||
|
400: "参数不正确",
|
||||||
|
401: "API-Key错误,认证失败,请检查/config/model_list.toml中的配置是否正确",
|
||||||
|
402: "账号余额不足",
|
||||||
|
403: "模型拒绝访问,可能需要实名或余额不足",
|
||||||
|
404: "Not Found",
|
||||||
|
413: "请求体过大,请尝试压缩图片或减少输入内容",
|
||||||
|
429: "请求过于频繁,请稍后再试",
|
||||||
|
500: "服务器内部故障",
|
||||||
|
503: "服务器负载过高",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkConnectionError(Exception):
|
||||||
|
"""连接异常,常见于网络问题或服务器不可用"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "连接异常,请检查网络连接状态或URL是否正确"
|
||||||
|
|
||||||
|
|
||||||
|
class ReqAbortException(Exception):
|
||||||
|
"""请求异常退出,常见于请求被中断或取消"""
|
||||||
|
|
||||||
|
def __init__(self, message: str | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message or "请求因未知原因异常终止"
|
||||||
|
|
||||||
|
|
||||||
|
class RespNotOkException(Exception):
|
||||||
|
"""请求响应异常,见于请求未能成功响应(非 '200 OK')"""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, message: str | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.status_code in error_code_mapping:
|
||||||
|
return error_code_mapping[self.status_code]
|
||||||
|
elif self.message:
|
||||||
|
return self.message
|
||||||
|
else:
|
||||||
|
return f"未知的异常响应代码:{self.status_code}"
|
||||||
|
|
||||||
|
|
||||||
|
class RespParseException(Exception):
|
||||||
|
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
|
||||||
|
|
||||||
|
def __init__(self, ext_info: Any, message: str | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.ext_info = ext_info
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||||
|
|
||||||
|
|
||||||
|
class PayLoadTooLargeError(Exception):
|
||||||
|
"""自定义异常类,用于处理请求体过大错误"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "请求体过大,请尝试压缩图片或减少输入内容。"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestAbortException(Exception):
|
||||||
|
"""自定义异常类,用于处理请求中断异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDeniedException(Exception):
|
||||||
|
"""自定义异常类,用于处理访问拒绝的异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
8
src/llm_models/model_client/__init__.py
Normal file
8
src/llm_models/model_client/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
used_client_types = {provider.client_type for provider in model_config.api_providers}
|
||||||
|
|
||||||
|
if "openai" in used_client_types:
|
||||||
|
from . import openai_client # noqa: F401
|
||||||
|
if "gemini" in used_client_types:
|
||||||
|
from . import gemini_client # noqa: F401
|
||||||
178
src/llm_models/model_client/base_client.py
Normal file
178
src/llm_models/model_client/base_client.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
|
from ..payload_content.message import Message
|
||||||
|
from ..payload_content.resp_format import RespFormat
|
||||||
|
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageRecord:
|
||||||
|
"""
|
||||||
|
使用记录类
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name: str
|
||||||
|
"""模型名称"""
|
||||||
|
|
||||||
|
provider_name: str
|
||||||
|
"""提供商名称"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
"""提示token数"""
|
||||||
|
|
||||||
|
completion_tokens: int
|
||||||
|
"""完成token数"""
|
||||||
|
|
||||||
|
total_tokens: int
|
||||||
|
"""总token数"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIResponse:
|
||||||
|
"""
|
||||||
|
API响应类
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str | None = None
|
||||||
|
"""响应内容"""
|
||||||
|
|
||||||
|
reasoning_content: str | None = None
|
||||||
|
"""推理内容"""
|
||||||
|
|
||||||
|
tool_calls: list[ToolCall] | None = None
|
||||||
|
"""工具调用 [(工具名称, 工具参数), ...]"""
|
||||||
|
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
"""嵌入向量"""
|
||||||
|
|
||||||
|
usage: UsageRecord | None = None
|
||||||
|
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
|
||||||
|
|
||||||
|
raw_data: Any = None
|
||||||
|
"""响应原始数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClient(ABC):
|
||||||
|
"""
|
||||||
|
基础客户端
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_provider: APIProvider
|
||||||
|
|
||||||
|
def __init__(self, api_provider: APIProvider):
|
||||||
|
self.api_provider = api_provider
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_response(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
message_list: list[Message],
|
||||||
|
tool_options: list[ToolOption] | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
|
stream_response_handler: Optional[
|
||||||
|
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||||
|
] = None,
|
||||||
|
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||||
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取对话响应
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param message_list: 对话体
|
||||||
|
:param tool_options: 工具选项(可选,默认为None)
|
||||||
|
:param max_tokens: 最大token数(可选,默认为1024)
|
||||||
|
:param temperature: 温度(可选,默认为0.7)
|
||||||
|
:param response_format: 响应格式(可选,默认为 NotGiven )
|
||||||
|
:param stream_response_handler: 流式响应处理函数(可选)
|
||||||
|
:param async_response_parser: 响应解析函数(可选)
|
||||||
|
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||||
|
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_embedding(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
embedding_input: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取文本嵌入
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param embedding_input: 嵌入输入文本
|
||||||
|
:return: 嵌入响应
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_audio_transcriptions(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: base64编码的音频数据
|
||||||
|
:extra_params: 附加的请求参数
|
||||||
|
:return: 音频转录响应
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||||
|
"""APIProvider.type -> BaseClient的映射表"""
|
||||||
|
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||||
|
"""APIProvider.name -> BaseClient的映射表"""
|
||||||
|
|
||||||
|
def register_client_class(self, client_type: str):
|
||||||
|
"""
|
||||||
|
注册API客户端类
|
||||||
|
Args:
|
||||||
|
client_class: API客户端类
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||||
|
if not issubclass(cls, BaseClient):
|
||||||
|
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||||
|
self.client_registry[client_type] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
|
||||||
|
"""
|
||||||
|
获取注册的API客户端实例
|
||||||
|
Args:
|
||||||
|
api_provider: APIProvider实例
|
||||||
|
Returns:
|
||||||
|
BaseClient: 注册的API客户端实例
|
||||||
|
"""
|
||||||
|
if api_provider.name not in self.client_instance_cache:
|
||||||
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
|
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
|
return self.client_instance_cache[api_provider.name]
|
||||||
|
|
||||||
|
|
||||||
|
client_registry = ClientRegistry()
|
||||||
561
src/llm_models/model_client/gemini_client.py
Normal file
561
src/llm_models/model_client/gemini_client.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import (
|
||||||
|
Content,
|
||||||
|
Part,
|
||||||
|
FunctionDeclaration,
|
||||||
|
GenerateContentResponse,
|
||||||
|
ContentListUnion,
|
||||||
|
ContentUnion,
|
||||||
|
ThinkingConfig,
|
||||||
|
Tool,
|
||||||
|
GenerateContentConfig,
|
||||||
|
EmbedContentResponse,
|
||||||
|
EmbedContentConfig,
|
||||||
|
SafetySetting,
|
||||||
|
HarmCategory,
|
||||||
|
HarmBlockThreshold,
|
||||||
|
)
|
||||||
|
from google.genai.errors import (
|
||||||
|
ClientError,
|
||||||
|
ServerError,
|
||||||
|
UnknownFunctionCallArgumentError,
|
||||||
|
UnsupportedFunctionError,
|
||||||
|
FunctionInvocationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
|
from ..exceptions import (
|
||||||
|
RespParseException,
|
||||||
|
NetworkConnectionError,
|
||||||
|
RespNotOkException,
|
||||||
|
ReqAbortException,
|
||||||
|
)
|
||||||
|
from ..payload_content.message import Message, RoleType
|
||||||
|
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||||
|
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
|
|
||||||
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
|
gemini_safe_settings = [
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_messages(
|
||||||
|
messages: list[Message],
|
||||||
|
) -> tuple[ContentListUnion, list[str] | None]:
|
||||||
|
"""
|
||||||
|
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||||
|
:param messages: 消息列表
|
||||||
|
:return: 转换后的消息列表(和可能存在的system消息)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_message_item(message: Message) -> Content:
|
||||||
|
"""
|
||||||
|
转换单个消息格式,除了system和tool类型的消息
|
||||||
|
:param message: 消息对象
|
||||||
|
:return: 转换后的消息字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 将openai格式的角色重命名为gemini格式的角色
|
||||||
|
if message.role == RoleType.Assistant:
|
||||||
|
role = "model"
|
||||||
|
elif message.role == RoleType.User:
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
# 添加Content
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
content = [Part.from_text(text=message.content)]
|
||||||
|
elif isinstance(message.content, list):
|
||||||
|
content: List[Part] = []
|
||||||
|
for item in message.content:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
||||||
|
content.append(
|
||||||
|
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")
|
||||||
|
)
|
||||||
|
elif isinstance(item, str):
|
||||||
|
content.append(Part.from_text(text=item))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
|
return Content(role=role, parts=content)
|
||||||
|
|
||||||
|
temp_list: list[ContentUnion] = []
|
||||||
|
system_instructions: list[str] = []
|
||||||
|
for message in messages:
|
||||||
|
if message.role == RoleType.System:
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
system_instructions.append(message.content)
|
||||||
|
else:
|
||||||
|
raise ValueError("你tm怎么往system里面塞图片base64?")
|
||||||
|
elif message.role == RoleType.Tool:
|
||||||
|
if not message.tool_call_id:
|
||||||
|
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
else:
|
||||||
|
temp_list.append(_convert_message_item(message))
|
||||||
|
if system_instructions:
|
||||||
|
# 如果有system消息,就把它加上去
|
||||||
|
ret: tuple = (temp_list, system_instructions)
|
||||||
|
else:
|
||||||
|
# 如果没有system消息,就直接返回
|
||||||
|
ret: tuple = (temp_list, None)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
|
||||||
|
"""
|
||||||
|
转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
|
||||||
|
:param tool_options: 工具选项列表
|
||||||
|
:return: 转换后的工具对象列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_tool_param(tool_option_param: ToolParam) -> dict:
|
||||||
|
"""
|
||||||
|
转换单个工具参数格式
|
||||||
|
:param tool_option_param: 工具参数对象
|
||||||
|
:return: 转换后的工具参数字典
|
||||||
|
"""
|
||||||
|
return_dict: dict[str, Any] = {
|
||||||
|
"type": tool_option_param.param_type.value,
|
||||||
|
"description": tool_option_param.description,
|
||||||
|
}
|
||||||
|
if tool_option_param.enum_values:
|
||||||
|
return_dict["enum"] = tool_option_param.enum_values
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
|
||||||
|
"""
|
||||||
|
转换单个工具项格式
|
||||||
|
:param tool_option: 工具选项对象
|
||||||
|
:return: 转换后的Gemini工具选项对象
|
||||||
|
"""
|
||||||
|
ret: dict[str, Any] = {
|
||||||
|
"name": tool_option.name,
|
||||||
|
"description": tool_option.description,
|
||||||
|
}
|
||||||
|
if tool_option.params:
|
||||||
|
ret["parameters"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||||
|
"required": [param.name for param in tool_option.params if param.required],
|
||||||
|
}
|
||||||
|
ret1 = FunctionDeclaration(**ret)
|
||||||
|
return ret1
|
||||||
|
|
||||||
|
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
|
||||||
|
|
||||||
|
|
||||||
|
def _process_delta(
|
||||||
|
delta: GenerateContentResponse,
|
||||||
|
fc_delta_buffer: io.StringIO,
|
||||||
|
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
|
||||||
|
):
|
||||||
|
if not hasattr(delta, "candidates") or not delta.candidates:
|
||||||
|
raise RespParseException(delta, "响应解析失败,缺失candidates字段")
|
||||||
|
|
||||||
|
if delta.text:
|
||||||
|
fc_delta_buffer.write(delta.text)
|
||||||
|
|
||||||
|
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
||||||
|
for call in delta.function_calls:
|
||||||
|
try:
|
||||||
|
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
|
||||||
|
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||||
|
if not call.id or not call.name:
|
||||||
|
raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段")
|
||||||
|
tool_calls_buffer.append(
|
||||||
|
(
|
||||||
|
call.id,
|
||||||
|
call.name,
|
||||||
|
call.args or {}, # 如果args是None,则转换为一个空字典
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _build_stream_api_resp(
|
||||||
|
_fc_delta_buffer: io.StringIO,
|
||||||
|
_tool_calls_buffer: list[tuple[str, str, dict]],
|
||||||
|
) -> APIResponse:
|
||||||
|
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
||||||
|
resp = APIResponse()
|
||||||
|
|
||||||
|
if _fc_delta_buffer.tell() > 0:
|
||||||
|
# 如果正式内容缓冲区不为空,则将其写入APIResponse对象
|
||||||
|
resp.content = _fc_delta_buffer.getvalue()
|
||||||
|
_fc_delta_buffer.close()
|
||||||
|
if len(_tool_calls_buffer) > 0:
|
||||||
|
# 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表
|
||||||
|
resp.tool_calls = []
|
||||||
|
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
|
||||||
|
if arguments_buffer is not None:
|
||||||
|
arguments = arguments_buffer
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
raise RespParseException(
|
||||||
|
None,
|
||||||
|
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arguments = None
|
||||||
|
|
||||||
|
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_stream_response_handler(
|
||||||
|
resp_stream: AsyncIterator[GenerateContentResponse],
|
||||||
|
interrupt_flag: asyncio.Event | None,
|
||||||
|
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||||
|
"""
|
||||||
|
流式响应处理函数 - 处理Gemini API的流式响应
|
||||||
|
:param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西
|
||||||
|
:return: APIResponse对象
|
||||||
|
"""
|
||||||
|
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||||
|
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||||
|
_usage_record = None # 使用情况记录
|
||||||
|
|
||||||
|
def _insure_buffer_closed():
|
||||||
|
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||||
|
_fc_delta_buffer.close()
|
||||||
|
|
||||||
|
async for chunk in resp_stream:
|
||||||
|
# 检查是否有中断量
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量被设置,则抛出ReqAbortException
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
|
||||||
|
_process_delta(
|
||||||
|
chunk,
|
||||||
|
_fc_delta_buffer,
|
||||||
|
_tool_calls_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunk.usage_metadata:
|
||||||
|
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||||
|
_usage_record = (
|
||||||
|
chunk.usage_metadata.prompt_token_count or 0,
|
||||||
|
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||||
|
chunk.usage_metadata.total_token_count or 0,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return _build_stream_api_resp(
|
||||||
|
_fc_delta_buffer,
|
||||||
|
_tool_calls_buffer,
|
||||||
|
), _usage_record
|
||||||
|
except Exception:
|
||||||
|
# 确保缓冲区被关闭
|
||||||
|
_insure_buffer_closed()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _default_normal_response_parser(
|
||||||
|
resp: GenerateContentResponse,
|
||||||
|
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||||
|
"""
|
||||||
|
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
|
||||||
|
:param resp: 响应对象
|
||||||
|
:return: APIResponse对象
|
||||||
|
"""
|
||||||
|
api_response = APIResponse()
|
||||||
|
|
||||||
|
if not hasattr(resp, "candidates") or not resp.candidates:
|
||||||
|
raise RespParseException(resp, "响应解析失败,缺失candidates字段")
|
||||||
|
try:
|
||||||
|
if resp.candidates[0].content and resp.candidates[0].content.parts:
|
||||||
|
for part in resp.candidates[0].content.parts:
|
||||||
|
if not part.text:
|
||||||
|
continue
|
||||||
|
if part.thought:
|
||||||
|
api_response.reasoning_content = (
|
||||||
|
api_response.reasoning_content + part.text if api_response.reasoning_content else part.text
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析思考内容时发生错误: {e},跳过解析")
|
||||||
|
|
||||||
|
if resp.text:
|
||||||
|
api_response.content = resp.text
|
||||||
|
|
||||||
|
if resp.function_calls:
|
||||||
|
api_response.tool_calls = []
|
||||||
|
for call in resp.function_calls:
|
||||||
|
try:
|
||||||
|
if not isinstance(call.args, dict):
|
||||||
|
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||||
|
if not call.name:
|
||||||
|
raise RespParseException(resp, "响应解析失败,工具调用缺失name字段")
|
||||||
|
api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {}))
|
||||||
|
except Exception as e:
|
||||||
|
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
||||||
|
|
||||||
|
if resp.usage_metadata:
|
||||||
|
_usage_record = (
|
||||||
|
resp.usage_metadata.prompt_token_count or 0,
|
||||||
|
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
|
||||||
|
resp.usage_metadata.total_token_count or 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_usage_record = None
|
||||||
|
|
||||||
|
api_response.raw_data = resp
|
||||||
|
|
||||||
|
return api_response, _usage_record
|
||||||
|
|
||||||
|
|
||||||
|
@client_registry.register_client_class("gemini")
|
||||||
|
class GeminiClient(BaseClient):
|
||||||
|
client: genai.Client
|
||||||
|
|
||||||
|
def __init__(self, api_provider: APIProvider):
|
||||||
|
super().__init__(api_provider)
|
||||||
|
self.client = genai.Client(
|
||||||
|
api_key=api_provider.api_key,
|
||||||
|
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||||
|
|
||||||
|
async def get_response(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
message_list: list[Message],
|
||||||
|
tool_options: list[ToolOption] | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
|
stream_response_handler: Optional[
|
||||||
|
Callable[
|
||||||
|
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
|
||||||
|
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
async_response_parser: Optional[
|
||||||
|
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||||
|
] = None,
|
||||||
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取对话响应
|
||||||
|
Args:
|
||||||
|
model_info: 模型信息
|
||||||
|
message_list: 对话体
|
||||||
|
tool_options: 工具选项(可选,默认为None)
|
||||||
|
max_tokens: 最大token数(可选,默认为1024)
|
||||||
|
temperature: 温度(可选,默认为0.7)
|
||||||
|
response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入)
|
||||||
|
stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
|
||||||
|
async_response_parser: 响应解析函数(可选,默认为default_response_parser)
|
||||||
|
interrupt_flag: 中断信号量(可选,默认为None)
|
||||||
|
Returns:
|
||||||
|
APIResponse对象,包含响应内容、推理内容、工具调用等信息
|
||||||
|
"""
|
||||||
|
if stream_response_handler is None:
|
||||||
|
stream_response_handler = _default_stream_response_handler
|
||||||
|
|
||||||
|
if async_response_parser is None:
|
||||||
|
async_response_parser = _default_normal_response_parser
|
||||||
|
|
||||||
|
# 将messages构造为Gemini API所需的格式
|
||||||
|
messages = _convert_messages(message_list)
|
||||||
|
# 将tool_options转换为Gemini API所需的格式
|
||||||
|
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||||
|
# 将response_format转换为Gemini API所需的格式
|
||||||
|
generation_config_dict = {
|
||||||
|
"max_output_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"response_modalities": ["TEXT"],
|
||||||
|
"thinking_config": ThinkingConfig(
|
||||||
|
include_thoughts=True,
|
||||||
|
thinking_budget=(
|
||||||
|
extra_params["thinking_budget"]
|
||||||
|
if extra_params and "thinking_budget" in extra_params
|
||||||
|
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||||
|
if messages[1]:
|
||||||
|
# 如果有system消息,则将其添加到配置中
|
||||||
|
generation_config_dict["system_instructions"] = messages[1]
|
||||||
|
if response_format and response_format.format_type == RespFormatType.TEXT:
|
||||||
|
generation_config_dict["response_mime_type"] = "text/plain"
|
||||||
|
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
|
||||||
|
generation_config_dict["response_mime_type"] = "application/json"
|
||||||
|
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||||
|
|
||||||
|
generation_config = GenerateContentConfig(**generation_config_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_info.force_stream_mode:
|
||||||
|
req_task = asyncio.create_task(
|
||||||
|
self.client.aio.models.generate_content_stream(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=messages[0],
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not req_task.done():
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量存在且被设置,则取消任务并抛出异常
|
||||||
|
req_task.cancel()
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
|
||||||
|
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
||||||
|
else:
|
||||||
|
req_task = asyncio.create_task(
|
||||||
|
self.client.aio.models.generate_content(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=messages[0],
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not req_task.done():
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量存在且被设置,则取消任务并抛出异常
|
||||||
|
req_task.cancel()
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||||
|
|
||||||
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
|
except (ClientError, ServerError) as e:
|
||||||
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.code, e.message) from None
|
||||||
|
except (
|
||||||
|
UnknownFunctionCallArgumentError,
|
||||||
|
UnsupportedFunctionError,
|
||||||
|
FunctionInvocationError,
|
||||||
|
) as e:
|
||||||
|
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
||||||
|
except Exception as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
if usage_record:
|
||||||
|
resp.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=usage_record[0],
|
||||||
|
completion_tokens=usage_record[1],
|
||||||
|
total_tokens=usage_record[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def get_embedding(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
embedding_input: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取文本嵌入
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param embedding_input: 嵌入输入文本
|
||||||
|
:return: 嵌入响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=embedding_input,
|
||||||
|
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
||||||
|
)
|
||||||
|
except (ClientError, ServerError) as e:
|
||||||
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.code) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
response = APIResponse()
|
||||||
|
|
||||||
|
# 解析嵌入响应和使用情况
|
||||||
|
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
|
||||||
|
response.embedding = raw_response.embeddings[0].values
|
||||||
|
else:
|
||||||
|
raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段")
|
||||||
|
|
||||||
|
response.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=len(embedding_input),
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=len(embedding_input),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_audio_transcriptions(
|
||||||
|
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: 音频文件的Base64编码字符串
|
||||||
|
:param extra_params: 额外参数(可选)
|
||||||
|
:return: 转录响应
|
||||||
|
"""
|
||||||
|
generation_config_dict = {
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
"response_modalities": ["TEXT"],
|
||||||
|
"thinking_config": ThinkingConfig(
|
||||||
|
include_thoughts=True,
|
||||||
|
thinking_budget=(
|
||||||
|
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"safety_settings": gemini_safe_settings,
|
||||||
|
}
|
||||||
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
try:
|
||||||
|
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=[
|
||||||
|
Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part.from_text(text=prompt),
|
||||||
|
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
config=generate_content_config,
|
||||||
|
)
|
||||||
|
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||||
|
except (ClientError, ServerError) as e:
|
||||||
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.code) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
if usage_record:
|
||||||
|
resp.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=usage_record[0],
|
||||||
|
completion_tokens=usage_record[1],
|
||||||
|
total_tokens=usage_record[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
|
||||||
591
src/llm_models/model_client/openai_client.py
Normal file
591
src/llm_models/model_client/openai_client.py
Normal file
@@ -0,0 +1,591 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable, Any, Coroutine, Optional
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
AsyncOpenAI,
|
||||||
|
APIConnectionError,
|
||||||
|
APIStatusError,
|
||||||
|
NOT_GIVEN,
|
||||||
|
AsyncStream,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||||
|
|
||||||
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
|
from ..exceptions import (
|
||||||
|
RespParseException,
|
||||||
|
NetworkConnectionError,
|
||||||
|
RespNotOkException,
|
||||||
|
ReqAbortException,
|
||||||
|
)
|
||||||
|
from ..payload_content.message import Message, RoleType
|
||||||
|
from ..payload_content.resp_format import RespFormat
|
||||||
|
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
|
|
||||||
|
logger = get_logger("OpenAI客户端")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""
|
||||||
|
转换消息格式 - 将消息转换为OpenAI API所需的格式
|
||||||
|
:param messages: 消息列表
|
||||||
|
:return: 转换后的消息列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
|
||||||
|
"""
|
||||||
|
转换单个消息格式
|
||||||
|
:param message: 消息对象
|
||||||
|
:return: 转换后的消息字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 添加Content
|
||||||
|
content: str | list[dict[str, Any]]
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
content = message.content
|
||||||
|
elif isinstance(message.content, list):
|
||||||
|
content = []
|
||||||
|
for item in message.content:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(item, str):
|
||||||
|
content.append({"type": "text", "text": item})
|
||||||
|
else:
|
||||||
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"role": message.role.value,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加工具调用ID
|
||||||
|
if message.role == RoleType.Tool:
|
||||||
|
if not message.tool_call_id:
|
||||||
|
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
ret["tool_call_id"] = message.tool_call_id
|
||||||
|
|
||||||
|
return ret # type: ignore
|
||||||
|
|
||||||
|
return [_convert_message_item(message) for message in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
|
||||||
|
:param tool_options: 工具选项列表
|
||||||
|
:return: 转换后的工具选项列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
转换单个工具参数格式
|
||||||
|
:param tool_option_param: 工具参数对象
|
||||||
|
:return: 转换后的工具参数字典
|
||||||
|
"""
|
||||||
|
return_dict: dict[str, Any] = {
|
||||||
|
"type": tool_option_param.param_type.value,
|
||||||
|
"description": tool_option_param.description,
|
||||||
|
}
|
||||||
|
if tool_option_param.enum_values:
|
||||||
|
return_dict["enum"] = tool_option_param.enum_values
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
转换单个工具项格式
|
||||||
|
:param tool_option: 工具选项对象
|
||||||
|
:return: 转换后的工具选项字典
|
||||||
|
"""
|
||||||
|
ret: dict[str, Any] = {
|
||||||
|
"name": tool_option.name,
|
||||||
|
"description": tool_option.description,
|
||||||
|
}
|
||||||
|
if tool_option.params:
|
||||||
|
ret["parameters"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||||
|
"required": [param.name for param in tool_option.params if param.required],
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": _convert_tool_option_item(tool_option),
|
||||||
|
}
|
||||||
|
for tool_option in tool_options
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _process_delta(
|
||||||
|
delta: ChoiceDelta,
|
||||||
|
has_rc_attr_flag: bool,
|
||||||
|
in_rc_flag: bool,
|
||||||
|
rc_delta_buffer: io.StringIO,
|
||||||
|
fc_delta_buffer: io.StringIO,
|
||||||
|
tool_calls_buffer: list[tuple[str, str, io.StringIO]],
|
||||||
|
) -> bool:
|
||||||
|
# 接收content
|
||||||
|
if has_rc_attr_flag:
|
||||||
|
# 有独立的推理内容块,则无需考虑content内容的判读
|
||||||
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
|
||||||
|
# 如果有推理内容,则将其写入推理内容缓冲区
|
||||||
|
assert isinstance(delta.reasoning_content, str) # type: ignore
|
||||||
|
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
|
||||||
|
elif delta.content:
|
||||||
|
# 如果有正式内容,则将其写入正式内容缓冲区
|
||||||
|
fc_delta_buffer.write(delta.content)
|
||||||
|
elif hasattr(delta, "content") and delta.content is not None:
|
||||||
|
# 没有独立的推理内容块,但有正式内容
|
||||||
|
if in_rc_flag:
|
||||||
|
# 当前在推理内容块中
|
||||||
|
if delta.content == "</think>":
|
||||||
|
# 如果当前内容是</think>,则将其视为推理内容的结束标记,退出推理内容块
|
||||||
|
in_rc_flag = False
|
||||||
|
else:
|
||||||
|
# 其他情况视为推理内容,加入推理内容缓冲区
|
||||||
|
rc_delta_buffer.write(delta.content)
|
||||||
|
elif delta.content == "<think>" and not fc_delta_buffer.getvalue():
|
||||||
|
# 如果当前内容是<think>,且正式内容缓冲区为空,说明<think>为输出的首个token
|
||||||
|
# 则将其视为推理内容的开始标记,进入推理内容块
|
||||||
|
in_rc_flag = True
|
||||||
|
else:
|
||||||
|
# 其他情况视为正式内容,加入正式内容缓冲区
|
||||||
|
fc_delta_buffer.write(delta.content)
|
||||||
|
# 接收tool_calls
|
||||||
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||||
|
tool_call_delta = delta.tool_calls[0]
|
||||||
|
|
||||||
|
if tool_call_delta.index >= len(tool_calls_buffer):
|
||||||
|
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
|
||||||
|
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
|
||||||
|
tool_calls_buffer.append(
|
||||||
|
(
|
||||||
|
tool_call_delta.id,
|
||||||
|
tool_call_delta.function.name,
|
||||||
|
io.StringIO(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。")
|
||||||
|
|
||||||
|
if tool_call_delta.function and tool_call_delta.function.arguments:
|
||||||
|
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
|
||||||
|
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
|
||||||
|
|
||||||
|
return in_rc_flag
|
||||||
|
|
||||||
|
|
||||||
|
def _build_stream_api_resp(
|
||||||
|
_fc_delta_buffer: io.StringIO,
|
||||||
|
_rc_delta_buffer: io.StringIO,
|
||||||
|
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
|
||||||
|
) -> APIResponse:
|
||||||
|
resp = APIResponse()
|
||||||
|
|
||||||
|
if _rc_delta_buffer.tell() > 0:
|
||||||
|
# 如果推理内容缓冲区不为空,则将其写入APIResponse对象
|
||||||
|
resp.reasoning_content = _rc_delta_buffer.getvalue()
|
||||||
|
_rc_delta_buffer.close()
|
||||||
|
if _fc_delta_buffer.tell() > 0:
|
||||||
|
# 如果正式内容缓冲区不为空,则将其写入APIResponse对象
|
||||||
|
resp.content = _fc_delta_buffer.getvalue()
|
||||||
|
_fc_delta_buffer.close()
|
||||||
|
if _tool_calls_buffer:
|
||||||
|
# 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表
|
||||||
|
resp.tool_calls = []
|
||||||
|
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
|
||||||
|
if arguments_buffer.tell() > 0:
|
||||||
|
# 如果参数串缓冲区不为空,则解析为JSON对象
|
||||||
|
raw_arg_data = arguments_buffer.getvalue()
|
||||||
|
arguments_buffer.close()
|
||||||
|
try:
|
||||||
|
arguments = json.loads(repair_json(raw_arg_data))
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
raise RespParseException(
|
||||||
|
None,
|
||||||
|
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}",
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RespParseException(
|
||||||
|
None,
|
||||||
|
f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}",
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
arguments_buffer.close()
|
||||||
|
arguments = None
|
||||||
|
|
||||||
|
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_stream_response_handler(
|
||||||
|
resp_stream: AsyncStream[ChatCompletionChunk],
|
||||||
|
interrupt_flag: asyncio.Event | None,
|
||||||
|
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||||
|
"""
|
||||||
|
流式响应处理函数 - 处理OpenAI API的流式响应
|
||||||
|
:param resp_stream: 流式响应对象
|
||||||
|
:return: APIResponse对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
_has_rc_attr_flag = False # 标记是否有独立的推理内容块
|
||||||
|
_in_rc_flag = False # 标记是否在推理内容块中
|
||||||
|
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
|
||||||
|
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||||
|
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||||
|
_usage_record = None # 使用情况记录
|
||||||
|
|
||||||
|
def _insure_buffer_closed():
|
||||||
|
# 确保缓冲区被关闭
|
||||||
|
if _rc_delta_buffer and not _rc_delta_buffer.closed:
|
||||||
|
_rc_delta_buffer.close()
|
||||||
|
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||||
|
_fc_delta_buffer.close()
|
||||||
|
for _, _, buffer in _tool_calls_buffer:
|
||||||
|
if buffer and not buffer.closed:
|
||||||
|
buffer.close()
|
||||||
|
|
||||||
|
async for event in resp_stream:
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量被设置,则抛出ReqAbortException
|
||||||
|
_insure_buffer_closed()
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
# 空 choices / usage-only 帧的防御
|
||||||
|
if not hasattr(event, "choices") or not event.choices:
|
||||||
|
if hasattr(event, "usage") and event.usage:
|
||||||
|
_usage_record = (
|
||||||
|
event.usage.prompt_tokens or 0,
|
||||||
|
event.usage.completion_tokens or 0,
|
||||||
|
event.usage.total_tokens or 0,
|
||||||
|
)
|
||||||
|
continue # 跳过本帧,避免访问 choices[0]
|
||||||
|
delta = event.choices[0].delta # 获取当前块的delta内容
|
||||||
|
|
||||||
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
|
||||||
|
# 标记:有独立的推理内容块
|
||||||
|
_has_rc_attr_flag = True
|
||||||
|
|
||||||
|
_in_rc_flag = _process_delta(
|
||||||
|
delta,
|
||||||
|
_has_rc_attr_flag,
|
||||||
|
_in_rc_flag,
|
||||||
|
_rc_delta_buffer,
|
||||||
|
_fc_delta_buffer,
|
||||||
|
_tool_calls_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.usage:
|
||||||
|
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||||
|
_usage_record = (
|
||||||
|
event.usage.prompt_tokens or 0,
|
||||||
|
event.usage.completion_tokens or 0,
|
||||||
|
event.usage.total_tokens or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _build_stream_api_resp(
|
||||||
|
_fc_delta_buffer,
|
||||||
|
_rc_delta_buffer,
|
||||||
|
_tool_calls_buffer,
|
||||||
|
), _usage_record
|
||||||
|
except Exception:
|
||||||
|
# 确保缓冲区被关闭
|
||||||
|
_insure_buffer_closed()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
pattern = re.compile(
|
||||||
|
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
"""用于解析推理内容的正则表达式"""
|
||||||
|
|
||||||
|
|
||||||
|
def _default_normal_response_parser(
|
||||||
|
resp: ChatCompletion,
|
||||||
|
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||||
|
"""
|
||||||
|
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
|
||||||
|
:param resp: 响应对象
|
||||||
|
:return: APIResponse对象
|
||||||
|
"""
|
||||||
|
api_response = APIResponse()
|
||||||
|
|
||||||
|
if not hasattr(resp, "choices") or len(resp.choices) == 0:
|
||||||
|
raise RespParseException(resp, "响应解析失败,缺失choices字段")
|
||||||
|
message_part = resp.choices[0].message
|
||||||
|
|
||||||
|
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
|
||||||
|
# 有有效的推理字段
|
||||||
|
api_response.content = message_part.content
|
||||||
|
api_response.reasoning_content = message_part.reasoning_content # type: ignore
|
||||||
|
elif message_part.content:
|
||||||
|
# 提取推理和内容
|
||||||
|
match = pattern.match(message_part.content)
|
||||||
|
if not match:
|
||||||
|
raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
|
||||||
|
if match.group("think") is not None:
|
||||||
|
result = match.group("think").strip(), match.group("content").strip()
|
||||||
|
elif match.group("think_unclosed") is not None:
|
||||||
|
result = match.group("think_unclosed").strip(), None
|
||||||
|
else:
|
||||||
|
result = None, match.group("content_only").strip()
|
||||||
|
api_response.reasoning_content, api_response.content = result
|
||||||
|
|
||||||
|
# 提取工具调用
|
||||||
|
if message_part.tool_calls:
|
||||||
|
api_response.tool_calls = []
|
||||||
|
for call in message_part.tool_calls:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(repair_json(call.function.arguments))
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||||
|
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
||||||
|
|
||||||
|
# 提取Usage信息
|
||||||
|
if resp.usage:
|
||||||
|
_usage_record = (
|
||||||
|
resp.usage.prompt_tokens or 0,
|
||||||
|
resp.usage.completion_tokens or 0,
|
||||||
|
resp.usage.total_tokens or 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_usage_record = None
|
||||||
|
|
||||||
|
# 将原始响应存储在原始数据中
|
||||||
|
api_response.raw_data = resp
|
||||||
|
|
||||||
|
return api_response, _usage_record
|
||||||
|
|
||||||
|
|
||||||
|
@client_registry.register_client_class("openai")
|
||||||
|
class OpenaiClient(BaseClient):
|
||||||
|
def __init__(self, api_provider: APIProvider):
|
||||||
|
super().__init__(api_provider)
|
||||||
|
self.client: AsyncOpenAI = AsyncOpenAI(
|
||||||
|
base_url=api_provider.base_url,
|
||||||
|
api_key=api_provider.api_key,
|
||||||
|
max_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_response(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
message_list: list[Message],
|
||||||
|
tool_options: list[ToolOption] | None = None,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
|
stream_response_handler: Optional[
|
||||||
|
Callable[
|
||||||
|
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||||
|
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
async_response_parser: Optional[
|
||||||
|
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||||
|
] = None,
|
||||||
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取对话响应
|
||||||
|
Args:
|
||||||
|
model_info: 模型信息
|
||||||
|
message_list: 对话体
|
||||||
|
tool_options: 工具选项(可选,默认为None)
|
||||||
|
max_tokens: 最大token数(可选,默认为1024)
|
||||||
|
temperature: 温度(可选,默认为0.7)
|
||||||
|
response_format: 响应格式(可选,默认为 NotGiven )
|
||||||
|
stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
|
||||||
|
async_response_parser: 响应解析函数(可选,默认为default_response_parser)
|
||||||
|
interrupt_flag: 中断信号量(可选,默认为None)
|
||||||
|
Returns:
|
||||||
|
(响应文本, 推理文本, 工具调用, 其他数据)
|
||||||
|
"""
|
||||||
|
if stream_response_handler is None:
|
||||||
|
stream_response_handler = _default_stream_response_handler
|
||||||
|
|
||||||
|
if async_response_parser is None:
|
||||||
|
async_response_parser = _default_normal_response_parser
|
||||||
|
|
||||||
|
# 将messages构造为OpenAI API所需的格式
|
||||||
|
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
|
||||||
|
# 将tool_options转换为OpenAI API所需的格式
|
||||||
|
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_info.force_stream_mode:
|
||||||
|
req_task = asyncio.create_task(
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=True,
|
||||||
|
response_format=NOT_GIVEN,
|
||||||
|
extra_body=extra_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not req_task.done():
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量存在且被设置,则取消任务并抛出异常
|
||||||
|
req_task.cancel()
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
|
||||||
|
|
||||||
|
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
||||||
|
else:
|
||||||
|
# 发送请求并获取响应
|
||||||
|
# start_time = time.time()
|
||||||
|
req_task = asyncio.create_task(
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=False,
|
||||||
|
response_format=NOT_GIVEN,
|
||||||
|
extra_body=extra_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not req_task.done():
|
||||||
|
if interrupt_flag and interrupt_flag.is_set():
|
||||||
|
# 如果中断量存在且被设置,则取消任务并抛出异常
|
||||||
|
req_task.cancel()
|
||||||
|
raise ReqAbortException("请求被外部信号中断")
|
||||||
|
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||||
|
|
||||||
|
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
|
||||||
|
|
||||||
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
|
except APIConnectionError as e:
|
||||||
|
# 重封装APIConnectionError为NetworkConnectionError
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
except APIStatusError as e:
|
||||||
|
# 重封装APIError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.status_code, e.message) from e
|
||||||
|
|
||||||
|
if usage_record:
|
||||||
|
resp.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=usage_record[0],
|
||||||
|
completion_tokens=usage_record[1],
|
||||||
|
total_tokens=usage_record[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def get_embedding(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
embedding_input: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取文本嵌入
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param embedding_input: 嵌入输入文本
|
||||||
|
:return: 嵌入响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response = await self.client.embeddings.create(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
input=embedding_input,
|
||||||
|
extra_body=extra_params,
|
||||||
|
)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
except APIStatusError as e:
|
||||||
|
# 重封装APIError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.status_code) from e
|
||||||
|
|
||||||
|
response = APIResponse()
|
||||||
|
|
||||||
|
# 解析嵌入响应
|
||||||
|
if len(raw_response.data) > 0:
|
||||||
|
response.embedding = raw_response.data[0].embedding
|
||||||
|
else:
|
||||||
|
raise RespParseException(
|
||||||
|
raw_response,
|
||||||
|
"响应解析失败,缺失嵌入数据。",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析使用情况
|
||||||
|
if hasattr(raw_response, "usage"):
|
||||||
|
response.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=raw_response.usage.prompt_tokens or 0,
|
||||||
|
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
|
||||||
|
total_tokens=raw_response.usage.total_tokens or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def get_audio_transcriptions(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: base64编码的音频数据
|
||||||
|
:extra_params: 附加的请求参数
|
||||||
|
:return: 音频转录响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response = await self.client.audio.transcriptions.create(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
|
||||||
|
extra_body=extra_params,
|
||||||
|
)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
except APIStatusError as e:
|
||||||
|
# 重封装APIError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.status_code) from e
|
||||||
|
response = APIResponse()
|
||||||
|
# 解析转录响应
|
||||||
|
if hasattr(raw_response, "text"):
|
||||||
|
response.content = raw_response.text
|
||||||
|
else:
|
||||||
|
raise RespParseException(
|
||||||
|
raw_response,
|
||||||
|
"响应解析失败,缺失转录文本。",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
return ["jpg", "jpeg", "png", "webp", "gif"]
|
||||||
3
src/llm_models/payload_content/__init__.py
Normal file
3
src/llm_models/payload_content/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .tool_option import ToolCall
|
||||||
|
|
||||||
|
__all__ = ["ToolCall"]
|
||||||
107
src/llm_models/payload_content/message.py
Normal file
107
src/llm_models/payload_content/message.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
# 设计这系列类的目的是为未来可能的扩展做准备
|
||||||
|
|
||||||
|
|
||||||
|
class RoleType(Enum):
|
||||||
|
System = "system"
|
||||||
|
User = "user"
|
||||||
|
Assistant = "assistant"
|
||||||
|
Tool = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: RoleType,
|
||||||
|
content: str | list[tuple[str, str] | str],
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化消息对象
|
||||||
|
(不应直接修改Message类,而应使用MessageBuilder类来构建对象)
|
||||||
|
"""
|
||||||
|
self.role: RoleType = role
|
||||||
|
self.content: str | list[tuple[str, str] | str] = content
|
||||||
|
self.tool_call_id: str | None = tool_call_id
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBuilder:
|
||||||
|
def __init__(self):
|
||||||
|
self.__role: RoleType = RoleType.User
|
||||||
|
self.__content: list[tuple[str, str] | str] = []
|
||||||
|
self.__tool_call_id: str | None = None
|
||||||
|
|
||||||
|
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||||
|
"""
|
||||||
|
设置角色(默认为User)
|
||||||
|
:param role: 角色
|
||||||
|
:return: MessageBuilder对象
|
||||||
|
"""
|
||||||
|
self.__role = role
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_text_content(self, text: str) -> "MessageBuilder":
|
||||||
|
"""
|
||||||
|
添加文本内容
|
||||||
|
:param text: 文本内容
|
||||||
|
:return: MessageBuilder对象
|
||||||
|
"""
|
||||||
|
self.__content.append(text)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_image_content(
|
||||||
|
self,
|
||||||
|
image_format: str,
|
||||||
|
image_base64: str,
|
||||||
|
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
|
||||||
|
) -> "MessageBuilder":
|
||||||
|
"""
|
||||||
|
添加图片内容
|
||||||
|
:param image_format: 图片格式
|
||||||
|
:param image_base64: 图片的base64编码
|
||||||
|
:return: MessageBuilder对象
|
||||||
|
"""
|
||||||
|
if image_format.lower() not in support_formats:
|
||||||
|
raise ValueError("不受支持的图片格式")
|
||||||
|
if not image_base64:
|
||||||
|
raise ValueError("图片的base64编码不能为空")
|
||||||
|
self.__content.append((image_format, image_base64))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
|
||||||
|
"""
|
||||||
|
添加工具调用指令(调用时请确保已设置为Tool角色)
|
||||||
|
:param tool_call_id: 工具调用指令的id
|
||||||
|
:return: MessageBuilder对象
|
||||||
|
"""
|
||||||
|
if self.__role != RoleType.Tool:
|
||||||
|
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
|
||||||
|
if not tool_call_id:
|
||||||
|
raise ValueError("工具调用ID不能为空")
|
||||||
|
self.__tool_call_id = tool_call_id
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> Message:
|
||||||
|
"""
|
||||||
|
构建消息对象
|
||||||
|
:return: Message对象
|
||||||
|
"""
|
||||||
|
if len(self.__content) == 0:
|
||||||
|
raise ValueError("内容不能为空")
|
||||||
|
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||||
|
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
role=self.__role,
|
||||||
|
content=(
|
||||||
|
self.__content[0]
|
||||||
|
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
|
||||||
|
else self.__content
|
||||||
|
),
|
||||||
|
tool_call_id=self.__tool_call_id,
|
||||||
|
)
|
||||||
223
src/llm_models/payload_content/resp_format.py
Normal file
223
src/llm_models/payload_content/resp_format.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import TypedDict, Required
|
||||||
|
|
||||||
|
|
||||||
|
class RespFormatType(Enum):
|
||||||
|
TEXT = "text" # 文本
|
||||||
|
JSON_OBJ = "json_object" # JSON
|
||||||
|
JSON_SCHEMA = "json_schema" # JSON Schema
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchema(TypedDict, total=False):
|
||||||
|
name: Required[str]
|
||||||
|
"""
|
||||||
|
The name of the response format.
|
||||||
|
|
||||||
|
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
|
||||||
|
of 64.
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: Optional[str]
|
||||||
|
"""
|
||||||
|
A description of what the response format is for, used by the model to determine
|
||||||
|
how to respond in the format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema: dict[str, object]
|
||||||
|
"""
|
||||||
|
The schema for the response format, described as a JSON Schema object. Learn how
|
||||||
|
to build JSON schemas [here](https://json-schema.org/).
|
||||||
|
"""
|
||||||
|
|
||||||
|
strict: Optional[bool]
|
||||||
|
"""
|
||||||
|
Whether to enable strict schema adherence when generating the output. If set to
|
||||||
|
true, the model will always follow the exact schema defined in the `schema`
|
||||||
|
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
|
||||||
|
learn more, read the
|
||||||
|
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _json_schema_type_check(instance) -> str | None:
|
||||||
|
if "name" not in instance:
|
||||||
|
return "schema必须包含'name'字段"
|
||||||
|
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||||
|
return "schema的'name'字段必须是非空字符串"
|
||||||
|
if "description" in instance and (
|
||||||
|
not isinstance(instance["description"], str)
|
||||||
|
or instance["description"].strip() == ""
|
||||||
|
):
|
||||||
|
return "schema的'description'字段只能填入非空字符串"
|
||||||
|
if "schema" not in instance:
|
||||||
|
return "schema必须包含'schema'字段"
|
||||||
|
elif not isinstance(instance["schema"], dict):
|
||||||
|
return "schema的'schema'字段必须是字典,详见https://json-schema.org/"
|
||||||
|
if "strict" in instance and not isinstance(instance["strict"], bool):
|
||||||
|
return "schema的'strict'字段只能填入布尔值"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
|
||||||
|
"""
|
||||||
|
递归移除JSON Schema中的title字段
|
||||||
|
"""
|
||||||
|
if isinstance(schema, list):
|
||||||
|
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||||
|
for idx, item in enumerate(schema):
|
||||||
|
if isinstance(item, (dict, list)):
|
||||||
|
schema[idx] = _remove_title(item)
|
||||||
|
elif isinstance(schema, dict):
|
||||||
|
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||||
|
if "title" in schema:
|
||||||
|
del schema["title"]
|
||||||
|
for key, value in schema.items():
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
schema[key] = _remove_title(value)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
链接JSON Schema中的definitions字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
def link_definitions_recursive(
|
||||||
|
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
递归链接JSON Schema中的definitions字段
|
||||||
|
:param path: 当前路径
|
||||||
|
:param sub_schema: 子Schema
|
||||||
|
:param defs: Schema定义集
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if isinstance(sub_schema, list):
|
||||||
|
# 如果当前Schema是列表,则遍历每个元素
|
||||||
|
for i in range(len(sub_schema)):
|
||||||
|
if isinstance(sub_schema[i], dict):
|
||||||
|
sub_schema[i] = link_definitions_recursive(
|
||||||
|
f"{path}/{str(i)}", sub_schema[i], defs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 否则为字典
|
||||||
|
if "$defs" in sub_schema:
|
||||||
|
# 如果当前Schema有$def字段,则将其添加到defs中
|
||||||
|
key_prefix = f"{path}/$defs/"
|
||||||
|
for key, value in sub_schema["$defs"].items():
|
||||||
|
def_key = key_prefix + key
|
||||||
|
if def_key not in defs:
|
||||||
|
defs[def_key] = value
|
||||||
|
del sub_schema["$defs"]
|
||||||
|
if "$ref" in sub_schema:
|
||||||
|
# 如果当前Schema有$ref字段,则将其替换为defs中的定义
|
||||||
|
def_key = sub_schema["$ref"]
|
||||||
|
if def_key in defs:
|
||||||
|
sub_schema = defs[def_key]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
|
||||||
|
# 遍历键值对
|
||||||
|
for key, value in sub_schema.items():
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
# 如果当前值是字典或列表,则递归调用
|
||||||
|
sub_schema[key] = link_definitions_recursive(
|
||||||
|
f"{path}/{key}", value, defs
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub_schema
|
||||||
|
|
||||||
|
return link_definitions_recursive("#", schema, {})
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
递归移除JSON Schema中的$defs字段
|
||||||
|
"""
|
||||||
|
if isinstance(schema, list):
|
||||||
|
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||||
|
for idx, item in enumerate(schema):
|
||||||
|
if isinstance(item, (dict, list)):
|
||||||
|
schema[idx] = _remove_title(item)
|
||||||
|
elif isinstance(schema, dict):
|
||||||
|
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||||
|
if "$defs" in schema:
|
||||||
|
del schema["$defs"]
|
||||||
|
for key, value in schema.items():
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
schema[key] = _remove_title(value)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class RespFormat:
|
||||||
|
"""
|
||||||
|
响应格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_schema_from_model(schema):
|
||||||
|
json_schema = {
|
||||||
|
"name": schema.__name__,
|
||||||
|
"schema": _remove_defs(
|
||||||
|
_link_definitions(_remove_title(schema.model_json_schema()))
|
||||||
|
),
|
||||||
|
"strict": False,
|
||||||
|
}
|
||||||
|
if schema.__doc__:
|
||||||
|
json_schema["description"] = schema.__doc__
|
||||||
|
return json_schema
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
format_type: RespFormatType = RespFormatType.TEXT,
|
||||||
|
schema: type | JsonSchema | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
响应格式
|
||||||
|
:param format_type: 响应格式类型(默认为文本)
|
||||||
|
:param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效)
|
||||||
|
"""
|
||||||
|
self.format_type: RespFormatType = format_type
|
||||||
|
|
||||||
|
if format_type == RespFormatType.JSON_SCHEMA:
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if check_msg := _json_schema_type_check(schema):
|
||||||
|
raise ValueError(f"schema格式不正确,{check_msg}")
|
||||||
|
|
||||||
|
self.schema = schema
|
||||||
|
elif issubclass(schema, BaseModel):
|
||||||
|
try:
|
||||||
|
json_schema = self._generate_schema_from_model(schema)
|
||||||
|
|
||||||
|
self.schema = json_schema
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
|
||||||
|
f"{schema.__name__}:\n"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
|
||||||
|
else:
|
||||||
|
self.schema = None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
将响应格式转换为字典
|
||||||
|
:return: 字典
|
||||||
|
"""
|
||||||
|
if self.schema:
|
||||||
|
return {
|
||||||
|
"format_type": self.format_type.value,
|
||||||
|
"schema": self.schema,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"format_type": self.format_type.value,
|
||||||
|
}
|
||||||
163
src/llm_models/payload_content/tool_option.py
Normal file
163
src/llm_models/payload_content/tool_option.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParamType(Enum):
|
||||||
|
"""
|
||||||
|
工具调用参数类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRING = "string" # 字符串
|
||||||
|
INTEGER = "integer" # 整型
|
||||||
|
FLOAT = "float" # 浮点型
|
||||||
|
BOOLEAN = "bool" # 布尔型
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParam:
|
||||||
|
"""
|
||||||
|
工具调用参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
param_type: ToolParamType,
|
||||||
|
description: str,
|
||||||
|
required: bool,
|
||||||
|
enum_values: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化工具调用参数
|
||||||
|
(不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象)
|
||||||
|
:param name: 参数名称
|
||||||
|
:param param_type: 参数类型
|
||||||
|
:param description: 参数描述
|
||||||
|
:param required: 是否必填
|
||||||
|
"""
|
||||||
|
self.name: str = name
|
||||||
|
self.param_type: ToolParamType = param_type
|
||||||
|
self.description: str = description
|
||||||
|
self.required: bool = required
|
||||||
|
self.enum_values: list[str] | None = enum_values
|
||||||
|
|
||||||
|
|
||||||
|
class ToolOption:
|
||||||
|
"""
|
||||||
|
工具调用项
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
params: list[ToolParam] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化工具调用项
|
||||||
|
(不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象)
|
||||||
|
:param name: 工具名称
|
||||||
|
:param description: 工具描述
|
||||||
|
:param params: 工具参数列表
|
||||||
|
"""
|
||||||
|
self.name: str = name
|
||||||
|
self.description: str = description
|
||||||
|
self.params: list[ToolParam] | None = params
|
||||||
|
|
||||||
|
|
||||||
|
class ToolOptionBuilder:
|
||||||
|
"""
|
||||||
|
工具调用项构建器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.__name: str = ""
|
||||||
|
self.__description: str = ""
|
||||||
|
self.__params: list[ToolParam] = []
|
||||||
|
|
||||||
|
def set_name(self, name: str) -> "ToolOptionBuilder":
|
||||||
|
"""
|
||||||
|
设置工具名称
|
||||||
|
:param name: 工具名称
|
||||||
|
:return: ToolBuilder实例
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
raise ValueError("工具名称不能为空")
|
||||||
|
self.__name = name
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_description(self, description: str) -> "ToolOptionBuilder":
|
||||||
|
"""
|
||||||
|
设置工具描述
|
||||||
|
:param description: 工具描述
|
||||||
|
:return: ToolBuilder实例
|
||||||
|
"""
|
||||||
|
if not description:
|
||||||
|
raise ValueError("工具描述不能为空")
|
||||||
|
self.__description = description
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_param(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
param_type: ToolParamType,
|
||||||
|
description: str,
|
||||||
|
required: bool = False,
|
||||||
|
enum_values: list[str] | None = None,
|
||||||
|
) -> "ToolOptionBuilder":
|
||||||
|
"""
|
||||||
|
添加工具参数
|
||||||
|
:param name: 参数名称
|
||||||
|
:param param_type: 参数类型
|
||||||
|
:param description: 参数描述
|
||||||
|
:param required: 是否必填(默认为False)
|
||||||
|
:return: ToolBuilder实例
|
||||||
|
"""
|
||||||
|
if not name or not description:
|
||||||
|
raise ValueError("参数名称/描述不能为空")
|
||||||
|
|
||||||
|
self.__params.append(
|
||||||
|
ToolParam(
|
||||||
|
name=name,
|
||||||
|
param_type=param_type,
|
||||||
|
description=description,
|
||||||
|
required=required,
|
||||||
|
enum_values=enum_values,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
"""
|
||||||
|
构建工具调用项
|
||||||
|
:return: 工具调用项
|
||||||
|
"""
|
||||||
|
if self.__name == "" or self.__description == "":
|
||||||
|
raise ValueError("工具名称/描述不能为空")
|
||||||
|
|
||||||
|
return ToolOption(
|
||||||
|
name=self.__name,
|
||||||
|
description=self.__description,
|
||||||
|
params=None if len(self.__params) == 0 else self.__params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall:
|
||||||
|
"""
|
||||||
|
来自模型反馈的工具调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
call_id: str,
|
||||||
|
func_name: str,
|
||||||
|
args: dict | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化工具调用
|
||||||
|
:param call_id: 工具调用ID
|
||||||
|
:param func_name: 要调用的函数名称
|
||||||
|
:param args: 工具调用参数
|
||||||
|
"""
|
||||||
|
self.call_id: str = call_id
|
||||||
|
self.func_name: str = func_name
|
||||||
|
self.args: dict | None = args
|
||||||
189
src/llm_models/utils.py
Normal file
189
src/llm_models/utils.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
|
from src.common.database.database_model import LLMUsage
|
||||||
|
from src.config.api_ada_configs import ModelInfo
|
||||||
|
from .payload_content.message import Message, MessageBuilder
|
||||||
|
from .model_client.base_client import UsageRecord
|
||||||
|
|
||||||
|
logger = get_logger("消息压缩工具")
|
||||||
|
|
||||||
|
|
||||||
|
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
|
||||||
|
"""
|
||||||
|
压缩消息列表中的图片
|
||||||
|
:param messages: 消息列表
|
||||||
|
:param img_target_size: 图片目标大小,默认1MB
|
||||||
|
:return: 压缩后的消息列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reformat_static_image(image_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
将静态图片转换为JPEG格式
|
||||||
|
:param image_data: 图片数据
|
||||||
|
:return: 转换后的图片数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = Image.open(image_data)
|
||||||
|
|
||||||
|
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
|
||||||
|
# 静态图像,转换为JPEG格式
|
||||||
|
reformated_image_data = io.BytesIO()
|
||||||
|
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||||
|
image_data = reformated_image_data.getvalue()
|
||||||
|
|
||||||
|
return image_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换格式失败: {str(e)}")
|
||||||
|
return image_data
|
||||||
|
|
||||||
|
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||||
|
"""
|
||||||
|
缩放图片
|
||||||
|
:param image_data: 图片数据
|
||||||
|
:param scale: 缩放比例
|
||||||
|
:return: 缩放后的图片数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = Image.open(image_data)
|
||||||
|
|
||||||
|
# 原始尺寸
|
||||||
|
original_size = (image.width, image.height)
|
||||||
|
|
||||||
|
# 计算新的尺寸
|
||||||
|
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
||||||
|
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
|
||||||
|
if getattr(image, "is_animated", False):
|
||||||
|
# 动态图片,处理所有帧
|
||||||
|
frames = []
|
||||||
|
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
|
||||||
|
for frame_idx in range(getattr(image, "n_frames", 1)):
|
||||||
|
image.seek(frame_idx)
|
||||||
|
new_frame = image.copy()
|
||||||
|
new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
frames.append(new_frame)
|
||||||
|
|
||||||
|
# 保存到缓冲区
|
||||||
|
frames[0].save(
|
||||||
|
output_buffer,
|
||||||
|
format="GIF",
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
optimize=True,
|
||||||
|
duration=image.info.get("duration", 100),
|
||||||
|
loop=image.info.get("loop", 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 静态图片,直接缩放保存
|
||||||
|
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||||
|
|
||||||
|
return output_buffer.getvalue(), original_size, new_size
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片缩放失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return image_data, None, None
|
||||||
|
|
||||||
|
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
|
||||||
|
original_b64_data_size = len(base64_data) # 计算原始数据大小
|
||||||
|
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
# 先尝试转换格式为JPEG
|
||||||
|
image_data = reformat_static_image(image_data)
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
if len(base64_data) <= target_size:
|
||||||
|
# 如果转换后小于目标大小,直接返回
|
||||||
|
logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB")
|
||||||
|
return base64_data
|
||||||
|
|
||||||
|
# 如果转换后仍然大于目标大小,进行尺寸压缩
|
||||||
|
scale = min(1.0, target_size / len(base64_data))
|
||||||
|
image_data, original_size, new_size = rescale_image(image_data, scale)
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
if original_size and new_size:
|
||||||
|
logger.info(
|
||||||
|
f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n"
|
||||||
|
f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return base64_data
|
||||||
|
|
||||||
|
compressed_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
# 检查content,如有图片则压缩
|
||||||
|
message_builder = MessageBuilder()
|
||||||
|
for content_item in message.content:
|
||||||
|
if isinstance(content_item, tuple):
|
||||||
|
# 图片,进行压缩
|
||||||
|
message_builder.add_image_content(
|
||||||
|
content_item[0],
|
||||||
|
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_builder.add_text_content(content_item)
|
||||||
|
compressed_messages.append(message_builder.build())
|
||||||
|
else:
|
||||||
|
compressed_messages.append(message)
|
||||||
|
|
||||||
|
return compressed_messages
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsageRecorder:
|
||||||
|
"""
|
||||||
|
LLM使用情况记录器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||||
|
db.create_tables([LLMUsage], safe=True)
|
||||||
|
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||||
|
|
||||||
|
def record_usage_to_database(
|
||||||
|
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||||
|
):
|
||||||
|
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||||
|
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||||
|
total_cost = round(input_cost + output_cost, 6)
|
||||||
|
try:
|
||||||
|
# 使用 Peewee 模型创建记录
|
||||||
|
LLMUsage.create(
|
||||||
|
model_name=model_info.model_identifier,
|
||||||
|
model_assign_name=model_info.name,
|
||||||
|
model_api_provider=model_info.api_provider,
|
||||||
|
user_id=user_id,
|
||||||
|
request_type=request_type,
|
||||||
|
endpoint=endpoint,
|
||||||
|
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||||
|
completion_tokens=model_usage.completion_tokens or 0,
|
||||||
|
total_tokens=model_usage.total_tokens or 0,
|
||||||
|
cost=total_cost or 0.0,
|
||||||
|
time_cost = round(time_cost or 0.0, 3),
|
||||||
|
status="success",
|
||||||
|
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||||
|
f"用户: {user_id}, 类型: {request_type}, "
|
||||||
|
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
|
||||||
|
f"总计: {model_usage.total_tokens}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||||
|
|
||||||
|
llm_usage_recorder = LLMUsageRecorder()
|
||||||
File diff suppressed because it is too large
Load Diff
59
src/main.py
59
src/main.py
@@ -2,12 +2,10 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from maim_message import MessageServer
|
from maim_message import MessageServer
|
||||||
|
|
||||||
from src.chat.express.expression_learner import get_expression_learner
|
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
from src.common.remote import TelemetryHeartBeatTask
|
||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
from src.chat.willing.willing_manager import get_willing_manager
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
@@ -16,6 +14,7 @@ from src.individuality.individuality import get_individuality, Individuality
|
|||||||
from src.common.server import get_global_server, Server
|
from src.common.server import get_global_server, Server
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from src.migrate_helper.migrate import check_and_run_migrations
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器
|
# 导入新的插件管理器
|
||||||
@@ -32,8 +31,6 @@ if global_config.memory.enable_memory:
|
|||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
willing_manager = get_willing_manager()
|
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
|
||||||
@@ -53,12 +50,22 @@ class MainSystem:
|
|||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
logger.debug(f"正在唤醒{global_config.bot.nickname}......")
|
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||||
|
|
||||||
# 其他初始化任务
|
# 其他初始化任务
|
||||||
await asyncio.gather(self._init_components())
|
await asyncio.gather(self._init_components())
|
||||||
|
|
||||||
logger.debug("系统初始化完成")
|
logger.info(f"""
|
||||||
|
--------------------------------
|
||||||
|
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
|
||||||
|
--------------------------------
|
||||||
|
如果想要自定义{global_config.bot.nickname}的功能,请查阅:https://docs.mai-mai.org/manual/usage/
|
||||||
|
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
|
||||||
|
--------------------------------
|
||||||
|
如果你想要编写或了解插件相关内容,请访问开发文档https://docs.mai-mai.org/develop/
|
||||||
|
--------------------------------
|
||||||
|
如果你需要查阅模型的消耗以及麦麦的统计数据,请访问根目录的maibot_statistics.html文件
|
||||||
|
""")
|
||||||
|
|
||||||
async def _init_components(self):
|
async def _init_components(self):
|
||||||
"""初始化其他组件"""
|
"""初始化其他组件"""
|
||||||
@@ -84,11 +91,6 @@ class MainSystem:
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
# 启动愿望管理器
|
|
||||||
await willing_manager.async_task_starter()
|
|
||||||
|
|
||||||
logger.info("willing管理器初始化成功")
|
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
@@ -115,6 +117,9 @@ class MainSystem:
|
|||||||
|
|
||||||
# 初始化个体特征
|
# 初始化个体特征
|
||||||
await self.individuality.initialize()
|
await self.individuality.initialize()
|
||||||
|
|
||||||
|
await check_and_run_migrations()
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init_time = int(1000 * (time.time() - init_start_time))
|
init_time = int(1000 * (time.time() - init_start_time))
|
||||||
@@ -136,23 +141,14 @@ class MainSystem:
|
|||||||
if global_config.memory.enable_memory and self.hippocampus_manager:
|
if global_config.memory.enable_memory and self.hippocampus_manager:
|
||||||
tasks.extend(
|
tasks.extend(
|
||||||
[
|
[
|
||||||
self.build_memory_task(),
|
# 移除记忆构建的定期调用,改为在heartFC_chat.py中调用
|
||||||
|
# self.build_memory_task(),
|
||||||
self.forget_memory_task(),
|
self.forget_memory_task(),
|
||||||
self.consolidate_memory_task(),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
tasks.append(self.learn_and_store_expression_task())
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def build_memory_task(self):
|
|
||||||
"""记忆构建任务"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(global_config.memory.memory_build_interval)
|
|
||||||
logger.info("正在进行记忆构建")
|
|
||||||
await self.hippocampus_manager.build_memory() # type: ignore
|
|
||||||
|
|
||||||
async def forget_memory_task(self):
|
async def forget_memory_task(self):
|
||||||
"""记忆遗忘任务"""
|
"""记忆遗忘任务"""
|
||||||
while True:
|
while True:
|
||||||
@@ -161,24 +157,7 @@ class MainSystem:
|
|||||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||||
|
|
||||||
async def consolidate_memory_task(self):
|
|
||||||
"""记忆整合任务"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
|
||||||
logger.info("[记忆整合] 开始整合记忆...")
|
|
||||||
await self.hippocampus_manager.consolidate_memory() # type: ignore
|
|
||||||
logger.info("[记忆整合] 记忆整合完成")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def learn_and_store_expression_task():
|
|
||||||
"""学习并存储表达方式任务"""
|
|
||||||
expression_learner = get_expression_learner()
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(global_config.expression.learning_interval)
|
|
||||||
if global_config.expression.enable_expression_learning and global_config.expression.enable_expression:
|
|
||||||
logger.info("[表达方式学习] 开始学习表达方式...")
|
|
||||||
await expression_learner.learn_and_store_expression()
|
|
||||||
logger.info("[表达方式学习] 表达方式学习完成")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -192,3 +171,5 @@ async def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
[inner]
|
|
||||||
version = "1.1.0"
|
|
||||||
|
|
||||||
#----以下是S4U聊天系统配置文件----
|
|
||||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
|
||||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
|
||||||
#
|
|
||||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
|
||||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
|
||||||
#
|
|
||||||
# 版本格式:主版本号.次版本号.修订号
|
|
||||||
#----S4U配置说明结束----
|
|
||||||
|
|
||||||
[s4u]
|
|
||||||
# 消息管理配置
|
|
||||||
message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
|
||||||
recent_message_keep_count = 8 # 保留最近N条消息,超出范围的普通消息将被移除
|
|
||||||
|
|
||||||
# 优先级系统配置
|
|
||||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
|
||||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
|
||||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
|
||||||
|
|
||||||
# 打字效果配置
|
|
||||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
|
||||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
|
||||||
|
|
||||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
|
||||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
|
||||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
|
||||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
|
||||||
|
|
||||||
# 系统功能开关
|
|
||||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
|
||||||
enable_loading_indicator = true # 是否显示加载提示
|
|
||||||
|
|
||||||
enable_streaming_output = false # 是否启用流式输出,false时全部生成后一次性发送
|
|
||||||
|
|
||||||
max_context_message_length = 30
|
|
||||||
max_core_message_length = 20
|
|
||||||
|
|
||||||
# 模型配置
|
|
||||||
[models]
|
|
||||||
# 主要对话模型配置
|
|
||||||
[models.chat]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
enable_thinking = false
|
|
||||||
|
|
||||||
# 规划模型配置
|
|
||||||
[models.motion]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
enable_thinking = false
|
|
||||||
|
|
||||||
# 情感分析模型配置
|
|
||||||
[models.emotion]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 记忆模型配置
|
|
||||||
[models.memory]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 工具使用模型配置
|
|
||||||
[models.tool_use]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 嵌入模型配置
|
|
||||||
[models.embedding]
|
|
||||||
name = "text-embedding-v1"
|
|
||||||
provider = "OPENAI"
|
|
||||||
dimension = 1024
|
|
||||||
|
|
||||||
# 视觉语言模型配置
|
|
||||||
[models.vlm]
|
|
||||||
name = "qwen-vl-plus"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 知识库模型配置
|
|
||||||
[models.knowledge]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 实体提取模型配置
|
|
||||||
[models.entity_extract]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 问答模型配置
|
|
||||||
[models.qa]
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
|
|
||||||
# 兼容性配置(已废弃,请使用models.motion)
|
|
||||||
[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
|
||||||
# 强烈建议使用免费的小模型
|
|
||||||
name = "qwen3-8b"
|
|
||||||
provider = "BAILIAN"
|
|
||||||
pri_in = 0.5
|
|
||||||
pri_out = 2
|
|
||||||
temp = 0.7
|
|
||||||
enable_thinking = false # 是否启用思考
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "1.1.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
#----以下是S4U聊天系统配置文件----
|
#----以下是S4U聊天系统配置文件----
|
||||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||||
@@ -12,6 +12,7 @@ version = "1.1.0"
|
|||||||
#----S4U配置说明结束----
|
#----S4U配置说明结束----
|
||||||
|
|
||||||
[s4u]
|
[s4u]
|
||||||
|
enable_s4u = false
|
||||||
# 消息管理配置
|
# 消息管理配置
|
||||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
ENABLE_S4U = False
|
|
||||||
@@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
|||||||
import time
|
import time
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import model_config
|
||||||
from src.chat.message_receive.message import MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
@@ -32,10 +34,8 @@ def init_prompt():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MaiThinking:
|
class MaiThinking:
|
||||||
def __init__(self,chat_id):
|
def __init__(self, chat_id):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||||
self.platform = self.chat_stream.platform
|
self.platform = self.chat_stream.platform
|
||||||
@@ -44,11 +44,11 @@ class MaiThinking:
|
|||||||
self.is_group = True
|
self.is_group = True
|
||||||
else:
|
else:
|
||||||
self.is_group = False
|
self.is_group = False
|
||||||
|
|
||||||
self.s4u_message_processor = S4UMessageProcessor()
|
self.s4u_message_processor = S4UMessageProcessor()
|
||||||
|
|
||||||
self.mind = ""
|
self.mind = ""
|
||||||
|
|
||||||
self.memory_block = ""
|
self.memory_block = ""
|
||||||
self.relation_info_block = ""
|
self.relation_info_block = ""
|
||||||
self.time_block = ""
|
self.time_block = ""
|
||||||
@@ -59,17 +59,13 @@ class MaiThinking:
|
|||||||
self.identity = ""
|
self.identity = ""
|
||||||
self.sender = ""
|
self.sender = ""
|
||||||
self.target = ""
|
self.target = ""
|
||||||
|
|
||||||
self.thinking_model = LLMRequest(
|
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
|
||||||
model=global_config.model.replyer_1,
|
|
||||||
request_type="thinking",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def do_think_before_response(self):
|
async def do_think_before_response(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def do_think_after_response(self,reponse:str):
|
async def do_think_after_response(self, reponse: str):
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"after_response_think_prompt",
|
"after_response_think_prompt",
|
||||||
mind=self.mind,
|
mind=self.mind,
|
||||||
@@ -85,47 +81,44 @@ class MaiThinking:
|
|||||||
sender=self.sender,
|
sender=self.sender,
|
||||||
target=self.target,
|
target=self.target,
|
||||||
)
|
)
|
||||||
|
|
||||||
result, _ = await self.thinking_model.generate_response_async(prompt)
|
result, _ = await self.thinking_model.generate_response_async(prompt)
|
||||||
self.mind = result
|
self.mind = result
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
||||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||||
|
|
||||||
|
|
||||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||||
await self.s4u_message_processor.process_message(msg_recv)
|
await self.s4u_message_processor.process_message(msg_recv)
|
||||||
internal_manager.set_internal_state(self.mind)
|
internal_manager.set_internal_state(self.mind)
|
||||||
|
|
||||||
|
|
||||||
async def do_think_when_receive_message(self):
|
async def do_think_when_receive_message(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def build_internal_message_recv(self,message_text:str):
|
async def build_internal_message_recv(self, message_text: str):
|
||||||
|
|
||||||
msg_id = f"internal_{time.time()}"
|
msg_id = f"internal_{time.time()}"
|
||||||
|
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"message_info": {
|
"message_info": {
|
||||||
"message_id": msg_id,
|
"message_id": msg_id,
|
||||||
"time": time.time(),
|
"time": time.time(),
|
||||||
"user_info": {
|
"user_info": {
|
||||||
"user_id": "internal", # 内部用户ID
|
"user_id": "internal", # 内部用户ID
|
||||||
"user_nickname": "内心", # 内部昵称
|
"user_nickname": "内心", # 内部昵称
|
||||||
"platform": self.platform, # 平台标记为 internal
|
"platform": self.platform, # 平台标记为 internal
|
||||||
# 其他 user_info 字段按需补充
|
# 其他 user_info 字段按需补充
|
||||||
},
|
},
|
||||||
"platform": self.platform, # 平台
|
"platform": self.platform, # 平台
|
||||||
# 其他 message_info 字段按需补充
|
# 其他 message_info 字段按需补充
|
||||||
},
|
},
|
||||||
"message_segment": {
|
"message_segment": {
|
||||||
"type": "text", # 消息类型
|
"type": "text", # 消息类型
|
||||||
"data": message_text, # 消息内容
|
"data": message_text, # 消息内容
|
||||||
# 其他 segment 字段按需补充
|
# 其他 segment 字段按需补充
|
||||||
},
|
},
|
||||||
"raw_message": message_text, # 原始消息内容
|
"raw_message": message_text, # 原始消息内容
|
||||||
"processed_plain_text": message_text, # 处理后的纯文本
|
"processed_plain_text": message_text, # 处理后的纯文本
|
||||||
# 下面这些字段可选,根据 MessageRecv 需要
|
# 下面这些字段可选,根据 MessageRecv 需要
|
||||||
"is_emoji": False,
|
"is_emoji": False,
|
||||||
"has_emoji": False,
|
"has_emoji": False,
|
||||||
@@ -139,45 +132,36 @@ class MaiThinking:
|
|||||||
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
||||||
"interest_value": 1.0,
|
"interest_value": 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.is_group:
|
if self.is_group:
|
||||||
message_dict["message_info"]["group_info"] = {
|
message_dict["message_info"]["group_info"] = {
|
||||||
"platform": self.platform,
|
"platform": self.platform,
|
||||||
"group_id": self.chat_stream.group_info.group_id,
|
"group_id": self.chat_stream.group_info.group_id,
|
||||||
"group_name": self.chat_stream.group_info.group_name,
|
"group_name": self.chat_stream.group_info.group_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
msg_recv = MessageRecvS4U(message_dict)
|
msg_recv = MessageRecvS4U(message_dict)
|
||||||
msg_recv.chat_info = self.chat_info
|
msg_recv.chat_info = self.chat_info
|
||||||
msg_recv.chat_stream = self.chat_stream
|
msg_recv.chat_stream = self.chat_stream
|
||||||
msg_recv.is_internal = True
|
msg_recv.is_internal = True
|
||||||
|
|
||||||
return msg_recv
|
return msg_recv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MaiThinkingManager:
|
class MaiThinkingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.mai_think_list = []
|
self.mai_think_list = []
|
||||||
|
|
||||||
def get_mai_think(self,chat_id):
|
def get_mai_think(self, chat_id):
|
||||||
for mai_think in self.mai_think_list:
|
for mai_think in self.mai_think_list:
|
||||||
if mai_think.chat_id == chat_id:
|
if mai_think.chat_id == chat_id:
|
||||||
return mai_think
|
return mai_think
|
||||||
mai_think = MaiThinking(chat_id)
|
mai_think = MaiThinking(chat_id)
|
||||||
self.mai_think_list.append(mai_think)
|
self.mai_think_list.append(mai_think)
|
||||||
return mai_think
|
return mai_think
|
||||||
|
|
||||||
|
|
||||||
mai_thinking_manager = MaiThinkingManager()
|
mai_thinking_manager = MaiThinkingManager()
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
from json_repair import repair_json
|
|
||||||
from src.mais4u.s4u_config import s4u_config
|
from src.mais4u.s4u_config import s4u_config
|
||||||
|
|
||||||
logger = get_logger("action")
|
logger = get_logger("action")
|
||||||
|
|
||||||
HEAD_CODE = {
|
# 使用字典作为默认值,但通过Prompt来注册以便外部重载
|
||||||
|
DEFAULT_HEAD_CODE = {
|
||||||
"看向上方": "(0,0.5,0)",
|
"看向上方": "(0,0.5,0)",
|
||||||
"看向下方": "(0,-0.5,0)",
|
"看向下方": "(0,-0.5,0)",
|
||||||
"看向左边": "(-1,0,0)",
|
"看向左边": "(-1,0,0)",
|
||||||
@@ -24,7 +27,7 @@ HEAD_CODE = {
|
|||||||
"看向正前方": "(0,0,0)",
|
"看向正前方": "(0,0,0)",
|
||||||
}
|
}
|
||||||
|
|
||||||
BODY_CODE = {
|
DEFAULT_BODY_CODE = {
|
||||||
"双手背后向前弯腰": "010_0070",
|
"双手背后向前弯腰": "010_0070",
|
||||||
"歪头双手合十": "010_0100",
|
"歪头双手合十": "010_0100",
|
||||||
"标准文静站立": "010_0101",
|
"标准文静站立": "010_0101",
|
||||||
@@ -32,7 +35,7 @@ BODY_CODE = {
|
|||||||
"帅气的姿势": "010_0190",
|
"帅气的姿势": "010_0190",
|
||||||
"另一个帅气的姿势": "010_0191",
|
"另一个帅气的姿势": "010_0191",
|
||||||
"手掌朝前可爱": "010_0210",
|
"手掌朝前可爱": "010_0210",
|
||||||
"平静,双手后放":"平静,双手后放",
|
"平静,双手后放": "平静,双手后放",
|
||||||
"思考": "思考",
|
"思考": "思考",
|
||||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||||
"一般": "一般",
|
"一般": "一般",
|
||||||
@@ -40,7 +43,44 @@ BODY_CODE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_head_code() -> dict:
|
||||||
|
"""获取头部动作代码字典"""
|
||||||
|
head_code_str = global_prompt_manager.get_prompt("head_code_prompt")
|
||||||
|
if not head_code_str:
|
||||||
|
return DEFAULT_HEAD_CODE
|
||||||
|
try:
|
||||||
|
return json.loads(head_code_str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析head_code_prompt失败,使用默认值: {e}")
|
||||||
|
return DEFAULT_HEAD_CODE
|
||||||
|
|
||||||
|
|
||||||
|
def get_body_code() -> dict:
|
||||||
|
"""获取身体动作代码字典"""
|
||||||
|
body_code_str = global_prompt_manager.get_prompt("body_code_prompt")
|
||||||
|
if not body_code_str:
|
||||||
|
return DEFAULT_BODY_CODE
|
||||||
|
try:
|
||||||
|
return json.loads(body_code_str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析body_code_prompt失败,使用默认值: {e}")
|
||||||
|
return DEFAULT_BODY_CODE
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
|
# 注册头部动作代码
|
||||||
|
Prompt(
|
||||||
|
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
|
||||||
|
"head_code_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册身体动作代码
|
||||||
|
Prompt(
|
||||||
|
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
|
||||||
|
"body_code_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册原有提示模板
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{chat_talking_prompt}
|
{chat_talking_prompt}
|
||||||
@@ -94,20 +134,16 @@ class ChatAction:
|
|||||||
self.body_action_cooldown: dict[str, int] = {}
|
self.body_action_cooldown: dict[str, int] = {}
|
||||||
|
|
||||||
print(s4u_config.models.motion)
|
print(s4u_config.models.motion)
|
||||||
print(global_config.model.emotion)
|
print(model_config.model_task_config.emotion)
|
||||||
|
|
||||||
self.action_model = LLMRequest(
|
|
||||||
model=global_config.model.emotion,
|
|
||||||
temperature=0.7,
|
|
||||||
request_type="motion",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_change_time = 0
|
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||||
|
|
||||||
|
self.last_change_time: float = 0
|
||||||
|
|
||||||
async def send_action_update(self):
|
async def send_action_update(self):
|
||||||
"""发送动作更新到前端"""
|
"""发送动作更新到前端"""
|
||||||
|
|
||||||
body_code = BODY_CODE.get(self.body_action, "")
|
body_code = get_body_code().get(self.body_action, "")
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="body_action",
|
message_type="body_action",
|
||||||
content=body_code,
|
content=body_code,
|
||||||
@@ -115,13 +151,11 @@ class ChatAction:
|
|||||||
storage_message=False,
|
storage_message=False,
|
||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def update_action_by_message(self, message: MessageRecv):
|
async def update_action_by_message(self, message: MessageRecv):
|
||||||
self.regression_count = 0
|
self.regression_count = 0
|
||||||
|
|
||||||
message_time = message.message_info.time
|
message_time: float = message.message_info.time # type: ignore
|
||||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
@@ -147,13 +181,13 @@ class ChatAction:
|
|||||||
|
|
||||||
prompt_personality = global_config.personality.personality_core
|
prompt_personality = global_config.personality.personality_core
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 冷却池处理:过滤掉冷却中的动作
|
# 冷却池处理:过滤掉冷却中的动作
|
||||||
self._update_body_action_cooldown()
|
self._update_body_action_cooldown()
|
||||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
|
||||||
all_actions = "\n".join(available_actions)
|
all_actions = "\n".join(available_actions)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"change_action_prompt",
|
"change_action_prompt",
|
||||||
chat_talking_prompt=chat_talking_prompt,
|
chat_talking_prompt=chat_talking_prompt,
|
||||||
@@ -163,19 +197,18 @@ class ChatAction:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"prompt: {prompt}")
|
logger.info(f"prompt: {prompt}")
|
||||||
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||||
|
prompt=prompt, temperature=0.7
|
||||||
|
)
|
||||||
logger.info(f"response: {response}")
|
logger.info(f"response: {response}")
|
||||||
logger.info(f"reasoning_content: {reasoning_content}")
|
logger.info(f"reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
action_data = json.loads(repair_json(response))
|
if action_data := json.loads(repair_json(response)):
|
||||||
|
|
||||||
if action_data:
|
|
||||||
# 记录原动作,切换后进入冷却
|
# 记录原动作,切换后进入冷却
|
||||||
prev_body_action = self.body_action
|
prev_body_action = self.body_action
|
||||||
new_body_action = action_data.get("body_action", self.body_action)
|
new_body_action = action_data.get("body_action", self.body_action)
|
||||||
if new_body_action != prev_body_action:
|
if new_body_action != prev_body_action and prev_body_action:
|
||||||
if prev_body_action:
|
self.body_action_cooldown[prev_body_action] = 3
|
||||||
self.body_action_cooldown[prev_body_action] = 3
|
|
||||||
self.body_action = new_body_action
|
self.body_action = new_body_action
|
||||||
self.head_action = action_data.get("head_action", self.head_action)
|
self.head_action = action_data.get("head_action", self.head_action)
|
||||||
# 发送动作更新
|
# 发送动作更新
|
||||||
@@ -213,10 +246,9 @@ class ChatAction:
|
|||||||
prompt_personality = global_config.personality.personality_core
|
prompt_personality = global_config.personality.personality_core
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 冷却池处理:过滤掉冷却中的动作
|
# 冷却池处理:过滤掉冷却中的动作
|
||||||
self._update_body_action_cooldown()
|
self._update_body_action_cooldown()
|
||||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
|
||||||
all_actions = "\n".join(available_actions)
|
all_actions = "\n".join(available_actions)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
@@ -228,17 +260,17 @@ class ChatAction:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"prompt: {prompt}")
|
logger.info(f"prompt: {prompt}")
|
||||||
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||||
|
prompt=prompt, temperature=0.7
|
||||||
|
)
|
||||||
logger.info(f"response: {response}")
|
logger.info(f"response: {response}")
|
||||||
logger.info(f"reasoning_content: {reasoning_content}")
|
logger.info(f"reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
action_data = json.loads(repair_json(response))
|
if action_data := json.loads(repair_json(response)):
|
||||||
if action_data:
|
|
||||||
prev_body_action = self.body_action
|
prev_body_action = self.body_action
|
||||||
new_body_action = action_data.get("body_action", self.body_action)
|
new_body_action = action_data.get("body_action", self.body_action)
|
||||||
if new_body_action != prev_body_action:
|
if new_body_action != prev_body_action and prev_body_action:
|
||||||
if prev_body_action:
|
self.body_action_cooldown[prev_body_action] = 6
|
||||||
self.body_action_cooldown[prev_body_action] = 6
|
|
||||||
self.body_action = new_body_action
|
self.body_action = new_body_action
|
||||||
# 发送动作更新
|
# 发送动作更新
|
||||||
await self.send_action_update()
|
await self.send_action_update()
|
||||||
@@ -306,9 +338,6 @@ class ActionManager:
|
|||||||
return new_action_state
|
return new_action_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
action_manager = ActionManager()
|
action_manager = ActionManager()
|
||||||
|
|||||||
@@ -16,10 +16,9 @@ import json
|
|||||||
from .s4u_mood_manager import mood_manager
|
from .s4u_mood_manager import mood_manager
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||||
from src.mais4u.s4u_config import s4u_config
|
from src.mais4u.s4u_config import s4u_config
|
||||||
from src.person_info.person_info import PersonInfoManager
|
from src.person_info.person_info import get_person_id
|
||||||
from .super_chat_manager import get_super_chat_manager
|
from .super_chat_manager import get_super_chat_manager
|
||||||
from .yes_or_no import yes_or_no_head
|
from .yes_or_no import yes_or_no_head
|
||||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
|
||||||
|
|
||||||
logger = get_logger("S4U_chat")
|
logger = get_logger("S4U_chat")
|
||||||
|
|
||||||
@@ -137,7 +136,7 @@ class MessageSenderContainer:
|
|||||||
await self.storage.store_message(bot_message, self.chat_stream)
|
await self.storage.store_message(bot_message, self.chat_stream)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||||
@@ -166,7 +165,7 @@ class S4UChatManager:
|
|||||||
return self.s4u_chats[chat_stream.stream_id]
|
return self.s4u_chats[chat_stream.stream_id]
|
||||||
|
|
||||||
|
|
||||||
if not ENABLE_S4U:
|
if not s4u_config.enable_s4u:
|
||||||
s4u_chat_manager = None
|
s4u_chat_manager = None
|
||||||
else:
|
else:
|
||||||
s4u_chat_manager = S4UChatManager()
|
s4u_chat_manager = S4UChatManager()
|
||||||
@@ -262,7 +261,7 @@ class S4UChat:
|
|||||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id
|
||||||
platform = message.message_info.platform
|
platform = message.message_info.platform
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_gift = message.is_gift
|
is_gift = message.is_gift
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user