Merge branch 'MaiM-with-u:dev' into dev
This commit is contained in:
27
.github/workflows/precheck.yml
vendored
27
.github/workflows/precheck.yml
vendored
@@ -4,21 +4,32 @@ on: [pull_request]
|
||||
|
||||
jobs:
|
||||
conflict-check:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
outputs:
|
||||
conflict: ${{ steps.check-conflicts.outputs.conflict }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Check Conflicts
|
||||
id: check-conflicts
|
||||
run: |
|
||||
git fetch origin main
|
||||
if git diff --name-only --diff-filter=U origin/main...HEAD | grep .; then
|
||||
echo "CONFLICT=true" >> $GITHUB_ENV
|
||||
fi
|
||||
$conflicts = git diff --name-only --diff-filter=U origin/main...HEAD
|
||||
if ($conflicts) {
|
||||
echo "conflict=true" >> $env:GITHUB_OUTPUT
|
||||
Write-Host "Conflicts detected in files: $conflicts"
|
||||
} else {
|
||||
echo "conflict=false" >> $env:GITHUB_OUTPUT
|
||||
Write-Host "No conflicts detected"
|
||||
}
|
||||
shell: pwsh
|
||||
labeler:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
needs: conflict-check
|
||||
if: needs.conflict-check.outputs.conflict == 'true'
|
||||
steps:
|
||||
- uses: actions/github-script@v6
|
||||
if: env.CONFLICT == 'true'
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
|
||||
18
.github/workflows/ruff-pr.yml
vendored
18
.github/workflows/ruff-pr.yml
vendored
@@ -1,9 +1,21 @@
|
||||
name: Ruff
|
||||
name: Ruff PR Check
|
||||
on: [ pull_request ]
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/ruff-action@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install Ruff and Run Checks
|
||||
uses: astral-sh/ruff-action@v3
|
||||
with:
|
||||
args: "--version"
|
||||
version: "latest"
|
||||
- name: Run Ruff Check (No Fix)
|
||||
run: ruff check --output-format=github
|
||||
shell: pwsh
|
||||
- name: Run Ruff Format Check
|
||||
run: ruff format --check --diff
|
||||
shell: pwsh
|
||||
|
||||
|
||||
21
.github/workflows/ruff.yml
vendored
21
.github/workflows/ruff.yml
vendored
@@ -7,13 +7,18 @@ on:
|
||||
- dev
|
||||
- dev-refactor # 例如:匹配所有以 feature/ 开头的分支
|
||||
# 添加你希望触发此 workflow 的其他分支
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- dev-refactor
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
# 关键修改:添加条件判断
|
||||
# 确保只有在 event_name 是 'push' 且不是由 Pull Request 引起的 push 时才运行
|
||||
if: github.event_name == 'push' && !startsWith(github.ref, 'refs/pull/')
|
||||
@@ -29,14 +34,20 @@ jobs:
|
||||
args: "--version"
|
||||
version: "latest"
|
||||
- name: Run Ruff Fix
|
||||
run: ruff check --fix --unsafe-fixes || true
|
||||
run: ruff check --fix --unsafe-fixes; if ($LASTEXITCODE -ne 0) { Write-Host "Ruff check completed with warnings" }
|
||||
shell: pwsh
|
||||
- name: Run Ruff Format
|
||||
run: ruff format || true
|
||||
run: ruff format; if ($LASTEXITCODE -ne 0) { Write-Host "Ruff format completed with warnings" }
|
||||
shell: pwsh
|
||||
- name: 提交更改
|
||||
if: success()
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git add -A
|
||||
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
|
||||
git push
|
||||
$changes = git diff --quiet; $staged = git diff --staged --quiet
|
||||
if (-not ($changes -and $staged)) {
|
||||
git commit -m "🤖 自动格式化代码 [skip ci]"
|
||||
git push
|
||||
}
|
||||
shell: pwsh
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -40,10 +40,13 @@ config/bot_config.toml
|
||||
config/bot_config.toml.bak
|
||||
config/lpmm_config.toml
|
||||
config/lpmm_config.toml.bak
|
||||
template/compare/bot_config_template.toml
|
||||
(测试版)麦麦生成人格.bat
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
s4u.s4u
|
||||
s4u.s4u1
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@@ -316,4 +319,6 @@ run_pet.bat
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
|
||||
config.toml
|
||||
config.toml
|
||||
|
||||
interested_rates.txt
|
||||
50
EULA.md
50
EULA.md
@@ -1,6 +1,6 @@
|
||||
# **MaiBot最终用户许可协议**
|
||||
**版本:V1.0**
|
||||
**更新日期:2025年5月9日**
|
||||
**版本:V1.1**
|
||||
**更新日期:2025年7月10日**
|
||||
**生效日期:2025年3月18日**
|
||||
**适用的MaiBot版本号:所有版本**
|
||||
|
||||
@@ -37,6 +37,22 @@
|
||||
**2.5** 项目团队**不对**第三方API的服务质量、稳定性、准确性、安全性负责,亦**不对**第三方API的服务变更、终止、限制等行为负责。
|
||||
|
||||
|
||||
### 插件系统授权和责任免责
|
||||
|
||||
**2.6** 您**了解**本项目包含插件系统功能,允许加载和使用由第三方开发者(非MaiBot核心开发组成员)开发的插件。这些第三方插件可能具有独立的许可证条款和使用协议。
|
||||
|
||||
**2.7** 您**了解并同意**:
|
||||
- 第三方插件的开发、维护、分发由其各自的开发者负责,**与MaiBot项目团队无关**;
|
||||
- 第三方插件的功能、质量、安全性、合规性**完全由插件开发者负责**;
|
||||
- MaiBot项目团队**仅提供**插件系统的技术框架,**不对**任何第三方插件的内容、行为或后果承担责任;
|
||||
- 您使用任何第三方插件的风险**完全由您自行承担**;
|
||||
|
||||
**2.8** 在使用第三方插件前,您**应当**:
|
||||
- 仔细阅读并遵守插件开发者提供的许可证条款和使用协议;
|
||||
- 自行评估插件的安全性、合规性和适用性;
|
||||
- 确保插件的使用符合您所在地区的法律法规要求;
|
||||
|
||||
|
||||
## 三、用户行为
|
||||
|
||||
**3.1** 您**了解**本项目会将您的配置信息、输入指令和生成内容发送到第三方API,您**不应**在输入指令和生成内容中包含以下内容:
|
||||
@@ -50,6 +66,13 @@
|
||||
|
||||
**3.3** 您**应当**自行确保您被存储在本项目的知识库、记忆库和日志中的输入和输出内容的合法性与合规性以及存储行为的合法性与合规性。您需**自行承担**由此产生的任何法律责任。
|
||||
|
||||
**3.4** 对于第三方插件的使用,您**不应**:
|
||||
- 使用可能存在安全漏洞、恶意代码或违法内容的插件;
|
||||
- 通过插件进行任何违反法律法规的行为;
|
||||
- 将插件用于侵犯他人权益或危害系统安全的用途;
|
||||
|
||||
**3.5** 您**承诺**对使用第三方插件的行为及其后果承担**完全责任**,包括但不限于因插件缺陷、恶意行为或不当使用造成的任何损失或法律纠纷。
|
||||
|
||||
|
||||
|
||||
## 四、免责条款
|
||||
@@ -58,6 +81,12 @@
|
||||
|
||||
**4.2** 除本协议条目2.4提到的隐私政策之外,项目团队**不会**对您提供任何形式的担保,亦**不对**使用本项目的造成的任何后果负责。
|
||||
|
||||
**4.3** 关于第三方插件,项目团队**明确声明**:
|
||||
- 项目团队**不对**任何第三方插件的功能、安全性、稳定性、合规性或适用性提供任何形式的保证或担保;
|
||||
- 项目团队**不对**因使用第三方插件而产生的任何直接或间接损失、数据丢失、系统故障、安全漏洞、法律纠纷或其他后果承担责任;
|
||||
- 第三方插件的质量问题、技术支持、bug修复等事宜应**直接联系插件开发者**,与项目团队无关;
|
||||
- 项目团队**保留**在不另行通知的情况下,对插件系统功能进行修改、限制或移除的权利;
|
||||
|
||||
## 五、其他条款
|
||||
|
||||
**5.1** 项目团队有权**随时修改本协议的条款**,但**没有**义务通知您。修改后的协议将在本项目的新版本中生效,您应定期检查本协议的最新版本。
|
||||
@@ -91,6 +120,23 @@
|
||||
- 如感到心理不适,请及时寻求专业心理咨询服务。
|
||||
- 如遇心理困扰,请寻求专业帮助(全国心理援助热线:12355)。
|
||||
|
||||
**2.3 第三方插件风险**
|
||||
|
||||
本项目的插件系统允许加载第三方开发的插件,这可能带来以下风险:
|
||||
- **安全风险**:第三方插件可能包含恶意代码、安全漏洞或未知的安全威胁;
|
||||
- **稳定性风险**:插件可能导致系统崩溃、性能下降或功能异常;
|
||||
- **隐私风险**:插件可能收集、传输或泄露您的个人信息和数据;
|
||||
- **合规风险**:插件的功能或行为可能违反相关法律法规或平台规则;
|
||||
- **兼容性风险**:插件可能与主程序或其他插件产生冲突;
|
||||
|
||||
**因此,在使用第三方插件时,请务必:**
|
||||
|
||||
- 仅从可信来源获取和安装插件;
|
||||
- 在安装前仔细了解插件的功能、权限和开发者信息;
|
||||
- 定期检查和更新已安装的插件;
|
||||
- 如发现插件异常行为,请立即停止使用并卸载;
|
||||
- 对插件的使用后果承担完全责任;
|
||||
|
||||
### 三、其他
|
||||
**3.1 争议解决**
|
||||
- 本协议适用中国法律,争议提交相关地区法院管辖;
|
||||
|
||||
15
PRIVACY.md
15
PRIVACY.md
@@ -1,6 +1,6 @@
|
||||
### MaiBot用户隐私条款
|
||||
**版本:V1.0**
|
||||
**更新日期:2025年5月9日**
|
||||
**版本:V1.1**
|
||||
**更新日期:2025年7月10日**
|
||||
**生效日期:2025年3月18日**
|
||||
**适用的MaiBot版本号:所有版本**
|
||||
|
||||
@@ -16,6 +16,13 @@ MaiBot项目团队(以下简称项目团队)**尊重并保护**用户(以
|
||||
|
||||
**1.4** 本项目可能**会**收集部分统计信息(如使用频率、基础指令类型)以改进服务,您可在[bot_config.toml]中随时关闭此功能**。
|
||||
|
||||
**1.5** 由于您的自身行为或不可抗力等情形,导致上述可能涉及您隐私或您认为是私人信息的内容发生被泄露、批漏,或被第三方获取、使用、转让等情形的,均由您**自行承担**不利后果,我们对此**不承担**任何责任。
|
||||
**1.5** 关于第三方插件的隐私处理:
|
||||
- 本项目包含插件系统,允许加载第三方开发者开发的插件;
|
||||
- **第三方插件可能会**收集、处理、存储或传输您的数据,这些行为**完全由插件开发者控制**,与项目团队无关;
|
||||
- 项目团队**无法监控或控制**第三方插件的数据处理行为,亦**无法保证**第三方插件的隐私安全性;
|
||||
- 第三方插件的隐私政策**由插件开发者负责制定和执行**,您应直接向插件开发者了解其隐私处理方式;
|
||||
- 您使用第三方插件时,**需自行评估**插件的隐私风险并**自行承担**相关后果;
|
||||
|
||||
**1.6** 项目团队保留在未来更新隐私条款的权利,但没有义务通知您。若您不同意更新后的隐私条款,您应立即停止使用本项目。
|
||||
**1.6** 由于您的自身行为或不可抗力等情形,导致上述可能涉及您隐私或您认为是私人信息的内容发生被泄露、批漏,或被第三方获取、使用、转让等情形的,均由您**自行承担**不利后果,我们对此**不承担**任何责任。**特别地,因使用第三方插件而导致的任何隐私泄露或数据安全问题,项目团队概不负责。**
|
||||
|
||||
**1.7** 项目团队保留在未来更新隐私条款的权利,但没有义务通知您。若您不同意更新后的隐私条款,您应立即停止使用本项目。
|
||||
144
bot.py
144
bot.py
@@ -8,6 +8,7 @@ if os.path.exists(".env"):
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
import sys
|
||||
import time
|
||||
import platform
|
||||
@@ -140,87 +141,88 @@ async def graceful_shutdown():
|
||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _calculate_file_hash(file_path: Path, file_type: str) -> str:
|
||||
"""计算文件的MD5哈希值"""
|
||||
if not file_path.exists():
|
||||
logger.error(f"{file_type} 文件不存在")
|
||||
raise FileNotFoundError(f"{file_type} 文件不存在")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
|
||||
"""检查协议确认状态
|
||||
|
||||
Returns:
|
||||
tuple[bool, bool]: (已确认, 未更新)
|
||||
"""
|
||||
# 检查环境变量确认
|
||||
if file_hash == os.getenv(env_var):
|
||||
return True, False
|
||||
|
||||
# 检查确认文件
|
||||
if confirm_file.exists():
|
||||
with open(confirm_file, "r", encoding="utf-8") as f:
|
||||
confirmed_content = f.read()
|
||||
if file_hash == confirmed_content:
|
||||
return True, False
|
||||
|
||||
return False, True
|
||||
|
||||
|
||||
def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
|
||||
"""提示用户确认协议"""
|
||||
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
||||
confirm_logger.critical(
|
||||
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行'
|
||||
)
|
||||
|
||||
while True:
|
||||
user_input = input().strip().lower()
|
||||
if user_input in ["同意", "confirmed"]:
|
||||
return
|
||||
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
||||
|
||||
|
||||
def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
|
||||
"""保存用户确认结果"""
|
||||
if eula_updated:
|
||||
logger.info(f"更新EULA确认文件{eula_hash}")
|
||||
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
|
||||
|
||||
if privacy_updated:
|
||||
logger.info(f"更新隐私条款确认文件{privacy_hash}")
|
||||
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
|
||||
|
||||
|
||||
def check_eula():
|
||||
eula_confirm_file = Path("eula.confirmed")
|
||||
privacy_confirm_file = Path("privacy.confirmed")
|
||||
eula_file = Path("EULA.md")
|
||||
privacy_file = Path("PRIVACY.md")
|
||||
"""检查EULA和隐私条款确认状态"""
|
||||
# 计算文件哈希值
|
||||
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
|
||||
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
|
||||
|
||||
eula_updated = True
|
||||
privacy_updated = True
|
||||
# 检查确认状态
|
||||
eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
|
||||
privacy_confirmed, privacy_updated = _check_agreement_status(
|
||||
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
|
||||
)
|
||||
|
||||
eula_confirmed = False
|
||||
privacy_confirmed = False
|
||||
# 早期返回:如果都已确认且未更新
|
||||
if eula_confirmed and privacy_confirmed:
|
||||
return
|
||||
|
||||
# 首先计算当前EULA文件的哈希值
|
||||
if eula_file.exists():
|
||||
with open(eula_file, "r", encoding="utf-8") as f:
|
||||
eula_content = f.read()
|
||||
eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
logger.error("EULA.md 文件不存在")
|
||||
raise FileNotFoundError("EULA.md 文件不存在")
|
||||
|
||||
# 首先计算当前隐私条款文件的哈希值
|
||||
if privacy_file.exists():
|
||||
with open(privacy_file, "r", encoding="utf-8") as f:
|
||||
privacy_content = f.read()
|
||||
privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
logger.error("PRIVACY.md 文件不存在")
|
||||
raise FileNotFoundError("PRIVACY.md 文件不存在")
|
||||
|
||||
# 检查EULA确认文件是否存在
|
||||
if eula_confirm_file.exists():
|
||||
with open(eula_confirm_file, "r", encoding="utf-8") as f:
|
||||
confirmed_content = f.read()
|
||||
if eula_new_hash == confirmed_content:
|
||||
eula_confirmed = True
|
||||
eula_updated = False
|
||||
if eula_new_hash == os.getenv("EULA_AGREE"):
|
||||
eula_confirmed = True
|
||||
eula_updated = False
|
||||
|
||||
# 检查隐私条款确认文件是否存在
|
||||
if privacy_confirm_file.exists():
|
||||
with open(privacy_confirm_file, "r", encoding="utf-8") as f:
|
||||
confirmed_content = f.read()
|
||||
if privacy_new_hash == confirmed_content:
|
||||
privacy_confirmed = True
|
||||
privacy_updated = False
|
||||
if privacy_new_hash == os.getenv("PRIVACY_AGREE"):
|
||||
privacy_confirmed = True
|
||||
privacy_updated = False
|
||||
|
||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
||||
# 如果有更新,需要重新确认
|
||||
if eula_updated or privacy_updated:
|
||||
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
||||
confirm_logger.critical(
|
||||
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行'
|
||||
)
|
||||
while True:
|
||||
user_input = input().strip().lower()
|
||||
if user_input in ["同意", "confirmed"]:
|
||||
# print("确认成功,继续运行")
|
||||
# print(f"确认成功,继续运行{eula_updated} {privacy_updated}")
|
||||
if eula_updated:
|
||||
logger.info(f"更新EULA确认文件{eula_new_hash}")
|
||||
eula_confirm_file.write_text(eula_new_hash, encoding="utf-8")
|
||||
if privacy_updated:
|
||||
logger.info(f"更新隐私条款确认文件{privacy_new_hash}")
|
||||
privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8")
|
||||
break
|
||||
else:
|
||||
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
||||
return
|
||||
elif eula_confirmed and privacy_confirmed:
|
||||
return
|
||||
_prompt_user_confirmation(eula_hash, privacy_hash)
|
||||
_save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash)
|
||||
|
||||
|
||||
def raw_main():
|
||||
# 利用 TZ 环境变量设定程序工作的时区
|
||||
if platform.system().lower() != "windows":
|
||||
time.tzset()
|
||||
time.tzset() # type: ignore
|
||||
|
||||
check_eula()
|
||||
logger.info("检查EULA和隐私条款完成")
|
||||
|
||||
26
changes.md
Normal file
26
changes.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# 插件API与规范修改
|
||||
|
||||
1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from plugin_system import *`来导入所有API。
|
||||
|
||||
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。
|
||||
|
||||
3. 现在强制要求的property如下:
|
||||
- `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同)
|
||||
- `enable_plugin`: 是否启用插件,默认为`True`。
|
||||
- `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)**
|
||||
- `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查**
|
||||
- `config_file_name`: 插件配置文件名,默认为`config.toml`。
|
||||
- `config_schema`: 插件配置文件的schema,用于自动生成配置文件。
|
||||
|
||||
# 插件系统修改
|
||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
|
||||
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)**
|
||||
3. 部分API的参数类型和返回值进行了调整
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
4. 现在增加了参数类型检查,完善了对应注释
|
||||
5. 现在插件抽象出了总基类 `PluginBase`
|
||||
- 基于`Action`和`Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。
|
||||
- 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。
|
||||
@@ -27,8 +27,8 @@ services:
|
||||
# image: infinitycat/maibot:dev
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=bda99dca873f5d8044e9987eac417e01 # 同意EULA
|
||||
# - PRIVACY_AGREE=42dddb3cbe2b784b45a2781407b298a1 # 同意EULA
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
# ports:
|
||||
# - "8000:8000"
|
||||
volumes:
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
0.02388322700338219
|
||||
0.02789637960584667
|
||||
6.1002656551513885
|
||||
6.1002656551513885
|
||||
6.1171064375469255
|
||||
6.106626351535966
|
||||
6.112541462320276
|
||||
0.04230527065567247
|
||||
9.04004621778353
|
||||
6.104278807753853
|
||||
6.106626351535966
|
||||
6.198517524266092
|
||||
0.020373848987042205
|
||||
6.106626351535966
|
||||
6.104278807753853
|
||||
0.03203964454588806
|
||||
6.104278807753853
|
||||
6.104278807753853
|
||||
6.104278807753853
|
||||
6.104278807753853
|
||||
6.1002656551513885
|
||||
6.1002656551513885
|
||||
6.1002656551513885
|
||||
0.02605261040985793
|
||||
1.0273445569816615
|
||||
0.02203945780739345
|
||||
0.03203964454588806
|
||||
0.014013152602464482
|
||||
0.03203964454588806
|
||||
1.018026305204929
|
||||
4.183876948487736
|
||||
0.020373848987042205
|
||||
0.19241219083184483
|
||||
6.103223210543543
|
||||
6.1002656551513885
|
||||
6.103223210543543
|
||||
6.103223210543543
|
||||
1.021266343711497
|
||||
6.103223210543543
|
||||
0.018026305204928966
|
||||
0.020373848987042205
|
||||
6.106626351535966
|
||||
6.089034714923968
|
||||
0.03203964454588806
|
||||
6.089034714923968
|
||||
0.027344556981661584
|
||||
6.0950644780757655
|
||||
1.0360527971483526
|
||||
0.02126634371149695
|
||||
6.100437294458919
|
||||
6.181947292804878
|
||||
6.108429840061738
|
||||
6.107935292179331
|
||||
6.099721599895046
|
||||
6.091382258706081
|
||||
6.747791924069589
|
||||
0.016360696384577725
|
||||
0.016360696384577725
|
||||
0.016360696384577725
|
||||
0.014013152602464482
|
||||
0.019318251776732617
|
||||
6.093511295222046
|
||||
0.019318251776732617
|
||||
0.019318251776732617
|
||||
0.019318251776732617
|
||||
6.093511295222046
|
||||
0.019318251776732617
|
||||
7.515984058229312
|
||||
1.6068256002855255
|
||||
6.093940362250887
|
||||
1.6170212888969302
|
||||
6.179882232137178
|
||||
6.179882232137178
|
||||
6.087979117713658
|
||||
6.089034714923968
|
||||
1.200467605219352
|
||||
6.0899272096484225
|
||||
6.091382258706081
|
||||
6.087979117713658
|
||||
6.089034714923968
|
||||
6.091382258706081
|
||||
6.087979117713658
|
||||
6.087979117713658
|
||||
1.7348177649966143
|
||||
6.093940362250887
|
||||
8.65717782684436
|
||||
8.65717782684436
|
||||
0.020373848987042205
|
||||
@@ -10,8 +10,7 @@
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.8.0",
|
||||
"max_version": "0.8.0"
|
||||
"min_version": "0.8.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
|
||||
@@ -103,6 +103,8 @@ class HelloWorldPlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name = "hello_world_plugin" # 内部标识符
|
||||
enable_plugin = True
|
||||
dependencies = [] # 插件依赖列表
|
||||
python_dependencies = [] # Python包依赖列表
|
||||
config_file_name = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
|
||||
@@ -10,8 +10,7 @@
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.8.0",
|
||||
"max_version": "0.8.0"
|
||||
"min_version": "0.9.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
|
||||
@@ -36,11 +36,12 @@ import urllib.error
|
||||
import base64
|
||||
import traceback
|
||||
|
||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system import register_plugin
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("take_picture_plugin")
|
||||
@@ -105,9 +106,9 @@ class TakePictureAction(BaseAction):
|
||||
bot_nickname = self.api.get_global_config("bot.nickname", "麦麦")
|
||||
bot_personality = self.api.get_global_config("personality.personality_core", "")
|
||||
|
||||
personality_sides = self.api.get_global_config("personality.personality_sides", [])
|
||||
if personality_sides:
|
||||
bot_personality += random.choice(personality_sides)
|
||||
personality_side = self.api.get_global_config("personality.personality_side", [])
|
||||
if personality_side:
|
||||
bot_personality += random.choice(personality_side)
|
||||
|
||||
# 准备模板变量
|
||||
template_vars = {"name": bot_nickname, "personality": bot_personality}
|
||||
@@ -441,7 +442,9 @@ class TakePicturePlugin(BasePlugin):
|
||||
"""拍照插件"""
|
||||
|
||||
plugin_name = "take_picture_plugin" # 内部标识符
|
||||
enable_plugin = True
|
||||
enable_plugin = False
|
||||
dependencies = [] # 插件依赖列表
|
||||
python_dependencies = [] # Python包依赖列表
|
||||
config_file_name = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
|
||||
@@ -1,7 +1,58 @@
|
||||
[project]
|
||||
name = "MaiMaiBot"
|
||||
version = "0.1.0"
|
||||
description = "MaiMaiBot"
|
||||
name = "MaiBot"
|
||||
version = "0.8.1"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.14",
|
||||
"apscheduler>=3.11.0",
|
||||
"colorama>=0.4.6",
|
||||
"cryptography>=45.0.5",
|
||||
"customtkinter>=5.2.2",
|
||||
"dotenv>=0.9.9",
|
||||
"faiss-cpu>=1.11.0",
|
||||
"fastapi>=0.116.0",
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"jsonlines>=4.0.0",
|
||||
"maim-message>=0.3.8",
|
||||
"matplotlib>=3.10.3",
|
||||
"networkx>=3.4.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"psutil>=7.0.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pymongo>=4.13.2",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-igraph>=0.11.9",
|
||||
"quick-algo>=0.1.3",
|
||||
"reportportal-client>=5.6.5",
|
||||
"requests>=2.32.4",
|
||||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
"scikit-learn>=1.7.0",
|
||||
"scipy>=1.15.3",
|
||||
"seaborn>=0.13.2",
|
||||
"setuptools>=80.9.0",
|
||||
"strawberry-graphql[fastapi]>=0.275.5",
|
||||
"structlog>=25.4.0",
|
||||
"toml>=0.10.2",
|
||||
"tomli>=2.2.1",
|
||||
"tomli-w>=1.2.0",
|
||||
"tomlkit>=0.13.3",
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
|
||||
271
requirements.lock
Normal file
271
requirements.lock
Normal file
@@ -0,0 +1,271 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.txt -o requirements.lock
|
||||
aenum==3.1.16
|
||||
# via reportportal-client
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.14
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# reportportal-client
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.9.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
apscheduler==3.11.0
|
||||
# via -r requirements.txt
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonlines
|
||||
certifi==2025.7.9
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# reportportal-client
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
# via uvicorn
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# click
|
||||
# tqdm
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
cryptography==45.0.5
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
customtkinter==5.2.2
|
||||
# via -r requirements.txt
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
darkdetect==0.8.0
|
||||
# via customtkinter
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
dnspython==2.7.0
|
||||
# via pymongo
|
||||
dotenv==0.9.9
|
||||
# via -r requirements.txt
|
||||
faiss-cpu==1.11.0
|
||||
# via -r requirements.txt
|
||||
fastapi==0.116.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# strawberry-graphql
|
||||
fonttools==4.58.5
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
graphql-core==3.2.6
|
||||
# via strawberry-graphql
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via openai
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
igraph==0.11.9
|
||||
# via python-igraph
|
||||
jieba==0.42.1
|
||||
# via -r requirements.txt
|
||||
jiter==0.10.0
|
||||
# via openai
|
||||
joblib==1.5.1
|
||||
# via scikit-learn
|
||||
json-repair==0.47.6
|
||||
# via -r requirements.txt
|
||||
jsonlines==4.0.0
|
||||
# via -r requirements.txt
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
maim-message==0.3.8
|
||||
# via -r requirements.txt
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.10.3
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
networkx==3.5
|
||||
# via -r requirements.txt
|
||||
numpy==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# contourpy
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# pandas
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# seaborn
|
||||
openai==1.95.0
|
||||
# via -r requirements.txt
|
||||
packaging==25.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# customtkinter
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# strawberry-graphql
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
peewee==3.18.2
|
||||
# via -r requirements.txt
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==7.0.0
|
||||
# via -r requirements.txt
|
||||
pyarrow==20.0.0
|
||||
# via -r requirements.txt
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# fastapi
|
||||
# maim-message
|
||||
# openai
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.19.2
|
||||
# via rich
|
||||
pymongo==4.13.2
|
||||
# via -r requirements.txt
|
||||
pyparsing==3.2.3
|
||||
# via matplotlib
|
||||
pypinyin==0.54.0
|
||||
# via -r requirements.txt
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
# pandas
|
||||
# strawberry-graphql
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# dotenv
|
||||
python-igraph==0.11.9
|
||||
# via -r requirements.txt
|
||||
python-multipart==0.0.20
|
||||
# via strawberry-graphql
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
quick-algo==0.1.3
|
||||
# via -r requirements.txt
|
||||
reportportal-client==5.6.5
|
||||
# via -r requirements.txt
|
||||
requests==2.32.4
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# reportportal-client
|
||||
rich==14.0.0
|
||||
# via -r requirements.txt
|
||||
ruff==0.12.2
|
||||
# via -r requirements.txt
|
||||
scikit-learn==1.7.0
|
||||
# via -r requirements.txt
|
||||
scipy==1.16.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# scikit-learn
|
||||
seaborn==0.13.2
|
||||
# via -r requirements.txt
|
||||
setuptools==80.9.0
|
||||
# via -r requirements.txt
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# openai
|
||||
starlette==0.46.2
|
||||
# via fastapi
|
||||
strawberry-graphql==0.275.5
|
||||
# via -r requirements.txt
|
||||
structlog==25.4.0
|
||||
# via -r requirements.txt
|
||||
texttable==1.7.0
|
||||
# via igraph
|
||||
threadpoolctl==3.6.0
|
||||
# via scikit-learn
|
||||
toml==0.10.2
|
||||
# via -r requirements.txt
|
||||
tomli==2.2.1
|
||||
# via -r requirements.txt
|
||||
tomli-w==1.2.0
|
||||
# via -r requirements.txt
|
||||
tomlkit==0.13.3
|
||||
# via -r requirements.txt
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# openai
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# fastapi
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# strawberry-graphql
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# pandas
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via apscheduler
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# requests
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
@@ -1,6 +1,7 @@
|
||||
APScheduler
|
||||
Pillow
|
||||
aiohttp
|
||||
aiohttp-cors
|
||||
colorama
|
||||
customtkinter
|
||||
dotenv
|
||||
|
||||
@@ -9,22 +9,60 @@ import os
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
||||
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
|
||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||
def scan_provider(env_config: dict):
|
||||
provider = {}
|
||||
|
||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||
# 避免 GPG_KEY 这样的变量干扰检查
|
||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||
|
||||
# 遍历 env_config 的所有键
|
||||
for key in env_config:
|
||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||
# 提取 provider 名称
|
||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||
|
||||
# 初始化 provider 的字典(如果尚未初始化)
|
||||
if provider_name not in provider:
|
||||
provider[provider_name] = {"url": None, "key": None}
|
||||
|
||||
# 根据键的类型填充 url 或 key
|
||||
if key.endswith("_BASE_URL"):
|
||||
provider[provider_name]["url"] = env_config[key]
|
||||
elif key.endswith("_KEY"):
|
||||
provider[provider_name]["key"] = env_config[key]
|
||||
|
||||
# 检查每个 provider 是否同时存在 url 和 key
|
||||
for provider_name, config in provider.items():
|
||||
if config["url"] is None or config["key"] is None:
|
||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
@@ -58,10 +96,12 @@ def hash_deduplicate(
|
||||
# 保存去重后的三元组
|
||||
new_triple_list_data = {}
|
||||
|
||||
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
|
||||
for _, (raw_paragraph, triple_list) in enumerate(
|
||||
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
||||
):
|
||||
# 段落hash
|
||||
paragraph_hash = get_sha256(raw_paragraph)
|
||||
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||
continue
|
||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||
new_triple_list_data[paragraph_hash] = triple_list
|
||||
@@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||
|
||||
def main(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
env_config = {key: os.getenv(key) for key in os.environ}
|
||||
scan_provider(env_config)
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
@@ -191,15 +233,9 @@ def main(): # sourcery skip: dict-comprehension
|
||||
logger.info("----开始导入openie数据----\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = {}
|
||||
for key in global_config["llm_providers"]:
|
||||
llm_client_list[key] = LLMClient(
|
||||
global_config["llm_providers"][key]["base_url"],
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
)
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
@@ -228,7 +264,7 @@ def main(): # sourcery skip: dict-comprehension
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
||||
key = f"{local_storage['pg_namespace']}-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import glob
|
||||
import datetime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
@@ -13,11 +12,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.raw_processing import load_raw_data
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
@@ -27,24 +24,57 @@ from rich.progress import (
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
||||
ROOT_PATH, "data", "imported_lpmm_data"
|
||||
)
|
||||
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||
def scan_provider(env_config: dict):
|
||||
provider = {}
|
||||
|
||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||
# 避免 GPG_KEY 这样的变量干扰检查
|
||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||
|
||||
# 遍历 env_config 的所有键
|
||||
for key in env_config:
|
||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||
# 提取 provider 名称
|
||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||
|
||||
# 初始化 provider 的字典(如果尚未初始化)
|
||||
if provider_name not in provider:
|
||||
provider[provider_name] = {"url": None, "key": None}
|
||||
|
||||
# 根据键的类型填充 url 或 key
|
||||
if key.endswith("_BASE_URL"):
|
||||
provider[provider_name]["url"] = env_config[key]
|
||||
elif key.endswith("_KEY"):
|
||||
provider[provider_name]["key"] = env_config[key]
|
||||
|
||||
# 检查每个 provider 是否同时存在 url 和 key
|
||||
for provider_name, config in provider.items():
|
||||
if config["url"] is None or config["key"] is None:
|
||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
@@ -54,12 +84,26 @@ def ensure_dirs():
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(IMPORTED_DATA_PATH):
|
||||
os.makedirs(IMPORTED_DATA_PATH)
|
||||
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
|
||||
def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model=global_config.model.lpmm_entity_extract,
|
||||
request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(
|
||||
model=global_config.model.lpmm_rdf_build,
|
||||
request_type="lpmm.rdf_build"
|
||||
)
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
@@ -77,8 +121,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
|
||||
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
@@ -113,7 +157,9 @@ def signal_handler(_signum, _frame):
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
ensure_dirs() # 确保目录存在
|
||||
env_config = {key: os.getenv(key) for key in os.environ}
|
||||
scan_provider(env_config)
|
||||
# 新增用户确认提示
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
@@ -130,51 +176,18 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = {
|
||||
key: LLMClient(
|
||||
global_config["llm_providers"][key]["base_url"],
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
)
|
||||
for key in global_config["llm_providers"]
|
||||
}
|
||||
# 检查 openie 输出目录
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
|
||||
# 确保 TEMP_DIR 目录存在
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
||||
|
||||
# 遍历IMPORTED_DATA_PATH下所有json文件
|
||||
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
||||
if not imported_files:
|
||||
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
||||
sys.exit(1)
|
||||
|
||||
all_sha256_list = []
|
||||
all_raw_datas = []
|
||||
|
||||
for imported_file in imported_files:
|
||||
logger.info(f"正在处理文件: {imported_file}")
|
||||
try:
|
||||
sha256_list, raw_datas = load_raw_data(imported_file)
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
||||
continue
|
||||
all_sha256_list.extend(sha256_list)
|
||||
all_raw_datas.extend(raw_datas)
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config["info_extraction"]["workers"]
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
|
||||
@@ -354,7 +354,7 @@ class VirtualLogDisplay:
|
||||
|
||||
# 为每个部分应用正确的标签
|
||||
current_len = 0
|
||||
for part, tag_name in zip(parts, tags):
|
||||
for part, tag_name in zip(parts, tags, strict=False):
|
||||
start_index = f"{start_pos}+{current_len}c"
|
||||
end_index = f"{start_pos}+{current_len + len(part)}c"
|
||||
self.text_widget.tag_add(tag_name, start_index, end_index)
|
||||
|
||||
@@ -205,7 +205,6 @@ class MongoToSQLiteMigrator:
|
||||
"user_info.user_nickname": "user_nickname",
|
||||
"user_info.user_cardname": "user_cardname",
|
||||
"processed_plain_text": "processed_plain_text",
|
||||
"detailed_plain_text": "detailed_plain_text",
|
||||
"memorized_times": "memorized_times",
|
||||
},
|
||||
enable_validation=False, # 禁用数据验证
|
||||
|
||||
@@ -1,40 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
import datetime # 新增导入
|
||||
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# 新增:确保 RAW_DATA_PATH 存在
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||
logger.info(f"已创建目录: {RAW_DATA_PATH}")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
if global_config.get("persistence", {}).get("raw_data_path") is not None:
|
||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
|
||||
else:
|
||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
|
||||
def check_and_create_dirs():
|
||||
"""检查并创建必要的目录"""
|
||||
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
||||
|
||||
for dir_path in required_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
logger.info(f"已创建目录: {dir_path}")
|
||||
|
||||
|
||||
def process_text_file(file_path):
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
@@ -55,54 +31,45 @@ def process_text_file(file_path):
|
||||
return paragraphs
|
||||
|
||||
|
||||
def main():
|
||||
# 新增用户确认提示
|
||||
print("=== 数据预处理脚本 ===")
|
||||
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
|
||||
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
|
||||
print("请确保原始数据已放置在正确的目录中。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
|
||||
# 检查并创建必要的目录
|
||||
check_and_create_dirs()
|
||||
|
||||
# # 检查输出文件是否存在
|
||||
# if os.path.exists(RAW_DATA_PATH):
|
||||
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||
# sys.exit(1)
|
||||
|
||||
# if os.path.exists(RAW_DATA_PATH):
|
||||
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||
# sys.exit(1)
|
||||
|
||||
# 获取所有原始文本文件
|
||||
def _process_multi_files() -> list:
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
|
||||
# 处理所有文件
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
paragraphs = process_text_file(file)
|
||||
paragraphs = _process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
||||
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
logger.info(f"处理完成,结果已保存到: {output_path}")
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
|
||||
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
|
||||
main()
|
||||
Returns:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_data = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
for item in raw_data:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
pg_hash = get_sha256(item)
|
||||
if pg_hash in sha256_set:
|
||||
logger.warning(f"重复数据:{item}")
|
||||
continue
|
||||
sha256_set.add(pg_hash)
|
||||
sha256_list.append(pg_hash)
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
@@ -1,26 +0,0 @@
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
|
||||
async def get_all_subheartflow_ids() -> list:
|
||||
"""获取所有子心流的ID列表"""
|
||||
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
||||
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
|
||||
|
||||
|
||||
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
|
||||
"""强制改变子心流的状态"""
|
||||
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
|
||||
if subheartflow:
|
||||
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
||||
return False
|
||||
|
||||
|
||||
async def get_all_states():
|
||||
"""获取所有状态"""
|
||||
all_states = await heartflow.api_get_all_states()
|
||||
logger.debug(f"所有状态: {all_states}")
|
||||
return all_states
|
||||
@@ -1,169 +0,0 @@
|
||||
import platform
|
||||
import psutil
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def get_system_info():
|
||||
"""获取操作系统信息"""
|
||||
return {
|
||||
"system": platform.system(),
|
||||
"release": platform.release(),
|
||||
"version": platform.version(),
|
||||
"machine": platform.machine(),
|
||||
"processor": platform.processor(),
|
||||
}
|
||||
|
||||
|
||||
def get_python_version():
|
||||
"""获取 Python 版本信息"""
|
||||
return sys.version
|
||||
|
||||
|
||||
def get_cpu_usage():
|
||||
"""获取系统总CPU使用率"""
|
||||
return psutil.cpu_percent(interval=1)
|
||||
|
||||
|
||||
def get_process_cpu_usage():
|
||||
"""获取当前进程CPU使用率"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.cpu_percent(interval=1)
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
"""获取系统内存使用情况 (单位 MB)"""
|
||||
mem = psutil.virtual_memory()
|
||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
||||
return {
|
||||
"total_mb": bytes_to_mb(mem.total),
|
||||
"available_mb": bytes_to_mb(mem.available),
|
||||
"percent": mem.percent,
|
||||
"used_mb": bytes_to_mb(mem.used),
|
||||
"free_mb": bytes_to_mb(mem.free),
|
||||
}
|
||||
|
||||
|
||||
def get_process_memory_usage():
|
||||
"""获取当前进程内存使用情况 (单位 MB)"""
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
||||
return {
|
||||
"rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存
|
||||
"vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小
|
||||
"percent": process.memory_percent(), # 进程内存使用百分比
|
||||
}
|
||||
|
||||
|
||||
def get_disk_usage(path="/"):
|
||||
"""获取指定路径磁盘使用情况 (单位 GB)"""
|
||||
disk = psutil.disk_usage(path)
|
||||
bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa
|
||||
return {
|
||||
"total_gb": bytes_to_gb(disk.total),
|
||||
"used_gb": bytes_to_gb(disk.used),
|
||||
"free_gb": bytes_to_gb(disk.free),
|
||||
"percent": disk.percent,
|
||||
}
|
||||
|
||||
|
||||
def get_all_basic_info():
|
||||
"""获取所有基本信息并封装返回"""
|
||||
# 对于进程CPU使用率,需要先初始化
|
||||
process = psutil.Process(os.getpid())
|
||||
process.cpu_percent(interval=None) # 初始化调用
|
||||
process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取
|
||||
|
||||
return {
|
||||
"system_info": get_system_info(),
|
||||
"python_version": get_python_version(),
|
||||
"cpu_usage_percent": get_cpu_usage(),
|
||||
"process_cpu_usage_percent": process_cpu,
|
||||
"memory_usage": get_memory_usage(),
|
||||
"process_memory_usage": get_process_memory_usage(),
|
||||
"disk_usage_root": get_disk_usage("/"),
|
||||
}
|
||||
|
||||
|
||||
def get_all_basic_info_string() -> str:
|
||||
"""获取所有基本信息并以带解释的字符串形式返回"""
|
||||
info = get_all_basic_info()
|
||||
|
||||
sys_info = info["system_info"]
|
||||
mem_usage = info["memory_usage"]
|
||||
proc_mem_usage = info["process_memory_usage"]
|
||||
disk_usage = info["disk_usage_root"]
|
||||
|
||||
# 对进程内存使用百分比进行格式化,保留两位小数
|
||||
proc_mem_percent = round(proc_mem_usage["percent"], 2)
|
||||
|
||||
output_string = f"""[系统信息]
|
||||
- 操作系统: {sys_info["system"]} (例如: Windows, Linux)
|
||||
- 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04)
|
||||
- 详细版本: {sys_info["version"]}
|
||||
- 硬件架构: {sys_info["machine"]} (例如: AMD64)
|
||||
- 处理器信息: {sys_info["processor"]}
|
||||
|
||||
[Python 环境]
|
||||
- Python 版本: {info["python_version"]}
|
||||
|
||||
[CPU 状态]
|
||||
- 系统总 CPU 使用率: {info["cpu_usage_percent"]}%
|
||||
- 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}%
|
||||
|
||||
[系统内存使用情况]
|
||||
- 总物理内存: {mem_usage["total_mb"]} MB
|
||||
- 可用物理内存: {mem_usage["available_mb"]} MB
|
||||
- 物理内存使用率: {mem_usage["percent"]}%
|
||||
- 已用物理内存: {mem_usage["used_mb"]} MB
|
||||
- 空闲物理内存: {mem_usage["free_mb"]} MB
|
||||
|
||||
[当前进程内存使用情况]
|
||||
- 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB
|
||||
- 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB
|
||||
- 进程内存使用率: {proc_mem_percent}%
|
||||
|
||||
[磁盘使用情况 (根目录)]
|
||||
- 总空间: {disk_usage["total_gb"]} GB
|
||||
- 已用空间: {disk_usage["used_gb"]} GB
|
||||
- 可用空间: {disk_usage["free_gb"]} GB
|
||||
- 磁盘使用率: {disk_usage["percent"]}%
|
||||
"""
|
||||
return output_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"System Info: {get_system_info()}")
|
||||
print(f"Python Version: {get_python_version()}")
|
||||
print(f"CPU Usage: {get_cpu_usage()}%")
|
||||
# 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用
|
||||
# 或者在初始化Process对象后,先调用一次cpu_percent(interval=None),然后再调用cpu_percent(interval=1)
|
||||
current_process = psutil.Process(os.getpid())
|
||||
current_process.cpu_percent(interval=None) # 初始化
|
||||
print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取
|
||||
|
||||
memory_usage_info = get_memory_usage()
|
||||
print(
|
||||
f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%"
|
||||
)
|
||||
|
||||
process_memory_info = get_process_memory_usage()
|
||||
print(
|
||||
f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%"
|
||||
)
|
||||
|
||||
disk_usage_info = get_disk_usage("/")
|
||||
print(
|
||||
f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%"
|
||||
)
|
||||
|
||||
print("\n--- All Basic Info (JSON) ---")
|
||||
all_info = get_all_basic_info()
|
||||
import json
|
||||
|
||||
print(json.dumps(all_info, indent=4, ensure_ascii=False))
|
||||
|
||||
print("\n--- All Basic Info (String with Explanations) ---")
|
||||
info_string = get_all_basic_info_string()
|
||||
print(info_string)
|
||||
@@ -1,317 +0,0 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
import strawberry
|
||||
|
||||
# from packaging.version import Version
|
||||
import os
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class APIBotConfig:
|
||||
"""机器人配置类"""
|
||||
|
||||
INNER_VERSION: str # 配置文件内部版本号(toml为字符串)
|
||||
MAI_VERSION: str # 硬编码的版本信息
|
||||
|
||||
# bot
|
||||
BOT_QQ: Optional[int] # 机器人QQ号
|
||||
BOT_NICKNAME: Optional[str] # 机器人昵称
|
||||
BOT_ALIAS_NAMES: List[str] # 机器人别名列表
|
||||
|
||||
# group
|
||||
talk_allowed_groups: List[int] # 允许回复消息的群号列表
|
||||
talk_frequency_down_groups: List[int] # 降低回复频率的群号列表
|
||||
ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表
|
||||
|
||||
# personality
|
||||
personality_core: str # 人格核心特点描述
|
||||
personality_sides: List[str] # 人格细节描述列表
|
||||
|
||||
# identity
|
||||
identity_detail: List[str] # 身份特点列表
|
||||
age: int # 年龄(岁)
|
||||
gender: str # 性别
|
||||
appearance: str # 外貌特征描述
|
||||
|
||||
# platforms
|
||||
platforms: Dict[str, str] # 平台信息
|
||||
|
||||
# chat
|
||||
allow_focus_mode: bool # 是否允许专注聊天状态
|
||||
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
|
||||
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
|
||||
observation_context_size: int # 观察到的最长上下文大小
|
||||
message_buffer: bool # 是否启用消息缓冲
|
||||
ban_words: List[str] # 禁止词列表
|
||||
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
|
||||
|
||||
# normal_chat
|
||||
model_reasoning_probability: float # 推理模型概率
|
||||
model_normal_probability: float # 普通模型概率
|
||||
emoji_chance: float # 表情符号出现概率
|
||||
thinking_timeout: int # 思考超时时间
|
||||
willing_mode: str # 意愿模式
|
||||
response_interested_rate_amplifier: float # 回复兴趣率放大器
|
||||
emoji_response_penalty: float # 表情回复惩罚
|
||||
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
|
||||
at_bot_inevitable_reply: bool # @bot 必然回复
|
||||
|
||||
# focus_chat
|
||||
reply_trigger_threshold: float # 回复触发阈值
|
||||
default_decay_rate_per_second: float # 默认每秒衰减率
|
||||
|
||||
# compressed
|
||||
compressed_length: int # 压缩长度
|
||||
compress_length_limit: int # 压缩长度限制
|
||||
|
||||
# emoji
|
||||
max_emoji_num: int # 最大表情符号数量
|
||||
max_reach_deletion: bool # 达到最大数量时是否删除
|
||||
check_interval: int # 检查表情包的时间间隔(分钟)
|
||||
save_emoji: bool # 是否保存表情包
|
||||
steal_emoji: bool # 是否偷取表情包
|
||||
enable_check: bool # 是否启用表情包过滤
|
||||
check_prompt: str # 表情包过滤要求
|
||||
|
||||
# memory
|
||||
build_memory_interval: int # 记忆构建间隔
|
||||
build_memory_distribution: List[float] # 记忆构建分布
|
||||
build_memory_sample_num: int # 采样数量
|
||||
build_memory_sample_length: int # 采样长度
|
||||
memory_compress_rate: float # 记忆压缩率
|
||||
forget_memory_interval: int # 记忆遗忘间隔
|
||||
memory_forget_time: int # 记忆遗忘时间(小时)
|
||||
memory_forget_percentage: float # 记忆遗忘比例
|
||||
consolidate_memory_interval: int # 记忆整合间隔
|
||||
consolidation_similarity_threshold: float # 相似度阈值
|
||||
consolidation_check_percentage: float # 检查节点比例
|
||||
memory_ban_words: List[str] # 记忆禁止词列表
|
||||
|
||||
# mood
|
||||
mood_update_interval: float # 情绪更新间隔
|
||||
mood_decay_rate: float # 情绪衰减率
|
||||
mood_intensity_factor: float # 情绪强度因子
|
||||
|
||||
# keywords_reaction
|
||||
keywords_reaction_enable: bool # 是否启用关键词反应
|
||||
keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则
|
||||
|
||||
# chinese_typo
|
||||
chinese_typo_enable: bool # 是否启用中文错别字
|
||||
chinese_typo_error_rate: float # 中文错别字错误率
|
||||
chinese_typo_min_freq: int # 中文错别字最小频率
|
||||
chinese_typo_tone_error_rate: float # 中文错别字声调错误率
|
||||
chinese_typo_word_replace_rate: float # 中文错别字单词替换率
|
||||
|
||||
# response_splitter
|
||||
enable_response_splitter: bool # 是否启用回复分割器
|
||||
response_max_length: int # 回复最大长度
|
||||
response_max_sentence_num: int # 回复最大句子数
|
||||
enable_kaomoji_protection: bool # 是否启用颜文字保护
|
||||
|
||||
model_max_output_length: int # 模型最大输出长度
|
||||
|
||||
# remote
|
||||
remote_enable: bool # 是否启用远程功能
|
||||
|
||||
# experimental
|
||||
enable_friend_chat: bool # 是否启用好友聊天
|
||||
talk_allowed_private: List[int] # 允许私聊的QQ号列表
|
||||
pfc_chatting: bool # 是否启用PFC聊天
|
||||
|
||||
# 模型配置
|
||||
llm_reasoning: Dict[str, Any] # 推理模型配置
|
||||
llm_normal: Dict[str, Any] # 普通模型配置
|
||||
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
|
||||
summary: Dict[str, Any] # 总结模型配置
|
||||
vlm: Dict[str, Any] # VLM模型配置
|
||||
llm_heartflow: Dict[str, Any] # 心流模型配置
|
||||
llm_observation: Dict[str, Any] # 观察模型配置
|
||||
llm_sub_heartflow: Dict[str, Any] # 子心流模型配置
|
||||
llm_plan: Optional[Dict[str, Any]] # 计划模型配置
|
||||
embedding: Dict[str, Any] # 嵌入模型配置
|
||||
llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置
|
||||
llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置
|
||||
llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置
|
||||
llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置
|
||||
|
||||
api_urls: Optional[Dict[str, str]] # API地址配置
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: dict):
|
||||
"""
|
||||
校验传入的 toml 配置字典是否合法。
|
||||
:param config: toml库load后的配置字典
|
||||
:raises: ValueError, KeyError, TypeError
|
||||
"""
|
||||
# 检查主层级
|
||||
required_sections = [
|
||||
"inner",
|
||||
"bot",
|
||||
"groups",
|
||||
"personality",
|
||||
"identity",
|
||||
"platforms",
|
||||
"chat",
|
||||
"normal_chat",
|
||||
"focus_chat",
|
||||
"emoji",
|
||||
"memory",
|
||||
"mood",
|
||||
"keywords_reaction",
|
||||
"chinese_typo",
|
||||
"response_splitter",
|
||||
"remote",
|
||||
"experimental",
|
||||
"model",
|
||||
]
|
||||
for section in required_sections:
|
||||
if section not in config:
|
||||
raise KeyError(f"缺少配置段: [{section}]")
|
||||
|
||||
# 检查部分关键字段
|
||||
if "version" not in config["inner"]:
|
||||
raise KeyError("缺少 inner.version 字段")
|
||||
if not isinstance(config["inner"]["version"], str):
|
||||
raise TypeError("inner.version 必须为字符串")
|
||||
|
||||
if "qq" not in config["bot"]:
|
||||
raise KeyError("缺少 bot.qq 字段")
|
||||
if not isinstance(config["bot"]["qq"], int):
|
||||
raise TypeError("bot.qq 必须为整数")
|
||||
|
||||
if "personality_core" not in config["personality"]:
|
||||
raise KeyError("缺少 personality.personality_core 字段")
|
||||
if not isinstance(config["personality"]["personality_core"], str):
|
||||
raise TypeError("personality.personality_core 必须为字符串")
|
||||
|
||||
if "identity_detail" not in config["identity"]:
|
||||
raise KeyError("缺少 identity.identity_detail 字段")
|
||||
if not isinstance(config["identity"]["identity_detail"], list):
|
||||
raise TypeError("identity.identity_detail 必须为列表")
|
||||
|
||||
# 可继续添加更多字段的类型和值检查
|
||||
# ...
|
||||
|
||||
# 检查模型配置
|
||||
model_keys = [
|
||||
"llm_reasoning",
|
||||
"llm_normal",
|
||||
"llm_topic_judge",
|
||||
"summary",
|
||||
"vlm",
|
||||
"llm_heartflow",
|
||||
"llm_observation",
|
||||
"llm_sub_heartflow",
|
||||
"embedding",
|
||||
]
|
||||
if "model" not in config:
|
||||
raise KeyError("缺少 [model] 配置段")
|
||||
for key in model_keys:
|
||||
if key not in config["model"]:
|
||||
raise KeyError(f"缺少 model.{key} 配置")
|
||||
|
||||
# 检查通过
|
||||
return True
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class APIEnvConfig:
|
||||
"""环境变量配置"""
|
||||
|
||||
HOST: str # 服务主机地址
|
||||
PORT: int # 服务端口
|
||||
|
||||
PLUGINS: List[str] # 插件列表
|
||||
|
||||
MONGODB_HOST: str # MongoDB 主机地址
|
||||
MONGODB_PORT: int # MongoDB 端口
|
||||
DATABASE_NAME: str # 数据库名称
|
||||
|
||||
CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL
|
||||
SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL
|
||||
DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL
|
||||
|
||||
DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key
|
||||
CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key
|
||||
SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key
|
||||
|
||||
SIMPLE_OUTPUT: Optional[bool] # 是否简化输出
|
||||
CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级
|
||||
FILE_LOG_LEVEL: Optional[str] # 文件日志等级
|
||||
DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级
|
||||
DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级
|
||||
|
||||
@strawberry.field
|
||||
def get_env(self) -> str:
|
||||
return "env"
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: dict):
|
||||
"""
|
||||
校验环境变量配置字典是否合法。
|
||||
:param config: 环境变量配置字典
|
||||
:raises: KeyError, TypeError
|
||||
"""
|
||||
required_fields = [
|
||||
"HOST",
|
||||
"PORT",
|
||||
"PLUGINS",
|
||||
"MONGODB_HOST",
|
||||
"MONGODB_PORT",
|
||||
"DATABASE_NAME",
|
||||
"CHAT_ANY_WHERE_BASE_URL",
|
||||
"SILICONFLOW_BASE_URL",
|
||||
"DEEP_SEEK_BASE_URL",
|
||||
]
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise KeyError(f"缺少环境变量配置字段: {field}")
|
||||
|
||||
if not isinstance(config["HOST"], str):
|
||||
raise TypeError("HOST 必须为字符串")
|
||||
if not isinstance(config["PORT"], int):
|
||||
raise TypeError("PORT 必须为整数")
|
||||
if not isinstance(config["PLUGINS"], list):
|
||||
raise TypeError("PLUGINS 必须为列表")
|
||||
if not isinstance(config["MONGODB_HOST"], str):
|
||||
raise TypeError("MONGODB_HOST 必须为字符串")
|
||||
if not isinstance(config["MONGODB_PORT"], int):
|
||||
raise TypeError("MONGODB_PORT 必须为整数")
|
||||
if not isinstance(config["DATABASE_NAME"], str):
|
||||
raise TypeError("DATABASE_NAME 必须为字符串")
|
||||
if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str):
|
||||
raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串")
|
||||
if not isinstance(config["SILICONFLOW_BASE_URL"], str):
|
||||
raise TypeError("SILICONFLOW_BASE_URL 必须为字符串")
|
||||
if not isinstance(config["DEEP_SEEK_BASE_URL"], str):
|
||||
raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串")
|
||||
|
||||
# 可选字段类型检查
|
||||
optional_str_fields = [
|
||||
"DEEP_SEEK_KEY",
|
||||
"CHAT_ANY_WHERE_KEY",
|
||||
"SILICONFLOW_KEY",
|
||||
"CONSOLE_LOG_LEVEL",
|
||||
"FILE_LOG_LEVEL",
|
||||
"DEFAULT_CONSOLE_LOG_LEVEL",
|
||||
"DEFAULT_FILE_LOG_LEVEL",
|
||||
]
|
||||
for field in optional_str_fields:
|
||||
if field in config and config[field] is not None and not isinstance(config[field], str):
|
||||
raise TypeError(f"{field} 必须为字符串或None")
|
||||
|
||||
if (
|
||||
"SIMPLE_OUTPUT" in config
|
||||
and config["SIMPLE_OUTPUT"] is not None
|
||||
and not isinstance(config["SIMPLE_OUTPUT"], bool)
|
||||
):
|
||||
raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None")
|
||||
|
||||
# 检查通过
|
||||
return True
|
||||
|
||||
|
||||
print("当前路径:")
|
||||
print(ROOT_PATH)
|
||||
@@ -1,22 +0,0 @@
|
||||
import strawberry
|
||||
|
||||
from fastapi import FastAPI
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
from src.common.server import get_global_server
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Query:
|
||||
@strawberry.field
|
||||
def hello(self) -> str:
|
||||
return "Hello World"
|
||||
|
||||
|
||||
schema = strawberry.Schema(Query)
|
||||
|
||||
graphql_app = GraphQLRouter(schema)
|
||||
|
||||
fast_api_app: FastAPI = get_global_server().get_app()
|
||||
|
||||
fast_api_app.include_router(graphql_app, prefix="/graphql")
|
||||
@@ -1 +0,0 @@
|
||||
pass
|
||||
112
src/api/main.py
112
src/api/main.py
@@ -1,112 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
import os
|
||||
import sys
|
||||
|
||||
# from src.chat.heart_flow.heartflow import heartflow
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
# from src.config.config import BotConfig
|
||||
from src.common.logger import get_logger
|
||||
from src.api.reload_config import reload_config as reload_config_func
|
||||
from src.common.server import get_global_server
|
||||
from src.api.apiforgui import (
|
||||
get_all_subheartflow_ids,
|
||||
forced_change_subheartflow_status,
|
||||
get_subheartflow_cycle_info,
|
||||
get_all_states,
|
||||
)
|
||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
||||
from src.api.basic_info_api import get_all_basic_info # 新增导入
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
logger.info("麦麦API服务器已启动")
|
||||
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
||||
|
||||
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
||||
|
||||
|
||||
@router.post("/config/reload")
|
||||
async def reload_config():
|
||||
return await reload_config_func()
|
||||
|
||||
|
||||
@router.get("/gui/subheartflow/get/all")
|
||||
async def get_subheartflow_ids():
|
||||
"""获取所有子心流的ID列表"""
|
||||
return await get_all_subheartflow_ids()
|
||||
|
||||
|
||||
@router.post("/gui/subheartflow/forced_change_status")
|
||||
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa
|
||||
"""强制改变子心流的状态"""
|
||||
# 参数检查
|
||||
if not isinstance(status, ChatState):
|
||||
logger.warning(f"无效的状态参数: {status}")
|
||||
return {"status": "failed", "reason": "invalid status"}
|
||||
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
|
||||
success = await forced_change_subheartflow_status(subheartflow_id, status)
|
||||
if success:
|
||||
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
|
||||
return {"status": "success"}
|
||||
else:
|
||||
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
@router.get("/stop")
|
||||
async def force_stop_maibot():
|
||||
"""强制停止MAI Bot"""
|
||||
from bot import request_shutdown
|
||||
|
||||
success = await request_shutdown()
|
||||
if success:
|
||||
logger.info("MAI Bot已强制停止")
|
||||
return {"status": "success"}
|
||||
else:
|
||||
logger.error("MAI Bot强制停止失败")
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
@router.get("/gui/subheartflow/cycleinfo")
|
||||
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
|
||||
"""获取子心流的循环信息"""
|
||||
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
|
||||
if cycle_info:
|
||||
return {"status": "success", "data": cycle_info}
|
||||
else:
|
||||
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
||||
return {"status": "failed", "reason": "subheartflow not found"}
|
||||
|
||||
|
||||
@router.get("/gui/get_all_states")
|
||||
async def get_all_states_api():
|
||||
"""获取所有状态"""
|
||||
all_states = await get_all_states()
|
||||
if all_states:
|
||||
return {"status": "success", "data": all_states}
|
||||
else:
|
||||
logger.warning("获取所有状态失败")
|
||||
return {"status": "failed", "reason": "failed to get all states"}
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_system_basic_info():
|
||||
"""获取系统基本信息"""
|
||||
logger.info("请求系统基本信息")
|
||||
try:
|
||||
info = get_all_basic_info()
|
||||
return {"status": "success", "data": info}
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统基本信息失败: {e}")
|
||||
return {"status": "failed", "reason": str(e)}
|
||||
|
||||
|
||||
def start_api_server():
|
||||
"""启动API服务器"""
|
||||
get_global_server().register_router(router, prefix="/api/v1")
|
||||
# pass
|
||||
@@ -1,24 +0,0 @@
|
||||
from fastapi import HTTPException
|
||||
from rich.traceback import install
|
||||
from src.config.config import get_config_dir, load_config
|
||||
from src.common.logger import get_logger
|
||||
import os
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("api")
|
||||
|
||||
|
||||
async def reload_config():
|
||||
try:
|
||||
from src.config import config as config_module
|
||||
|
||||
logger.debug("正在重载配置文件...")
|
||||
bot_config_path = os.path.join(get_config_dir(), "bot_config.toml")
|
||||
config_module.global_config = load_config(config_path=bot_config_path)
|
||||
logger.debug("配置文件重载成功")
|
||||
return {"status": "reloaded"}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
|
||||
@@ -1,62 +0,0 @@
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MockAudio")
|
||||
|
||||
|
||||
class MockAudioPlayer:
|
||||
"""
|
||||
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
|
||||
"""
|
||||
|
||||
def __init__(self, audio_data: bytes):
|
||||
self._audio_data = audio_data
|
||||
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
|
||||
self._duration = (len(audio_data) / 1024.0) * 0.5
|
||||
|
||||
async def play(self):
|
||||
"""模拟播放音频。该过程可以被中断。"""
|
||||
if self._duration <= 0:
|
||||
return
|
||||
logger.info(f"开始播放模拟音频,预计时长: {self._duration:.2f} 秒...")
|
||||
try:
|
||||
await asyncio.sleep(self._duration)
|
||||
logger.info("模拟音频播放完毕。")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("音频播放被中断。")
|
||||
raise # 重新抛出异常,以便上层逻辑可以捕获它
|
||||
|
||||
|
||||
class MockAudioGenerator:
|
||||
"""
|
||||
一个模拟的文本到语音(TTS)生成器。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 模拟生成速度:每秒生成的字符数
|
||||
self.chars_per_second = 25.0
|
||||
|
||||
async def generate(self, text: str) -> bytes:
|
||||
"""
|
||||
模拟从文本生成音频数据。该过程可以被中断。
|
||||
|
||||
Args:
|
||||
text: 需要转换为音频的文本。
|
||||
|
||||
Returns:
|
||||
模拟的音频数据(bytes)。
|
||||
"""
|
||||
if not text:
|
||||
return b""
|
||||
|
||||
generation_time = len(text) / self.chars_per_second
|
||||
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
|
||||
try:
|
||||
await asyncio.sleep(generation_time)
|
||||
# 生成虚拟的音频数据,其长度与文本长度成正比
|
||||
mock_audio_data = b"\x01\x02\x03" * (len(text) * 40)
|
||||
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
|
||||
return mock_audio_data
|
||||
except asyncio.CancelledError:
|
||||
logger.info("音频生成被中断。")
|
||||
raise # 重新抛出异常
|
||||
@@ -5,11 +5,9 @@ MaiBot模块系统
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
"get_willing_manager",
|
||||
]
|
||||
|
||||
@@ -5,20 +5,20 @@ import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
import io
|
||||
import re
|
||||
import binascii
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
# from gradio_client import file
|
||||
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -26,7 +26,7 @@ logger = get_logger("emoji")
|
||||
|
||||
BASE_DIR = os.path.join("data")
|
||||
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
"""
|
||||
@@ -47,7 +47,7 @@ class MaiEmoji:
|
||||
self.embedding = []
|
||||
self.hash = "" # 初始为空,在创建实例时会计算
|
||||
self.description = ""
|
||||
self.emotion = []
|
||||
self.emotion: List[str] = []
|
||||
self.usage_count = 0
|
||||
self.last_used_time = time.time()
|
||||
self.register_time = time.time()
|
||||
@@ -85,7 +85,7 @@ class MaiEmoji:
|
||||
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
||||
try:
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
self.format = img.format.lower()
|
||||
self.format = img.format.lower() # type: ignore
|
||||
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
||||
except Exception as pil_error:
|
||||
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
||||
@@ -100,7 +100,7 @@ class MaiEmoji:
|
||||
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
|
||||
self.is_deleted = True
|
||||
return None
|
||||
except base64.binascii.Error as b64_error:
|
||||
except (binascii.Error, ValueError) as b64_error:
|
||||
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
|
||||
self.is_deleted = True
|
||||
return None
|
||||
@@ -113,7 +113,7 @@ class MaiEmoji:
|
||||
async def register_to_db(self) -> bool:
|
||||
"""
|
||||
注册表情包
|
||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下
|
||||
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
||||
"""
|
||||
try:
|
||||
@@ -122,7 +122,7 @@ class MaiEmoji:
|
||||
# 源路径是当前实例的完整路径 self.full_path
|
||||
source_full_path = self.full_path
|
||||
# 目标完整路径
|
||||
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
|
||||
destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not os.path.exists(source_full_path):
|
||||
@@ -139,7 +139,7 @@ class MaiEmoji:
|
||||
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
|
||||
# 更新实例的路径属性为新路径
|
||||
self.full_path = destination_full_path
|
||||
self.path = EMOJI_REGISTED_DIR
|
||||
self.path = EMOJI_REGISTERED_DIR
|
||||
# self.filename 保持不变
|
||||
except Exception as move_error:
|
||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||
@@ -202,7 +202,7 @@ class MaiEmoji:
|
||||
try:
|
||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||
except Emoji.DoesNotExist:
|
||||
except Emoji.DoesNotExist: # type: ignore
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
|
||||
@@ -298,7 +298,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
def _ensure_emoji_dir() -> None:
|
||||
"""确保表情存储目录存在"""
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
|
||||
async def clear_temp_emoji() -> None:
|
||||
@@ -331,10 +331,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||
return removed_count
|
||||
|
||||
cleaned_count = 0
|
||||
try:
|
||||
# 获取内存中所有有效表情包的完整路径集合
|
||||
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
|
||||
cleaned_count = 0
|
||||
|
||||
# 遍历指定目录中的所有文件
|
||||
for file_name in os.listdir(emoji_dir):
|
||||
@@ -358,11 +358,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
else:
|
||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
|
||||
return removed_count + cleaned_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||
|
||||
return removed_count + cleaned_count
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
_instance = None
|
||||
@@ -414,7 +414,7 @@ class EmojiManager:
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
emoji_update.save() # Persist changes to DB
|
||||
except Emoji.DoesNotExist:
|
||||
except Emoji.DoesNotExist: # type: ignore
|
||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
@@ -570,8 +570,8 @@ class EmojiManager:
|
||||
if objects_to_remove:
|
||||
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
|
||||
|
||||
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件
|
||||
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count)
|
||||
# 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
|
||||
removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
|
||||
|
||||
# 输出清理结果
|
||||
if removed_count > 0:
|
||||
@@ -850,11 +850,13 @@ class EmojiManager:
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = get_image_manager().transform_gif(image_base64)
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import os
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import json
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
@@ -73,8 +76,70 @@ class ExpressionLearner:
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
self.llm_model = None
|
||||
self._auto_migrate_json_to_db()
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
"""
|
||||
done_flag = os.path.join("data", "expression", "done.done")
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
return
|
||||
base_dir = os.path.join("data", "expression")
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
continue
|
||||
for chat_id in os.listdir(type_dir):
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
# 标记迁移完成
|
||||
try:
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("表达方式JSON迁移已完成,已写入done.done标记文件")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -82,32 +147,31 @@ class ExpressionLearner:
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 获取style表达方式
|
||||
style_dir = os.path.join("data", "expression", "learnt_style", str(chat_id))
|
||||
style_file = os.path.join(style_dir, "expressions.json")
|
||||
if os.path.exists(style_file):
|
||||
try:
|
||||
with open(style_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_style_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取style表达方式失败: {e}")
|
||||
|
||||
# 获取grammar表达方式
|
||||
grammar_dir = os.path.join("data", "expression", "learnt_grammar", str(chat_id))
|
||||
grammar_file = os.path.join(grammar_dir, "expressions.json")
|
||||
if os.path.exists(grammar_file):
|
||||
try:
|
||||
with open(grammar_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_grammar_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取grammar表达方式失败: {e}")
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style",
|
||||
}
|
||||
)
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar",
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -119,10 +183,10 @@ class ExpressionLearner:
|
||||
min_len = min(len(s1), len(s2))
|
||||
if min_len < 5:
|
||||
return False
|
||||
same = sum(1 for a, b in zip(s1, s2) if a == b)
|
||||
same = sum(a == b for a, b in zip(s1, s2, strict=False))
|
||||
return same / min_len > 0.8
|
||||
|
||||
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
|
||||
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
|
||||
"""
|
||||
学习并存储表达方式,分别学习语言风格和句法特点
|
||||
同时对所有已存储的表达方式进行全局衰减
|
||||
@@ -154,16 +218,18 @@ class ExpressionLearner:
|
||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||
# 学习新的表达方式(这里会进行局部衰减)
|
||||
for _ in range(3):
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||
learnt_style = await self.learn_and_store(type="style", num=25)
|
||||
if not learnt_style:
|
||||
return []
|
||||
return [], []
|
||||
|
||||
for _ in range(1):
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
||||
if not learnt_grammar:
|
||||
return []
|
||||
return [], []
|
||||
|
||||
return learnt_style, learnt_grammar
|
||||
|
||||
@@ -214,6 +280,7 @@ class ExpressionLearner:
|
||||
return result
|
||||
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
# sourcery skip: use-join
|
||||
"""
|
||||
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
||||
type: "style" or "grammar"
|
||||
@@ -233,7 +300,6 @@ class ExpressionLearner:
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
# 如果聊天流不在内存中,使用chat_id作为默认名称
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
@@ -249,7 +315,7 @@ class ExpressionLearner:
|
||||
return []
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, str]]] = {}
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
@@ -257,80 +323,44 @@ class ExpressionLearner:
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到/data/expression/对应chat_id/expressions.json
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, "expressions.json")
|
||||
|
||||
# 若已存在,先读出合并
|
||||
old_data: List[Dict[str, Any]] = []
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
old_data = json.load(f)
|
||||
except Exception:
|
||||
old_data = []
|
||||
|
||||
# 应用衰减
|
||||
# old_data = self.apply_decay_to_expressions(old_data, current_time)
|
||||
|
||||
# 合并逻辑
|
||||
for new_expr in expr_list:
|
||||
found = False
|
||||
for old_expr in old_data:
|
||||
if self.is_similar(new_expr["situation"], old_expr.get("situation", "")) and self.is_similar(
|
||||
new_expr["style"], old_expr.get("style", "")
|
||||
):
|
||||
found = True
|
||||
# 50%概率替换
|
||||
if random.random() < 0.5:
|
||||
old_expr["situation"] = new_expr["situation"]
|
||||
old_expr["style"] = new_expr["style"]
|
||||
old_expr["count"] = old_expr.get("count", 1) + 1
|
||||
old_expr["last_active_time"] = current_time
|
||||
break
|
||||
if not found:
|
||||
new_expr["count"] = 1
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type,
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
@@ -1,14 +1,16 @@
|
||||
from .exprssion_learner import get_expression_learner
|
||||
import random
|
||||
from typing import List, Dict, Tuple
|
||||
from json_repair import repair_json
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .expression_learner import get_expression_learner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -79,94 +81,128 @@ class ExpressionSelector:
|
||||
request_type="expression.selector",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
import hashlib
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if chat_id_candidate:
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
) = self.expression_learner.get_expression_by_chat_id(chat_id)
|
||||
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_exprs = []
|
||||
grammar_exprs = []
|
||||
for cid in related_chat_ids:
|
||||
style_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "style"))
|
||||
grammar_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "grammar"))
|
||||
style_exprs.extend([
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "style"
|
||||
} for expr in style_query
|
||||
])
|
||||
grammar_exprs.extend([
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "grammar"
|
||||
} for expr in grammar_query
|
||||
])
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if learnt_style_expressions:
|
||||
style_weights = [expr.get("count", 1) for expr in learnt_style_expressions]
|
||||
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, style_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
if learnt_grammar_expressions:
|
||||
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
||||
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
||||
if grammar_exprs:
|
||||
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
|
||||
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
|
||||
else:
|
||||
selected_grammar = []
|
||||
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按文件分组后一次性写入"""
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
updates_by_file = {}
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
if not source_id:
|
||||
logger.warning(f"表达方式缺少source_id,无法更新: {expr}")
|
||||
expr_type = expr.get("type", "style")
|
||||
situation = expr.get("situation")
|
||||
style = expr.get("style")
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
|
||||
file_path = ""
|
||||
if source_id == "personality":
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
chat_id = source_id
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
|
||||
if file_path:
|
||||
if file_path not in updates_by_file:
|
||||
updates_by_file[file_path] = []
|
||||
updates_by_file[file_path].append(expr)
|
||||
|
||||
for file_path, updates in updates_by_file.items():
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
all_expressions = json.load(f)
|
||||
|
||||
# Create a dictionary for quick lookup
|
||||
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
|
||||
|
||||
# Update counts in memory
|
||||
for expr_to_update in updates:
|
||||
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
|
||||
if key in expr_map:
|
||||
expr_in_map = expr_map[key]
|
||||
current_count = expr_in_map.get("count", 1)
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_in_map["count"] = new_count
|
||||
expr_in_map["last_active_time"] = time.time()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
|
||||
)
|
||||
|
||||
# Save the updated list once for this file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == expr_type) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
target_message: Optional[str] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
# 1. 获取35个随机表达方式(现在按权重抽取)
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||
from typing import List
|
||||
# Import the new utility function
|
||||
|
||||
logger = get_logger("loop_info")
|
||||
|
||||
|
||||
# 所有观察的基类
|
||||
class FocusLoopInfo:
|
||||
def __init__(self, observe_id):
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
|
||||
def add_loop_info(self, loop_info: CycleDetail):
|
||||
self.history_loop.append(loop_info)
|
||||
|
||||
async def observe(self):
|
||||
recent_active_cycles: List[CycleDetail] = []
|
||||
for cycle in reversed(self.history_loop):
|
||||
# 只关心实际执行了动作的循环
|
||||
# action_taken = cycle.loop_action_info["action_taken"]
|
||||
# if action_taken:
|
||||
recent_active_cycles.append(cycle)
|
||||
if len(recent_active_cycles) == 5:
|
||||
break
|
||||
|
||||
cycle_info_block = ""
|
||||
action_detailed_str = ""
|
||||
consecutive_text_replies = 0
|
||||
responses_for_prompt = []
|
||||
|
||||
cycle_last_reason = ""
|
||||
|
||||
# 检查这最近的活动循环中有多少是连续的文本回复 (从最近的开始看)
|
||||
for cycle in recent_active_cycles:
|
||||
action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
action_type = action_result.get("action_type", "unknown")
|
||||
action_reasoning = action_result.get("reasoning", "未提供理由")
|
||||
is_taken = cycle.loop_action_info.get("action_taken", False)
|
||||
action_taken_time = cycle.loop_action_info.get("taken_time", 0)
|
||||
action_taken_time_str = (
|
||||
datetime.fromtimestamp(action_taken_time).strftime("%H:%M:%S") if action_taken_time > 0 else "未知时间"
|
||||
)
|
||||
if action_reasoning != cycle_last_reason:
|
||||
cycle_last_reason = action_reasoning
|
||||
action_reasoning_str = f"你选择这个action的原因是:{action_reasoning}"
|
||||
else:
|
||||
action_reasoning_str = ""
|
||||
|
||||
if action_type == "reply":
|
||||
consecutive_text_replies += 1
|
||||
response_text = cycle.loop_action_info.get("reply_text", "")
|
||||
responses_for_prompt.append(response_text)
|
||||
|
||||
if is_taken:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择回复(action:{action_type},内容是:'{response_text}')。{action_reasoning_str}\n"
|
||||
else:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择回复(action:{action_type},内容是:'{response_text}'),但是动作失败了。{action_reasoning_str}\n"
|
||||
elif action_type == "no_reply":
|
||||
pass
|
||||
else:
|
||||
if is_taken:
|
||||
action_detailed_str += (
|
||||
f"{action_taken_time_str}时,你选择执行了(action:{action_type}),{action_reasoning_str}\n"
|
||||
)
|
||||
else:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择执行了(action:{action_type}),但是动作失败了。{action_reasoning_str}\n"
|
||||
|
||||
if action_detailed_str:
|
||||
cycle_info_block = f"\n你最近做的事:\n{action_detailed_str}\n"
|
||||
else:
|
||||
cycle_info_block = "\n"
|
||||
|
||||
# 获取history_loop中最新添加的
|
||||
if self.history_loop:
|
||||
last_loop = self.history_loop[0]
|
||||
start_time = last_loop.start_time
|
||||
end_time = last_loop.end_time
|
||||
if start_time is not None and end_time is not None:
|
||||
time_diff = int(end_time - start_time)
|
||||
if time_diff > 60:
|
||||
cycle_info_block += f"距离你上一次阅读消息并思考和规划,已经过去了{int(time_diff / 60)}分钟\n"
|
||||
else:
|
||||
cycle_info_block += f"距离你上一次阅读消息并思考和规划,已经过去了{time_diff}秒\n"
|
||||
else:
|
||||
cycle_info_block += "你还没看过消息\n"
|
||||
@@ -1,24 +1,54 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any
|
||||
from rich.traceback import install
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.focus_loop_info import FocusLoopInfo
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.hfc_performance_logger import HFCPerformanceLogger
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||
from src.chat.willing.willing_manager import get_willing_manager
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
NO_ACTION = {
|
||||
"action_result": {
|
||||
"action_type": "no_action",
|
||||
"action_data": {},
|
||||
"reasoning": "规划器初始化默认",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
@@ -36,7 +66,6 @@ class HeartFChatting:
|
||||
def __init__(
|
||||
self,
|
||||
chat_id: str,
|
||||
on_stop_focus_chat: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
):
|
||||
"""
|
||||
HeartFChatting 初始化函数
|
||||
@@ -48,85 +77,68 @@ class HeartFChatting:
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
|
||||
self.loop_mode = ChatMode.NORMAL # 初始循环模式为普通模式
|
||||
|
||||
# 新增:消息计数器和疲惫阈值
|
||||
self._message_count = 0 # 发送的消息计数
|
||||
# 基于exit_focus_threshold动态计算疲惫阈值
|
||||
# 基础值30条,通过exit_focus_threshold调节:threshold越小,越容易疲惫
|
||||
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
|
||||
self._message_threshold = max(10, int(30 * global_config.chat.focus_value))
|
||||
self._fatigue_triggered = False # 是否已触发疲惫退出
|
||||
|
||||
self.loop_info: FocusLoopInfo = FocusLoopInfo(observe_id=self.stream_id)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
# 循环控制内部状态
|
||||
self._loop_active: bool = False # 循环是否正在运行
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
self._energy_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息
|
||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
||||
self._shutting_down: bool = False # 关闭标志位
|
||||
|
||||
# 存储回调函数
|
||||
self.on_stop_focus_chat = on_stop_focus_chat
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.reply_timeout_count = 0
|
||||
self.plan_timeout_count = 0
|
||||
|
||||
# 初始化性能记录器
|
||||
# 如果没有指定版本号,则使用全局版本管理器的版本号
|
||||
self.last_read_time = time.time() - 1
|
||||
|
||||
self.performance_logger = HFCPerformanceLogger(chat_id)
|
||||
self.willing_amplifier = 1
|
||||
self.willing_manager = get_willing_manager()
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}条(基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算,仅在auto模式下生效)"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 初始化完成")
|
||||
|
||||
self.energy_value = 5
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self._loop_active:
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 重置消息计数器,开始新的focus会话
|
||||
self.reset_message_count()
|
||||
|
||||
# 标记为活动状态,防止重复启动
|
||||
self._loop_active = True
|
||||
self.running = True
|
||||
|
||||
# 检查是否已有任务在运行(理论上不应该,因为 _loop_active=False)
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
logger.warning(f"{self.log_prefix} 发现之前的循环任务仍在运行(不符合预期)。取消旧任务。")
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
# 等待旧任务确实被取消
|
||||
await asyncio.wait_for(self._loop_task, timeout=5.0)
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix} 等待旧任务取消时出错: {e}")
|
||||
self._loop_task = None # 清理旧任务引用
|
||||
self._energy_task = asyncio.create_task(self._energy_loop())
|
||||
self._energy_task.add_done_callback(self._handle_energy_completion)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 创建新的 HeartFChatting 主循环任务")
|
||||
self._loop_task = asyncio.create_task(self._run_focus_chat())
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
logger.debug(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self._loop_active = False
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}")
|
||||
raise
|
||||
@@ -134,273 +146,250 @@ class HeartFChatting:
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
exception = task.exception()
|
||||
if exception:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天(任务取消)")
|
||||
finally:
|
||||
self._loop_active = False
|
||||
self._loop_task = None
|
||||
if self._processing_lock.locked():
|
||||
logger.warning(f"{self.log_prefix} HeartFChatting: 处理锁在循环结束时仍被锁定,强制释放。")
|
||||
self._processing_lock.release()
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
async def _run_focus_chat(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while True: # 主循环
|
||||
logger.debug(f"{self.log_prefix} 开始第{self._cycle_counter}次循环")
|
||||
def start_cycle(self):
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
# 检查关闭标志
|
||||
if self._shutting_down:
|
||||
logger.info(f"{self.log_prefix} 检测到关闭标志,退出 Focus Chat 循环。")
|
||||
break
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
# 创建新的循环信息
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.prefix = self.log_prefix
|
||||
def _handle_energy_completion(self, task: asyncio.Task):
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} HeartFChatting: 能量循环异常: {exception}")
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} HeartFChatting: 能量循环完成")
|
||||
|
||||
# 初始化周期状态
|
||||
cycle_timers = {}
|
||||
async def _energy_loop(self):
|
||||
while self.running:
|
||||
await asyncio.sleep(10)
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
self.energy_value -= 0.3
|
||||
if self.energy_value <= 0.3:
|
||||
self.energy_value = 0.3
|
||||
|
||||
# 执行规划和处理阶段
|
||||
try:
|
||||
async with self._get_cycle_context():
|
||||
thinking_id = "tid" + str(round(time.time(), 2))
|
||||
self._current_cycle_detail.set_thinking_id(thinking_id)
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
# 使用异步上下文管理器处理消息
|
||||
try:
|
||||
async with global_prompt_manager.async_message_scope(
|
||||
self.chat_stream.context.get_template_name()
|
||||
):
|
||||
# 在上下文内部检查关闭状态
|
||||
if self._shutting_down:
|
||||
logger.info(f"{self.log_prefix} 在处理上下文中检测到关闭信号,退出")
|
||||
break
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
|
||||
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
|
||||
|
||||
if loop_info["loop_action_info"]["command"] == "stop_focus_chat":
|
||||
logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天")
|
||||
|
||||
# 如果设置了回调函数,则调用它
|
||||
if self.on_stop_focus_chat:
|
||||
try:
|
||||
await self.on_stop_focus_chat()
|
||||
logger.info(f"{self.log_prefix} 成功调用回调函数处理停止专注聊天")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 调用停止专注聊天回调函数时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 处理上下文时任务被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理上下文时出错: {e}")
|
||||
# 为当前循环设置错误状态,防止后续重复报错
|
||||
error_loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
self._current_cycle_detail.set_loop_info(error_loop_info)
|
||||
self._current_cycle_detail.complete_cycle()
|
||||
|
||||
# 上下文处理失败,跳过当前循环
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
|
||||
self.loop_info.add_loop_info(self._current_cycle_detail)
|
||||
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
|
||||
# 完成当前循环并保存历史
|
||||
self._current_cycle_detail.complete_cycle()
|
||||
self._cycle_history.append(self._current_cycle_detail)
|
||||
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, "
|
||||
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
# 记录性能数据
|
||||
try:
|
||||
action_result = self._current_cycle_detail.loop_plan_info.get("action_result", {})
|
||||
cycle_performance_data = {
|
||||
"cycle_id": self._current_cycle_detail.cycle_id,
|
||||
"action_type": action_result.get("action_type", "unknown"),
|
||||
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time,
|
||||
"step_times": cycle_timers.copy(),
|
||||
"reasoning": action_result.get("reasoning", ""),
|
||||
"success": self._current_cycle_detail.loop_action_info.get("action_taken", False),
|
||||
}
|
||||
self.performance_logger.record_cycle(cycle_performance_data)
|
||||
except Exception as perf_e:
|
||||
logger.warning(f"{self.log_prefix} 记录性能数据失败: {perf_e}")
|
||||
|
||||
await asyncio.sleep(global_config.focus_chat.think_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 循环处理时任务被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 循环处理时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 如果_current_cycle_detail存在但未完成,为其设置错误状态
|
||||
if self._current_cycle_detail and not hasattr(self._current_cycle_detail, "end_time"):
|
||||
error_loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": f"循环处理失败: {e}",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
try:
|
||||
self._current_cycle_detail.set_loop_info(error_loop_info)
|
||||
self._current_cycle_detail.complete_cycle()
|
||||
except Exception as inner_e:
|
||||
logger.error(f"{self.log_prefix} 设置错误状态时出错: {inner_e}")
|
||||
|
||||
await asyncio.sleep(1) # 出错后等待一秒再继续
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
if not self._shutting_down:
|
||||
logger.warning(f"{self.log_prefix} 麦麦Focus聊天模式意外被取消")
|
||||
async def _loopbody(self):
|
||||
if self.loop_mode == ChatMode.FOCUS:
|
||||
if await self._observe():
|
||||
self.energy_value -= 1 * global_config.chat.focus_value
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 麦麦已离开Focus聊天模式")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 麦麦Focus聊天模式意外错误: {e}")
|
||||
print(traceback.format_exc())
|
||||
self.energy_value -= 3 * global_config.chat.focus_value
|
||||
if self.energy_value <= 1:
|
||||
self.energy_value = 1
|
||||
self.loop_mode = ChatMode.NORMAL
|
||||
return True
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _get_cycle_context(self):
|
||||
"""
|
||||
循环周期的上下文管理器
|
||||
return True
|
||||
elif self.loop_mode == ChatMode.NORMAL:
|
||||
new_messages_data = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp_start=self.last_read_time,
|
||||
timestamp_end=time.time(),
|
||||
limit=10,
|
||||
limit_mode="earliest",
|
||||
filter_bot=True,
|
||||
)
|
||||
|
||||
用于确保资源的正确获取和释放:
|
||||
1. 获取处理锁
|
||||
2. 执行操作
|
||||
3. 释放锁
|
||||
"""
|
||||
acquired = False
|
||||
try:
|
||||
await self._processing_lock.acquire()
|
||||
acquired = True
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired and self._processing_lock.locked():
|
||||
self._processing_lock.release()
|
||||
if len(new_messages_data) > 3 * global_config.chat.focus_value:
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
self.energy_value = 10 + (len(new_messages_data) / (3 * global_config.chat.focus_value)) * 10
|
||||
return True
|
||||
|
||||
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> dict:
|
||||
try:
|
||||
if self.energy_value >= 30 * global_config.chat.focus_value:
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
return True
|
||||
|
||||
if new_messages_data:
|
||||
earliest_messages_data = new_messages_data[0]
|
||||
self.last_read_time = earliest_messages_data.get("time")
|
||||
|
||||
if_think = await self.normal_response(earliest_messages_data)
|
||||
if if_think:
|
||||
factor = max(global_config.chat.focus_value, 0.1)
|
||||
self.energy_value *= 1.1 / factor
|
||||
logger.info(f"{self.log_prefix} 麦麦进行了思考,能量值按倍数增加,当前能量值:{self.energy_value}")
|
||||
else:
|
||||
self.energy_value += 0.1 / global_config.chat.focus_value
|
||||
logger.info(f"{self.log_prefix} 麦麦没有进行思考,能量值线性增加,当前能量值:{self.energy_value}")
|
||||
|
||||
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value}")
|
||||
return True
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
async def build_reply_to_str(self, message_data: dict):
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id(
|
||||
message_data.get("chat_info_platform"), # type: ignore
|
||||
message_data.get("user_id"), # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||
if not message_data:
|
||||
message_data = {}
|
||||
action_type = "no_action"
|
||||
# 创建新的循环信息
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
loop_start_time = time.time()
|
||||
await self.loop_info.observe()
|
||||
|
||||
await self.relationship_builder.build_relation()
|
||||
|
||||
# 顺序执行调整动作和处理器阶段
|
||||
available_actions = {}
|
||||
|
||||
# 第一步:动作修改
|
||||
with Timer("动作修改", cycle_timers):
|
||||
try:
|
||||
# 调用完整的动作修改流程
|
||||
await self.action_modifier.modify_actions(
|
||||
loop_info=self.loop_info,
|
||||
mode="focus",
|
||||
)
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
# 继续执行,不中断流程
|
||||
|
||||
# 如果normal,开始一个回复生成进程,先准备好回复(其实是和planer同时进行的)
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
reply_to_str = await self.build_reply_to_str(message_data)
|
||||
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
plan_result = await self.action_planner.plan()
|
||||
plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode)
|
||||
|
||||
loop_plan_info = {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
}
|
||||
|
||||
action_type, action_data, reasoning = (
|
||||
plan_result.get("action_result", {}).get("action_type", "error"),
|
||||
plan_result.get("action_result", {}).get("action_data", {}),
|
||||
plan_result.get("action_result", {}).get("reasoning", "未提供理由"),
|
||||
action_result: dict = plan_result.get("action_result", {}) # type: ignore
|
||||
action_type, action_data, reasoning, is_parallel = (
|
||||
action_result.get("action_type", "error"),
|
||||
action_result.get("action_data", {}),
|
||||
action_result.get("reasoning", "未提供理由"),
|
||||
action_result.get("is_parallel", True),
|
||||
)
|
||||
|
||||
action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
if action_type == "reply":
|
||||
action_str = "回复"
|
||||
elif action_type == "no_reply":
|
||||
action_str = "不回复"
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
if action_type == "no_action":
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复")
|
||||
elif is_parallel:
|
||||
logger.info(
|
||||
f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复, 同时执行{action_type}动作"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
|
||||
|
||||
if action_type == "no_action":
|
||||
# 等待回复生成完毕
|
||||
gather_timeout = global_config.chat.thinking_timeout
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
response_set = None
|
||||
|
||||
if response_set:
|
||||
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
||||
|
||||
# 模型炸了,没有回复内容生成
|
||||
if not response_set:
|
||||
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
||||
return False
|
||||
elif action_type not in ["no_action"] and not is_parallel:
|
||||
logger.info(
|
||||
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
||||
|
||||
# 发送回复 (不再需要传入 chat)
|
||||
await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
action_str = action_type
|
||||
action_message: Dict[str, Any] = message_data or target_message # type: ignore
|
||||
|
||||
logger.debug(f"{self.log_prefix} 麦麦想要:'{action_str}',理由是:{reasoning}")
|
||||
# 动作执行计时
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
)
|
||||
|
||||
# 动作执行计时
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id
|
||||
)
|
||||
|
||||
loop_action_info = {
|
||||
"action_taken": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
"taken_time": time.time(),
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
loop_info = {
|
||||
"loop_plan_info": loop_plan_info,
|
||||
"loop_action_info": loop_action_info,
|
||||
}
|
||||
if loop_info["loop_action_info"]["command"] == "stop_focus_chat":
|
||||
logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天")
|
||||
return False
|
||||
# 停止该聊天模式的循环
|
||||
|
||||
return loop_info
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} FOCUS聊天处理失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return {
|
||||
"loop_plan_info": {
|
||||
"action_result": {"action_type": "error", "action_data": {}, "reasoning": f"处理失败: {e}"},
|
||||
},
|
||||
"loop_action_info": {"action_taken": False, "reply_text": "", "command": "", "taken_time": time.time()},
|
||||
}
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||
|
||||
if action_type != "no_reply" and action_type != "no_action":
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running: # 主循环
|
||||
success = await self._loopbody()
|
||||
await asyncio.sleep(0.1)
|
||||
if not success:
|
||||
break
|
||||
|
||||
logger.info(f"{self.log_prefix} 麦麦已强制离开聊天")
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误")
|
||||
print(traceback.format_exc())
|
||||
# 理论上不能到这里
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,结束了聊天循环")
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
@@ -409,6 +398,7 @@ class HeartFChatting:
|
||||
action_data: dict,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
action_message: dict,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
@@ -434,7 +424,7 @@ class HeartFChatting:
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
shutting_down=self._shutting_down,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
@@ -447,46 +437,17 @@ class HeartFChatting:
|
||||
|
||||
# 处理动作并获取结果
|
||||
result = await action_handler.handle_action()
|
||||
if len(result) == 3:
|
||||
success, reply_text, command = result
|
||||
else:
|
||||
success, reply_text = result
|
||||
command = ""
|
||||
success, reply_text = result
|
||||
command = ""
|
||||
|
||||
# 检查action_data中是否有系统命令,优先使用系统命令
|
||||
if "_system_command" in action_data:
|
||||
command = action_data["_system_command"]
|
||||
logger.debug(f"{self.log_prefix} 从action_data中获取系统命令: {command}")
|
||||
|
||||
# 新增:消息计数和疲惫检查
|
||||
if action == "reply" and success:
|
||||
self._message_count += 1
|
||||
current_threshold = self._get_current_fatigue_threshold()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已发送第 {self._message_count} 条消息(动态阈值: {current_threshold}, exit_focus_threshold: {global_config.chat.exit_focus_threshold})"
|
||||
)
|
||||
|
||||
# 检查是否达到疲惫阈值(只有在auto模式下才会自动退出)
|
||||
if (
|
||||
global_config.chat.chat_mode == "auto"
|
||||
and self._message_count >= current_threshold
|
||||
and not self._fatigue_triggered
|
||||
):
|
||||
self._fatigue_triggered = True
|
||||
logger.info(
|
||||
f"{self.log_prefix} [auto模式] 已发送 {self._message_count} 条消息,达到疲惫阈值 {current_threshold},麦麦感到疲惫了,准备退出专注聊天模式"
|
||||
if reply_text == "timeout":
|
||||
self.reply_timeout_count += 1
|
||||
if self.reply_timeout_count > 5:
|
||||
logger.warning(
|
||||
f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。"
|
||||
)
|
||||
# 设置系统命令,在下次循环检查时触发退出
|
||||
command = "stop_focus_chat"
|
||||
else:
|
||||
if reply_text == "timeout":
|
||||
self.reply_timeout_count += 1
|
||||
if self.reply_timeout_count > 5:
|
||||
logger.warning(
|
||||
f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容,请检查你的api是否速度过慢或配置错误。建议不要使用推理模型,推理模型生成速度过慢。或者尝试拉高thinking_timeout参数,这可能导致回复时间过长。"
|
||||
)
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过")
|
||||
return False, "", ""
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s,已跳过")
|
||||
return False, "", ""
|
||||
|
||||
return success, reply_text, command
|
||||
|
||||
@@ -495,88 +456,108 @@ class HeartFChatting:
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
def _get_current_fatigue_threshold(self) -> int:
|
||||
"""动态获取当前的疲惫阈值,基于exit_focus_threshold配置
|
||||
|
||||
Returns:
|
||||
int: 当前的疲惫阈值
|
||||
async def normal_response(self, message_data: dict) -> bool:
|
||||
"""
|
||||
return max(10, int(30 / global_config.chat.exit_focus_threshold))
|
||||
|
||||
def get_message_count_info(self) -> dict:
|
||||
"""获取消息计数信息
|
||||
|
||||
Returns:
|
||||
dict: 包含消息计数信息的字典
|
||||
处理接收到的消息。
|
||||
在"兴趣"模式下,判断是否回复并生成内容。
|
||||
"""
|
||||
current_threshold = self._get_current_fatigue_threshold()
|
||||
return {
|
||||
"current_count": self._message_count,
|
||||
"threshold": current_threshold,
|
||||
"fatigue_triggered": self._fatigue_triggered,
|
||||
"remaining": max(0, current_threshold - self._message_count),
|
||||
}
|
||||
|
||||
def reset_message_count(self):
|
||||
"""重置消息计数器(用于重新启动focus模式时)"""
|
||||
self._message_count = 0
|
||||
self._fatigue_triggered = False
|
||||
logger.info(f"{self.log_prefix} 消息计数器已重置")
|
||||
interested_rate = message_data.get("interest_value", 0.0) * self.willing_amplifier
|
||||
|
||||
self.willing_manager.setup(message_data, self.chat_stream)
|
||||
|
||||
|
||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||
|
||||
async def shutdown(self):
|
||||
"""优雅关闭HeartFChatting实例,取消活动循环任务"""
|
||||
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
||||
self._shutting_down = True # <-- 在开始关闭时设置标志位
|
||||
talk_frequency = -1.00
|
||||
|
||||
# 记录最终的消息统计
|
||||
if self._message_count > 0:
|
||||
logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
|
||||
if self._fatigue_triggered:
|
||||
logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
|
||||
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
|
||||
additional_config = message_data.get("additional_config", {})
|
||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
||||
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||
reply_probability = talk_frequency * reply_probability
|
||||
|
||||
# 取消循环任务
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._loop_task, timeout=1.0)
|
||||
logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
|
||||
# 处理表情包
|
||||
if message_data.get("is_emoji") or message_data.get("is_picid"):
|
||||
reply_probability = 0
|
||||
|
||||
# 清理状态
|
||||
self._loop_active = False
|
||||
self._loop_task = None
|
||||
if self._processing_lock.locked():
|
||||
self._processing_lock.release()
|
||||
logger.warning(f"{self.log_prefix} 已释放处理锁")
|
||||
# 打印消息信息
|
||||
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
||||
|
||||
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
||||
|
||||
if reply_probability > 0.05:
|
||||
logger.info(
|
||||
f"[{mes_name}]"
|
||||
f"{message_data.get('user_nickname')}:"
|
||||
f"{message_data.get('processed_plain_text')}[兴趣:{interested_rate:.2f}][回复概率:{reply_probability * 100:.1f}%]"
|
||||
)
|
||||
|
||||
# 完成性能统计
|
||||
if random.random() < reply_probability:
|
||||
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id", ""))
|
||||
await self._observe(message_data=message_data)
|
||||
return True
|
||||
|
||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||
return False
|
||||
|
||||
|
||||
async def _generate_response(
|
||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||
) -> Optional[list]:
|
||||
"""生成普通回复"""
|
||||
try:
|
||||
self.performance_logger.finalize_session()
|
||||
logger.info(f"{self.log_prefix} 性能统计已完成")
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_to=reply_to,
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_in_normal_chat,
|
||||
request_type="chat.replyer.normal",
|
||||
)
|
||||
|
||||
if not success or not reply_set:
|
||||
logger.info(f"对 {message_data.get('processed_plain_text')} 的回复生成失败")
|
||||
return None
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix} 完成性能统计时出错: {e}")
|
||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
# 重置消息计数器,为下次启动做准备
|
||||
self.reset_message_count()
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data):
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
)
|
||||
platform = message_data.get("user_platform", "")
|
||||
user_id = message_data.get("user_id", "")
|
||||
reply_to_platform_id = f"{platform}:{user_id}"
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
def get_cycle_history(self, last_n: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""获取循环历史记录
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
||||
)
|
||||
|
||||
参数:
|
||||
last_n: 获取最近n个循环的信息,如果为None则获取所有历史记录
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await send_api.text_to_stream(
|
||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
)
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
reply_text += data
|
||||
|
||||
返回:
|
||||
List[Dict[str, Any]]: 循环历史记录列表
|
||||
"""
|
||||
history = list(self._cycle_history)
|
||||
if last_n is not None:
|
||||
history = history[-last_n:]
|
||||
return [cycle.to_dict() for cycle in history]
|
||||
return reply_text
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("hfc_performance")
|
||||
|
||||
|
||||
class HFCPerformanceLogger:
|
||||
"""HFC性能记录管理器"""
|
||||
|
||||
# 版本号常量,可在启动时修改
|
||||
INTERNAL_VERSION = "v7.0.0"
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.version = self.INTERNAL_VERSION
|
||||
self.log_dir = Path("log/hfc_loop")
|
||||
self.session_start_time = datetime.now()
|
||||
|
||||
# 确保目录存在
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 当前会话的日志文件,包含版本号
|
||||
version_suffix = self.version.replace(".", "_")
|
||||
self.session_file = (
|
||||
self.log_dir / f"{chat_id}_{version_suffix}_{self.session_start_time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
)
|
||||
self.current_session_data = []
|
||||
|
||||
def record_cycle(self, cycle_data: Dict[str, Any]):
|
||||
"""记录单次循环数据"""
|
||||
try:
|
||||
# 构建记录数据
|
||||
record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"version": self.version,
|
||||
"cycle_id": cycle_data.get("cycle_id"),
|
||||
"chat_id": self.chat_id,
|
||||
"action_type": cycle_data.get("action_type", "unknown"),
|
||||
"total_time": cycle_data.get("total_time", 0),
|
||||
"step_times": cycle_data.get("step_times", {}),
|
||||
"reasoning": cycle_data.get("reasoning", ""),
|
||||
"success": cycle_data.get("success", False),
|
||||
}
|
||||
|
||||
# 添加到当前会话数据
|
||||
self.current_session_data.append(record)
|
||||
|
||||
# 立即写入文件(防止数据丢失)
|
||||
self._write_session_data()
|
||||
|
||||
# 构建详细的日志信息
|
||||
log_parts = [
|
||||
f"cycle_id={record['cycle_id']}",
|
||||
f"action={record['action_type']}",
|
||||
f"time={record['total_time']:.2f}s",
|
||||
]
|
||||
|
||||
logger.debug(f"记录HFC循环数据: {', '.join(log_parts)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录HFC循环数据失败: {e}")
|
||||
|
||||
def _write_session_data(self):
|
||||
"""写入当前会话数据到文件"""
|
||||
try:
|
||||
with open(self.session_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.current_session_data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"写入会话数据失败: {e}")
|
||||
|
||||
def get_current_session_stats(self) -> Dict[str, Any]:
|
||||
"""获取当前会话的基本信息"""
|
||||
if not self.current_session_data:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"version": self.version,
|
||||
"session_file": str(self.session_file),
|
||||
"record_count": len(self.current_session_data),
|
||||
"start_time": self.session_start_time.isoformat(),
|
||||
}
|
||||
|
||||
def finalize_session(self):
|
||||
"""结束会话"""
|
||||
try:
|
||||
if self.current_session_data:
|
||||
logger.info(f"完成会话,当前会话 {len(self.current_session_data)} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"结束会话失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def cleanup_old_logs(cls, max_size_mb: float = 50.0):
|
||||
"""
|
||||
清理旧的HFC日志文件,保持目录大小在指定限制内
|
||||
|
||||
Args:
|
||||
max_size_mb: 最大目录大小限制(MB)
|
||||
"""
|
||||
log_dir = Path("log/hfc_loop")
|
||||
if not log_dir.exists():
|
||||
logger.info("HFC日志目录不存在,跳过日志清理")
|
||||
return
|
||||
|
||||
# 获取所有日志文件及其信息
|
||||
log_files = []
|
||||
total_size = 0
|
||||
|
||||
for log_file in log_dir.glob("*.json"):
|
||||
try:
|
||||
file_stat = log_file.stat()
|
||||
log_files.append({"path": log_file, "size": file_stat.st_size, "mtime": file_stat.st_mtime})
|
||||
total_size += file_stat.st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取文件信息 {log_file}: {e}")
|
||||
|
||||
if not log_files:
|
||||
logger.info("没有找到HFC日志文件")
|
||||
return
|
||||
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
current_size_mb = total_size / (1024 * 1024)
|
||||
|
||||
logger.info(f"HFC日志目录当前大小: {current_size_mb:.2f}MB,限制: {max_size_mb}MB")
|
||||
|
||||
if total_size <= max_size_bytes:
|
||||
logger.info("HFC日志目录大小在限制范围内,无需清理")
|
||||
return
|
||||
|
||||
# 按修改时间排序(最早的在前面)
|
||||
log_files.sort(key=lambda x: x["mtime"])
|
||||
|
||||
deleted_count = 0
|
||||
deleted_size = 0
|
||||
|
||||
for file_info in log_files:
|
||||
if total_size <= max_size_bytes:
|
||||
break
|
||||
|
||||
try:
|
||||
file_size = file_info["size"]
|
||||
file_path = file_info["path"]
|
||||
|
||||
file_path.unlink()
|
||||
total_size -= file_size
|
||||
deleted_size += file_size
|
||||
deleted_count += 1
|
||||
|
||||
logger.info(f"删除旧日志文件: {file_path.name} ({file_size / 1024:.1f}KB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除日志文件失败 {file_info['path']}: {e}")
|
||||
|
||||
final_size_mb = total_size / (1024 * 1024)
|
||||
deleted_size_mb = deleted_size / (1024 * 1024)
|
||||
|
||||
logger.info(f"HFC日志清理完成: 删除了{deleted_count}个文件,释放{deleted_size_mb:.2f}MB空间")
|
||||
logger.info(f"清理后目录大小: {final_size_mb:.2f}MB")
|
||||
@@ -1,23 +1,19 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import count_messages
|
||||
from src.common.logger import get_logger
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
log_dir = "log/log_cycle_debug/"
|
||||
|
||||
|
||||
class CycleDetail:
|
||||
"""循环信息记录类"""
|
||||
|
||||
def __init__(self, cycle_id: int):
|
||||
self.cycle_id = cycle_id
|
||||
self.prefix = ""
|
||||
self.thinking_id = ""
|
||||
self.start_time = time.time()
|
||||
self.end_time: Optional[float] = None
|
||||
@@ -79,85 +75,34 @@ class CycleDetail:
|
||||
"loop_action_info": convert_to_serializable(self.loop_action_info),
|
||||
}
|
||||
|
||||
def complete_cycle(self):
|
||||
"""完成循环,记录结束时间"""
|
||||
self.end_time = time.time()
|
||||
|
||||
# 处理 prefix,只保留中英文字符和基本标点
|
||||
if not self.prefix:
|
||||
self.prefix = "group"
|
||||
else:
|
||||
# 只保留中文、英文字母、数字和基本标点
|
||||
allowed_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_")
|
||||
self.prefix = (
|
||||
"".join(char for char in self.prefix if "\u4e00" <= char <= "\u9fff" or char in allowed_chars)
|
||||
or "group"
|
||||
)
|
||||
|
||||
def set_thinking_id(self, thinking_id: str):
|
||||
"""设置思考消息ID"""
|
||||
self.thinking_id = thinking_id
|
||||
|
||||
def set_loop_info(self, loop_info: Dict[str, Any]):
|
||||
"""设置循环信息"""
|
||||
self.loop_plan_info = loop_info["loop_plan_info"]
|
||||
self.loop_action_info = loop_info["loop_action_info"]
|
||||
|
||||
|
||||
async def create_empty_anchor_message(
|
||||
platform: str, group_info: dict, chat_stream: ChatStream
|
||||
) -> Optional[MessageRecv]:
|
||||
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
重构观察到的最后一条消息作为回复的锚点,
|
||||
如果重构失败或观察为空,则创建一个占位符。
|
||||
Args:
|
||||
minutes (float): 检索的分钟数,默认30分钟
|
||||
chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。
|
||||
Returns:
|
||||
dict: {"bot_reply_count": int, "total_message_count": int}
|
||||
"""
|
||||
|
||||
placeholder_id = f"mid_pf_{int(time.time() * 1000)}"
|
||||
placeholder_user = UserInfo(user_id="system_trigger", user_nickname="System Trigger", platform=platform)
|
||||
placeholder_msg_info = BaseMessageInfo(
|
||||
message_id=placeholder_id,
|
||||
platform=platform,
|
||||
group_info=group_info,
|
||||
user_info=placeholder_user,
|
||||
time=time.time(),
|
||||
)
|
||||
placeholder_msg_dict = {
|
||||
"message_info": placeholder_msg_info.to_dict(),
|
||||
"processed_plain_text": "[System Trigger Context]",
|
||||
"raw_message": "",
|
||||
"time": placeholder_msg_info.time,
|
||||
}
|
||||
anchor_message = MessageRecv(placeholder_msg_dict)
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
now = time.time()
|
||||
start_time = now - minutes * 60
|
||||
bot_id = global_config.bot.qq_account
|
||||
|
||||
return anchor_message
|
||||
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
|
||||
if chat_id is not None:
|
||||
filter_base["chat_id"] = chat_id
|
||||
|
||||
# 总消息数
|
||||
total_message_count = count_messages(filter_base)
|
||||
# bot自身回复数
|
||||
bot_filter = filter_base.copy()
|
||||
bot_filter["user_id"] = bot_id
|
||||
bot_reply_count = count_messages(bot_filter)
|
||||
|
||||
def parse_thinking_id_to_timestamp(thinking_id: str) -> float:
|
||||
"""
|
||||
将形如 'tid<timestamp>' 的 thinking_id 解析回 float 时间戳
|
||||
例如: 'tid1718251234.56' -> 1718251234.56
|
||||
"""
|
||||
if not thinking_id.startswith("tid"):
|
||||
raise ValueError("thinking_id 格式不正确")
|
||||
ts_str = thinking_id[3:]
|
||||
return float(ts_str)
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str: str) -> list[str]:
|
||||
# 提取JSON内容
|
||||
start = json_str.find("{")
|
||||
end = json_str.rfind("}") + 1
|
||||
if start == -1 or end == 0:
|
||||
logger.error("未找到有效的JSON内容")
|
||||
return []
|
||||
|
||||
json_content = json_str[start:end]
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
json_data = json.loads(json_content)
|
||||
return json_data.get("keywords", [])
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}")
|
||||
return []
|
||||
return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import enum
|
||||
|
||||
|
||||
class ChatState(enum.Enum):
|
||||
ABSENT = "没在看群"
|
||||
NORMAL = "随便水群"
|
||||
FOCUSED = "认真水群"
|
||||
|
||||
|
||||
class ChatStateInfo:
|
||||
def __init__(self):
|
||||
self.chat_status: ChatState = ChatState.NORMAL
|
||||
self.current_state_time = 120
|
||||
@@ -1,7 +1,8 @@
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any, Optional
|
||||
from typing import Dict
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
@@ -16,41 +17,24 @@ class Heartflow:
|
||||
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
|
||||
"""获取或创建一个新的SubHeartflow实例"""
|
||||
if subheartflow_id in self.subheartflows:
|
||||
subflow = self.subheartflows.get(subheartflow_id)
|
||||
if subflow:
|
||||
if subflow := self.subheartflows.get(subheartflow_id):
|
||||
return subflow
|
||||
|
||||
try:
|
||||
new_subflow = SubHeartflow(
|
||||
subheartflow_id,
|
||||
)
|
||||
new_subflow = SubHeartflow(subheartflow_id)
|
||||
|
||||
await new_subflow.initialize()
|
||||
|
||||
# 注册子心流
|
||||
self.subheartflows[subheartflow_id] = new_subflow
|
||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
||||
logger.debug(f"[{heartflow_name}] 开始接收消息")
|
||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
||||
|
||||
return new_subflow
|
||||
except Exception as e:
|
||||
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None:
|
||||
"""强制改变子心流的状态"""
|
||||
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
||||
return await self.force_change_state(subheartflow_id, status)
|
||||
|
||||
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
|
||||
"""强制改变指定子心流的状态"""
|
||||
subflow = self.subheartflows.get(subflow_id)
|
||||
if not subflow:
|
||||
logger.warning(f"[强制状态转换]尝试转换不存在的子心流{subflow_id} 到 {target_state.value}")
|
||||
return False
|
||||
await subflow.change_chat_state(target_state)
|
||||
logger.info(f"[强制状态转换]子心流 {subflow_id} 已转换到 {target_state.value}")
|
||||
return True
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.config.config import global_config
|
||||
import asyncio
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
import re
|
||||
import math
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
@@ -27,16 +29,16 @@ async def _process_relationship(message: MessageRecv) -> None:
|
||||
message: 消息对象,包含用户信息
|
||||
"""
|
||||
platform = message.message_info.platform
|
||||
user_id = message.message_info.user_info.user_id
|
||||
nickname = message.message_info.user_info.user_nickname
|
||||
cardname = message.message_info.user_info.user_cardname or nickname
|
||||
user_id = message.message_info.user_info.user_id # type: ignore
|
||||
nickname = message.message_info.user_info.user_nickname # type: ignore
|
||||
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
||||
|
||||
relationship_manager = get_relationship_manager()
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
@@ -51,13 +53,12 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
@@ -97,32 +98,24 @@ class HeartFCMessageReceiver:
|
||||
"""
|
||||
try:
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
chat = message.chat_stream
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
# 2. 兴趣度计算与更新
|
||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
||||
message.interest_value = interested_rate
|
||||
message.is_mentioned = is_mentioned
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
|
||||
message.update_chat_stream(chat)
|
||||
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
||||
|
||||
# 6. 兴趣度计算与更新
|
||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
||||
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
with open("interested_rates.txt", "a", encoding="utf-8") as f:
|
||||
f.write(f"{interested_rate}\n")
|
||||
|
||||
# 7. 日志记录
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
||||
@@ -131,11 +124,11 @@ class HeartFCMessageReceiver:
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
|
||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
||||
|
||||
# 8. 关系处理
|
||||
# 4. 关系处理
|
||||
if global_config.relationship.enable_relationship:
|
||||
await _process_relationship(message)
|
||||
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
||||
from src.chat.normal_chat.normal_chat import NormalChat
|
||||
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
logger = get_logger("sub_heartflow")
|
||||
|
||||
@@ -31,323 +24,18 @@ class SubHeartflow:
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.chat_id = subheartflow_id
|
||||
|
||||
# 这个聊天流的状态
|
||||
self.chat_state: ChatStateInfo = ChatStateInfo()
|
||||
self.chat_state_changed_time: float = time.time()
|
||||
self.chat_state_last_time: float = 0
|
||||
self.history_chat_state: List[Tuple[ChatState, float]] = []
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
# 兴趣消息集合
|
||||
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
|
||||
|
||||
# focus模式退出冷却时间管理
|
||||
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
|
||||
|
||||
# 随便水群 normal_chat 和 认真水群 focus_chat 实例
|
||||
# CHAT模式激活 随便水群 FOCUS模式激活 认真水群
|
||||
self.heart_fc_instance: Optional[HeartFChatting] = None # 该sub_heartflow的HeartFChatting实例
|
||||
self.normal_chat_instance: Optional[NormalChat] = None # 该sub_heartflow的NormalChat实例
|
||||
self.heart_fc_instance: HeartFChatting = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
) # 该sub_heartflow的HeartFChatting实例
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
||||
|
||||
# 根据配置决定初始状态
|
||||
if not self.is_group_chat:
|
||||
logger.debug(f"{self.log_prefix} 检测到是私聊,将直接尝试进入 FOCUSED 状态。")
|
||||
await self.change_chat_state(ChatState.FOCUSED)
|
||||
elif global_config.chat.chat_mode == "focus":
|
||||
logger.debug(f"{self.log_prefix} 配置为 focus 模式,将直接尝试进入 FOCUSED 状态。")
|
||||
await self.change_chat_state(ChatState.FOCUSED)
|
||||
else: # "auto" 或其他模式保持原有逻辑或默认为 NORMAL
|
||||
logger.debug(f"{self.log_prefix} 配置为 auto 或其他模式,将尝试进入 NORMAL 状态。")
|
||||
await self.change_chat_state(ChatState.NORMAL)
|
||||
|
||||
def update_last_chat_state_time(self):
|
||||
self.chat_state_last_time = time.time() - self.chat_state_changed_time
|
||||
|
||||
async def _stop_normal_chat(self):
|
||||
"""
|
||||
停止 NormalChat 实例
|
||||
切出 CHAT 状态时使用
|
||||
"""
|
||||
if self.normal_chat_instance:
|
||||
logger.info(f"{self.log_prefix} 离开normal模式")
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix} 开始调用 stop_chat()")
|
||||
# 使用更短的超时时间,强制快速停止
|
||||
await asyncio.wait_for(self.normal_chat_instance.stop_chat(), timeout=3.0)
|
||||
logger.debug(f"{self.log_prefix} stop_chat() 调用完成")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self.log_prefix} 停止 NormalChat 超时,强制清理")
|
||||
# 超时时强制清理实例
|
||||
self.normal_chat_instance = None
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 停止 NormalChat 监控任务时出错: {e}")
|
||||
# 出错时也要清理实例,避免状态不一致
|
||||
self.normal_chat_instance = None
|
||||
finally:
|
||||
# 确保实例被清理
|
||||
if self.normal_chat_instance:
|
||||
logger.warning(f"{self.log_prefix} 强制清理 NormalChat 实例")
|
||||
self.normal_chat_instance = None
|
||||
logger.debug(f"{self.log_prefix} _stop_normal_chat 完成")
|
||||
|
||||
async def _start_normal_chat(self, rewind=False) -> bool:
|
||||
"""
|
||||
启动 NormalChat 实例,并进行异步初始化。
|
||||
进入 CHAT 状态时使用。
|
||||
确保 HeartFChatting 已停止。
|
||||
"""
|
||||
await self._stop_heart_fc_chat() # 确保 专注聊天已停止
|
||||
|
||||
self.interest_dict.clear()
|
||||
|
||||
log_prefix = self.log_prefix
|
||||
try:
|
||||
# 获取聊天流并创建 NormalChat 实例 (同步部分)
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
if not chat_stream:
|
||||
logger.error(f"{log_prefix} 无法获取 chat_stream,无法启动 NormalChat。")
|
||||
return False
|
||||
# 在 rewind 为 True 或 NormalChat 实例尚未创建时,创建新实例
|
||||
if rewind or not self.normal_chat_instance:
|
||||
# 提供回调函数,用于接收需要切换到focus模式的通知
|
||||
self.normal_chat_instance = NormalChat(
|
||||
chat_stream=chat_stream,
|
||||
interest_dict=self.interest_dict,
|
||||
on_switch_to_focus_callback=self._handle_switch_to_focus_request,
|
||||
get_cooldown_progress_callback=self.get_cooldown_progress,
|
||||
)
|
||||
|
||||
logger.info(f"{log_prefix} 开始普通聊天,随便水群...")
|
||||
await self.normal_chat_instance.start_chat() # start_chat now ensures init is called again if needed
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix} 启动 NormalChat 或其初始化时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.normal_chat_instance = None # 启动/初始化失败,清理实例
|
||||
return False
|
||||
|
||||
async def _handle_switch_to_focus_request(self) -> bool:
|
||||
"""
|
||||
处理来自NormalChat的切换到focus模式的请求
|
||||
|
||||
Args:
|
||||
stream_id: 请求切换的stream_id
|
||||
Returns:
|
||||
bool: 切换成功返回True,失败返回False
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 收到NormalChat请求切换到focus模式")
|
||||
|
||||
# 检查是否在focus冷却期内
|
||||
if self.is_in_focus_cooldown():
|
||||
logger.info(f"{self.log_prefix} 正在focus冷却期内,忽略切换到focus模式的请求")
|
||||
return False
|
||||
|
||||
# 切换到focus模式
|
||||
current_state = self.chat_state.chat_status
|
||||
if current_state == ChatState.NORMAL:
|
||||
await self.change_chat_state(ChatState.FOCUSED)
|
||||
logger.info(f"{self.log_prefix} 已根据NormalChat请求从NORMAL切换到FOCUSED状态")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 当前状态为{current_state.value},无法切换到FOCUSED状态")
|
||||
return False
|
||||
|
||||
async def _handle_stop_focus_chat_request(self) -> None:
|
||||
"""
|
||||
处理来自HeartFChatting的停止focus模式的请求
|
||||
当收到stop_focus_chat命令时被调用
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 收到HeartFChatting请求停止focus模式")
|
||||
|
||||
# 切换到normal模式
|
||||
current_state = self.chat_state.chat_status
|
||||
if current_state == ChatState.FOCUSED:
|
||||
await self.change_chat_state(ChatState.NORMAL)
|
||||
logger.info(f"{self.log_prefix} 已根据HeartFChatting请求从FOCUSED切换到NORMAL状态")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 当前状态为{current_state.value},无法切换到NORMAL状态")
|
||||
|
||||
async def _stop_heart_fc_chat(self):
|
||||
"""停止并清理 HeartFChatting 实例"""
|
||||
if self.heart_fc_instance:
|
||||
logger.debug(f"{self.log_prefix} 结束专注聊天...")
|
||||
try:
|
||||
await self.heart_fc_instance.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 关闭 HeartFChatting 实例时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# 无论是否成功关闭,都清理引用
|
||||
self.heart_fc_instance = None
|
||||
|
||||
async def _start_heart_fc_chat(self) -> bool:
|
||||
"""启动 HeartFChatting 实例,确保 NormalChat 已停止"""
|
||||
logger.debug(f"{self.log_prefix} 开始启动 HeartFChatting")
|
||||
|
||||
try:
|
||||
# 确保普通聊天监控已停止
|
||||
await self._stop_normal_chat()
|
||||
self.interest_dict.clear()
|
||||
|
||||
log_prefix = self.log_prefix
|
||||
# 如果实例已存在,检查其循环任务状态
|
||||
if self.heart_fc_instance:
|
||||
logger.debug(f"{log_prefix} HeartFChatting 实例已存在,检查状态")
|
||||
# 如果任务已完成或不存在,则尝试重新启动
|
||||
if self.heart_fc_instance._loop_task is None or self.heart_fc_instance._loop_task.done():
|
||||
logger.info(f"{log_prefix} HeartFChatting 实例存在但循环未运行,尝试启动...")
|
||||
try:
|
||||
# 添加超时保护
|
||||
await asyncio.wait_for(self.heart_fc_instance.start(), timeout=15.0)
|
||||
logger.info(f"{log_prefix} HeartFChatting 循环已启动。")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix} 尝试启动现有 HeartFChatting 循环时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 出错时清理实例,准备重新创建
|
||||
self.heart_fc_instance = None
|
||||
else:
|
||||
# 任务正在运行
|
||||
logger.debug(f"{log_prefix} HeartFChatting 已在运行中。")
|
||||
return True # 已经在运行
|
||||
|
||||
# 如果实例不存在,则创建并启动
|
||||
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
|
||||
try:
|
||||
logger.debug(f"{log_prefix} 创建新的 HeartFChatting 实例")
|
||||
self.heart_fc_instance = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
on_stop_focus_chat=self._handle_stop_focus_chat_request,
|
||||
)
|
||||
|
||||
logger.debug(f"{log_prefix} 启动 HeartFChatting 实例")
|
||||
# 添加超时保护
|
||||
await asyncio.wait_for(self.heart_fc_instance.start(), timeout=15.0)
|
||||
logger.debug(f"{log_prefix} 麦麦已成功进入专注聊天模式 (新实例已启动)。")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix} 创建或启动 HeartFChatting 实例时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.heart_fc_instance = None # 创建或初始化异常,清理实例
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} _start_heart_fc_chat 执行时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def change_chat_state(self, new_state: ChatState) -> None:
|
||||
"""
|
||||
改变聊天状态。
|
||||
如果转换到CHAT或FOCUSED状态时超过限制,会保持当前状态。
|
||||
"""
|
||||
current_state = self.chat_state.chat_status
|
||||
state_changed = False
|
||||
log_prefix = f"[{self.log_prefix}]"
|
||||
|
||||
if new_state == ChatState.NORMAL:
|
||||
logger.debug(f"{log_prefix} 准备进入 normal聊天 状态")
|
||||
if await self._start_normal_chat():
|
||||
logger.debug(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
|
||||
state_changed = True
|
||||
else:
|
||||
logger.error(f"{log_prefix} 启动 NormalChat 失败,无法进入 CHAT 状态。")
|
||||
# 启动失败时,保持当前状态
|
||||
return
|
||||
|
||||
elif new_state == ChatState.FOCUSED:
|
||||
logger.debug(f"{log_prefix} 准备进入 focus聊天 状态")
|
||||
if await self._start_heart_fc_chat():
|
||||
logger.debug(f"{log_prefix} 成功进入或保持 HeartFChatting 状态。")
|
||||
state_changed = True
|
||||
else:
|
||||
logger.error(f"{log_prefix} 启动 HeartFChatting 失败,无法进入 FOCUSED 状态。")
|
||||
# 启动失败时,保持当前状态
|
||||
return
|
||||
|
||||
elif new_state == ChatState.ABSENT:
|
||||
logger.info(f"{log_prefix} 进入 ABSENT 状态,停止所有聊天活动...")
|
||||
self.interest_dict.clear()
|
||||
await self._stop_normal_chat()
|
||||
await self._stop_heart_fc_chat()
|
||||
state_changed = True
|
||||
|
||||
# --- 记录focus模式退出时间 ---
|
||||
if state_changed and current_state == ChatState.FOCUSED and new_state != ChatState.FOCUSED:
|
||||
self.last_focus_exit_time = time.time()
|
||||
logger.debug(f"{log_prefix} 记录focus模式退出时间: {self.last_focus_exit_time}")
|
||||
|
||||
# --- 更新状态和最后活动时间 ---
|
||||
if state_changed:
|
||||
self.update_last_chat_state_time()
|
||||
self.history_chat_state.append((current_state, self.chat_state_last_time))
|
||||
|
||||
self.chat_state.chat_status = new_state
|
||||
self.chat_state_last_time = 0
|
||||
self.chat_state_changed_time = time.time()
|
||||
else:
|
||||
logger.debug(
|
||||
f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。"
|
||||
)
|
||||
|
||||
def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
|
||||
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
|
||||
# 如果字典长度超过10,删除最旧的消息
|
||||
if len(self.interest_dict) > 30:
|
||||
oldest_key = next(iter(self.interest_dict))
|
||||
self.interest_dict.pop(oldest_key)
|
||||
|
||||
def is_in_focus_cooldown(self) -> bool:
|
||||
"""检查是否在focus模式的冷却期内
|
||||
|
||||
Returns:
|
||||
bool: 如果在冷却期内返回True,否则返回False
|
||||
"""
|
||||
if self.last_focus_exit_time == 0:
|
||||
return False
|
||||
|
||||
# 基础冷却时间10分钟,受auto_focus_threshold调控
|
||||
base_cooldown = 10 * 60 # 10分钟转换为秒
|
||||
cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
|
||||
|
||||
current_time = time.time()
|
||||
elapsed_since_exit = current_time - self.last_focus_exit_time
|
||||
|
||||
is_cooling = elapsed_since_exit < cooldown_duration
|
||||
|
||||
if is_cooling:
|
||||
remaining_time = cooldown_duration - elapsed_since_exit
|
||||
remaining_minutes = remaining_time / 60
|
||||
logger.debug(
|
||||
f"[{self.log_prefix}] focus冷却中,剩余时间: {remaining_minutes:.1f}分钟 (阈值: {global_config.chat.auto_focus_threshold})"
|
||||
)
|
||||
|
||||
return is_cooling
|
||||
|
||||
def get_cooldown_progress(self) -> float:
|
||||
"""获取冷却进度,返回0-1之间的值
|
||||
|
||||
Returns:
|
||||
float: 0表示刚开始冷却,1表示冷却完成
|
||||
"""
|
||||
if self.last_focus_exit_time == 0:
|
||||
return 1.0 # 没有冷却,返回1表示完全恢复
|
||||
|
||||
# 基础冷却时间10分钟,受auto_focus_threshold调控
|
||||
base_cooldown = 10 * 60 # 10分钟转换为秒
|
||||
cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
|
||||
|
||||
current_time = time.time()
|
||||
elapsed_since_exit = current_time - self.last_focus_exit_time
|
||||
|
||||
if elapsed_since_exit >= cooldown_duration:
|
||||
return 1.0 # 冷却完成
|
||||
|
||||
# 计算进度:0表示刚开始冷却,1表示冷却完成
|
||||
progress = elapsed_since_exit / cooldown_duration
|
||||
return progress
|
||||
await self.heart_fc_instance.start()
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -10,8 +11,8 @@ import pandas as pd
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .llm_client import LLMClient
|
||||
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
|
||||
# from .llm_client import LLMClient
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
@@ -25,14 +26,14 @@ from rich.progress import (
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = (
|
||||
os.path.join(ROOT_PATH, "data", "embedding")
|
||||
if global_config["persistence"]["embedding_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"])
|
||||
)
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
@@ -59,7 +60,7 @@ EMBEDDING_SIM_THRESHOLD = 0.99
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
# 计算余弦相似度
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
@@ -86,21 +87,43 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
|
||||
def __init__(self, namespace: str, dir_path: str):
|
||||
self.namespace = namespace
|
||||
self.llm_client = llm_client
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
|
||||
self.index_file_path = dir_path + "/" + namespace + ".index"
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||
|
||||
self.store = dict()
|
||||
self.store = {}
|
||||
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||
"""获取字符串的嵌入向量,处理异步调用"""
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,直接运行
|
||||
result = asyncio.run(get_embedding(s))
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
@@ -258,7 +281,7 @@ class EmbeddingStore:
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config["embedding"]["dimension"])
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
@@ -285,7 +308,7 @@ class EmbeddingStore:
|
||||
distances = list(distances.flatten())
|
||||
result = [
|
||||
(self.idx2hash[str(int(idx))], float(sim))
|
||||
for (idx, sim) in zip(indices, distances)
|
||||
for (idx, sim) in zip(indices, distances, strict=False)
|
||||
if idx in range(len(self.idx2hash))
|
||||
]
|
||||
|
||||
@@ -293,20 +316,17 @@ class EmbeddingStore:
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
def __init__(self, llm_client: LLMClient):
|
||||
def __init__(self):
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
PG_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
ENT_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
REL_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
@@ -1,31 +1,86 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||
from .llm_client import LLMClient
|
||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
def _extract_json_from_text(text: str):
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
if text is None:
|
||||
logger.error("输入文本为None")
|
||||
return []
|
||||
|
||||
try:
|
||||
fixed_json = repair_json(text)
|
||||
if isinstance(fixed_json, str):
|
||||
parsed_json = json.loads(fixed_json)
|
||||
else:
|
||||
parsed_json = fixed_json
|
||||
|
||||
# 如果是列表,直接返回
|
||||
if isinstance(parsed_json, list):
|
||||
return parsed_json
|
||||
|
||||
# 如果是字典且只有一个项目,可能包装了列表
|
||||
if isinstance(parsed_json, dict):
|
||||
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||
if len(parsed_json) == 1:
|
||||
value = list(parsed_json.values())[0]
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return parsed_json
|
||||
|
||||
# 其他情况,尝试转换为列表
|
||||
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
||||
return []
|
||||
|
||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
except Exception as e:
|
||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||
return []
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
_, request_result = llm_client.send_chat_request(
|
||||
global_config["entity_extract"]["llm"]["model"], entity_extract_context
|
||||
)
|
||||
|
||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
||||
if "[" in request_result:
|
||||
request_result = request_result[request_result.index("[") :]
|
||||
|
||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
||||
if "]" in request_result:
|
||||
request_result = request_result[: request_result.rindex("]") + 1]
|
||||
|
||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
llm_req.generate_response_async(entity_extract_context), loop
|
||||
)
|
||||
response, (reasoning_content, model_name) = future.result()
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, (reasoning_content, model_name) = asyncio.run(
|
||||
llm_req.generate_response_async(entity_extract_context)
|
||||
)
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"LLM返回的原始响应: {response}")
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
|
||||
# 检查返回的是否为有效的实体列表
|
||||
if not isinstance(entity_extract_result, list):
|
||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||
if isinstance(entity_extract_result, dict):
|
||||
# 尝试常见的键名
|
||||
for key in ['entities', 'result', 'data', 'items']:
|
||||
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||
entity_extract_result = entity_extract_result[key]
|
||||
break
|
||||
else:
|
||||
# 如果找不到合适的列表,抛出异常
|
||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||
else:
|
||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||
|
||||
# 过滤无效实体
|
||||
entity_extract_result = [
|
||||
entity
|
||||
for entity in entity_extract_result
|
||||
@@ -38,32 +93,56 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
llm_req.generate_response_async(rdf_extract_context), loop
|
||||
)
|
||||
response, (reasoning_content, model_name) = future.result()
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, (reasoning_content, model_name) = asyncio.run(
|
||||
llm_req.generate_response_async(rdf_extract_context)
|
||||
)
|
||||
|
||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
||||
if "[" in request_result:
|
||||
request_result = request_result[request_result.index("[") :]
|
||||
|
||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
||||
if "]" in request_result:
|
||||
request_result = request_result[: request_result.rindex("]") + 1]
|
||||
|
||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||
|
||||
for triple in entity_extract_result:
|
||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||
# 添加调试日志
|
||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||
|
||||
rdf_triple_result = _extract_json_from_text(response)
|
||||
|
||||
# 检查返回的是否为有效的三元组列表
|
||||
if not isinstance(rdf_triple_result, list):
|
||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||
if isinstance(rdf_triple_result, dict):
|
||||
# 尝试常见的键名
|
||||
for key in ['triples', 'result', 'data', 'items']:
|
||||
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||
rdf_triple_result = rdf_triple_result[key]
|
||||
break
|
||||
else:
|
||||
# 如果找不到合适的列表,抛出异常
|
||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||
else:
|
||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||
|
||||
# 验证三元组格式
|
||||
for triple in rdf_triple_result:
|
||||
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||
raise Exception("RDF提取结果格式错误")
|
||||
|
||||
return entity_extract_result
|
||||
return rdf_triple_result
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
try_count = 0
|
||||
while True:
|
||||
|
||||
@@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .lpmmconfig import (
|
||||
ENT_NAMESPACE,
|
||||
PG_NAMESPACE,
|
||||
RAG_ENT_CNT_NAMESPACE,
|
||||
RAG_GRAPH_NAMESPACE,
|
||||
RAG_PG_HASH_NAMESPACE,
|
||||
global_config,
|
||||
)
|
||||
from .lpmmconfig import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
|
||||
)
|
||||
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
|
||||
|
||||
def _get_kg_dir():
|
||||
"""
|
||||
安全地获取KG数据目录路径
|
||||
"""
|
||||
root_path = local_storage['root_path']
|
||||
if root_path is None:
|
||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||
|
||||
# 获取RAG数据目录
|
||||
rag_data_dir = global_config["persistence"]["rag_data_dir"]
|
||||
if rag_data_dir is None:
|
||||
kg_dir = os.path.join(root_path, "data/rag")
|
||||
else:
|
||||
kg_dir = os.path.join(root_path, rag_data_dir)
|
||||
|
||||
return str(kg_dir).replace("\\", "/")
|
||||
|
||||
|
||||
# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage
|
||||
def get_kg_dir_str():
|
||||
"""获取KG目录字符串"""
|
||||
return _get_kg_dir()
|
||||
|
||||
|
||||
class KGManager:
|
||||
@@ -46,15 +59,15 @@ class KGManager:
|
||||
# 存储段落的hash值,用于去重
|
||||
self.stored_paragraph_hashes = set()
|
||||
# 实体出现次数
|
||||
self.ent_appear_cnt = dict()
|
||||
self.ent_appear_cnt = {}
|
||||
# KG
|
||||
self.graph = di_graph.DiGraph()
|
||||
|
||||
# 持久化相关
|
||||
self.dir_path = KG_DIR_STR
|
||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
||||
# 持久化相关 - 使用延迟初始化的路径
|
||||
self.dir_path = get_kg_dir_str()
|
||||
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
|
||||
|
||||
def save_to_file(self):
|
||||
"""将KG数据保存到文件"""
|
||||
@@ -109,8 +122,8 @@ class KGManager:
|
||||
# 避免自连接
|
||||
continue
|
||||
# 一个triple就是一条边(同时构建双向联系)
|
||||
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
|
||||
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
|
||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||
entity_set.add(hash_key1)
|
||||
@@ -128,8 +141,8 @@ class KGManager:
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
for triple in triple_list_data[idx]:
|
||||
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
|
||||
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
|
||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||
|
||||
@staticmethod
|
||||
@@ -144,8 +157,8 @@ class KGManager:
|
||||
ent_hash_list = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
for triple in triple_list:
|
||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
|
||||
synonym_hash_set = set()
|
||||
@@ -171,10 +184,10 @@ class KGManager:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
if ent is None:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
# 查询相似实体
|
||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||
@@ -250,18 +263,24 @@ class KGManager:
|
||||
for src_tgt in node_to_node.keys():
|
||||
for node_hash in src_tgt:
|
||||
if node_hash not in existed_nodes:
|
||||
if node_hash.startswith(ENT_NAMESPACE):
|
||||
if node_hash.startswith(local_storage['ent_namespace']):
|
||||
# 新增实体节点
|
||||
node = embedding_manager.entities_embedding_store.store[node_hash]
|
||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||
continue
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
node_item = self.graph[node_hash]
|
||||
node_item["content"] = node.str
|
||||
node_item["type"] = "ent"
|
||||
node_item["create_time"] = now_time
|
||||
self.graph.update_node(node_item)
|
||||
elif node_hash.startswith(PG_NAMESPACE):
|
||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||
# 新增文段节点
|
||||
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||
continue
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
content = node.str.replace("\n", " ")
|
||||
node_item = self.graph[node_hash]
|
||||
@@ -340,7 +359,7 @@ class KGManager:
|
||||
# 关系三元组
|
||||
triple = relation[2:-2].split("', '")
|
||||
for ent in [(triple[0]), (triple[2])]:
|
||||
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
|
||||
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
|
||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||
ent_sim_scores[ent_hash] = []
|
||||
@@ -418,7 +437,7 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
||||
@@ -6,14 +6,84 @@ from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config as bot_global_config
|
||||
# try:
|
||||
# import quick_algo
|
||||
# except ImportError:
|
||||
# print("quick_algo not found, please install it first")
|
||||
from src.manager.local_store_manager import local_storage
|
||||
import os
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"她们",
|
||||
"它们",
|
||||
]
|
||||
PG_NAMESPACE = "paragraph"
|
||||
ENT_NAMESPACE = "entity"
|
||||
REL_NAMESPACE = "relation"
|
||||
|
||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
|
||||
def _initialize_knowledge_local_storage():
|
||||
"""
|
||||
初始化知识库相关的本地存储配置
|
||||
使用字典批量设置,避免重复的if判断
|
||||
"""
|
||||
# 定义所有需要初始化的配置项
|
||||
default_configs = {
|
||||
# 路径配置
|
||||
'root_path': ROOT_PATH,
|
||||
'data_path': f"{ROOT_PATH}/data",
|
||||
|
||||
# 实体和命名空间配置
|
||||
'lpmm_invalid_entity': INVALID_ENTITY,
|
||||
'pg_namespace': PG_NAMESPACE,
|
||||
'ent_namespace': ENT_NAMESPACE,
|
||||
'rel_namespace': REL_NAMESPACE,
|
||||
|
||||
# RAG相关命名空间配置
|
||||
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
|
||||
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
|
||||
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
|
||||
}
|
||||
|
||||
# 日志级别映射:重要配置用info,其他用debug
|
||||
important_configs = {'root_path', 'data_path'}
|
||||
|
||||
# 批量设置配置项
|
||||
initialized_count = 0
|
||||
for key, default_value in default_configs.items():
|
||||
if local_storage[key] is None:
|
||||
local_storage[key] = default_value
|
||||
|
||||
# 根据重要性选择日志级别
|
||||
if key in important_configs:
|
||||
logger.info(f"设置{key}: {default_value}")
|
||||
else:
|
||||
logger.debug(f"设置{key}: {default_value}")
|
||||
|
||||
initialized_count += 1
|
||||
|
||||
if initialized_count > 0:
|
||||
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
||||
else:
|
||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
||||
|
||||
# 初始化本地存储路径
|
||||
_initialize_knowledge_local_storage()
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
if bot_global_config.lpmm_knowledge.enable:
|
||||
logger.info("正在初始化Mai-LPMM\n")
|
||||
logger.info("正在初始化Mai-LPMM")
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = dict()
|
||||
for key in global_config["llm_providers"]:
|
||||
@@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
)
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
@@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
qa_manager = QAManager(
|
||||
embed_manager,
|
||||
kg_manager,
|
||||
llm_client_list[global_config["embedding"]["provider"]],
|
||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
||||
)
|
||||
|
||||
# 记忆激活(用于记忆库)
|
||||
|
||||
@@ -4,9 +4,8 @@ import glob
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||
# from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
@@ -107,7 +106,7 @@ class OpenIE:
|
||||
@staticmethod
|
||||
def load() -> "OpenIE":
|
||||
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
||||
openie_dir = os.path.join(DATA_PATH, "openie")
|
||||
if not os.path.exists(openie_dir):
|
||||
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
@@ -122,12 +121,6 @@ class OpenIE:
|
||||
openie_data = OpenIE._from_dict(data_list)
|
||||
return openie_data
|
||||
|
||||
@staticmethod
|
||||
def save(openie_data: "OpenIE"):
|
||||
"""保存OpenIE数据到文件"""
|
||||
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
||||
|
||||
def extract_entity_dict(self):
|
||||
"""提取实体列表"""
|
||||
ner_output_dict = dict(
|
||||
|
||||
@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
|
||||
"""
|
||||
|
||||
|
||||
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
||||
]
|
||||
return messages
|
||||
def build_entity_extract_context(paragraph: str) -> str:
|
||||
"""构建实体提取的完整提示文本"""
|
||||
return f"""{entity_extract_system_prompt}
|
||||
|
||||
段落:
|
||||
```
|
||||
{paragraph}
|
||||
```"""
|
||||
|
||||
|
||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
||||
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
||||
"""
|
||||
|
||||
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
||||
]
|
||||
return messages
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
|
||||
"""构建RDF三元组提取的完整提示文本"""
|
||||
return f"""{rdf_triple_extract_system_prompt}
|
||||
|
||||
段落:
|
||||
```
|
||||
{paragraph}
|
||||
```
|
||||
|
||||
实体列表:
|
||||
```
|
||||
{entities}
|
||||
```"""
|
||||
|
||||
|
||||
qa_system_prompt = """
|
||||
|
||||
@@ -5,11 +5,13 @@ from .global_logger import logger
|
||||
|
||||
# from . import prompt_template
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
# from .llm_client import LLMClient
|
||||
from .kg_manager import KGManager
|
||||
from .lpmmconfig import global_config
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -19,26 +21,25 @@ class QAManager:
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
llm_client_filter: LLMClient,
|
||||
llm_client_qa: LLMClient,
|
||||
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.llm_client_list = {
|
||||
"embedding": llm_client_embedding,
|
||||
"message_filter": llm_client_filter,
|
||||
"qa": llm_client_qa,
|
||||
}
|
||||
# TODO: API-Adapter修改标记
|
||||
self.qa_model = LLMRequest(
|
||||
model=global_config.model.lpmm_qa,
|
||||
request_type="lpmm.qa"
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
||||
global_config["embedding"]["model"], question
|
||||
)
|
||||
question_embedding = await get_embedding(question)
|
||||
if question_embedding is None:
|
||||
logger.error("生成问题Embedding失败")
|
||||
return None
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
@@ -46,14 +47,15 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["relation_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||
)
|
||||
if relation_search_res is not None:
|
||||
# 过滤阈值
|
||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||
# 未找到相关关系
|
||||
logger.debug("未找到相关关系,跳过关系检索")
|
||||
relation_search_res = []
|
||||
|
||||
part_end_time = time.perf_counter()
|
||||
@@ -71,7 +73,7 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
@@ -101,10 +103,10 @@ class QAManager:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_knowledge(self, question: str) -> str:
|
||||
async def get_knowledge(self, question: str) -> str:
|
||||
"""获取知识"""
|
||||
# 处理查询
|
||||
processed_result = self.process_query(question)
|
||||
processed_result = await self.process_query(question)
|
||||
if processed_result is not None:
|
||||
query_res = processed_result[0]
|
||||
knowledge = [
|
||||
|
||||
@@ -42,7 +42,7 @@ def calculate_information_content(text):
|
||||
return entropy
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
@@ -89,14 +89,13 @@ class MemoryGraph:
|
||||
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
||||
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
||||
self.G.nodes[concept]["memory_items"].append(memory)
|
||||
# 更新最后修改时间
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
else:
|
||||
self.G.nodes[concept]["memory_items"] = [memory]
|
||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||
if "created_time" not in self.G.nodes[concept]:
|
||||
self.G.nodes[concept]["created_time"] = current_time
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
# 更新最后修改时间
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(
|
||||
@@ -108,11 +107,7 @@ class MemoryGraph:
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
if concept in self.G:
|
||||
# 从图中获取节点数据
|
||||
node_data = self.G.nodes[concept]
|
||||
return concept, node_data
|
||||
return None
|
||||
return (concept, self.G.nodes[concept]) if concept in self.G else None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
@@ -139,8 +134,7 @@ class MemoryGraph:
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
if node_data := self.get_dot(neighbor):
|
||||
concept, data = node_data
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
@@ -194,9 +188,9 @@ class MemoryGraph:
|
||||
class Hippocampus:
|
||||
def __init__(self):
|
||||
self.memory_graph = MemoryGraph()
|
||||
self.model_summary = None
|
||||
self.entorhinal_cortex = None
|
||||
self.parahippocampal_gyrus = None
|
||||
self.model_summary: LLMRequest = None # type: ignore
|
||||
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
|
||||
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
|
||||
|
||||
def initialize(self):
|
||||
# 初始化子组件
|
||||
@@ -205,7 +199,7 @@ class Hippocampus:
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
# TODO: API-Adapter修改标记
|
||||
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory")
|
||||
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
@@ -218,7 +212,7 @@ class Hippocampus:
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 使用集合来去重,避免排序
|
||||
unique_items = set(str(item) for item in memory_items)
|
||||
unique_items = {str(item) for item in memory_items}
|
||||
# 使用frozenset来保证顺序一致性
|
||||
content = f"{concept}:{frozenset(unique_items)}"
|
||||
return hash(content)
|
||||
@@ -231,6 +225,7 @@ class Hippocampus:
|
||||
|
||||
@staticmethod
|
||||
def find_topic_llm(text, topic_num):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
prompt = (
|
||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
@@ -240,6 +235,7 @@ class Hippocampus:
|
||||
|
||||
@staticmethod
|
||||
def topic_what(text, topic):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
# 不再需要 time_info 参数
|
||||
prompt = (
|
||||
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||
@@ -480,9 +476,7 @@ class Hippocampus:
|
||||
top_memories = memory_similarities[:max_memory_length]
|
||||
|
||||
# 添加到结果中
|
||||
for memory, similarity in top_memories:
|
||||
all_memories.append((node, [memory], similarity))
|
||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
||||
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||
else:
|
||||
logger.info("节点没有记忆")
|
||||
|
||||
@@ -646,9 +640,7 @@ class Hippocampus:
|
||||
top_memories = memory_similarities[:max_memory_length]
|
||||
|
||||
# 添加到结果中
|
||||
for memory, similarity in top_memories:
|
||||
all_memories.append((node, [memory], similarity))
|
||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
||||
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||
else:
|
||||
logger.info("节点没有记忆")
|
||||
|
||||
@@ -819,15 +811,15 @@ class EntorhinalCortex:
|
||||
timestamps = sample_scheduler.get_timestamp_array()
|
||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||
for _, readable_timestamp in zip(timestamps, readable_timestamps):
|
||||
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||
chat_samples = []
|
||||
for timestamp in timestamps:
|
||||
# 调用修改后的 random_get_msg_snippet
|
||||
messages = self.random_get_msg_snippet(
|
||||
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
|
||||
)
|
||||
if messages:
|
||||
if messages := self.random_get_msg_snippet(
|
||||
timestamp,
|
||||
global_config.memory.memory_build_sample_length,
|
||||
max_memorized_time_per_msg,
|
||||
):
|
||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
@@ -838,31 +830,30 @@ class EntorhinalCortex:
|
||||
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||
try_count = 0
|
||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||
|
||||
while try_count < 3:
|
||||
for _ in range(3):
|
||||
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
||||
timestamp_start = target_timestamp
|
||||
timestamp_end = target_timestamp + time_window_seconds
|
||||
|
||||
chosen_message = get_raw_msg_by_timestamp(
|
||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
||||
)
|
||||
if chosen_message := get_raw_msg_by_timestamp(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=1,
|
||||
limit_mode="earliest",
|
||||
):
|
||||
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
|
||||
|
||||
if chosen_message:
|
||||
chat_id = chosen_message[0].get("chat_id")
|
||||
|
||||
messages = get_raw_msg_by_timestamp_with_chat(
|
||||
if messages := get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=chat_size,
|
||||
limit_mode="earliest",
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
if messages:
|
||||
):
|
||||
# 检查获取到的所有消息是否都未达到最大记忆次数
|
||||
all_valid = True
|
||||
for message in messages:
|
||||
@@ -882,8 +873,6 @@ class EntorhinalCortex:
|
||||
).execute()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
# 如果获取失败或消息无效,增加尝试次数
|
||||
try_count += 1
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
|
||||
# 三次尝试都失败,返回 None
|
||||
@@ -975,7 +964,7 @@ class EntorhinalCortex:
|
||||
).execute()
|
||||
|
||||
if nodes_to_delete:
|
||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute()
|
||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(GraphEdges.select())
|
||||
@@ -1075,19 +1064,17 @@ class EntorhinalCortex:
|
||||
|
||||
try:
|
||||
memory_items = [str(item) for item in memory_items]
|
||||
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
||||
if not memory_items_json:
|
||||
continue
|
||||
if memory_items_json := json.dumps(memory_items, ensure_ascii=False):
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1114,7 +1101,7 @@ class EntorhinalCortex:
|
||||
node_start = time.time()
|
||||
if nodes_data:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphNodes._meta.database.atomic():
|
||||
with GraphNodes._meta.database.atomic(): # type: ignore
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
@@ -1125,7 +1112,7 @@ class EntorhinalCortex:
|
||||
edge_start = time.time()
|
||||
if edges_data:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphEdges._meta.database.atomic():
|
||||
with GraphEdges._meta.database.atomic(): # type: ignore
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
@@ -1279,7 +1266,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 3. 过滤掉包含禁用关键词的topic
|
||||
filtered_topics = [
|
||||
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
|
||||
topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words)
|
||||
]
|
||||
|
||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||
@@ -1489,32 +1476,30 @@ class ParahippocampalGyrus:
|
||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
# 条件1:检查是否长时间未修改 (超过24小时)
|
||||
if current_time - last_modified > 3600 * 24:
|
||||
# 条件2:再次确认节点包含记忆项(理论上已确认,但作为保险)
|
||||
if memory_items:
|
||||
current_count = len(memory_items)
|
||||
# 如果列表非空,才进行随机选择
|
||||
if current_count > 0:
|
||||
removed_item = random.choice(memory_items)
|
||||
try:
|
||||
memory_items.remove(removed_item)
|
||||
if current_time - last_modified > 3600 * 24 and memory_items:
|
||||
current_count = len(memory_items)
|
||||
# 如果列表非空,才进行随机选择
|
||||
if current_count > 0:
|
||||
removed_item = random.choice(memory_items)
|
||||
try:
|
||||
memory_items.remove(removed_item)
|
||||
|
||||
# 条件3:检查移除后 memory_items 是否变空
|
||||
if memory_items: # 如果移除后列表不为空
|
||||
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
|
||||
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
|
||||
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
|
||||
else: # 如果移除后列表为空
|
||||
# 尝试移除节点,处理可能的错误
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
|
||||
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
|
||||
except ValueError:
|
||||
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
|
||||
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
|
||||
# 条件3:检查移除后 memory_items 是否变空
|
||||
if memory_items: # 如果移除后列表不为空
|
||||
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
|
||||
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
|
||||
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
|
||||
else: # 如果移除后列表为空
|
||||
# 尝试移除节点,处理可能的错误
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
|
||||
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
|
||||
except ValueError:
|
||||
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
|
||||
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
|
||||
node_check_end = time.time()
|
||||
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒")
|
||||
|
||||
@@ -1669,7 +1654,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
class HippocampusManager:
|
||||
def __init__(self):
|
||||
self._hippocampus = None
|
||||
self._hippocampus: Hippocampus = None # type: ignore
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
@@ -1686,7 +1671,8 @@ class HippocampusManager:
|
||||
node_count = len(memory_graph.nodes())
|
||||
edge_count = len(memory_graph.edges())
|
||||
|
||||
logger.info(f"""--------------------------------
|
||||
logger.info(f"""
|
||||
--------------------------------
|
||||
记忆系统参数配置:
|
||||
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
|
||||
记忆构建分布: {global_config.memory.memory_build_distribution}
|
||||
|
||||
256
src/chat/memory_system/instant_memory.py
Normal file
256
src/chat/memory_system/instant_memory.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
import ast
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
self.memory_id = memory_id
|
||||
self.chat_id = chat_id
|
||||
self.memory_text: str = memory_text
|
||||
self.keywords: list[str] = keywords
|
||||
self.create_time: float = time.time()
|
||||
self.last_view_time: float = time.time()
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self):
|
||||
# self.memory_items:list[MemoryItem] = []
|
||||
pass
|
||||
|
||||
|
||||
class InstantMemory:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.5,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
async def if_need_build(self, text):
|
||||
prompt = f"""
|
||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||
{text}
|
||||
请只输出1或0就好
|
||||
"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
if "1" in response:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def build_memory(self, text):
|
||||
prompt = f"""
|
||||
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
||||
{text}
|
||||
请以json格式输出一段概括的记忆内容和关键词
|
||||
{{
|
||||
"memory_text": "记忆内容",
|
||||
"keywords": "关键词,用/划分"
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
memory_text = result.get("memory_text", "")
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
return {"memory_text": memory_text, "keywords": keywords_list}
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def create_and_store_memory(self, text):
|
||||
if_need = await self.if_need_build(text)
|
||||
if if_need:
|
||||
logger.info(f"需要记忆:{text}")
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get("memory_text"):
|
||||
memory_id = f"{self.chat_id}_{time.time()}"
|
||||
memory_item = MemoryItem(
|
||||
memory_id=memory_id,
|
||||
chat_id=self.chat_id,
|
||||
memory_text=memory["memory_text"],
|
||||
keywords=memory.get("keywords", []),
|
||||
)
|
||||
await self.store_memory(memory_item)
|
||||
else:
|
||||
logger.info(f"不需要记忆:{text}")
|
||||
|
||||
async def store_memory(self, memory_item: MemoryItem):
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=memory_item.keywords,
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory.save()
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
|
||||
prompt = f"""
|
||||
请根据以下发言内容,判断是否需要提取记忆
|
||||
{target}
|
||||
请用json格式输出,包含以下字段:
|
||||
其中,time的要求是:
|
||||
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
||||
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
||||
可以选择留空进行模糊搜索
|
||||
{{
|
||||
"need_memory": 1,
|
||||
"keywords": "希望获取的记忆关键词,用/划分",
|
||||
"time": "希望获取的记忆大致时间"
|
||||
}}
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
# 解析keywords
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
# 解析time为时间段
|
||||
time_str = result.get("time", "").strip()
|
||||
start_time, end_time = self._parse_time_range(time_str)
|
||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||
# 检索包含关键词的记忆
|
||||
memories_set = set()
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
query = Memory.select().where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts) # type: ignore
|
||||
& (Memory.create_time < end_ts) # type: ignore
|
||||
)
|
||||
else:
|
||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
mem_keywords = mem.keywords or []
|
||||
parsed = ast.literal_eval(mem_keywords)
|
||||
if isinstance(parsed, list):
|
||||
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
|
||||
else:
|
||||
mem_keywords = []
|
||||
# logger.info(f"mem_keywords: {mem_keywords}")
|
||||
# logger.info(f"keywords_list: {keywords_list}")
|
||||
for kw in keywords_list:
|
||||
# logger.info(f"kw: {kw}")
|
||||
# logger.info(f"kw in mem_keywords: {kw in mem_keywords}")
|
||||
if kw in mem_keywords:
|
||||
# logger.info(f"mem.memory_text: {mem.memory_text}")
|
||||
memories_set.add(mem.memory_text)
|
||||
break
|
||||
return list(memories_set)
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _parse_time_range(self, time_str):
|
||||
"""
|
||||
支持解析如下格式:
|
||||
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
||||
- 具体日期:YYYY-MM-DD
|
||||
- 相对时间:今天,昨天,前天,N天前,N个月前
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
time_str = time_str.strip()
|
||||
# 具体日期时间
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||
return dt, dt + timedelta(hours=1)
|
||||
except Exception:
|
||||
pass
|
||||
# 具体日期
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d")
|
||||
return dt, dt + timedelta(days=1)
|
||||
except Exception:
|
||||
pass
|
||||
# 相对时间
|
||||
if time_str == "今天":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if time_str == "昨天":
|
||||
start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
if time_str == "前天":
|
||||
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)天前", time_str)
|
||||
if m:
|
||||
days = int(m.group(1))
|
||||
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)个月前", time_str)
|
||||
if m:
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
# 其他无法解析
|
||||
return 0, now
|
||||
@@ -13,7 +13,7 @@ from json_repair import repair_json
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str):
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
|
||||
fixed_json = repair_json(json_str)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
if isinstance(fixed_json, str):
|
||||
result = json.loads(fixed_json)
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
result = fixed_json
|
||||
|
||||
# 提取关键词
|
||||
keywords = result.get("keywords", [])
|
||||
return keywords
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
return result.get("keywords", [])
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
return []
|
||||
@@ -73,7 +66,7 @@ class MemoryActivator:
|
||||
self.key_words_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
temperature=0.5,
|
||||
request_type="memory_activator",
|
||||
request_type="memory.activator",
|
||||
)
|
||||
|
||||
self.running_memory = []
|
||||
|
||||
@@ -1,52 +1,10 @@
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from datetime import datetime, timedelta
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class DistributionVisualizer:
|
||||
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
|
||||
"""
|
||||
初始化分布可视化器
|
||||
|
||||
参数:
|
||||
mean (float): 期望均值
|
||||
std (float): 标准差
|
||||
skewness (float): 偏度
|
||||
sample_size (int): 样本大小
|
||||
"""
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.skewness = skewness
|
||||
self.sample_size = sample_size
|
||||
self.samples = None
|
||||
|
||||
def generate_samples(self):
|
||||
"""生成具有指定参数的样本"""
|
||||
if self.skewness == 0:
|
||||
# 对于无偏度的情况,直接使用正态分布
|
||||
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
|
||||
else:
|
||||
# 使用 scipy.stats 生成具有偏度的分布
|
||||
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
|
||||
|
||||
def get_weighted_samples(self):
|
||||
"""获取加权后的样本数列"""
|
||||
if self.samples is None:
|
||||
self.generate_samples()
|
||||
# 将样本值乘以样本大小
|
||||
return self.samples * self.sample_size
|
||||
|
||||
def get_statistics(self):
|
||||
"""获取分布的统计信息"""
|
||||
if self.samples is None:
|
||||
self.generate_samples()
|
||||
|
||||
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
|
||||
|
||||
|
||||
class MemoryBuildScheduler:
|
||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||
"""
|
||||
@@ -108,61 +66,61 @@ class MemoryBuildScheduler:
|
||||
return [int(t.timestamp()) for t in timestamps]
|
||||
|
||||
|
||||
def print_time_samples(timestamps, show_distribution=True):
|
||||
"""打印时间样本和分布信息"""
|
||||
print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||
print("-" * 50)
|
||||
# def print_time_samples(timestamps, show_distribution=True):
|
||||
# """打印时间样本和分布信息"""
|
||||
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||
# print("-" * 50)
|
||||
|
||||
now = datetime.now()
|
||||
time_diffs = []
|
||||
# now = datetime.now()
|
||||
# time_diffs = []
|
||||
|
||||
for i, timestamp in enumerate(timestamps, 1):
|
||||
hours_diff = (now - timestamp).total_seconds() / 3600
|
||||
time_diffs.append(hours_diff)
|
||||
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||
# for i, timestamp in enumerate(timestamps, 1):
|
||||
# hours_diff = (now - timestamp).total_seconds() / 3600
|
||||
# time_diffs.append(hours_diff)
|
||||
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||
|
||||
# 打印统计信息
|
||||
print("\n统计信息:")
|
||||
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||
print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||
# # 打印统计信息
|
||||
# print("\n统计信息:")
|
||||
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||
|
||||
if show_distribution:
|
||||
# 计算时间分布的直方图
|
||||
hist, bins = np.histogram(time_diffs, bins=40)
|
||||
print("\n时间分布(每个*代表一个时间点):")
|
||||
for i in range(len(hist)):
|
||||
if hist[i] > 0:
|
||||
print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||
# if show_distribution:
|
||||
# # 计算时间分布的直方图
|
||||
# hist, bins = np.histogram(time_diffs, bins=40)
|
||||
# print("\n时间分布(每个*代表一个时间点):")
|
||||
# for i in range(len(hist)):
|
||||
# if hist[i] > 0:
|
||||
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 创建一个双峰分布的记忆调度器
|
||||
scheduler = MemoryBuildScheduler(
|
||||
n_hours1=12, # 第一个分布均值(12小时前)
|
||||
std_hours1=8, # 第一个分布标准差
|
||||
weight1=0.7, # 第一个分布权重 70%
|
||||
n_hours2=36, # 第二个分布均值(36小时前)
|
||||
std_hours2=24, # 第二个分布标准差
|
||||
weight2=0.3, # 第二个分布权重 30%
|
||||
total_samples=50, # 总共生成50个时间点
|
||||
)
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 创建一个双峰分布的记忆调度器
|
||||
# scheduler = MemoryBuildScheduler(
|
||||
# n_hours1=12, # 第一个分布均值(12小时前)
|
||||
# std_hours1=8, # 第一个分布标准差
|
||||
# weight1=0.7, # 第一个分布权重 70%
|
||||
# n_hours2=36, # 第二个分布均值(36小时前)
|
||||
# std_hours2=24, # 第二个分布标准差
|
||||
# weight2=0.3, # 第二个分布权重 30%
|
||||
# total_samples=50, # 总共生成50个时间点
|
||||
# )
|
||||
|
||||
# 生成时间分布
|
||||
timestamps = scheduler.generate_time_samples()
|
||||
# # 生成时间分布
|
||||
# timestamps = scheduler.generate_time_samples()
|
||||
|
||||
# 打印结果,包含分布可视化
|
||||
print_time_samples(timestamps, show_distribution=True)
|
||||
# # 打印结果,包含分布可视化
|
||||
# print_time_samples(timestamps, show_distribution=True)
|
||||
|
||||
# 打印时间戳数组
|
||||
timestamp_array = scheduler.get_timestamp_array()
|
||||
print("\n时间戳数组(Unix时间戳):")
|
||||
print("[", end="")
|
||||
for i, ts in enumerate(timestamp_array):
|
||||
if i > 0:
|
||||
print(", ", end="")
|
||||
print(ts, end="")
|
||||
print("]")
|
||||
# # 打印时间戳数组
|
||||
# timestamp_array = scheduler.get_timestamp_array()
|
||||
# print("\n时间戳数组(Unix时间戳):")
|
||||
# print("[", end="")
|
||||
# for i, ts in enumerate(timestamp_array):
|
||||
# if i > 0:
|
||||
# print(", ", end="")
|
||||
# print(ts, end="")
|
||||
# print("]")
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.normal_message_sender import message_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"message_manager",
|
||||
"MessageStorage",
|
||||
]
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
import traceback
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.experimental.only_message_process import MessageProcessor
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.experimental.PFC.pfc_manager import PFCManager
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from maim_message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
import re
|
||||
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
@@ -80,9 +80,6 @@ class ChatBot:
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
# 创建初始化PFC管理器的任务,会在_ensure_started时执行
|
||||
self.only_process_chat = MessageProcessor()
|
||||
self.pfc_manager = PFCManager.get_instance()
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
async def _ensure_started(self):
|
||||
@@ -101,6 +98,7 @@ class ChatBot:
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
message.is_command = True
|
||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
||||
|
||||
# 获取插件配置
|
||||
@@ -144,6 +142,29 @@ class ChatBot:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -161,6 +182,10 @@ class ChatBot:
|
||||
try:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
if ENABLE_S4U_CHAT:
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
@@ -184,8 +209,8 @@ class ChatBot:
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform,
|
||||
user_info=user_info,
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
@@ -195,8 +220,10 @@ class ChatBot:
|
||||
await message.process()
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
|
||||
message.raw_message, chat, user_info
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
@@ -211,7 +238,7 @@ class ChatBot:
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name = message.message_info.template_info.template_name
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
@@ -222,11 +249,6 @@ class ChatBot:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
if ENABLE_S4U_CHAT:
|
||||
logger.info("进入S4U流程")
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
return
|
||||
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
|
||||
@@ -3,18 +3,17 @@ import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import ChatStreams # 新增导入
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import ChatStreams # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -28,10 +27,10 @@ class ChatMessageContext:
|
||||
def __init__(self, message: "MessageRecv"):
|
||||
self.message = message
|
||||
|
||||
def get_template_name(self) -> str:
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||
return self.message.message_info.template_info.template_name
|
||||
return self.message.message_info.template_info.template_name # type: ignore
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> "MessageRecv":
|
||||
@@ -39,11 +38,12 @@ class ChatMessageContext:
|
||||
return self.message
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
# sourcery skip: invert-any-all, use-any, use-next
|
||||
"""检查消息类型"""
|
||||
if not self.message.message_info.format_info.accept_format:
|
||||
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
for t in types:
|
||||
if t not in self.message.message_info.format_info.accept_format:
|
||||
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -67,7 +67,7 @@ class ChatStream:
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: dict = None,
|
||||
data: Optional[dict] = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
@@ -76,7 +76,7 @@ class ChatStream:
|
||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -98,7 +98,7 @@ class ChatStream:
|
||||
return cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info,
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
@@ -162,8 +162,8 @@ class ChatManager:
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform,
|
||||
message.message_info.user_info,
|
||||
message.message_info.platform, # type: ignore
|
||||
message.message_info.user_info, # type: ignore
|
||||
message.message_info.group_info,
|
||||
)
|
||||
self.last_messages[stream_id] = message
|
||||
@@ -184,10 +184,7 @@ class ChatManager:
|
||||
|
||||
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
||||
"""获取聊天流ID"""
|
||||
if is_group:
|
||||
components = [platform, str(id)]
|
||||
else:
|
||||
components = [platform, str(id), "private"]
|
||||
components = [platform, id] if is_group else [platform, id, "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, TYPE_CHECKING
|
||||
|
||||
import urllib3
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .chat_stream import ChatStream
|
||||
from ..utils.utils_image import get_image_manager
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from rich.traceback import install
|
||||
from typing import Optional, Any
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
@dataclass
|
||||
class Message(MessageBase):
|
||||
chat_stream: "ChatStream" = None
|
||||
chat_stream: "ChatStream" = None # type: ignore
|
||||
reply: Optional["Message"] = None
|
||||
processed_plain_text: str = ""
|
||||
memorized_times: int = 0
|
||||
@@ -40,7 +38,6 @@ class Message(MessageBase):
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
detailed_plain_text: str = "",
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 使用传入的时间戳或当前时间
|
||||
@@ -55,17 +52,17 @@ class Message(MessageBase):
|
||||
)
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
|
||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
# 文本处理相关属性
|
||||
self.processed_plain_text = processed_plain_text
|
||||
self.detailed_plain_text = detailed_plain_text
|
||||
|
||||
# 回复消息
|
||||
self.reply = reply
|
||||
|
||||
async def _process_message_segments(self, segment: Seg) -> str:
|
||||
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
@@ -78,13 +75,13 @@ class Message(MessageBase):
|
||||
# 处理消息段列表
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
processed = await self._process_message_segments(seg)
|
||||
processed = await self._process_message_segments(seg) # type: ignore
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment)
|
||||
return await self._process_single_segment(segment) # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
async def _process_single_segment(self, segment):
|
||||
@@ -105,14 +102,17 @@ class MessageRecv(Message):
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.detailed_plain_text = message_dict.get("detailed_plain_text", "")
|
||||
self.is_emoji = False
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_mentioned = None
|
||||
|
||||
self.is_command = False
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = None # type: ignore
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
@@ -123,7 +123,6 @@ class MessageRecv(Message):
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
self.detailed_plain_text = self._generate_detailed_text()
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
@@ -138,7 +137,7 @@ class MessageRecv(Message):
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
@@ -160,7 +159,7 @@ class MessageRecv(Message):
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data)
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
@@ -182,12 +181,112 @@ class MessageRecv(Message):
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.voice_done = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -234,7 +333,7 @@ class MessageProcessBase(Message):
|
||||
"""
|
||||
try:
|
||||
if seg.type == "text":
|
||||
return seg.data
|
||||
return seg.data # type: ignore
|
||||
elif seg.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
@@ -250,7 +349,7 @@ class MessageProcessBase(Message):
|
||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||
# print(f"reply: {self.reply}")
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
@@ -264,7 +363,7 @@ class MessageProcessBase(Message):
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||
|
||||
|
||||
@@ -313,7 +412,7 @@ class MessageSending(MessageProcessBase):
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: str = None,
|
||||
reply_to: str = None, # type: ignore
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
@@ -337,6 +436,8 @@ class MessageSending(MessageProcessBase):
|
||||
# 用于显示发送内容与显示不一致的情况
|
||||
self.display_message = display_message
|
||||
|
||||
self.interest_value = 0.0
|
||||
|
||||
def build_reply(self):
|
||||
"""设置回复消息"""
|
||||
if self.reply:
|
||||
@@ -344,7 +445,7 @@ class MessageSending(MessageProcessBase):
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id),
|
||||
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
@@ -364,10 +465,10 @@ class MessageSending(MessageProcessBase):
|
||||
) -> "MessageSending":
|
||||
"""从思考状态消息创建发送状态消息"""
|
||||
return cls(
|
||||
message_id=thinking.message_info.message_id,
|
||||
message_id=thinking.message_info.message_id, # type: ignore
|
||||
chat_stream=thinking.chat_stream,
|
||||
message_segment=message_segment,
|
||||
bot_user_info=thinking.message_info.user_info,
|
||||
bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||
reply=thinking.reply,
|
||||
is_head=is_head,
|
||||
is_emoji=is_emoji,
|
||||
@@ -399,13 +500,11 @@ class MessageSet:
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
"""通过索引获取消息"""
|
||||
if 0 <= index < len(self.messages):
|
||||
return self.messages[index]
|
||||
return None
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
"""获取最接近指定时间的消息"""
|
||||
@@ -415,7 +514,7 @@ class MessageSet:
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].message_info.time < target_time:
|
||||
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
@@ -438,3 +537,51 @@ class MessageSet:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||
return MessageRecv(message_dict)
|
||||
|
||||
|
||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"""从数据库字典创建MessageRecv实例"""
|
||||
# 转换扁平的数据库字典为嵌套结构
|
||||
message_info_dict = {
|
||||
"platform": db_dict.get("chat_info_platform"),
|
||||
"message_id": db_dict.get("message_id"),
|
||||
"time": db_dict.get("time"),
|
||||
"group_info": {
|
||||
"platform": db_dict.get("chat_info_group_platform"),
|
||||
"group_id": db_dict.get("chat_info_group_id"),
|
||||
"group_name": db_dict.get("chat_info_group_name"),
|
||||
},
|
||||
"user_info": {
|
||||
"platform": db_dict.get("user_platform"),
|
||||
"user_id": db_dict.get("user_id"),
|
||||
"user_nickname": db_dict.get("user_nickname"),
|
||||
"user_cardname": db_dict.get("user_cardname"),
|
||||
},
|
||||
}
|
||||
|
||||
processed_text = db_dict.get("processed_plain_text", "")
|
||||
|
||||
# 构建 MessageRecv 需要的字典
|
||||
recv_dict = {
|
||||
"message_info": message_info_dict,
|
||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
||||
"raw_message": None, # 数据库中未存储原始消息
|
||||
"processed_plain_text": processed_text,
|
||||
}
|
||||
|
||||
# 创建 MessageRecv 实例
|
||||
msg = MessageRecv(recv_dict)
|
||||
|
||||
# 从数据库字典中填充其他可选字段
|
||||
msg.interest_value = db_dict.get("interest_value", 0.0)
|
||||
msg.is_mentioned = db_dict.get("is_mentioned")
|
||||
msg.priority_mode = db_dict.get("priority_mode", "interest")
|
||||
msg.priority_info = db_dict.get("priority_info")
|
||||
msg.is_emoji = db_dict.get("is_emoji", False)
|
||||
msg.is_picid = db_dict.get("is_picid", False)
|
||||
|
||||
return msg
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
# src/plugins/chat/message_sender.py
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio import Task
|
||||
from typing import Union
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
|
||||
from .message import MessageSending, MessageThinking, MessageSet
|
||||
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_via_ws(message: MessageSending) -> None:
|
||||
"""通过 WebSocket 发送消息"""
|
||||
try:
|
||||
await get_global_api().send_message(message)
|
||||
except Exception as e:
|
||||
logger.error(f"WS发送失败: {e}")
|
||||
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
|
||||
|
||||
async def send_message(
|
||||
message: MessageSending,
|
||||
) -> None:
|
||||
"""发送消息(核心发送逻辑)"""
|
||||
|
||||
# --- 添加计算打字和延迟的逻辑 (从 heartflow_message_sender 移动并调整) ---
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
# logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束") # 减少日志
|
||||
await asyncio.sleep(typing_time)
|
||||
# logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||
# --- 结束打字延迟 ---
|
||||
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
|
||||
try:
|
||||
await send_via_ws(message)
|
||||
logger.info(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
|
||||
|
||||
|
||||
class MessageSender:
|
||||
"""发送器 (不再是单例)"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
|
||||
self.last_send_time = 0
|
||||
self._current_bot = None
|
||||
|
||||
def set_bot(self, bot):
|
||||
"""设置当前bot实例"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageContainer:
|
||||
"""单个聊天流的发送/思考消息容器"""
|
||||
|
||||
def __init__(self, chat_id: str, max_size: int = 100):
|
||||
self.chat_id = chat_id
|
||||
self.max_size = max_size
|
||||
self.messages: list[MessageThinking | MessageSending] = [] # 明确类型
|
||||
self.last_send_time = 0
|
||||
self.thinking_wait_timeout = 20 # 思考等待超时时间(秒) - 从旧 sender 合并
|
||||
|
||||
def count_thinking_messages(self) -> int:
|
||||
"""计算当前容器中思考消息的数量"""
|
||||
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
|
||||
|
||||
def get_timeout_sending_messages(self) -> list[MessageSending]:
|
||||
"""获取所有超时的MessageSending对象(思考时间超过20秒),按thinking_start_time排序 - 从旧 sender 合并"""
|
||||
current_time = time.time()
|
||||
timeout_messages = []
|
||||
|
||||
for msg in self.messages:
|
||||
# 只检查 MessageSending 类型
|
||||
if isinstance(msg, MessageSending):
|
||||
# 确保 thinking_start_time 有效
|
||||
if msg.thinking_start_time and current_time - msg.thinking_start_time > self.thinking_wait_timeout:
|
||||
timeout_messages.append(msg)
|
||||
|
||||
# 按thinking_start_time排序,时间早的在前面
|
||||
timeout_messages.sort(key=lambda x: x.thinking_start_time)
|
||||
return timeout_messages
|
||||
|
||||
def get_earliest_message(self):
|
||||
"""获取thinking_start_time最早的消息对象"""
|
||||
if not self.messages:
|
||||
return None
|
||||
earliest_time = float("inf")
|
||||
earliest_message = None
|
||||
for msg in self.messages:
|
||||
# 确保消息有 thinking_start_time 属性
|
||||
msg_time = getattr(msg, "thinking_start_time", float("inf"))
|
||||
if msg_time < earliest_time:
|
||||
earliest_time = msg_time
|
||||
earliest_message = msg
|
||||
return earliest_message
|
||||
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]):
|
||||
"""添加消息到队列"""
|
||||
if isinstance(message, MessageSet):
|
||||
for single_message in message.messages:
|
||||
self.messages.append(single_message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
|
||||
def remove_message(self, message_to_remove: Union[MessageThinking, MessageSending]):
|
||||
"""移除指定的消息对象,如果消息存在则返回True,否则返回False"""
|
||||
try:
|
||||
_initial_len = len(self.messages)
|
||||
# 使用列表推导式或 message_filter 创建新列表,排除要删除的元素
|
||||
# self.messages = [msg for msg in self.messages if msg is not message_to_remove]
|
||||
# 或者直接 remove (如果确定对象唯一性)
|
||||
if message_to_remove in self.messages:
|
||||
self.messages.remove(message_to_remove)
|
||||
return True
|
||||
# logger.debug(f"Removed message {getattr(message_to_remove, 'message_info', {}).get('message_id', 'UNKNOWN')}. Old len: {initial_len}, New len: {len(self.messages)}")
|
||||
# return len(self.messages) < initial_len
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"移除消息时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def has_messages(self) -> bool:
|
||||
"""检查是否有待发送的消息"""
|
||||
return bool(self.messages)
|
||||
|
||||
def get_all_messages(self) -> list[MessageThinking | MessageSending]:
|
||||
"""获取所有消息"""
|
||||
return list(self.messages) # 返回副本
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""管理所有聊天流的消息容器 (不再是单例)"""
|
||||
|
||||
def __init__(self):
|
||||
self._processor_task: Task | None = None
|
||||
self.containers: dict[str, MessageContainer] = {}
|
||||
self.storage = MessageStorage() # 添加 storage 实例
|
||||
self._running = True # 处理器运行状态
|
||||
self._container_lock = asyncio.Lock() # 保护 containers 字典的锁
|
||||
# self.message_sender = MessageSender() # 创建发送器实例 (改为全局实例)
|
||||
|
||||
async def start(self):
|
||||
"""启动后台处理器任务。"""
|
||||
# 检查是否已有任务在运行,避免重复启动
|
||||
if self._processor_task is not None and not self._processor_task.done():
|
||||
logger.warning("Processor task already running.")
|
||||
return
|
||||
self._processor_task = asyncio.create_task(self._start_processor_loop())
|
||||
logger.debug("MessageManager processor task started.")
|
||||
|
||||
def stop(self):
|
||||
"""停止后台处理器任务。"""
|
||||
self._running = False
|
||||
if self._processor_task is not None and not self._processor_task.done():
|
||||
self._processor_task.cancel()
|
||||
logger.debug("MessageManager processor task stopping.")
|
||||
else:
|
||||
logger.debug("MessageManager processor task not running or already stopped.")
|
||||
|
||||
async def get_container(self, chat_id: str) -> MessageContainer:
|
||||
"""获取或创建聊天流的消息容器 (异步,使用锁)"""
|
||||
async with self._container_lock:
|
||||
if chat_id not in self.containers:
|
||||
self.containers[chat_id] = MessageContainer(chat_id)
|
||||
return self.containers[chat_id]
|
||||
|
||||
async def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
|
||||
"""添加消息到对应容器"""
|
||||
chat_stream = message.chat_stream
|
||||
if not chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法添加到容器")
|
||||
return # 或者抛出异常
|
||||
container = await self.get_container(chat_stream.stream_id)
|
||||
container.add_message(message)
|
||||
|
||||
async def _handle_sending_message(self, container: MessageContainer, message: MessageSending):
|
||||
"""处理单个 MessageSending 消息 (包含 set_reply 逻辑)"""
|
||||
try:
|
||||
_ = message.update_thinking_time() # 更新思考时间
|
||||
thinking_start_time = message.thinking_start_time
|
||||
now_time = time.time()
|
||||
# logger.debug(f"thinking_start_time:{thinking_start_time},now_time:{now_time}")
|
||||
thinking_messages_count, thinking_messages_length = count_messages_between(
|
||||
start_time=thinking_start_time, end_time=now_time, stream_id=message.chat_stream.stream_id
|
||||
)
|
||||
|
||||
if (
|
||||
message.is_head
|
||||
and (thinking_messages_count > 3 or thinking_messages_length > 200)
|
||||
and not message.is_private_message()
|
||||
):
|
||||
logger.debug(
|
||||
f"[{message.chat_stream.stream_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
message.build_reply()
|
||||
# --- 结束条件 set_reply ---
|
||||
|
||||
await message.process() # 预处理消息内容
|
||||
|
||||
# logger.debug(f"{message}")
|
||||
|
||||
# 使用全局 message_sender 实例
|
||||
await send_message(message)
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
# 移除消息要在发送 *之后*
|
||||
container.remove_message(message)
|
||||
# logger.debug(f"[{message.chat_stream.stream_id}] Sent and removed message: {message.message_info.message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
|
||||
)
|
||||
logger.exception("详细错误信息:")
|
||||
# 考虑是否移除出错的消息,防止无限循环
|
||||
removed = container.remove_message(message)
|
||||
if removed:
|
||||
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
|
||||
|
||||
async def _process_chat_messages(self, chat_id: str):
|
||||
"""处理单个聊天流消息 (合并后的逻辑)"""
|
||||
container = await self.get_container(chat_id) # 获取容器是异步的了
|
||||
|
||||
if container.has_messages():
|
||||
message_earliest = container.get_earliest_message()
|
||||
|
||||
if not message_earliest: # 如果最早消息为空,则退出
|
||||
return
|
||||
|
||||
if isinstance(message_earliest, MessageThinking):
|
||||
# --- 处理思考消息 (来自旧 sender) ---
|
||||
message_earliest.update_thinking_time()
|
||||
thinking_time = message_earliest.thinking_time
|
||||
# 减少控制台刷新频率或只在时间显著变化时打印
|
||||
if int(thinking_time) % 5 == 0: # 每5秒打印一次
|
||||
print(
|
||||
f"消息 {message_earliest.message_info.message_id} 正在思考中,已思考 {int(thinking_time)} 秒\r",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
elif isinstance(message_earliest, MessageSending):
|
||||
# --- 处理发送消息 ---
|
||||
await self._handle_sending_message(container, message_earliest)
|
||||
|
||||
# --- 处理超时发送消息 (来自旧 sender) ---
|
||||
# 在处理完最早的消息后,检查是否有超时的发送消息
|
||||
timeout_sending_messages = container.get_timeout_sending_messages()
|
||||
if timeout_sending_messages:
|
||||
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
|
||||
for msg in timeout_sending_messages:
|
||||
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
|
||||
if msg is message_earliest:
|
||||
continue
|
||||
logger.info(f"[{chat_id}] 处理超时发送消息: {msg.message_info.message_id}")
|
||||
await self._handle_sending_message(container, msg) # 复用处理逻辑
|
||||
|
||||
async def _start_processor_loop(self):
|
||||
"""消息处理器主循环"""
|
||||
while self._running:
|
||||
tasks = []
|
||||
# 使用异步锁保护迭代器创建过程
|
||||
async with self._container_lock:
|
||||
# 创建 keys 的快照以安全迭代
|
||||
chat_ids = list(self.containers.keys())
|
||||
|
||||
for chat_id in chat_ids:
|
||||
# 为每个 chat_id 创建一个处理任务
|
||||
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
|
||||
|
||||
if tasks:
|
||||
try:
|
||||
# 等待当前批次的所有任务完成
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理循环 gather 出错: {e}")
|
||||
|
||||
# 等待一小段时间,避免CPU空转
|
||||
try:
|
||||
await asyncio.sleep(0.1) # 稍微降低轮询频率
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Processor loop sleep cancelled.")
|
||||
break # 退出循环
|
||||
logger.info("MessageManager processor loop finished.")
|
||||
|
||||
|
||||
# --- 创建全局实例 ---
|
||||
message_manager = MessageManager()
|
||||
message_sender = MessageSender()
|
||||
# --- 结束全局实例 ---
|
||||
@@ -1,11 +1,11 @@
|
||||
import re
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .chat_stream import ChatStream
|
||||
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
|
||||
from src.common.database.database_model import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -36,15 +36,27 @@ class MessageStorage:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_command = False
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
|
||||
interest_value = message.interest_value
|
||||
is_mentioned = message.is_mentioned
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_command = message.is_command
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id
|
||||
@@ -56,10 +68,11 @@ class MessageStorage:
|
||||
|
||||
Messages.create(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time),
|
||||
time=float(message.message_info.time), # type: ignore
|
||||
chat_id=chat_stream.stream_id,
|
||||
# Flattened chat_info
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
chat_info_platform=chat_info_dict.get("platform"),
|
||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||
@@ -80,32 +93,16 @@ class MessageStorage:
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_command=is_command,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
|
||||
@staticmethod
|
||||
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||
"""存储撤回消息到数据库"""
|
||||
# Table creation is handled by initialize_database in database_model.py
|
||||
try:
|
||||
RecalledMessages.create(
|
||||
message_id=message_id,
|
||||
time=float(time), # Assuming time is a string representing a float timestamp
|
||||
stream_id=chat_stream.stream_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("存储撤回消息失败")
|
||||
|
||||
@staticmethod
|
||||
async def remove_recalled_message(time: str) -> None:
|
||||
"""删除撤回消息"""
|
||||
try:
|
||||
# Assuming input 'time' is a string timestamp that can be converted to float
|
||||
current_time_float = float(time)
|
||||
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
|
||||
except Exception:
|
||||
logger.exception("删除撤回消息失败")
|
||||
traceback.print_exc()
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
@@ -115,22 +112,19 @@ class MessageStorage:
|
||||
"""更新最新一条匹配消息的message_id"""
|
||||
try:
|
||||
if message.message_segment.type == "notify":
|
||||
mmc_message_id = message.message_segment.data.get("echo")
|
||||
qq_message_id = message.message_segment.data.get("actual_id")
|
||||
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||
else:
|
||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||
return
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return
|
||||
# 查询最新一条匹配消息
|
||||
matched_message = (
|
||||
if matched_message := (
|
||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||
)
|
||||
|
||||
if matched_message:
|
||||
):
|
||||
# 更新找到的消息记录
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
@@ -155,10 +149,7 @@ class MessageStorage:
|
||||
image_record = (
|
||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||
)
|
||||
if image_record:
|
||||
return f"[picid:{image_record.image_id}]"
|
||||
else:
|
||||
return match.group(0) # 保持原样
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
|
||||
@@ -1,28 +1,29 @@
|
||||
import asyncio
|
||||
from typing import Dict, Optional # 重新导入类型
|
||||
from src.chat.message_receive.message import MessageSending, MessageThinking
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from rich.traceback import install
|
||||
import traceback
|
||||
|
||||
install(extra_lines=3)
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_message(message: MessageSending) -> bool:
|
||||
async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=40)
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=120)
|
||||
|
||||
try:
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -36,44 +37,10 @@ class HeartFCSender:
|
||||
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
# 用于存储活跃的思考消息
|
||||
self.thinking_messages: Dict[str, Dict[str, MessageThinking]] = {}
|
||||
self._thinking_lock = asyncio.Lock() # 保护 thinking_messages 的锁
|
||||
|
||||
async def register_thinking(self, thinking_message: MessageThinking):
|
||||
"""注册一个思考中的消息。"""
|
||||
if not thinking_message.chat_stream or not thinking_message.message_info.message_id:
|
||||
logger.error("无法注册缺少 chat_stream 或 message_id 的思考消息")
|
||||
return
|
||||
|
||||
chat_id = thinking_message.chat_stream.stream_id
|
||||
message_id = thinking_message.message_info.message_id
|
||||
|
||||
async with self._thinking_lock:
|
||||
if chat_id not in self.thinking_messages:
|
||||
self.thinking_messages[chat_id] = {}
|
||||
if message_id in self.thinking_messages[chat_id]:
|
||||
logger.warning(f"[{chat_id}] 尝试注册已存在的思考消息 ID: {message_id}")
|
||||
self.thinking_messages[chat_id][message_id] = thinking_message
|
||||
logger.debug(f"[{chat_id}] Registered thinking message: {message_id}")
|
||||
|
||||
async def complete_thinking(self, chat_id: str, message_id: str):
|
||||
"""完成并移除一个思考中的消息记录。"""
|
||||
async with self._thinking_lock:
|
||||
if chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]:
|
||||
del self.thinking_messages[chat_id][message_id]
|
||||
logger.debug(f"[{chat_id}] Completed thinking message: {message_id}")
|
||||
if not self.thinking_messages[chat_id]:
|
||||
del self.thinking_messages[chat_id]
|
||||
logger.debug(f"[{chat_id}] Removed empty thinking message container.")
|
||||
|
||||
async def get_thinking_start_time(self, chat_id: str, message_id: str) -> Optional[float]:
|
||||
"""获取已注册思考消息的开始时间。"""
|
||||
async with self._thinking_lock:
|
||||
thinking_message = self.thinking_messages.get(chat_id, {}).get(message_id)
|
||||
return thinking_message.thinking_start_time if thinking_message else None
|
||||
|
||||
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True):
|
||||
async def send_message(
|
||||
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||
):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
|
||||
@@ -86,10 +53,10 @@ class HeartFCSender:
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法发送")
|
||||
raise Exception("消息缺少 chat_stream,无法发送")
|
||||
raise ValueError("消息缺少 chat_stream,无法发送")
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||
raise Exception("消息缺少 message_info 或 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_info 或 message_id,无法发送")
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id
|
||||
@@ -109,7 +76,7 @@ class HeartFCSender:
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
sent_msg = await send_message(message)
|
||||
sent_msg = await send_message(message, show_log=show_log)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
@@ -121,5 +88,3 @@ class HeartFCSender:
|
||||
except Exception as e:
|
||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||
raise e
|
||||
finally:
|
||||
await self.complete_thinking(chat_id, message_id)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,12 @@
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
from typing import Dict, List, Optional, Type
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
# 定义动作信息类型
|
||||
ActionInfo = Dict[str, Any]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
@@ -20,8 +17,8 @@ class ActionManager:
|
||||
|
||||
# 类常量
|
||||
DEFAULT_RANDOM_PROBABILITY = 0.3
|
||||
DEFAULT_MODE = "all"
|
||||
DEFAULT_ACTIVATION_TYPE = "always"
|
||||
DEFAULT_MODE = ChatMode.ALL
|
||||
DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
@@ -30,14 +27,11 @@ class ActionManager:
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 默认动作集,仅作为快照,用于恢复默认
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 加载插件动作
|
||||
self._load_plugin_actions()
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
@@ -54,43 +48,15 @@ class ActionManager:
|
||||
def _load_plugin_system_actions(self) -> None:
|
||||
"""从插件系统的component_registry加载Action组件"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
# 获取所有Action组件
|
||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
||||
|
||||
for action_name, action_info in action_components.items():
|
||||
if action_name in self._registered_actions:
|
||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 将插件系统的ActionInfo转换为ActionManager格式
|
||||
converted_action_info = {
|
||||
"description": action_info.description,
|
||||
"parameters": getattr(action_info, "action_parameters", {}),
|
||||
"require": getattr(action_info, "action_require", []),
|
||||
"associated_types": getattr(action_info, "associated_types", []),
|
||||
"enable_plugin": action_info.enabled,
|
||||
# 激活类型相关
|
||||
"focus_activation_type": action_info.focus_activation_type.value,
|
||||
"normal_activation_type": action_info.normal_activation_type.value,
|
||||
"random_activation_probability": action_info.random_activation_probability,
|
||||
"llm_judge_prompt": action_info.llm_judge_prompt,
|
||||
"activation_keywords": action_info.activation_keywords,
|
||||
"keyword_case_sensitive": action_info.keyword_case_sensitive,
|
||||
# 模式和并行设置
|
||||
"mode_enable": action_info.mode_enable.value,
|
||||
"parallel_action": action_info.parallel_action,
|
||||
# 插件信息
|
||||
"_plugin_name": getattr(action_info, "plugin_name", ""),
|
||||
}
|
||||
|
||||
self._registered_actions[action_name] = converted_action_info
|
||||
|
||||
# 如果启用,也添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = converted_action_info
|
||||
self._registered_actions[action_name] = action_info
|
||||
|
||||
logger.debug(
|
||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||
@@ -114,6 +80,7 @@ class ActionManager:
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[dict] = None,
|
||||
) -> Optional[BaseAction]:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
@@ -133,7 +100,9 @@ class ActionManager:
|
||||
"""
|
||||
try:
|
||||
# 获取组件类 - 明确指定查询Action类型
|
||||
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||
action_name, ComponentType.ACTION
|
||||
) # type: ignore
|
||||
if not component_class:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
@@ -157,6 +126,7 @@ class ActionManager:
|
||||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
plugin_config=plugin_config,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
logger.debug(f"创建Action实例成功: {action_name}")
|
||||
@@ -173,37 +143,10 @@ class ActionManager:
|
||||
"""获取所有已注册的动作集"""
|
||||
return self._registered_actions.copy()
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
根据聊天模式获取可用的动作集合
|
||||
|
||||
Args:
|
||||
mode: 聊天模式 ("focus", "normal", "all")
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 在指定模式下可用的动作集合
|
||||
"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in self._using_actions.items():
|
||||
action_mode = action_info.get("mode_enable", "all")
|
||||
|
||||
# 检查动作是否在当前模式下启用
|
||||
if action_mode == "all" or action_mode == mode:
|
||||
filtered_actions[action_name] = action_info
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
|
||||
|
||||
logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
return filtered_actions
|
||||
|
||||
def add_action_to_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
添加已注册的动作到当前使用的动作集
|
||||
@@ -244,31 +187,31 @@ class ActionManager:
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||
"""
|
||||
添加新的动作到注册集
|
||||
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||
# """
|
||||
# 添加新的动作到注册集
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
description: 动作描述
|
||||
parameters: 动作参数定义,默认为空字典
|
||||
require: 动作依赖项,默认为空列表
|
||||
# Args:
|
||||
# action_name: 动作名称
|
||||
# description: 动作描述
|
||||
# parameters: 动作参数定义,默认为空字典
|
||||
# require: 动作依赖项,默认为空列表
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
if action_name in self._registered_actions:
|
||||
return False
|
||||
# Returns:
|
||||
# bool: 添加是否成功
|
||||
# """
|
||||
# if action_name in self._registered_actions:
|
||||
# return False
|
||||
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
if require is None:
|
||||
require = []
|
||||
# if parameters is None:
|
||||
# parameters = {}
|
||||
# if require is None:
|
||||
# require = []
|
||||
|
||||
action_info = {"description": description, "parameters": parameters, "require": require}
|
||||
# action_info = {"description": description, "parameters": parameters, "require": require}
|
||||
|
||||
self._registered_actions[action_name] = action_info
|
||||
return True
|
||||
# self._registered_actions[action_name] = action_info
|
||||
# return True
|
||||
|
||||
def remove_action(self, action_name: str) -> bool:
|
||||
"""从注册集移除指定动作"""
|
||||
@@ -287,10 +230,9 @@ class ActionManager:
|
||||
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
logger.debug(
|
||||
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}"
|
||||
)
|
||||
self._using_actions = self._default_actions.copy()
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
@@ -320,4 +262,4 @@ class ActionManager:
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_class(action_name)
|
||||
return component_registry.get_component_class(action_name) # type: ignore
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.focus_loop_info import FocusLoopInfo
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -25,7 +29,7 @@ class ActionModifier:
|
||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||
"""初始化动作处理器"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.action_manager = action_manager
|
||||
@@ -43,10 +47,9 @@ class ActionModifier:
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
loop_info=None,
|
||||
mode: str = "focus",
|
||||
history_loop=None,
|
||||
message_content: str = "",
|
||||
):
|
||||
): # sourcery skip: use-named-expression
|
||||
"""
|
||||
动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
@@ -62,12 +65,12 @@ class ActionModifier:
|
||||
removals_s2 = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions_for_mode(mode)
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.5),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
)
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -82,10 +85,10 @@ class ActionModifier:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
if loop_info:
|
||||
removals_from_loop = await self.analyze_loop_actions(loop_info)
|
||||
if removals_from_loop:
|
||||
removals_s1.extend(removals_from_loop)
|
||||
# if history_loop:
|
||||
# removals_from_loop = await self.analyze_loop_actions(history_loop)
|
||||
# if removals_from_loop:
|
||||
# removals_s1.extend(removals_from_loop)
|
||||
|
||||
# 检查动作的关联类型
|
||||
chat_context = self.chat_stream.context
|
||||
@@ -104,12 +107,11 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
current_using_actions = self.action_manager.get_using_actions_for_mode(mode)
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
removals_s2 = await self._get_deactivated_actions_by_type(
|
||||
current_using_actions,
|
||||
mode,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
@@ -120,28 +122,27 @@ class ActionModifier:
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}{mode}模式动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions_for_mode(mode).keys())}||移除记录: {removals_summary}"
|
||||
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}"
|
||||
)
|
||||
|
||||
def _check_action_associated_types(self, all_actions, chat_context):
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions = []
|
||||
for action_name, data in all_actions.items():
|
||||
if data.get("associated_types"):
|
||||
if not chat_context.check_types(data["associated_types"]):
|
||||
associated_types_str = ", ".join(data["associated_types"])
|
||||
reason = f"适配器不支持(需要: {associated_types_str})"
|
||||
type_mismatched_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
reason = f"适配器不支持(需要: {associated_types_str})"
|
||||
type_mismatched_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
||||
return type_mismatched_actions
|
||||
|
||||
async def _get_deactivated_actions_by_type(
|
||||
self,
|
||||
actions_with_info: Dict[str, Any],
|
||||
mode: str = "focus",
|
||||
actions_with_info: Dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> List[tuple[str, str]]:
|
||||
"""
|
||||
@@ -163,29 +164,33 @@ class ActionModifier:
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
activation_type = f"{mode}_activation_type"
|
||||
activation_type = action_info.get(activation_type, "always")
|
||||
activation_type = action_info.activation_type or action_info.focus_activation_type
|
||||
|
||||
if activation_type == "always":
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
elif activation_type == "random":
|
||||
probability = action_info.get("random_activation_probability", ActionManager.DEFAULT_RANDOM_PROBABILITY)
|
||||
if not (random.random() < probability):
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY
|
||||
if random.random() >= probability:
|
||||
reason = f"RANDOM类型未触发(概率{probability})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == "keyword":
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
if not self._check_keyword_activation(action_name, action_info, chat_content):
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
keywords = action_info.activation_keywords
|
||||
reason = f"关键词未匹配(关键词: {keywords})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == "llm_judge":
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
llm_judge_actions[action_name] = action_info
|
||||
|
||||
elif activation_type == ActionActivationType.NEVER:
|
||||
reason = "激活类型为never"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never")
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
@@ -203,35 +208,6 @@ class ActionModifier:
|
||||
|
||||
return deactivated_actions
|
||||
|
||||
async def process_actions_for_planner(
|
||||
self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
[已废弃] 此方法现在已被整合到 modify_actions() 中
|
||||
|
||||
为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions()
|
||||
规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法
|
||||
|
||||
新的架构:
|
||||
1. 主循环调用 modify_actions() 处理完整的动作管理流程
|
||||
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
|
||||
"""
|
||||
logger.warning(
|
||||
f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()"
|
||||
)
|
||||
|
||||
# 为了向后兼容,仍然返回当前使用的动作集
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
# 构建完整的动作信息
|
||||
result = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in all_registered_actions:
|
||||
result[action_name] = all_registered_actions[action_name]
|
||||
|
||||
return result
|
||||
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
@@ -299,7 +275,7 @@ class ActionModifier:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
for action_name, result in zip(task_names, task_results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
@@ -315,7 +291,7 @@ class ActionModifier:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||
# 如果并行执行失败,为所有任务返回False
|
||||
for action_name in tasks_to_run.keys():
|
||||
for action_name in tasks_to_run:
|
||||
results[action_name] = False
|
||||
|
||||
# 清理过期缓存
|
||||
@@ -326,10 +302,11 @@ class ActionModifier:
|
||||
def _cleanup_expired_cache(self, current_time: float):
|
||||
"""清理过期的缓存条目"""
|
||||
expired_keys = []
|
||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
expired_keys.extend(
|
||||
cache_key
|
||||
for cache_key, cache_data in self._llm_judge_cache.items()
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time
|
||||
)
|
||||
for key in expired_keys:
|
||||
del self._llm_judge_cache[key]
|
||||
|
||||
@@ -339,7 +316,7 @@ class ActionModifier:
|
||||
async def _llm_judge_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -358,9 +335,9 @@ class ActionModifier:
|
||||
|
||||
try:
|
||||
# 构建判定提示词
|
||||
action_description = action_info.get("description", "")
|
||||
action_require = action_info.get("require", [])
|
||||
custom_prompt = action_info.get("llm_judge_prompt", "")
|
||||
action_description = action_info.description
|
||||
action_require = action_info.action_require
|
||||
custom_prompt = action_info.llm_judge_prompt
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
@@ -408,7 +385,7 @@ class ActionModifier:
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -425,8 +402,8 @@ class ActionModifier:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
activation_keywords = action_info.activation_keywords
|
||||
case_sensitive = action_info.keyword_case_sensitive
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
@@ -459,84 +436,92 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
|
||||
async def analyze_loop_actions(self, obs: FocusLoopInfo) -> List[tuple[str, str]]:
|
||||
"""分析最近的循环内容并决定动作的移除
|
||||
# async def analyze_loop_actions(self, history_loop: List[CycleDetail]) -> List[tuple[str, str]]:
|
||||
# """分析最近的循环内容并决定动作的移除
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 包含要删除的动作及原因的元组列表
|
||||
[("action3", "some reason")]
|
||||
"""
|
||||
removals = []
|
||||
# Returns:
|
||||
# List[Tuple[str, str]]: 包含要删除的动作及原因的元组列表
|
||||
# [("action3", "some reason")]
|
||||
# """
|
||||
# removals = []
|
||||
|
||||
# 获取最近10次循环
|
||||
recent_cycles = obs.history_loop[-10:] if len(obs.history_loop) > 10 else obs.history_loop
|
||||
if not recent_cycles:
|
||||
return removals
|
||||
# # 获取最近10次循环
|
||||
# recent_cycles = history_loop[-10:] if len(history_loop) > 10 else history_loop
|
||||
# if not recent_cycles:
|
||||
# return removals
|
||||
|
||||
reply_sequence = [] # 记录最近的动作序列
|
||||
# reply_sequence = [] # 记录最近的动作序列
|
||||
|
||||
for cycle in recent_cycles:
|
||||
action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
action_type = action_result.get("action_type", "unknown")
|
||||
reply_sequence.append(action_type == "reply")
|
||||
# for cycle in recent_cycles:
|
||||
# action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
# action_type = action_result.get("action_type", "unknown")
|
||||
# reply_sequence.append(action_type == "reply")
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
# # 计算连续回复的相关阈值
|
||||
|
||||
max_reply_num = int(global_config.focus_chat.consecutive_replies * 3.2)
|
||||
sec_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 2)
|
||||
one_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 1.5)
|
||||
# max_reply_num = int(global_config.focus_chat.consecutive_replies * 3.2)
|
||||
# sec_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 2)
|
||||
# one_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 1.5)
|
||||
|
||||
# 获取最近max_reply_num次的reply状态
|
||||
if len(reply_sequence) >= max_reply_num:
|
||||
last_max_reply_num = reply_sequence[-max_reply_num:]
|
||||
else:
|
||||
last_max_reply_num = reply_sequence[:]
|
||||
# # 获取最近max_reply_num次的reply状态
|
||||
# if len(reply_sequence) >= max_reply_num:
|
||||
# last_max_reply_num = reply_sequence[-max_reply_num:]
|
||||
# else:
|
||||
# last_max_reply_num = reply_sequence[:]
|
||||
|
||||
# 详细打印阈值和序列信息,便于调试
|
||||
logger.info(
|
||||
f"连续回复阈值: max={max_reply_num}, sec={sec_thres_reply_num}, one={one_thres_reply_num},"
|
||||
f"最近reply序列: {last_max_reply_num}"
|
||||
)
|
||||
# print(f"consecutive_replies: {consecutive_replies}")
|
||||
# # 详细打印阈值和序列信息,便于调试
|
||||
# logger.info(
|
||||
# f"连续回复阈值: max={max_reply_num}, sec={sec_thres_reply_num}, one={one_thres_reply_num},"
|
||||
# f"最近reply序列: {last_max_reply_num}"
|
||||
# )
|
||||
# # print(f"consecutive_replies: {consecutive_replies}")
|
||||
|
||||
# 根据最近的reply情况决定是否移除reply动作
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
reason = f"连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
removals.append(("reply", reason))
|
||||
# reply_count = len(last_max_reply_num) - no_reply_count
|
||||
elif len(last_max_reply_num) >= sec_thres_reply_num and all(last_max_reply_num[-sec_thres_reply_num:]):
|
||||
# 如果最近sec_thres_reply_num次都是reply,40%概率移除
|
||||
removal_probability = 0.4 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
reason = (
|
||||
f"连续回复较多(最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
removals.append(("reply", reason))
|
||||
elif len(last_max_reply_num) >= one_thres_reply_num and all(last_max_reply_num[-one_thres_reply_num:]):
|
||||
# 如果最近one_thres_reply_num次都是reply,20%概率移除
|
||||
removal_probability = 0.2 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
reason = (
|
||||
f"连续回复检测(最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
removals.append(("reply", reason))
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
# # 根据最近的reply情况决定是否移除reply动作
|
||||
# if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# # 如果最近max_reply_num次都是reply,直接移除
|
||||
# reason = f"连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
# removals.append(("reply", reason))
|
||||
# # reply_count = len(last_max_reply_num) - no_reply_count
|
||||
# elif len(last_max_reply_num) >= sec_thres_reply_num and all(last_max_reply_num[-sec_thres_reply_num:]):
|
||||
# # 如果最近sec_thres_reply_num次都是reply,40%概率移除
|
||||
# removal_probability = 0.4 / global_config.focus_chat.consecutive_replies
|
||||
# if random.random() < removal_probability:
|
||||
# reason = (
|
||||
# f"连续回复较多(最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
# )
|
||||
# removals.append(("reply", reason))
|
||||
# elif len(last_max_reply_num) >= one_thres_reply_num and all(last_max_reply_num[-one_thres_reply_num:]):
|
||||
# # 如果最近one_thres_reply_num次都是reply,20%概率移除
|
||||
# removal_probability = 0.2 / global_config.focus_chat.consecutive_replies
|
||||
# if random.random() < removal_probability:
|
||||
# reason = (
|
||||
# f"连续回复检测(最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
# )
|
||||
# removals.append(("reply", reason))
|
||||
# else:
|
||||
# logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
|
||||
return removals
|
||||
# return removals
|
||||
|
||||
def get_available_actions_count(self) -> int:
|
||||
"""获取当前可用动作数量(排除默认的no_action)"""
|
||||
current_actions = self.action_manager.get_using_actions_for_mode("normal")
|
||||
# 排除no_action(如果存在)
|
||||
filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"}
|
||||
return len(filtered_actions)
|
||||
# def get_available_actions_count(self, mode: str = "focus") -> int:
|
||||
# """获取当前可用动作数量(排除默认的no_action)"""
|
||||
# current_actions = self.action_manager.get_using_actions_for_mode(mode)
|
||||
# # 排除no_action(如果存在)
|
||||
# filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"}
|
||||
# return len(filtered_actions)
|
||||
|
||||
def should_skip_planning(self) -> bool:
|
||||
"""判断是否应该跳过规划过程"""
|
||||
available_count = self.get_available_actions_count()
|
||||
if available_count == 0:
|
||||
logger.debug(f"{self.log_prefix} 没有可用动作,跳过规划")
|
||||
return True
|
||||
return False
|
||||
# def should_skip_planning_for_no_reply(self) -> bool:
|
||||
# """判断是否应该跳过规划过程"""
|
||||
# current_actions = self.action_manager.get_using_actions_for_mode("focus")
|
||||
# # 排除no_action(如果存在)
|
||||
# if len(current_actions) == 1 and "no_reply" in current_actions:
|
||||
# return True
|
||||
# return False
|
||||
|
||||
# def should_skip_planning_for_no_action(self) -> bool:
|
||||
# """判断是否应该跳过规划过程"""
|
||||
# available_count = self.action_manager.get_using_actions_for_mode("normal")
|
||||
# if available_count == 0:
|
||||
# logger.debug(f"{self.log_prefix} 没有可用动作,跳过规划")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
import json # <--- 确保导入 json
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from json_repair import repair_json
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from datetime import datetime
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
@@ -23,17 +31,22 @@ def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{time_block}
|
||||
{indentify_block}
|
||||
{identity_block}
|
||||
你现在需要根据聊天内容,选择的合适的action来参与聊天。
|
||||
{chat_context_description},以下是具体的聊天内容:
|
||||
{chat_content_block}
|
||||
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
现在请你根据{by_what}选择合适的action:
|
||||
现在请你根据{by_what}选择合适的action和触发action的消息:
|
||||
你刚刚选择并执行过的action是:
|
||||
{actions_before_now_block}
|
||||
|
||||
{no_action_block}
|
||||
{action_options_text}
|
||||
|
||||
你必须从上面列出的可用action中选择一个,并说明原因。
|
||||
你必须从上面列出的可用action中选择一个,并说明触发action的消息id和原因。
|
||||
|
||||
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||
""",
|
||||
@@ -46,7 +59,8 @@ def init_prompt():
|
||||
动作描述:{action_description}
|
||||
{action_require}
|
||||
{{
|
||||
"action": "{action_name}",{action_parameters}
|
||||
"action": "{action_name}",{action_parameters}{target_prompt}
|
||||
"reason":"触发action的原因"
|
||||
}}
|
||||
""",
|
||||
"action_prompt",
|
||||
@@ -54,20 +68,38 @@ def init_prompt():
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager, mode: str = "focus"):
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.mode = mode
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
request_type=f"{self.mode}.planner", # 用于动作规划
|
||||
request_type="planner", # 用于动作规划
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def plan(self) -> Dict[str, Any]:
|
||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item.get("id") == message_id:
|
||||
return item.get("message")
|
||||
return None
|
||||
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
@@ -75,6 +107,9 @@ class ActionPlanner:
|
||||
action = "no_reply" # 默认动作
|
||||
reasoning = "规划器初始化默认"
|
||||
action_data = {}
|
||||
current_available_actions: Dict[str, ActionInfo] = {}
|
||||
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
||||
prompt: str = ""
|
||||
|
||||
try:
|
||||
is_group_chat = True
|
||||
@@ -82,11 +117,11 @@ class ActionPlanner:
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions_for_mode(self.mode)
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
current_available_actions = {}
|
||||
|
||||
for action_name in current_available_actions_dict.keys():
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
@@ -94,24 +129,24 @@ class ActionPlanner:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
# 如果没有可用动作或只有no_reply动作,直接返回no_reply
|
||||
if not current_available_actions or (
|
||||
len(current_available_actions) == 1 and "no_reply" in current_available_actions
|
||||
):
|
||||
action = "no_reply"
|
||||
reasoning = "没有可用的动作" if not current_available_actions else "只有no_reply动作可用,跳过规划"
|
||||
if not current_available_actions:
|
||||
action = "no_reply" if mode == ChatMode.FOCUS else "no_action"
|
||||
reasoning = "没有可用的动作"
|
||||
logger.info(f"{self.log_prefix}{reasoning}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}[focus]沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
|
||||
}
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
},
|
||||
}, None
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt = await self.build_planner_prompt(
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
|
||||
current_available_actions=current_available_actions, # <-- Pass determined actions
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# --- 调用 LLM (普通文本生成) ---
|
||||
@@ -132,7 +167,7 @@ class ActionPlanner:
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
|
||||
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
action = "no_reply"
|
||||
|
||||
if llm_content:
|
||||
@@ -154,19 +189,26 @@ class ActionPlanner:
|
||||
reasoning = parsed_json.get("reasoning", "未提供原因")
|
||||
|
||||
# 将所有其他属性添加到action_data
|
||||
action_data = {}
|
||||
for key, value in parsed_json.items():
|
||||
if key not in ["action", "reasoning"]:
|
||||
action_data[key] = value
|
||||
|
||||
# 在FOCUS模式下,非no_reply动作需要target_message_id
|
||||
if mode == ChatMode.FOCUS and action != "no_reply":
|
||||
if target_message_id := parsed_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}FOCUS模式下动作'{action}'缺少target_message_id")
|
||||
|
||||
if action == "no_action":
|
||||
reasoning = "normal决定不使用额外动作"
|
||||
elif action not in current_available_actions:
|
||||
elif action != "no_reply" and action != "reply" and action not in current_available_actions:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
||||
)
|
||||
action = "no_reply"
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
||||
action = "no_reply"
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
@@ -182,8 +224,7 @@ class ActionPlanner:
|
||||
|
||||
is_parallel = False
|
||||
if action in current_available_actions:
|
||||
action_info = current_available_actions[action]
|
||||
is_parallel = action_info.get("parallel_action", False)
|
||||
is_parallel = current_available_actions[action].parallel_action
|
||||
|
||||
action_result = {
|
||||
"action_type": action,
|
||||
@@ -193,28 +234,27 @@ class ActionPlanner:
|
||||
"is_parallel": is_parallel,
|
||||
}
|
||||
|
||||
plan_result = {
|
||||
return {
|
||||
"action_result": action_result,
|
||||
"action_prompt": prompt,
|
||||
}
|
||||
|
||||
return plan_result
|
||||
}, target_message
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool, # Now passed as argument
|
||||
chat_target_info: Optional[dict], # Now passed as argument
|
||||
current_available_actions,
|
||||
) -> str:
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
mode: ChatMode = ChatMode.FOCUS,
|
||||
) -> tuple[str, list]: # sourcery skip: use-join
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
|
||||
chat_content_block = build_readable_messages(
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
@@ -222,13 +262,41 @@ class ActionPlanner:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(
|
||||
actions=actions_before_now,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
if self.mode == "focus":
|
||||
if mode == ChatMode.FOCUS:
|
||||
by_what = "聊天内容"
|
||||
no_action_block = ""
|
||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
||||
no_action_block = """重要说明1:
|
||||
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||
|
||||
动作:reply
|
||||
动作描述:参与聊天回复,发送文本进行表达
|
||||
- 你想要闲聊或者随便附和
|
||||
- 有人提到你
|
||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
||||
{
|
||||
"action": "reply",
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"回复的原因"
|
||||
}
|
||||
|
||||
"""
|
||||
else:
|
||||
by_what = "聊天内容和用户的最新消息"
|
||||
target_prompt = ""
|
||||
no_action_block = """重要说明:
|
||||
- 'no_action' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 其他action表示在普通回复的基础上,执行相应的额外动作"""
|
||||
@@ -244,25 +312,26 @@ class ActionPlanner:
|
||||
action_options_block = ""
|
||||
|
||||
for using_actions_name, using_actions_info in current_available_actions.items():
|
||||
if using_actions_info["parameters"]:
|
||||
if using_actions_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in using_actions_info["parameters"].items():
|
||||
for param_name, param_description in using_actions_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
require_text = ""
|
||||
for require_item in using_actions_info["require"]:
|
||||
for require_item in using_actions_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
action_description=using_actions_info["description"],
|
||||
action_description=using_actions_info.description,
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
target_prompt=target_prompt,
|
||||
)
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
@@ -277,7 +346,7 @@ class ActionPlanner:
|
||||
else:
|
||||
bot_nickname = ""
|
||||
bot_core_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
@@ -285,17 +354,17 @@ class ActionPlanner:
|
||||
by_what=by_what,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
no_action_block=no_action_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
indentify_block=indentify_block,
|
||||
identity_block=identity_block,
|
||||
)
|
||||
return prompt
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错"
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -1,36 +1,37 @@
|
||||
import traceback
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv, MessageThinking, MessageSending
|
||||
from src.chat.message_receive.message import Seg # Local import needed after move
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
import asyncio
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
import random
|
||||
import ast
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.instant_memory import InstantMemory
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.tools.tool_executor import ToolExecutor
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
ENABLE_S2S_MODE = True
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
@@ -87,6 +88,41 @@ def init_prompt():
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
|
||||
# s4u 风格的 prompt 模板
|
||||
Prompt(
|
||||
"""
|
||||
{expression_habits_block}
|
||||
{tool_info_block}
|
||||
{knowledge_prompt}
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
{extra_info_block}
|
||||
|
||||
{identity}
|
||||
|
||||
{action_descriptions}
|
||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。你现在的心情是:{mood_state}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{reply_target_block}
|
||||
对方最新发送的内容:{message_txt}
|
||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
||||
{config_expression_style}。注意不要复读你说过的话
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_style_prompt",
|
||||
)
|
||||
|
||||
|
||||
class DefaultReplyer:
|
||||
def __init__(
|
||||
@@ -124,6 +160,7 @@ class DefaultReplyer:
|
||||
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
def _select_weighted_model_config(self) -> Dict[str, Any]:
|
||||
@@ -132,54 +169,24 @@ class DefaultReplyer:
|
||||
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
||||
weights = [config.get("weight", 1.0) for config in configs]
|
||||
|
||||
# random.choices 返回一个列表,我们取第一个元素
|
||||
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
|
||||
return selected_config
|
||||
|
||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||
if not anchor_message or not anchor_message.chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法创建思考消息,缺少有效的锚点消息或聊天流。")
|
||||
return None
|
||||
|
||||
chat = anchor_message.chat_stream
|
||||
messageinfo = anchor_message.message_info
|
||||
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=messageinfo.platform,
|
||||
)
|
||||
|
||||
thinking_message = MessageThinking(
|
||||
message_id=thinking_id,
|
||||
chat_stream=chat,
|
||||
bot_user_info=bot_user_info,
|
||||
reply=anchor_message, # 回复的是锚点消息
|
||||
thinking_start_time=thinking_time_point,
|
||||
)
|
||||
# logger.debug(f"创建思考消息thinking_message:{thinking_message}")
|
||||
|
||||
await self.heart_fc_sender.register_thinking(thinking_message)
|
||||
return None
|
||||
return random.choices(population=configs, weights=weights, k=1)[0]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
reply_data: Dict[str, Any] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: List[str] = None,
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
enable_timeout: bool = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||
(已整合原 HeartFCGenerator 的功能)
|
||||
"""
|
||||
prompt = None
|
||||
if available_actions is None:
|
||||
available_actions = []
|
||||
if reply_data is None:
|
||||
reply_data = {}
|
||||
available_actions = {}
|
||||
try:
|
||||
if not reply_data:
|
||||
reply_data = {
|
||||
@@ -229,14 +236,14 @@ class DefaultReplyer:
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
return False, None # LLM 调用失败则无法生成回复
|
||||
return False, None, prompt # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, content, prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
return False, None, prompt
|
||||
|
||||
async def rewrite_reply_with_context(
|
||||
self,
|
||||
@@ -273,7 +280,7 @@ class DefaultReplyer:
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('model_name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
@@ -297,7 +304,7 @@ class DefaultReplyer:
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
async def build_relation_info(self, reply_data=None, chat_history=None):
|
||||
async def build_relation_info(self, reply_data=None):
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
@@ -316,15 +323,14 @@ class DefaultReplyer:
|
||||
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history)
|
||||
return relation_info
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
async def build_expression_habits(self, chat_history, target):
|
||||
if not global_config.expression.enable_expression:
|
||||
return ""
|
||||
|
||||
style_habbits = []
|
||||
grammar_habbits = []
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
@@ -338,22 +344,22 @@ class DefaultReplyer:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "grammar":
|
||||
grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
grammar_habits_str = "\n".join(grammar_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
if style_habbits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habbits_str}\n\n"
|
||||
if grammar_habbits_str.strip():
|
||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habbits_str}\n"
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
@@ -361,21 +367,31 @@ class DefaultReplyer:
|
||||
if not global_config.memory.enable_memory:
|
||||
return ""
|
||||
|
||||
running_memorys = await self.memory_activator.activate_memory_with_chat_history(
|
||||
instant_memory = None
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history_prompt=chat_history
|
||||
)
|
||||
|
||||
if global_config.memory.enable_instant_memory:
|
||||
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
|
||||
|
||||
if running_memorys:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memorys:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
memory_block = memory_str
|
||||
else:
|
||||
memory_block = ""
|
||||
instant_memory = await self.instant_memory.get_memory(target)
|
||||
logger.info(f"即时记忆:{instant_memory}")
|
||||
|
||||
if not running_memories:
|
||||
return ""
|
||||
|
||||
return memory_block
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
|
||||
if instant_memory:
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
|
||||
return memory_str
|
||||
|
||||
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
|
||||
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
|
||||
"""构建工具信息块
|
||||
|
||||
Args:
|
||||
@@ -400,7 +416,7 @@ class DefaultReplyer:
|
||||
|
||||
try:
|
||||
# 使用工具执行器获取信息
|
||||
tool_results = await self.tool_executor.execute_from_chat_message(
|
||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||
sender=sender, target_message=text, chat_history=chat_history, return_details=False
|
||||
)
|
||||
|
||||
@@ -455,7 +471,7 @@ class DefaultReplyer:
|
||||
for name, content in result.groupdict().items():
|
||||
reaction = reaction.replace(f"[{name}]", content)
|
||||
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
||||
keywords_reaction_prompt += reaction + ","
|
||||
keywords_reaction_prompt += f"{reaction},"
|
||||
break
|
||||
except re.error as e:
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
||||
@@ -465,21 +481,79 @@ class DefaultReplyer:
|
||||
|
||||
return keywords_reaction_prompt
|
||||
|
||||
async def _time_and_run_task(self, coro, name: str):
|
||||
async def _time_and_run_task(self, coroutine, name: str):
|
||||
"""一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时"""
|
||||
start_time = time.time()
|
||||
result = await coro
|
||||
result = await coroutine
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
tuple: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg_dict)
|
||||
else:
|
||||
# 其他用户的对话
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = core_dialogue_prompt_str
|
||||
|
||||
return core_dialogue_prompt, background_dialogue_prompt
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_data=None,
|
||||
available_actions: List[str] = None,
|
||||
reply_data: Dict[str, Any],
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_timeout: bool = False,
|
||||
enable_tool: bool = True,
|
||||
) -> str:
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
||||
@@ -495,11 +569,10 @@ class DefaultReplyer:
|
||||
str: 构建好的上下文
|
||||
"""
|
||||
if available_actions is None:
|
||||
available_actions = []
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
||||
@@ -514,10 +587,16 @@ class DefaultReplyer:
|
||||
if available_actions:
|
||||
action_descriptions = "你有以下的动作能力,但执行这些动作不由你决定,由另外一个模型同步决定,因此你只需要知道有如下能力即可:\n"
|
||||
for action_name, action_info in available_actions.items():
|
||||
action_description = action_info.get("description", "")
|
||||
action_description = action_info.description
|
||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||
action_descriptions += "\n"
|
||||
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
@@ -533,13 +612,13 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.5),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
@@ -550,14 +629,14 @@ class DefaultReplyer:
|
||||
# 并行执行四个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target), "build_expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "build_expression_habits"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_relation_info(reply_data, chat_talking_prompt_half), "build_relation_info"
|
||||
self.build_relation_info(reply_data), "build_relation_info"
|
||||
),
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_half, target), "build_memory_block"),
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(reply_data, chat_talking_prompt_half, enable_tool=enable_tool), "build_tool_info"
|
||||
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -592,31 +671,7 @@ class DefaultReplyer:
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||
# 解析字符串形式的Python列表
|
||||
try:
|
||||
if isinstance(short_impression, str) and short_impression.strip():
|
||||
short_impression = ast.literal_eval(short_impression)
|
||||
elif not short_impression:
|
||||
logger.warning("short_impression为空,使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
# 确保short_impression是列表格式且有足够的元素
|
||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
personality = short_impression[0]
|
||||
identity = short_impression[1]
|
||||
prompt_personality = personality + "," + identity
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
identity_block = await get_individuality().get_personality_block()
|
||||
|
||||
moderation_prompt_block = (
|
||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||
@@ -663,30 +718,76 @@ class DefaultReplyer:
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
memory_block=memory_block,
|
||||
tool_info_block=tool_info_block,
|
||||
knowledge_prompt=prompt_info,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=indentify_block,
|
||||
target_message=target,
|
||||
sender_name=sender,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
action_descriptions=action_descriptions,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_state=mood_prompt,
|
||||
)
|
||||
target_user_id = ""
|
||||
person_id = ""
|
||||
if sender:
|
||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
|
||||
return prompt
|
||||
# 根据配置选择使用哪种 prompt 构建模式
|
||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||
try:
|
||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||
if user_id_value:
|
||||
target_user_id = str(user_id_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
message_list_before_now_long, target_user_id
|
||||
)
|
||||
|
||||
# 使用 s4u 风格的模板
|
||||
template_name = "s4u_style_prompt"
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info_block,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=identity_block,
|
||||
action_descriptions=action_descriptions,
|
||||
sender_name=sender,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
message_txt=target,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
else:
|
||||
# 使用原有的模式
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
memory_block=memory_block,
|
||||
tool_info_block=tool_info_block,
|
||||
knowledge_prompt=prompt_info,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity_block,
|
||||
target_message=target,
|
||||
sender_name=sender,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
action_descriptions=action_descriptions,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_state=mood_prompt,
|
||||
)
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
@@ -694,8 +795,6 @@ class DefaultReplyer:
|
||||
) -> str:
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
@@ -706,7 +805,7 @@ class DefaultReplyer:
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.5),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
@@ -720,36 +819,14 @@ class DefaultReplyer:
|
||||
# 并行执行2个构建任务
|
||||
expression_habits_block, relation_info = await asyncio.gather(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||
self.build_relation_info(reply_data, chat_talking_prompt_half),
|
||||
self.build_relation_info(reply_data),
|
||||
)
|
||||
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||
try:
|
||||
if isinstance(short_impression, str) and short_impression.strip():
|
||||
short_impression = ast.literal_eval(short_impression)
|
||||
elif not short_impression:
|
||||
logger.warning("short_impression为空,使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
# 确保short_impression是列表格式且有足够的元素
|
||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
personality = short_impression[0]
|
||||
identity = short_impression[1]
|
||||
prompt_personality = personality + "," + identity
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
identity_block = await get_individuality().get_personality_block()
|
||||
|
||||
moderation_prompt_block = (
|
||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||
@@ -793,14 +870,14 @@ class DefaultReplyer:
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info,
|
||||
chat_target=chat_target_1,
|
||||
time_block=time_block,
|
||||
chat_info=chat_talking_prompt_half,
|
||||
identity=indentify_block,
|
||||
identity=identity_block,
|
||||
chat_target_2=chat_target_2,
|
||||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
@@ -810,110 +887,6 @@ class DefaultReplyer:
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
async def send_response_messages(
|
||||
self,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
response_set: List[Tuple[str, str]],
|
||||
thinking_id: str = "",
|
||||
display_message: str = "",
|
||||
) -> Optional[MessageSending]:
|
||||
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
|
||||
chat = self.chat_stream
|
||||
chat_id = self.chat_stream.stream_id
|
||||
if chat is None:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,chat_stream 为空。")
|
||||
return None
|
||||
if not anchor_message:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,anchor_message 为空。")
|
||||
return None
|
||||
|
||||
stream_name = get_chat_manager().get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
|
||||
# 检查思考过程是否仍在进行,并获取开始时间
|
||||
if thinking_id:
|
||||
# print(f"thinking_id: {thinking_id}")
|
||||
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
|
||||
else:
|
||||
print("thinking_id is None")
|
||||
# thinking_id = "ds" + str(round(time.time(), 2))
|
||||
thinking_start_time = time.time()
|
||||
|
||||
if thinking_start_time is None:
|
||||
logger.error(f"[{stream_name}]replyer思考过程未找到或已结束,无法发送回复。")
|
||||
return None
|
||||
|
||||
mark_head = False
|
||||
# first_bot_msg: Optional[MessageSending] = None
|
||||
reply_message_ids = [] # 记录实际发送的消息ID
|
||||
|
||||
sent_msg_list = []
|
||||
|
||||
for i, msg_text in enumerate(response_set):
|
||||
# 为每个消息片段生成唯一ID
|
||||
type = msg_text[0]
|
||||
data = msg_text[1]
|
||||
|
||||
if global_config.debug.debug_show_chat_mode and type == "text":
|
||||
data += "ᶠ"
|
||||
|
||||
part_message_id = f"{thinking_id}_{i}"
|
||||
message_segment = Seg(type=type, data=data)
|
||||
|
||||
if type == "emoji":
|
||||
is_emoji = True
|
||||
else:
|
||||
is_emoji = False
|
||||
reply_to = not mark_head
|
||||
|
||||
bot_message: MessageSending = await self._build_single_sending_message(
|
||||
anchor_message=anchor_message,
|
||||
message_id=part_message_id,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply_to=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_id=thinking_id,
|
||||
thinking_start_time=thinking_start_time,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
bot_message.is_private_message()
|
||||
or bot_message.reply.processed_plain_text != "[System Trigger Context]"
|
||||
or mark_head
|
||||
):
|
||||
set_reply = False
|
||||
else:
|
||||
set_reply = True
|
||||
|
||||
if not mark_head:
|
||||
mark_head = True
|
||||
typing = False
|
||||
else:
|
||||
typing = True
|
||||
|
||||
sent_msg = await self.heart_fc_sender.send_message(bot_message, typing=typing, set_reply=set_reply)
|
||||
|
||||
reply_message_ids.append(part_message_id) # 记录我们生成的ID
|
||||
|
||||
sent_msg_list.append((type, sent_msg))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
|
||||
traceback.print_exc()
|
||||
# 这里可以选择是继续发送下一个片段还是中止
|
||||
|
||||
# 在尝试发送完所有片段后,完成原始的 thinking_id 状态
|
||||
try:
|
||||
await self.heart_fc_sender.complete_thinking(chat_id, thinking_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}完成思考状态 {thinking_id} 时出错: {e}")
|
||||
|
||||
return sent_msg_list
|
||||
|
||||
async def _build_single_sending_message(
|
||||
self,
|
||||
message_id: str,
|
||||
@@ -922,7 +895,7 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: MessageRecv = None,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
@@ -933,12 +906,9 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
if anchor_message:
|
||||
sender_info = anchor_message.message_info.user_info
|
||||
else:
|
||||
sender_info = None
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
|
||||
bot_message = MessageSending(
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
@@ -951,8 +921,6 @@ class DefaultReplyer:
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
return bot_message
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
@@ -974,7 +942,7 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights))
|
||||
pool = list(zip(items, weights, strict=False))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
@@ -1000,7 +968,7 @@ async def get_prompt_info(message: str, threshold: float):
|
||||
logger.debug("LPMM知识库已禁用,跳过知识获取")
|
||||
return ""
|
||||
|
||||
found_knowledge_from_lpmm = qa_manager.get_knowledge(message)
|
||||
found_knowledge_from_lpmm = await qa_manager.get_knowledge(message)
|
||||
|
||||
end_time = time.time()
|
||||
if found_knowledge_from_lpmm is not None:
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
class ReplyerManager:
|
||||
def __init__(self):
|
||||
self._replyers: Dict[str, DefaultReplyer] = {}
|
||||
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
@@ -29,17 +30,16 @@ class ReplyerManager:
|
||||
return None
|
||||
|
||||
# 如果已有缓存实例,直接返回
|
||||
if stream_id in self._replyers:
|
||||
if stream_id in self._repliers:
|
||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
||||
return self._replyers[stream_id]
|
||||
return self._repliers[stream_id]
|
||||
|
||||
# 如果没有缓存,则创建新实例(首次初始化)
|
||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
||||
|
||||
target_stream = chat_stream
|
||||
if not target_stream:
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
if chat_manager := get_chat_manager():
|
||||
target_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not target_stream:
|
||||
@@ -52,7 +52,7 @@ class ReplyerManager:
|
||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
||||
request_type=request_type,
|
||||
)
|
||||
self._replyers[stream_id] = replyer
|
||||
self._repliers[stream_id] = replyer
|
||||
return replyer
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from src.config.config import global_config
|
||||
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable,assign_message_ids
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -28,7 +30,13 @@ def get_raw_msg_by_timestamp(
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -38,11 +46,23 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return find_messages(
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_bot,
|
||||
filter_command=filter_command,
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -52,14 +72,17 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
return find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: list,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -77,6 +100,60 @@ def get_raw_msg_by_timestamp_with_chat_users(
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float = 0,
|
||||
timestamp_end: float = time.time(),
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -135,7 +212,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
|
||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -172,9 +249,10 @@ def _build_readable_messages_internal(
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Dict[str, str] = None,
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
@@ -194,13 +272,22 @@ def _build_readable_messages_internal(
|
||||
if not messages:
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str]] = []
|
||||
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
pic_id_mapping = {}
|
||||
current_pic_counter = pic_counter
|
||||
|
||||
# 创建时间戳到消息ID的映射,用于在消息前添加[id]标识符
|
||||
timestamp_to_id = {}
|
||||
if message_id_list:
|
||||
for item in message_id_list:
|
||||
message = item.get("message", {})
|
||||
timestamp = message.get("time")
|
||||
if timestamp is not None:
|
||||
timestamp_to_id[timestamp] = item.get("id", "")
|
||||
|
||||
def process_pic_ids(content: str) -> str:
|
||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
||||
nonlocal current_pic_counter
|
||||
@@ -225,7 +312,7 @@ def _build_readable_messages_internal(
|
||||
# 检查是否是动作记录
|
||||
if msg.get("is_action_record", False):
|
||||
is_action = True
|
||||
timestamp = msg.get("time")
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content = msg.get("display_message", "")
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
@@ -249,9 +336,10 @@ def _build_readable_messages_internal(
|
||||
user_nickname = user_info.get("user_nickname")
|
||||
user_cardname = user_info.get("user_cardname")
|
||||
|
||||
timestamp = msg.get("time")
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content: str
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message")
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||
|
||||
@@ -271,10 +359,11 @@ def _build_readable_messages_internal(
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
person_name: str
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
@@ -289,12 +378,10 @@ def _build_readable_messages_internal(
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
aaa: str = match[1]
|
||||
bbb: str = match[2]
|
||||
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
||||
if not reply_person_name:
|
||||
reply_person_name = aaa
|
||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
|
||||
# 在内容前加上回复信息
|
||||
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
||||
|
||||
@@ -309,18 +396,15 @@ def _build_readable_messages_internal(
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
||||
if not at_person_name:
|
||||
at_person_name = aaa
|
||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||
if target_str in content:
|
||||
if random.random() < 0.6:
|
||||
content = content.replace(target_str, "")
|
||||
if target_str in content and random.random() < 0.6:
|
||||
content = content.replace(target_str, "")
|
||||
|
||||
if content != "":
|
||||
message_details_raw.append((timestamp, person_name, content, False))
|
||||
@@ -436,12 +520,16 @@ def _build_readable_messages_internal(
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
|
||||
# 查找对应的消息ID
|
||||
message_id = timestamp_to_id.get(merged["start_time"], "")
|
||||
id_prefix = f"[{message_id}] " if message_id else ""
|
||||
|
||||
# 检查是否是动作记录
|
||||
if merged["is_action"]:
|
||||
# 对于动作记录,使用特殊格式
|
||||
output_lines.append(f"{readable_time}, {merged['content'][0]}")
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {merged['content'][0]}")
|
||||
else:
|
||||
header = f"{readable_time}, {merged['name']} :"
|
||||
header = f"{id_prefix}{readable_time}, {merged['name']} :"
|
||||
output_lines.append(header)
|
||||
# 将内容合并,并添加缩进
|
||||
for line in merged["content"]:
|
||||
@@ -470,6 +558,7 @@ def _build_readable_messages_internal(
|
||||
|
||||
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
|
||||
@@ -503,6 +592,48 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
|
||||
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
将动作列表转换为可读的文本格式。
|
||||
格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display)
|
||||
|
||||
Args:
|
||||
actions: 动作记录字典列表。
|
||||
|
||||
Returns:
|
||||
格式化的动作字符串。
|
||||
"""
|
||||
if not actions:
|
||||
return ""
|
||||
|
||||
output_lines = []
|
||||
current_time = time.time()
|
||||
|
||||
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
|
||||
# sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True)
|
||||
|
||||
for action in actions:
|
||||
action_time = action.get("time", current_time)
|
||||
action_name = action.get("action_name", "未知动作")
|
||||
if action_name in ["no_action", "no_reply"]:
|
||||
continue
|
||||
|
||||
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
||||
|
||||
time_diff_seconds = current_time - action_time
|
||||
|
||||
if time_diff_seconds < 60:
|
||||
time_ago_str = f"在{int(time_diff_seconds)}秒前"
|
||||
else:
|
||||
time_diff_minutes = round(time_diff_seconds / 60)
|
||||
time_ago_str = f"在{int(time_diff_minutes)}分钟前"
|
||||
|
||||
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
||||
output_lines.append(line)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
@@ -518,13 +649,44 @@ async def build_readable_messages_with_list(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
formatted_string = build_readable_messages(
|
||||
messages = messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
timestamp_mode=timestamp_mode,
|
||||
truncate=truncate,
|
||||
show_actions=show_actions,
|
||||
show_pic=show_pic,
|
||||
read_mark=read_mark,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
return formatted_string , message_id_list
|
||||
|
||||
|
||||
def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -535,7 +697,8 @@ def build_readable_messages(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> str:
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||
@@ -607,7 +770,7 @@ def build_readable_messages(
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic, message_id_list=message_id_list
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
@@ -635,6 +798,7 @@ def build_readable_messages(
|
||||
pic_id_mapping,
|
||||
pic_counter,
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
@@ -645,6 +809,7 @@ def build_readable_messages(
|
||||
pic_id_mapping,
|
||||
pic_counter,
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||
@@ -658,9 +823,7 @@ def build_readable_messages(
|
||||
# 组合结果
|
||||
result_parts = []
|
||||
if pic_mapping_info:
|
||||
result_parts.append(pic_mapping_info)
|
||||
result_parts.append("\n")
|
||||
|
||||
result_parts.extend((pic_mapping_info, "\n"))
|
||||
if formatted_before and formatted_after:
|
||||
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
||||
elif formatted_before:
|
||||
@@ -733,8 +896,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
platform = msg.get("chat_info_platform")
|
||||
user_id = msg.get("user_id")
|
||||
_timestamp = msg.get("time")
|
||||
content: str = ""
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message")
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "")
|
||||
|
||||
@@ -822,17 +986,14 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
person_ids_set = set() # 使用集合来自动去重
|
||||
|
||||
for msg in messages:
|
||||
platform = msg.get("user_platform")
|
||||
user_id = msg.get("user_id")
|
||||
platform: str = msg.get("user_platform") # type: ignore
|
||||
user_id: str = msg.get("user_id") # type: ignore
|
||||
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
|
||||
# 只有当获取到有效 person_id 时才添加
|
||||
if person_id:
|
||||
if person_id := PersonInfoManager.get_person_id(platform, user_id):
|
||||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, TypeVar, List, Union, Tuple
|
||||
import ast
|
||||
|
||||
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
|
||||
|
||||
# 定义类型变量用于泛型类型提示
|
||||
T = TypeVar("T")
|
||||
@@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
# 尝试标准的 JSON 解析
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果标准解析失败,尝试将单引号替换为双引号再解析
|
||||
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
|
||||
# 更安全的方式是用 ast.literal_eval
|
||||
# 如果标准解析失败,尝试用 ast.literal_eval 解析
|
||||
try:
|
||||
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
||||
result = ast.literal_eval(json_str)
|
||||
# 确保结果是字典(因为我们通常期望参数是字典)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
||||
return default_value
|
||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
||||
return default_value
|
||||
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
||||
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
return default_value
|
||||
|
||||
|
||||
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def extract_tool_call_arguments(
|
||||
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从LLM工具调用对象中提取参数
|
||||
|
||||
@@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
|
||||
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
||||
return default_result
|
||||
|
||||
# 提取arguments
|
||||
arguments_str = function_data.get("arguments", "{}")
|
||||
if not arguments_str:
|
||||
if arguments_str := function_data.get("arguments", "{}"):
|
||||
# 解析JSON
|
||||
return safe_json_loads(arguments_str, default_result)
|
||||
else:
|
||||
return default_result
|
||||
|
||||
# 解析JSON
|
||||
return safe_json_loads(arguments_str, default_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取工具调用参数时出错: {e}")
|
||||
return default_result
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import contextvars
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# import traceback
|
||||
from rich.traceback import install
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -32,6 +32,7 @@ class PromptContext:
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
# 保存当前上下文并设置新上下文
|
||||
if context_id is not None:
|
||||
@@ -88,8 +89,7 @@ class PromptContext:
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
target_context = context_id or self._current_context
|
||||
if target_context:
|
||||
if target_context := context_id or self._current_context:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class Prompt(str):
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记"""
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
|
||||
# 如果传入的是列表,将其转换为字符串
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
@@ -195,14 +195,8 @@ class Prompt(str):
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 修改自动注册逻辑
|
||||
if should_register:
|
||||
if global_prompt_manager._context._current_context:
|
||||
# 如果存在当前上下文,则注册到上下文中
|
||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||
pass
|
||||
else:
|
||||
# 否则注册到全局管理器
|
||||
global_prompt_manager.register(obj)
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@@ -276,15 +270,13 @@ class Prompt(str):
|
||||
self.name,
|
||||
args=list(args) if args else self._args,
|
||||
_should_register=False,
|
||||
**kwargs if kwargs else self._kwargs,
|
||||
**kwargs or self._kwargs,
|
||||
)
|
||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
return str(ret)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self._kwargs or self._args:
|
||||
return super().__str__()
|
||||
return self.template
|
||||
return super().__str__() if self._kwargs or self._args else self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database.database import db # This db is the Peewee database instance
|
||||
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self):
|
||||
async def run(self): # sourcery skip: use-named-expression
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
extended_end_time = current_time + timedelta(minutes=1)
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||
updated_rows = query.execute()
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
|
||||
:param online_seconds: 在线时间(秒)
|
||||
:return: 格式化后的在线时间字符串
|
||||
"""
|
||||
total_oneline_time = timedelta(seconds=online_seconds)
|
||||
total_online_time = timedelta(seconds=online_seconds)
|
||||
|
||||
days = total_oneline_time.days
|
||||
hours = total_oneline_time.seconds // 3600
|
||||
minutes = (total_oneline_time.seconds // 60) % 60
|
||||
seconds = total_oneline_time.seconds % 60
|
||||
days = total_online_time.days
|
||||
hours = total_online_time.seconds // 3600
|
||||
minutes = (total_online_time.seconds // 60) % 60
|
||||
seconds = total_online_time.seconds % 60
|
||||
if days > 0:
|
||||
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
|
||||
return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
|
||||
return f"{total_online_time.days}天{hours}小时{minutes}分钟{seconds}秒"
|
||||
elif hours > 0:
|
||||
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
|
||||
return f"{hours}小时{minutes}分钟{seconds}秒"
|
||||
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
now = datetime.now()
|
||||
if "deploy_time" in local_storage:
|
||||
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
|
||||
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
||||
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||
else:
|
||||
# 否则,使用最大时间范围,并记录部署时间为当前时间
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 创建后台任务,不等待完成
|
||||
collect_task = asyncio.create_task(
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now)
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||
)
|
||||
|
||||
stats = await collect_task
|
||||
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 创建并发的输出任务
|
||||
output_tasks = [
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||
]
|
||||
|
||||
# 等待所有输出任务完成
|
||||
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
|
||||
chat_id = None
|
||||
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
return stat
|
||||
|
||||
def _convert_defaultdict_to_dict(self, data):
|
||||
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
|
||||
"""递归转换defaultdict为普通dict"""
|
||||
if isinstance(data, defaultdict):
|
||||
# 转换defaultdict为普通dict
|
||||
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 全局阶段平均时间
|
||||
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
|
||||
output.append("全局阶段平均时间:")
|
||||
for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items():
|
||||
output.append(f" {stage}: {avg_time:.3f}秒")
|
||||
output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
|
||||
output.append("")
|
||||
|
||||
# Action类型比例
|
||||
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
]
|
||||
|
||||
tab_content_list.append(
|
||||
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
|
||||
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
|
||||
)
|
||||
|
||||
# 添加Focus统计内容
|
||||
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
f.write(html_template)
|
||||
|
||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
|
||||
"""生成Focus统计独立分页的HTML内容"""
|
||||
|
||||
# 为每个时间段准备Focus数据
|
||||
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 聊天流Action选择比例对比表(横向表格)
|
||||
focus_chat_action_ratios_rows = ""
|
||||
if stat_data.get("focus_action_ratios_by_chat"):
|
||||
# 获取所有action类型(按全局频率排序)
|
||||
all_action_types_for_ratio = sorted(
|
||||
stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True
|
||||
)
|
||||
|
||||
if all_action_types_for_ratio:
|
||||
if all_action_types_for_ratio := sorted(
|
||||
stat_data[FOCUS_ACTION_RATIOS].keys(),
|
||||
key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
|
||||
reverse=True,
|
||||
):
|
||||
# 为每个聊天流生成数据行(按循环数排序)
|
||||
chat_ratio_rows = []
|
||||
for chat_id in sorted(
|
||||
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
if period_name == "all_time":
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
||||
time_range = (
|
||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||
else:
|
||||
start_time = datetime.now() - period_delta
|
||||
time_range = (
|
||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
# 生成该时间段的Focus统计HTML
|
||||
section_html = f"""
|
||||
<div class="focus-period-section">
|
||||
@@ -1565,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
|
||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""生成版本对比独立分页的HTML内容"""
|
||||
|
||||
# 为每个时间段准备版本对比数据
|
||||
@@ -1681,16 +1676,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if period_name == "all_time":
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
||||
time_range = (
|
||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||
else:
|
||||
start_time = datetime.now() - period_delta
|
||||
time_range = (
|
||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
# 生成该时间段的版本对比HTML
|
||||
section_html = f"""
|
||||
<div class="version-period-section">
|
||||
@@ -1865,7 +1854,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
@@ -1875,7 +1864,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.cost or 0.0
|
||||
total_cost_data[interval_index] += cost
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
@@ -1892,7 +1881,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
@@ -1982,6 +1971,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||
# sourcery skip: extract-duplicate-method, move-assign-in-block
|
||||
"""生成图表选项卡HTML内容"""
|
||||
|
||||
# 生成不同颜色的调色板
|
||||
@@ -2293,7 +2283,7 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
# 数据收集任务
|
||||
collect_task = asyncio.create_task(
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now)
|
||||
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||
)
|
||||
|
||||
stats = await collect_task
|
||||
@@ -2301,8 +2291,8 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
# 创建并发的输出任务
|
||||
output_tasks = [
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||
]
|
||||
|
||||
# 等待所有输出任务完成
|
||||
@@ -2317,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
# 复用 StatisticOutputTask 的所有方法
|
||||
def _collect_all_statistics(self, now: datetime):
|
||||
return StatisticOutputTask._collect_all_statistics(self, now)
|
||||
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._statistic_console_output(self, stats, now)
|
||||
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
|
||||
|
||||
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._generate_html_report(self, stats, now)
|
||||
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
|
||||
|
||||
# 其他需要的方法也可以类似复用...
|
||||
@staticmethod
|
||||
@@ -2335,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
|
||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
|
||||
|
||||
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
|
||||
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
|
||||
|
||||
def _process_focus_file_data(
|
||||
self,
|
||||
@@ -2347,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
collect_period: List[Tuple[str, datetime]],
|
||||
file_time: datetime,
|
||||
):
|
||||
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
|
||||
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
|
||||
|
||||
def _calculate_focus_averages(self, stats: Dict[str, Any]):
|
||||
return StatisticOutputTask._calculate_focus_averages(self, stats)
|
||||
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
@@ -2358,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
||||
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_chat_stat(self, stats)
|
||||
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
||||
|
||||
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_focus_stat(self, stats)
|
||||
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
|
||||
|
||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
return StatisticOutputTask._generate_chart_data(self, stat)
|
||||
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
|
||||
|
||||
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
|
||||
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
|
||||
|
||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||
return StatisticOutputTask._generate_chart_tab(self, chart_data)
|
||||
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
|
||||
|
||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
|
||||
|
||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._generate_focus_tab(self, stat)
|
||||
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
|
||||
|
||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._generate_versions_tab(self, stat)
|
||||
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
|
||||
|
||||
def _convert_defaultdict_to_dict(self, data):
|
||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
|
||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
import asyncio
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -88,10 +89,10 @@ class Timer:
|
||||
|
||||
self.name = name
|
||||
self.storage = storage
|
||||
self.elapsed = None
|
||||
self.elapsed: float = None # type: ignore
|
||||
|
||||
self.auto_unit = auto_unit
|
||||
self.start = None
|
||||
self.start: float = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _validate_types(name, storage):
|
||||
@@ -120,7 +121,7 @@ class Timer:
|
||||
return None
|
||||
|
||||
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
wrapper.__timer__ = self # 保留计时器引用
|
||||
wrapper.__timer__ = self # 保留计时器引用 # type: ignore
|
||||
return wrapper
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@@ -7,10 +7,10 @@ import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import jieba
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jieba
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
|
||||
try:
|
||||
return "\u4e00" <= char <= "\u9fff"
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
logger.debug(str(e))
|
||||
return False
|
||||
|
||||
def _get_pinyin(self, sentence):
|
||||
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
|
||||
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||
if not py[-1].isdigit():
|
||||
# 为非数字结尾的拼音添加数字声调1
|
||||
return py + "1"
|
||||
return f"{py}1"
|
||||
|
||||
base = py[:-1] # 去掉声调
|
||||
tone = int(py[-1]) # 获取声调
|
||||
@@ -363,7 +363,7 @@ class ChineseTypoGenerator:
|
||||
else:
|
||||
# 处理多字词的单字替换
|
||||
word_result = []
|
||||
for _, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||
for _, (char, py) in enumerate(zip(word, word_pinyin, strict=False)):
|
||||
# 词中的字替换概率降低
|
||||
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
||||
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections import Counter
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.mood.mood_manager import mood_manager
|
||||
from ..message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from ...common.message_repository import find_messages, count_messages
|
||||
from typing import Optional, Tuple, Dict
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
|
||||
@@ -31,11 +30,7 @@ def db_message_to_str(message_dict: dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name = "[(%s)%s]%s" % (
|
||||
message_dict["user_id"],
|
||||
message_dict.get("user_nickname", ""),
|
||||
message_dict.get("user_cardname", ""),
|
||||
)
|
||||
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
|
||||
except Exception:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
@@ -58,11 +53,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||
is_mentioned = True
|
||||
return is_mentioned, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
logger.warning(str(e))
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
)
|
||||
@@ -127,30 +122,6 @@ async def get_embedding(text, request_type="embedding"):
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
sort_order = [("time", -1)]
|
||||
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
message_detailed_plain_text = ""
|
||||
message_detailed_plain_text_list = []
|
||||
|
||||
# 反转消息列表,使最新的消息在最后
|
||||
recent_messages.reverse()
|
||||
|
||||
if combine:
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text
|
||||
else:
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text_list
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
@@ -204,10 +175,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
|
||||
len_text = len(text)
|
||||
if len_text < 3:
|
||||
if random.random() < 0.01:
|
||||
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
|
||||
else:
|
||||
return [text]
|
||||
return list(text) if random.random() < 0.01 else [text]
|
||||
|
||||
# 定义分隔符
|
||||
separators = {",", ",", " ", "。", ";"}
|
||||
@@ -352,10 +320,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
max_length = global_config.response_splitter.max_length * 2
|
||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||
# 如果基本上是中文,则进行长度过滤
|
||||
if get_western_ratio(cleaned_text) < 0.1:
|
||||
if len(cleaned_text) > max_length:
|
||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
|
||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=global_config.chinese_typo.error_rate,
|
||||
@@ -420,7 +387,7 @@ def calculate_typing_time(
|
||||
# chinese_time *= 1 / typing_speed_multiplier
|
||||
# english_time *= 1 / typing_speed_multiplier
|
||||
# 计算中文字符数
|
||||
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
|
||||
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
|
||||
|
||||
# 如果只有一个中文字符,使用3倍时间
|
||||
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
||||
@@ -429,11 +396,7 @@ def calculate_typing_time(
|
||||
# 正常计算所有字符的输入时间
|
||||
total_time = 0.0
|
||||
for char in input_string:
|
||||
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
|
||||
total_time += chinese_time
|
||||
else: # 其他字符(如英文)
|
||||
total_time += english_time
|
||||
|
||||
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
|
||||
if is_emoji:
|
||||
total_time = 1
|
||||
|
||||
@@ -453,18 +416,14 @@ def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
words = jieba.lcut(text)
|
||||
# 统计词频
|
||||
word_freq = Counter(words)
|
||||
return word_freq
|
||||
return Counter(words)
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
@@ -491,9 +450,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
if len(message) > max_length:
|
||||
return message[:max_length] + "..."
|
||||
return message
|
||||
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||
|
||||
|
||||
def protect_kaomoji(sentence):
|
||||
@@ -522,7 +479,7 @@ def protect_kaomoji(sentence):
|
||||
placeholder_to_kaomoji = {}
|
||||
|
||||
for idx, match in enumerate(kaomoji_matches):
|
||||
kaomoji = match[0] if match[0] else match[1]
|
||||
kaomoji = match[0] or match[1]
|
||||
placeholder = f"__KAOMOJI_{idx}__"
|
||||
sentence = sentence.replace(kaomoji, placeholder, 1)
|
||||
placeholder_to_kaomoji[placeholder] = kaomoji
|
||||
@@ -563,7 +520,7 @@ def get_western_ratio(paragraph):
|
||||
if not alnum_chars:
|
||||
return 0.0
|
||||
|
||||
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
|
||||
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
@@ -610,6 +567,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
"""将时间戳转换为人类可读的时间格式
|
||||
|
||||
Args:
|
||||
@@ -621,7 +579,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
"""
|
||||
if mode == "normal":
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||
if mode == "normal_no_YMD":
|
||||
elif mode == "normal_no_YMD":
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
elif mode == "relative":
|
||||
now = time.time()
|
||||
@@ -640,7 +598,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
else:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
||||
else: # mode = "lite" or unknown
|
||||
# 只返回时分秒格式,喵~
|
||||
# 只返回时分秒格式
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
@@ -670,8 +628,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
elif chat_stream.user_info: # It's a private chat
|
||||
is_group_chat = False
|
||||
user_info = chat_stream.user_info
|
||||
platform = chat_stream.platform
|
||||
user_id = user_info.user_id
|
||||
platform: str = chat_stream.platform # type: ignore
|
||||
user_id: str = user_info.user_id # type: ignore
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = {
|
||||
@@ -709,3 +667,107 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
# Keep defaults on error
|
||||
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
len_i = len(messages)
|
||||
if len_i > 100:
|
||||
a = 10
|
||||
b = 99
|
||||
else:
|
||||
a = 1
|
||||
b = 9
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的简短ID
|
||||
while True:
|
||||
# 使用索引+随机数生成简短ID
|
||||
random_suffix = random.randint(a, b)
|
||||
message_id = f"m{i+1}{random_suffix}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def assign_message_ids_flexible(
|
||||
messages: list,
|
||||
prefix: str = "msg",
|
||||
id_length: int = 6,
|
||||
use_timestamp: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prefix: ID前缀,默认为"msg"
|
||||
id_length: ID的总长度(不包括前缀),默认为6
|
||||
use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的ID
|
||||
while True:
|
||||
if use_timestamp:
|
||||
# 使用时间戳的后几位 + 随机字符
|
||||
timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
remaining_length = id_length - 3
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
else:
|
||||
# 使用索引 + 随机字符
|
||||
index_str = str(i + 1)
|
||||
remaining_length = max(1, id_length - len(index_str))
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
@@ -3,21 +3,20 @@ import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
@@ -103,7 +102,7 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
@@ -111,15 +110,15 @@ class ImageManager:
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
if image_format in ["gif", "GIF"]:
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些,输出一段平文本,不超过15个字"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
prompt = "图片是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些,输出一段平文本,不超过15个字"
|
||||
prompt = "图片是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
@@ -154,7 +153,7 @@ class ImageManager:
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
Images.create(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -204,8 +203,8 @@ class ImageManager:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
@@ -258,6 +257,7 @@ class ImageManager:
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||
|
||||
Args:
|
||||
@@ -351,7 +351,7 @@ class ImageManager:
|
||||
# 创建拼接图像
|
||||
total_width = target_width * len(resized_frames)
|
||||
# 防止总宽度为0
|
||||
if total_width == 0 and len(resized_frames) > 0:
|
||||
if total_width == 0 and resized_frames:
|
||||
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
|
||||
# 至少给点宽度吧
|
||||
total_width = len(resized_frames)
|
||||
@@ -368,10 +368,7 @@ class ImageManager:
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
|
||||
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return result_base64
|
||||
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
except MemoryError:
|
||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||
return None # 内存不够啦
|
||||
@@ -380,6 +377,7 @@ class ImageManager:
|
||||
return None # 其他错误也返回None
|
||||
|
||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||
# sourcery skip: hoist-if-from-if
|
||||
"""处理图片并返回图片ID和描述
|
||||
|
||||
Args:
|
||||
@@ -418,17 +416,9 @@ class ImageManager:
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片已存在: {existing_image.image_id}")
|
||||
# print(f"图片描述: {existing_image.description}")
|
||||
# print(f"图片计数: {existing_image.count}")
|
||||
# 更新计数
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
@@ -491,10 +481,10 @@ class ImageManager:
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"""
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
@@ -24,19 +24,25 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
willing_info = self.ongoing_messages[message_id]
|
||||
chat_id = willing_info.chat_id
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
|
||||
# print(f"[{chat_id}] 回复意愿: {current_willing}")
|
||||
|
||||
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
|
||||
|
||||
# print(f"[{chat_id}] 兴趣值: {interested_rate}")
|
||||
|
||||
if interested_rate > 0.4:
|
||||
current_willing += interested_rate - 0.3
|
||||
if interested_rate > 0.2:
|
||||
current_willing += interested_rate - 0.2
|
||||
|
||||
if willing_info.is_mentioned_bot:
|
||||
if willing_info.is_mentioned_bot and global_config.normal_chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
||||
current_willing += 1 if current_willing < 1.0 else 0.05
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
||||
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||
|
||||
|
||||
# print(f"[{chat_id}] 回复概率: {reply_probability}")
|
||||
|
||||
return reply_probability
|
||||
|
||||
async def before_generate_reply_handle(self, message_id):
|
||||
@@ -1,14 +1,15 @@
|
||||
from src.common.logger import get_logger
|
||||
import importlib
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
from rich.traceback import install
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from abc import ABC, abstractmethod
|
||||
import importlib
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -52,7 +53,7 @@ class WillingInfo:
|
||||
interested_rate (float): 兴趣度
|
||||
"""
|
||||
|
||||
message: MessageRecv
|
||||
message: Dict[str, Any] # 原始消息数据
|
||||
chat: ChatStream
|
||||
person_info_manager: PersonInfoManager
|
||||
chat_id: str
|
||||
@@ -91,19 +92,19 @@ class BaseWillingManager(ABC):
|
||||
self.lock = asyncio.Lock()
|
||||
self.logger = logger
|
||||
|
||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
||||
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
||||
def setup(self, message: dict, chat: ChatStream):
|
||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
|
||||
self.ongoing_messages[message.get("message_id", "")] = WillingInfo( # type: ignore
|
||||
message=message,
|
||||
chat=chat,
|
||||
person_info_manager=get_person_info_manager(),
|
||||
chat_id=chat.stream_id,
|
||||
person_id=person_id,
|
||||
group_info=chat.group_info,
|
||||
is_mentioned_bot=is_mentioned_bot,
|
||||
is_emoji=message.is_emoji,
|
||||
is_picid=message.is_picid,
|
||||
interested_rate=interested_rate,
|
||||
is_mentioned_bot=message.get("is_mentioned", False),
|
||||
is_emoji=message.get("is_emoji", False),
|
||||
is_picid=message.get("is_picid", False),
|
||||
interested_rate=message.get("interest_value", 0),
|
||||
)
|
||||
|
||||
def delete(self, message_id: str):
|
||||
@@ -1,84 +0,0 @@
|
||||
from typing import Tuple
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
||||
|
||||
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
初始化记忆项
|
||||
|
||||
Args:
|
||||
summary: 记忆内容概括
|
||||
from_source: 数据来源
|
||||
brief: 记忆内容主题
|
||||
"""
|
||||
# 生成可读ID:时间戳_随机字符串
|
||||
timestamp = int(time.time())
|
||||
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||
self.id = f"{timestamp}_{random_str}"
|
||||
self.from_source = from_source
|
||||
self.brief = brief
|
||||
self.timestamp = time.time()
|
||||
|
||||
# 记忆内容概括
|
||||
self.summary = summary
|
||||
|
||||
# 记忆精简次数
|
||||
self.compress_count = 0
|
||||
|
||||
# 记忆提取次数
|
||||
self.retrieval_count = 0
|
||||
|
||||
# 记忆强度 (初始为10)
|
||||
self.memory_strength = 10.0
|
||||
|
||||
# 记忆操作历史记录
|
||||
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
|
||||
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
|
||||
|
||||
def matches_source(self, source: str) -> bool:
|
||||
"""检查来源是否匹配"""
|
||||
return self.from_source == source
|
||||
|
||||
def increase_strength(self, amount: float) -> None:
|
||||
"""增加记忆强度"""
|
||||
self.memory_strength = min(10.0, self.memory_strength + amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("strengthen")
|
||||
|
||||
def decrease_strength(self, amount: float) -> None:
|
||||
"""减少记忆强度"""
|
||||
self.memory_strength = max(0.1, self.memory_strength - amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("weaken")
|
||||
|
||||
def increase_compress_count(self) -> None:
|
||||
"""增加精简次数并减弱记忆强度"""
|
||||
self.compress_count += 1
|
||||
# 记录操作历史
|
||||
self.record_operation("compress")
|
||||
|
||||
def record_retrieval(self) -> None:
|
||||
"""记录记忆被提取的情况"""
|
||||
self.retrieval_count += 1
|
||||
# 提取后强度翻倍
|
||||
self.memory_strength = min(10.0, self.memory_strength * 2)
|
||||
# 记录操作历史
|
||||
self.record_operation("retrieval")
|
||||
|
||||
def record_operation(self, operation_type: str) -> None:
|
||||
"""记录操作历史"""
|
||||
current_time = time.time()
|
||||
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
||||
|
||||
def to_tuple(self) -> Tuple[str, str, float, str]:
|
||||
"""转换为元组格式(为了兼容性)"""
|
||||
return (self.summary, self.from_source, self.timestamp, self.id)
|
||||
|
||||
def is_memory_valid(self) -> bool:
|
||||
"""检查记忆是否有效(强度是否大于等于1)"""
|
||||
return self.memory_strength >= 1.0
|
||||
@@ -1,413 +0,0 @@
|
||||
from typing import Dict, TypeVar, List, Optional
|
||||
import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
import json # 添加json模块导入
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("working_memory")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化工作记忆
|
||||
|
||||
Args:
|
||||
chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
|
||||
"""
|
||||
# 关联的聊天ID
|
||||
self._chat_id = chat_id
|
||||
|
||||
# 记忆项列表
|
||||
self._memories: List[MemoryItem] = []
|
||||
|
||||
# ID到记忆项的映射
|
||||
self._id_map: Dict[str, MemoryItem] = {}
|
||||
|
||||
self.llm_summarizer = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.3,
|
||||
request_type="working_memory",
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
"""获取关联的聊天ID"""
|
||||
return self._chat_id
|
||||
|
||||
@chat_id.setter
|
||||
def chat_id(self, value: str):
|
||||
"""设置关联的聊天ID"""
|
||||
self._chat_id = value
|
||||
|
||||
def push_item(self, memory_item: MemoryItem) -> str:
|
||||
"""
|
||||
推送一个已创建的记忆项到工作记忆中
|
||||
|
||||
Args:
|
||||
memory_item: 要存储的记忆项
|
||||
|
||||
Returns:
|
||||
记忆项的ID
|
||||
"""
|
||||
# 添加到内存和ID映射
|
||||
self._memories.append(memory_item)
|
||||
self._id_map[memory_item.id] = memory_item
|
||||
|
||||
return memory_item.id
|
||||
|
||||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
通过ID获取记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 记忆项ID
|
||||
|
||||
Returns:
|
||||
找到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self._id_map.get(memory_id)
|
||||
if memory_item:
|
||||
# 检查记忆强度,如果小于1则删除
|
||||
if not memory_item.is_memory_valid():
|
||||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||||
self.delete(memory_id)
|
||||
return None
|
||||
|
||||
return memory_item
|
||||
|
||||
def get_all_items(self) -> List[MemoryItem]:
|
||||
"""获取所有记忆项"""
|
||||
return list(self._id_map.values())
|
||||
|
||||
def find_items(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
memory_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
newest_first: bool = False,
|
||||
min_strength: float = 0.0,
|
||||
) -> List[MemoryItem]:
|
||||
"""
|
||||
按条件查找记忆项
|
||||
|
||||
Args:
|
||||
source: 数据来源
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
memory_id: 特定记忆项ID
|
||||
limit: 返回结果的最大数量
|
||||
newest_first: 是否按最新优先排序
|
||||
min_strength: 最小记忆强度
|
||||
|
||||
Returns:
|
||||
符合条件的记忆项列表
|
||||
"""
|
||||
# 如果提供了特定ID,直接查找
|
||||
if memory_id:
|
||||
item = self.get_by_id(memory_id)
|
||||
return [item] if item else []
|
||||
|
||||
results = []
|
||||
|
||||
# 获取所有项目
|
||||
items = self._memories
|
||||
|
||||
# 如果需要最新优先,则反转遍历顺序
|
||||
if newest_first:
|
||||
items_to_check = list(reversed(items))
|
||||
else:
|
||||
items_to_check = items
|
||||
|
||||
# 遍历项目
|
||||
for item in items_to_check:
|
||||
# 检查来源是否匹配
|
||||
if source is not None and not item.matches_source(source):
|
||||
continue
|
||||
|
||||
# 检查时间范围
|
||||
if start_time is not None and item.timestamp < start_time:
|
||||
continue
|
||||
if end_time is not None and item.timestamp > end_time:
|
||||
continue
|
||||
|
||||
# 检查记忆强度
|
||||
if min_strength > 0 and item.memory_strength < min_strength:
|
||||
continue
|
||||
|
||||
# 所有条件都满足,添加到结果中
|
||||
results.append(item)
|
||||
|
||||
# 如果达到限制数量,提前返回
|
||||
if limit is not None and len(results) >= limit:
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
|
||||
"""
|
||||
使用LLM总结记忆项
|
||||
|
||||
Args:
|
||||
content: 需要总结的内容
|
||||
|
||||
Returns:
|
||||
包含brief和summary的字典
|
||||
"""
|
||||
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
||||
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
||||
2. 记忆内容概括:对内容进行概括,保留重要信息,200字以内
|
||||
|
||||
内容:
|
||||
{content}
|
||||
|
||||
请按以下JSON格式输出:
|
||||
{{
|
||||
"brief": "记忆内容主题",
|
||||
"summary": "记忆内容概括"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
default_summary = {
|
||||
"brief": "主题未知的记忆",
|
||||
"summary": "无法概括的记忆内容",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM生成总结
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 使用repair_json解析响应
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
json_result = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
return default_summary
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
json_result = fixed_json_string
|
||||
|
||||
# 进行额外的类型检查
|
||||
if not isinstance(json_result, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
|
||||
return default_summary
|
||||
|
||||
# 确保所有必要字段都存在且类型正确
|
||||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||||
json_result["brief"] = "主题未知的记忆"
|
||||
|
||||
if "summary" not in json_result or not isinstance(json_result["summary"], str):
|
||||
json_result["summary"] = "无法概括的记忆内容"
|
||||
|
||||
return json_result
|
||||
|
||||
except Exception as json_error:
|
||||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||||
return default_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成总结时出错: {str(e)}")
|
||||
return default_summary
|
||||
|
||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||
"""
|
||||
使单个记忆衰减
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
|
||||
Returns:
|
||||
是否成功衰减
|
||||
"""
|
||||
memory_item = self.get_by_id(memory_id)
|
||||
if not memory_item:
|
||||
return False
|
||||
|
||||
# 计算衰减量(当前强度 * (1-衰减因子))
|
||||
old_strength = memory_item.memory_strength
|
||||
decay_amount = old_strength * (1 - decay_factor)
|
||||
|
||||
# 更新强度
|
||||
memory_item.memory_strength = decay_amount
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, memory_id: str) -> bool:
|
||||
"""
|
||||
删除指定ID的记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 要删除的记忆项ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
if memory_id not in self._id_map:
|
||||
return False
|
||||
|
||||
# 获取要删除的项
|
||||
self._id_map[memory_id]
|
||||
|
||||
# 从内存中删除
|
||||
self._memories = [i for i in self._memories if i.id != memory_id]
|
||||
|
||||
# 从ID映射中删除
|
||||
del self._id_map[memory_id]
|
||||
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有记忆"""
|
||||
self._memories.clear()
|
||||
self._id_map.clear()
|
||||
|
||||
async def merge_memories(
|
||||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||
) -> MemoryItem:
|
||||
"""
|
||||
合并两个记忆项
|
||||
|
||||
Args:
|
||||
memory_id1: 第一个记忆项ID
|
||||
memory_id2: 第二个记忆项ID
|
||||
reason: 合并原因
|
||||
delete_originals: 是否删除原始记忆,默认为True
|
||||
|
||||
Returns:
|
||||
合并后的记忆项
|
||||
"""
|
||||
# 获取两个记忆项
|
||||
memory_item1 = self.get_by_id(memory_id1)
|
||||
memory_item2 = self.get_by_id(memory_id2)
|
||||
|
||||
if not memory_item1 or not memory_item2:
|
||||
raise ValueError("无法找到指定的记忆项")
|
||||
|
||||
# 构建合并提示
|
||||
prompt = f"""
|
||||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||||
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
||||
|
||||
合并原因:{reason}
|
||||
|
||||
记忆1主题:{memory_item1.brief}
|
||||
记忆1内容:{memory_item1.summary}
|
||||
|
||||
记忆2主题:{memory_item2.brief}
|
||||
记忆2内容:{memory_item2.summary}
|
||||
|
||||
请按以下JSON格式输出合并结果:
|
||||
{{
|
||||
"brief": "合并后的主题(20字以内)",
|
||||
"summary": "合并后的内容概括(200字以内)"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
|
||||
# 默认合并结果
|
||||
default_merged = {
|
||||
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
|
||||
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM合并记忆
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 处理LLM返回的合并结果
|
||||
try:
|
||||
# 修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 将修复后的字符串解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
merged_data = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
merged_data = default_merged
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
merged_data = fixed_json_string
|
||||
|
||||
# 确保是字典类型
|
||||
if not isinstance(merged_data, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
|
||||
merged_data = default_merged
|
||||
|
||||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||||
merged_data["brief"] = default_merged["brief"]
|
||||
|
||||
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
|
||||
merged_data["summary"] = default_merged["summary"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆调用LLM出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
|
||||
# 创建新的记忆项
|
||||
# 取两个记忆项中更强的来源
|
||||
merged_source = (
|
||||
memory_item1.from_source
|
||||
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||||
else memory_item2.from_source
|
||||
)
|
||||
|
||||
# 创建新的记忆项
|
||||
merged_memory = MemoryItem(
|
||||
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
|
||||
)
|
||||
|
||||
# 记忆强度取两者最大值
|
||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||
|
||||
# 添加到存储中
|
||||
self.push_item(merged_memory)
|
||||
|
||||
# 如果需要,删除原始记忆
|
||||
if delete_originals:
|
||||
self.delete(memory_id1)
|
||||
self.delete(memory_id2)
|
||||
|
||||
return merged_memory
|
||||
|
||||
def delete_earliest_memory(self) -> bool:
|
||||
"""
|
||||
删除最早的记忆项
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
# 获取所有记忆项
|
||||
all_memories = self.get_all_items()
|
||||
|
||||
if not all_memories:
|
||||
return False
|
||||
|
||||
# 按时间戳排序,找到最早的记忆项
|
||||
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
|
||||
|
||||
# 删除最早的记忆项
|
||||
return self.delete(earliest_memory.id)
|
||||
@@ -1,156 +0,0 @@
|
||||
from typing import List, Any, Optional
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
工作记忆,负责协调和运作记忆
|
||||
从属于特定的流,用chat_id来标识
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
"""
|
||||
初始化工作记忆管理器
|
||||
|
||||
Args:
|
||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
||||
"""
|
||||
self.memory_manager = MemoryManager(chat_id)
|
||||
|
||||
# 记忆容量上限
|
||||
self.max_memories_per_chat = max_memories_per_chat
|
||||
|
||||
# 自动衰减间隔
|
||||
self.auto_decay_interval = auto_decay_interval
|
||||
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
|
||||
# 只有在工作记忆处理器启用时才启动自动衰减任务
|
||||
if global_config.focus_chat_processor.working_memory_processor:
|
||||
self._start_auto_decay()
|
||||
else:
|
||||
logger.debug(f"工作记忆处理器已禁用,跳过启动自动衰减任务 (chat_id: {chat_id})")
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
if self.decay_task is None:
|
||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
||||
|
||||
async def _auto_decay_loop(self):
|
||||
"""自动衰减循环"""
|
||||
while True:
|
||||
await asyncio.sleep(self.auto_decay_interval)
|
||||
try:
|
||||
await self.decay_all_memories()
|
||||
except Exception as e:
|
||||
print(f"自动衰减记忆时出错: {str(e)}")
|
||||
|
||||
async def add_memory(self, summary: Any, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
添加一段记忆到指定聊天
|
||||
|
||||
Args:
|
||||
summary: 记忆内容
|
||||
from_source: 数据来源
|
||||
|
||||
Returns:
|
||||
记忆项
|
||||
"""
|
||||
# 如果是字符串类型,生成总结
|
||||
|
||||
memory = MemoryItem(summary, from_source, brief)
|
||||
|
||||
# 添加到管理器
|
||||
self.memory_manager.push_item(memory)
|
||||
|
||||
# 如果超过最大记忆数量,删除最早的记忆
|
||||
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
||||
self.remove_earliest_memory()
|
||||
|
||||
return memory
|
||||
|
||||
def remove_earliest_memory(self):
|
||||
"""
|
||||
删除最早的记忆
|
||||
"""
|
||||
return self.memory_manager.delete_earliest_memory()
|
||||
|
||||
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
检索记忆
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
检索到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self.memory_manager.get_by_id(memory_id)
|
||||
if memory_item:
|
||||
memory_item.retrieval_count += 1
|
||||
memory_item.increase_strength(5)
|
||||
return memory_item
|
||||
return None
|
||||
|
||||
async def decay_all_memories(self, decay_factor: float = 0.5):
|
||||
"""
|
||||
对所有聊天的所有记忆进行衰减
|
||||
衰减:对记忆进行refine压缩,强度会变为原先的0.5
|
||||
|
||||
Args:
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
"""
|
||||
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
|
||||
|
||||
all_memories = self.memory_manager.get_all_items()
|
||||
|
||||
for memory_item in all_memories:
|
||||
# 如果压缩完小于1会被删除
|
||||
memory_id = memory_item.id
|
||||
self.memory_manager.decay_memory(memory_id, decay_factor)
|
||||
if memory_item.memory_strength < 1:
|
||||
self.memory_manager.delete(memory_id)
|
||||
continue
|
||||
# 计算衰减量
|
||||
# if memory_item.memory_strength < 5:
|
||||
# await self.memory_manager.refine_memory(
|
||||
# memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
|
||||
# )
|
||||
|
||||
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
||||
"""合并记忆
|
||||
|
||||
Args:
|
||||
memory_str: 记忆内容
|
||||
"""
|
||||
return await self.memory_manager.merge_memories(
|
||||
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭管理器,停止所有任务"""
|
||||
if self.decay_task and not self.decay_task.done():
|
||||
self.decay_task.cancel()
|
||||
try:
|
||||
await self.decay_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""
|
||||
获取所有记忆项目
|
||||
|
||||
Returns:
|
||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
||||
"""
|
||||
return self.memory_manager.get_all_items()
|
||||
@@ -1,261 +0,0 @@
|
||||
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.observation.observation import Observation
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from typing import List
|
||||
from src.chat.focus_chat.observation.working_observation import WorkingMemoryObservation
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
memory_proces_prompt = """
|
||||
你的名字是{bot_name}
|
||||
|
||||
现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
|
||||
{memory_str}
|
||||
|
||||
观察聊天内容和已经总结的记忆,思考如果有相近的记忆,请合并记忆,输出merge_memory,
|
||||
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
|
||||
|
||||
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
|
||||
```json
|
||||
{{
|
||||
"selected_memory_ids": ["id1", "id2", ...]
|
||||
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
|
||||
}}
|
||||
```
|
||||
"""
|
||||
Prompt(memory_proces_prompt, "prompt_memory_proces")
|
||||
|
||||
|
||||
class WorkingMemoryProcessor:
|
||||
log_prefix = "工作记忆"
|
||||
|
||||
def __init__(self, subheartflow_id: str):
|
||||
self.subheartflow_id = subheartflow_id
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
*infos: 可变数量的InfoBase类型的信息对象
|
||||
|
||||
Returns:
|
||||
List[InfoBase]: 处理后的结构化信息列表
|
||||
"""
|
||||
working_memory = None
|
||||
chat_info = ""
|
||||
chat_obs = None
|
||||
try:
|
||||
for observation in observations:
|
||||
if isinstance(observation, WorkingMemoryObservation):
|
||||
working_memory = observation.get_observe_info()
|
||||
if isinstance(observation, ChattingObservation):
|
||||
chat_info = observation.get_observe_info()
|
||||
chat_obs = observation
|
||||
# 检查是否有待压缩内容
|
||||
if chat_obs and chat_obs.compressor_prompt:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
|
||||
await self.compress_chat_memory(working_memory, chat_obs)
|
||||
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆观察,跳过处理")
|
||||
return []
|
||||
|
||||
all_memory = working_memory.get_all_memories()
|
||||
if not all_memory:
|
||||
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
|
||||
return []
|
||||
|
||||
memory_prompts = []
|
||||
for memory in all_memory:
|
||||
memory_id = memory.id
|
||||
memory_brief = memory.brief
|
||||
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
||||
memory_prompts.append(memory_single_prompt)
|
||||
|
||||
memory_choose_str = "".join(memory_prompts)
|
||||
|
||||
# 使用提示模板进行处理
|
||||
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
chat_observe_info=chat_info,
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(f"prompt: {prompt}---------------------------------")
|
||||
# print(f"content: {content}---------------------------------")
|
||||
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
# 解析LLM返回的JSON
|
||||
try:
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}")
|
||||
return []
|
||||
|
||||
selected_memory_ids = result.get("selected_memory_ids", [])
|
||||
merge_memory = result.get("merge_memory", [])
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}"
|
||||
)
|
||||
|
||||
# 根据selected_memory_ids,调取记忆
|
||||
memory_str = ""
|
||||
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
||||
|
||||
# 遍历所有记忆
|
||||
for memory in all_memory:
|
||||
if memory.id in selected_ids:
|
||||
# 选中的记忆显示详细内容
|
||||
memory = await working_memory.retrieve_memory(memory.id)
|
||||
if memory:
|
||||
memory_str += f"{memory.summary}\n"
|
||||
else:
|
||||
# 未选中的记忆显示梗概
|
||||
memory_str += f"{memory.brief}\n"
|
||||
|
||||
working_memory_info = WorkingMemoryInfo()
|
||||
if memory_str:
|
||||
working_memory_info.add_working_memory(memory_str)
|
||||
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
|
||||
|
||||
if merge_memory:
|
||||
for merge_pairs in merge_memory:
|
||||
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
|
||||
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
|
||||
if memory1 and memory2:
|
||||
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
|
||||
|
||||
return [working_memory_info]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
|
||||
"""压缩聊天记忆
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
obs: 聊天观察对象
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法压缩聊天记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
|
||||
if not summary_result:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
||||
return
|
||||
|
||||
print(f"compressor_prompt: {obs.compressor_prompt}")
|
||||
print(f"summary_result: {summary_result}")
|
||||
|
||||
# 修复并解析JSON
|
||||
try:
|
||||
fixed_json = repair_json(summary_result)
|
||||
summary_data = json.loads(fixed_json)
|
||||
|
||||
if not isinstance(summary_data, dict):
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
||||
return
|
||||
|
||||
theme = summary_data.get("theme", "")
|
||||
content = summary_data.get("content", "")
|
||||
|
||||
if not theme or not content:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
||||
return
|
||||
|
||||
# 创建新记忆
|
||||
await working_memory.add_memory(from_source="chat_compress", summary=content, brief=theme)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
# 清理压缩状态
|
||||
obs.compressor_prompt = ""
|
||||
obs.oldest_messages = []
|
||||
obs.oldest_messages_str = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
|
||||
"""异步合并记忆,不阻塞主流程
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
memory_id1: 第一个记忆ID
|
||||
memory_id2: 第二个记忆ID
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法合并记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -54,11 +54,11 @@ class DBWrapper:
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key]
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
memory_db: Database = DBWrapper()
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
|
||||
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
||||
user_cardname = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||
# 请取消注释并在下面设置数据库实例:
|
||||
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
|
||||
status = TextField()
|
||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# database = db
|
||||
table_name = "llm_usage"
|
||||
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
|
||||
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
||||
last_used_time = FloatField(null=True) # 上次使用时间
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "emoji"
|
||||
|
||||
@@ -129,6 +129,9 @@ class Messages(BaseModel):
|
||||
|
||||
reply_to = TextField(null=True)
|
||||
|
||||
interest_value = DoubleField(null=True)
|
||||
is_mentioned = BooleanField(null=True)
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id = TextField()
|
||||
chat_info_platform = TextField()
|
||||
@@ -150,10 +153,17 @@ class Messages(BaseModel):
|
||||
|
||||
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
||||
display_message = TextField(null=True) # 显示的消息
|
||||
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
|
||||
memorized_times = IntegerField(default=0) # 被记忆的次数
|
||||
|
||||
class Meta:
|
||||
priority_mode = TextField(null=True)
|
||||
priority_info = TextField(null=True)
|
||||
|
||||
additional_config = TextField(null=True)
|
||||
is_emoji = BooleanField(default=False)
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "messages"
|
||||
|
||||
@@ -177,7 +187,7 @@ class ActionRecords(BaseModel):
|
||||
chat_info_stream_id = TextField()
|
||||
chat_info_platform = TextField()
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "action_records"
|
||||
|
||||
@@ -197,7 +207,7 @@ class Images(BaseModel):
|
||||
type = TextField() # 图像类型,例如 "emoji"
|
||||
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "images"
|
||||
|
||||
|
||||
@@ -211,7 +221,7 @@ class ImageDescriptions(BaseModel):
|
||||
description = TextField() # 图像的描述
|
||||
timestamp = FloatField() # 时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "image_descriptions"
|
||||
|
||||
@@ -227,7 +237,7 @@ class OnlineTime(BaseModel):
|
||||
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||
end_timestamp = DateTimeField(index=True)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "online_time"
|
||||
|
||||
@@ -254,11 +264,23 @@ class PersonInfo(BaseModel):
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "person_info"
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
memory_id = TextField(index=True)
|
||||
chat_id = TextField(null=True)
|
||||
memory_text = TextField(null=True)
|
||||
keywords = TextField(null=True)
|
||||
create_time = FloatField(null=True)
|
||||
last_view_time = FloatField(null=True)
|
||||
|
||||
class Meta: # type: ignore
|
||||
table_name = "memory"
|
||||
|
||||
|
||||
class Knowledges(BaseModel):
|
||||
"""
|
||||
用于存储知识库条目的模型。
|
||||
@@ -268,11 +290,27 @@ class Knowledges(BaseModel):
|
||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||
# 可以添加其他元数据字段,如 source, create_time 等
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "knowledges"
|
||||
|
||||
|
||||
class Expression(BaseModel):
|
||||
"""
|
||||
用于存储表达风格的模型。
|
||||
"""
|
||||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
|
||||
class Meta: # type: ignore
|
||||
table_name = "expression"
|
||||
|
||||
|
||||
class ThinkingLog(BaseModel):
|
||||
chat_id = TextField(index=True)
|
||||
trigger_text = TextField(null=True)
|
||||
@@ -293,23 +331,10 @@ class ThinkingLog(BaseModel):
|
||||
# And: import datetime
|
||||
created_at = DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "thinking_logs"
|
||||
|
||||
|
||||
class RecalledMessages(BaseModel):
|
||||
"""
|
||||
用于存储撤回消息记录的模型。
|
||||
"""
|
||||
|
||||
message_id = TextField(index=True) # 被撤回的消息 ID
|
||||
time = DoubleField() # 撤回操作发生的时间戳
|
||||
stream_id = TextField() # 对应的 ChatStreams stream_id
|
||||
|
||||
class Meta:
|
||||
table_name = "recalled_messages"
|
||||
|
||||
|
||||
class GraphNodes(BaseModel):
|
||||
"""
|
||||
用于存储记忆图节点的模型
|
||||
@@ -321,7 +346,7 @@ class GraphNodes(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "graph_nodes"
|
||||
|
||||
|
||||
@@ -337,7 +362,7 @@ class GraphEdges(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "graph_edges"
|
||||
|
||||
|
||||
@@ -357,10 +382,11 @@ def create_tables():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
RecalledMessages, # 添加新模型
|
||||
GraphNodes, # 添加图节点表
|
||||
GraphEdges, # 添加图边表
|
||||
Memory,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
]
|
||||
)
|
||||
@@ -382,8 +408,9 @@ def initialize_database():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
Expression,
|
||||
Memory,
|
||||
ThinkingLog,
|
||||
RecalledMessages,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
@@ -404,9 +431,7 @@ def initialize_database():
|
||||
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||
model_fields = set(model._meta.fields.keys())
|
||||
|
||||
# 检查并添加缺失字段(原有逻辑)
|
||||
missing_fields = model_fields - existing_columns
|
||||
if missing_fields:
|
||||
if missing_fields := model_fields - existing_columns:
|
||||
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
|
||||
|
||||
for field_name, field_obj in model._meta.fields.items():
|
||||
@@ -422,10 +447,7 @@ def initialize_database():
|
||||
"DateTimeField": "DATETIME",
|
||||
}.get(field_type, "TEXT")
|
||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||
if field_obj.null:
|
||||
alter_sql += " NULL"
|
||||
else:
|
||||
alter_sql += " NOT NULL"
|
||||
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||
# 正确处理不同类型的默认值
|
||||
default_value = field_obj.default
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import logging
|
||||
|
||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import logging
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import structlog
|
||||
import toml
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 创建logs目录
|
||||
LOG_DIR = Path("logs")
|
||||
LOG_DIR.mkdir(exist_ok=True)
|
||||
@@ -160,7 +160,7 @@ def close_handlers():
|
||||
_console_handler = None
|
||||
|
||||
|
||||
def remove_duplicate_handlers():
|
||||
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
|
||||
"""移除重复的handler,特别是文件handler"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
@@ -184,7 +184,7 @@ def remove_duplicate_handlers():
|
||||
|
||||
|
||||
# 读取日志配置
|
||||
def load_log_config():
|
||||
def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"""从配置文件加载日志设置"""
|
||||
config_path = Path("config/bot_config.toml")
|
||||
default_config = {
|
||||
@@ -365,7 +365,7 @@ MODULE_COLORS = {
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"hearflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
"independent_apis": "\033[38;5;82m", # 绿色
|
||||
"llm_api": "\033[38;5;46m", # 亮绿色
|
||||
@@ -403,6 +403,10 @@ MODULE_COLORS = {
|
||||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
|
||||
#s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
@@ -412,6 +416,7 @@ class ModuleColoredConsoleRenderer:
|
||||
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
||||
|
||||
def __init__(self, colors=True):
|
||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||
self._colors = colors
|
||||
self._config = LOG_CONFIG
|
||||
|
||||
@@ -443,6 +448,7 @@ class ModuleColoredConsoleRenderer:
|
||||
self._enable_full_content_colors = False
|
||||
|
||||
def __call__(self, logger, method_name, event_dict):
|
||||
# sourcery skip: merge-duplicate-blocks
|
||||
"""渲染日志消息"""
|
||||
# 获取基本信息
|
||||
timestamp = event_dict.get("timestamp", "")
|
||||
@@ -662,7 +668,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
"""获取logger实例,支持按名称绑定"""
|
||||
if name is None:
|
||||
return raw_logger
|
||||
logger = binds.get(name)
|
||||
logger = binds.get(name) # type: ignore
|
||||
if logger is None:
|
||||
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
|
||||
binds[name] = logger
|
||||
@@ -671,8 +677,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
|
||||
def configure_logging(
|
||||
level: str = "INFO",
|
||||
console_level: str = None,
|
||||
file_level: str = None,
|
||||
console_level: Optional[str] = None,
|
||||
file_level: Optional[str] = None,
|
||||
max_bytes: int = 5 * 1024 * 1024,
|
||||
backup_count: int = 30,
|
||||
log_dir: str = "logs",
|
||||
@@ -729,14 +735,11 @@ def reload_log_config():
|
||||
global LOG_CONFIG
|
||||
LOG_CONFIG = load_log_config()
|
||||
|
||||
# 重新设置handler的日志级别
|
||||
file_handler = get_file_handler()
|
||||
if file_handler:
|
||||
if file_handler := get_file_handler():
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
||||
|
||||
console_handler = get_console_handler()
|
||||
if console_handler:
|
||||
if console_handler := get_console_handler():
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
||||
|
||||
@@ -780,8 +783,7 @@ def set_console_log_level(level: str):
|
||||
global LOG_CONFIG
|
||||
LOG_CONFIG["console_log_level"] = level.upper()
|
||||
|
||||
console_handler = get_console_handler()
|
||||
if console_handler:
|
||||
if console_handler := get_console_handler():
|
||||
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
# 重新设置root logger级别
|
||||
@@ -800,8 +802,7 @@ def set_file_log_level(level: str):
|
||||
global LOG_CONFIG
|
||||
LOG_CONFIG["file_log_level"] = level.upper()
|
||||
|
||||
file_handler = get_file_handler()
|
||||
if file_handler:
|
||||
if file_handler := get_file_handler():
|
||||
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
# 重新设置root logger级别
|
||||
@@ -933,13 +934,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
||||
Returns:
|
||||
str: 格式化后的JSON字符串
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
# 如果是JSON字符串,先解析再格式化
|
||||
parsed_data = json.loads(data)
|
||||
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
|
||||
else:
|
||||
if not isinstance(data, str):
|
||||
# 如果是对象,直接格式化
|
||||
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
||||
# 如果是JSON字符串,先解析再格式化
|
||||
parsed_data = json.loads(data)
|
||||
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
def cleanup_old_logs():
|
||||
|
||||
@@ -8,7 +8,7 @@ from src.config.config import global_config
|
||||
global_api = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer:
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
if global_api is None:
|
||||
@@ -36,9 +36,8 @@ def get_global_api() -> MessageServer:
|
||||
kwargs["custom_logger"] = maim_message_logger
|
||||
|
||||
# 添加token认证
|
||||
if maim_message_config.auth_token:
|
||||
if len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
if maim_message_config.use_custom:
|
||||
# 添加WSS模式支持
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from src.common.database.database_model import Messages # 更改导入
|
||||
from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -19,6 +22,8 @@ def find_messages(
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
@@ -68,6 +73,12 @@ def find_messages(
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
if filter_bot:
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
|
||||
@@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
self.server_url = TELEMETRY_SERVER_URL
|
||||
"""遥测服务地址"""
|
||||
|
||||
self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None
|
||||
self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
|
||||
"""客户端UUID"""
|
||||
|
||||
self.info_dict = self._get_sys_info()
|
||||
@@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
|
||||
) as response:
|
||||
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
||||
logger.debug(local_storage["deploy_time"])
|
||||
logger.debug(local_storage["deploy_time"]) # type: ignore
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
|
||||
if response.status == 200:
|
||||
@@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) if str(e) else "未知错误"
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(
|
||||
f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
|
||||
) # 可能是网络问题
|
||||
@@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
"""向服务器发送心跳"""
|
||||
headers = {
|
||||
"Client-UUID": self.client_uuid,
|
||||
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
|
||||
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
|
||||
}
|
||||
|
||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||
logger.debug(headers)
|
||||
logger.debug(str(headers))
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
@@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) if str(e) else "未知错误"
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
|
||||
@@ -1,9 +1,54 @@
|
||||
import shutil
|
||||
import tomlkit
|
||||
from tomlkit.items import Table, KeyType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if new_comments is None:
|
||||
new_comments = {}
|
||||
if old_comments is None:
|
||||
old_comments = {}
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def update_config():
|
||||
print("开始更新配置文件...")
|
||||
# 获取根目录路径
|
||||
@@ -45,16 +90,26 @@ def update_config():
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
||||
# 如果version相同,恢复旧配置文件并返回
|
||||
shutil.move(old_backup_path, old_config_path)
|
||||
shutil.move(old_backup_path, old_config_path) # type: ignore
|
||||
return
|
||||
else:
|
||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
print("配置项变动如下:")
|
||||
logs = compare_dicts(new_config, old_config)
|
||||
if logs:
|
||||
for log in logs:
|
||||
print(log)
|
||||
else:
|
||||
print("无新增或删减项")
|
||||
|
||||
# 递归更新配置
|
||||
def update_dict(target, source):
|
||||
for key, value in source.items():
|
||||
@@ -62,7 +117,7 @@ def update_config():
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
@@ -85,10 +140,7 @@ def update_config():
|
||||
if value and isinstance(value[0], dict) and "regex" in value[0]:
|
||||
contains_regex = True
|
||||
|
||||
if contains_regex:
|
||||
target[key] = value
|
||||
else:
|
||||
target[key] = tomlkit.array(value)
|
||||
target[key] = value if contains_regex else tomlkit.array(str(value))
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
import os
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from tomlkit.items import Table, KeyType
|
||||
from dataclasses import field, dataclass
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
IdentityConfig,
|
||||
ExpressionConfig,
|
||||
ChatConfig,
|
||||
NormalChatConfig,
|
||||
FocusChatConfig,
|
||||
EmojiConfig,
|
||||
MemoryConfig,
|
||||
MoodConfig,
|
||||
@@ -36,6 +32,7 @@ from src.config.official_configs import (
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -51,17 +48,167 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.8.2-snapshot.1"
|
||||
MMC_VERSION = "0.9.0-snapshot.2"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def get_value_by_path(d, path):
|
||||
for k in path:
|
||||
if isinstance(d, dict) and k in d:
|
||||
d = d[k]
|
||||
else:
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def set_value_by_path(d, path, value):
|
||||
for k in path[:-1]:
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
d[path[-1]] = value
|
||||
|
||||
|
||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
# 递归比较两个dict,找出默认值变化项
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if changes is None:
|
||||
changes = []
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key in old:
|
||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||
else:
|
||||
# 只要值发生变化就记录
|
||||
if new[key] != old[key]:
|
||||
logs.append(
|
||||
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
|
||||
)
|
||||
changes.append((path + [str(key)], old[key], new[key]))
|
||||
return logs, changes
|
||||
|
||||
|
||||
def update_config():
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
||||
|
||||
# 定义文件路径
|
||||
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml")
|
||||
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
compare_path = os.path.join(compare_dir, "bot_config_template.toml")
|
||||
|
||||
# 创建compare目录(如果不存在)
|
||||
os.makedirs(compare_dir, exist_ok=True)
|
||||
|
||||
# 处理compare下的模板文件
|
||||
def get_version_from_toml(toml_path):
|
||||
if not os.path.exists(toml_path):
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||
return doc["inner"]["version"] # type: ignore
|
||||
return None
|
||||
|
||||
template_version = get_version_from_toml(template_path)
|
||||
compare_version = get_version_from_toml(compare_path)
|
||||
|
||||
def version_tuple(v):
|
||||
if v is None:
|
||||
return (0,)
|
||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||
|
||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
compare_config = tomlkit.load(f)
|
||||
else:
|
||||
compare_config = None
|
||||
|
||||
# 读取当前模板
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||
if compare_config is not None:
|
||||
# 读取旧配置
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
logs, changes = compare_default_values(new_config, compare_config)
|
||||
if logs:
|
||||
logger.info("检测到模板默认值变动如下:")
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
||||
for path, old_default, new_default in changes:
|
||||
old_value = get_value_by_path(old_config, path)
|
||||
if old_value == old_default:
|
||||
set_value_by_path(old_config, path, new_default)
|
||||
logger.info(
|
||||
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
else:
|
||||
logger.info("未检测到模板默认值变动")
|
||||
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
||||
else:
|
||||
old_config = None
|
||||
|
||||
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
||||
if not os.path.exists(compare_path):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"已将模板文件复制到: {compare_path}")
|
||||
else:
|
||||
if version_tuple(template_version) > version_tuple(compare_version):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}")
|
||||
else:
|
||||
logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
@@ -72,21 +219,25 @@ def update_config():
|
||||
# 如果是新创建的配置文件,直接返回
|
||||
quit()
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||
if old_config is None:
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
# new_config 已经读取
|
||||
|
||||
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
logger.info(
|
||||
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||
)
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
@@ -103,7 +254,17 @@ def update_config():
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
logger.info("配置项变动如下:\n----------------------------------------")
|
||||
logs = compare_dicts(new_config, old_config)
|
||||
if logs:
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
else:
|
||||
logger.info("无新增或删减项")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
@@ -112,8 +273,9 @@ def update_config():
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||
update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
@@ -146,12 +308,10 @@ class Config(ConfigBase):
|
||||
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
identity: IdentityConfig
|
||||
relationship: RelationshipConfig
|
||||
chat: ChatConfig
|
||||
message_receive: MessageReceiveConfig
|
||||
normal_chat: NormalChatConfig
|
||||
focus_chat: FocusChatConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
memory: MemoryConfig
|
||||
@@ -167,6 +327,7 @@ class Config(ConfigBase):
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
|
||||
@@ -43,7 +43,7 @@ class ConfigBase:
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
@@ -94,7 +94,7 @@ class ConfigBase:
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
"""
|
||||
@@ -34,21 +35,16 @@ class PersonalityConfig(ConfigBase):
|
||||
personality_core: str
|
||||
"""核心人格"""
|
||||
|
||||
personality_sides: list[str] = field(default_factory=lambda: [])
|
||||
personality_side: str
|
||||
"""人格侧写"""
|
||||
|
||||
identity: str = ""
|
||||
"""身份特征"""
|
||||
|
||||
compress_personality: bool = True
|
||||
"""是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdentityConfig(ConfigBase):
|
||||
"""个体特征配置类"""
|
||||
|
||||
identity_detail: list[str] = field(default_factory=lambda: [])
|
||||
"""身份特征"""
|
||||
|
||||
compress_indentity: bool = True
|
||||
compress_identity: bool = True
|
||||
"""是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭"""
|
||||
|
||||
|
||||
@@ -67,9 +63,6 @@ class RelationshipConfig(ConfigBase):
|
||||
class ChatConfig(ConfigBase):
|
||||
"""聊天配置类"""
|
||||
|
||||
chat_mode: str = "normal"
|
||||
"""聊天模式"""
|
||||
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
@@ -85,6 +78,9 @@ class ChatConfig(ConfigBase):
|
||||
talk_frequency: float = 1
|
||||
"""回复频率阈值"""
|
||||
|
||||
use_s4u_prompt_mode: bool = False
|
||||
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
|
||||
|
||||
# 修改:基于时段的回复频率配置,改为数组格式
|
||||
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
@@ -107,13 +103,10 @@ class ChatConfig(ConfigBase):
|
||||
表示从该时间开始使用该频率,直到下一个时间点
|
||||
"""
|
||||
|
||||
auto_focus_threshold: float = 1.0
|
||||
"""自动切换到专注聊天的阈值,越低越容易进入专注聊天"""
|
||||
focus_value: float = 1.0
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
exit_focus_threshold: float = 1.0
|
||||
"""自动退出专注聊天的阈值,越低越容易退出专注聊天"""
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: str = None) -> float:
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
|
||||
@@ -138,7 +131,7 @@ class ChatConfig(ConfigBase):
|
||||
# 如果都没有匹配,返回默认值
|
||||
return self.talk_frequency
|
||||
|
||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> float:
|
||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的频率
|
||||
|
||||
@@ -186,7 +179,7 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
return current_frequency
|
||||
|
||||
def _get_stream_specific_frequency(self, chat_stream_id: str) -> float:
|
||||
def _get_stream_specific_frequency(self, chat_stream_id: str):
|
||||
"""
|
||||
获取特定聊天流在当前时间的频率
|
||||
|
||||
@@ -217,7 +210,7 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
return None
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str:
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
||||
@@ -280,20 +273,6 @@ class NormalChatConfig(ConfigBase):
|
||||
at_bot_inevitable_reply: bool = False
|
||||
"""@bot 必然回复"""
|
||||
|
||||
enable_planner: bool = False
|
||||
"""是否启用动作规划器"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FocusChatConfig(ConfigBase):
|
||||
"""专注聊天配置类"""
|
||||
|
||||
think_interval: float = 1
|
||||
"""思考间隔(秒)"""
|
||||
|
||||
consecutive_replies: float = 1
|
||||
"""连续回复能力,值越高,麦麦连续回复的概率越高"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
@@ -406,6 +385,9 @@ class MemoryConfig(ConfigBase):
|
||||
|
||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||
"""不允许记忆的词列表"""
|
||||
|
||||
enable_instant_memory: bool = True
|
||||
"""是否启用即时记忆"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -471,6 +453,13 @@ class KeywordReactionConfig(ConfigBase):
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
@dataclass
|
||||
class CustomPromptConfig(ConfigBase):
|
||||
"""自定义提示词配置类"""
|
||||
|
||||
image_prompt: str = ""
|
||||
"""图片提示词"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
@@ -529,9 +518,6 @@ class TelemetryConfig(ConfigBase):
|
||||
class DebugConfig(ConfigBase):
|
||||
"""调试配置类"""
|
||||
|
||||
debug_show_chat_mode: bool = False
|
||||
"""是否在回复后显示当前聊天模式"""
|
||||
|
||||
show_prompt: bool = False
|
||||
"""是否显示prompt"""
|
||||
|
||||
@@ -613,6 +599,9 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
qa_res_top_k: int = 10
|
||||
"""QA最终结果的Top K数量"""
|
||||
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(ConfigBase):
|
||||
@@ -649,3 +638,12 @@ class ModelConfig(ConfigBase):
|
||||
|
||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""嵌入模型配置"""
|
||||
|
||||
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""LPMM实体提取模型配置"""
|
||||
|
||||
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""LPMM RDF构建模型配置"""
|
||||
|
||||
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""LPMM问答模型配置"""
|
||||
|
||||
@@ -1,490 +0,0 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.experimental.PFC.chat_observer import ChatObserver
|
||||
from src.experimental.PFC.pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.experimental.PFC.observation_info import ObservationInfo
|
||||
from src.experimental.PFC.conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
|
||||
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择
|
||||
direct_reply: 直接回复对方
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
|
||||
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择)
|
||||
listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个)
|
||||
send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言**
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# 新增:Prompt(3): 决定是否在结束对话前发送告别语
|
||||
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
|
||||
|
||||
【你们之前的聊天记录】
|
||||
{chat_history_text}
|
||||
|
||||
你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。
|
||||
如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。
|
||||
如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。
|
||||
|
||||
请以 JSON 格式输出你的选择:
|
||||
{{
|
||||
"say_bye": "yes/no",
|
||||
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
|
||||
}}
|
||||
|
||||
注意:请严格按照 JSON 格式输出,不要包含任何其他内容。"""
|
||||
|
||||
|
||||
# ActionPlanner 类定义,顶格
|
||||
class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_action_planner,
|
||||
temperature=global_config.llm_PFC_action_planner["temp"],
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = get_individuality().get_prompt(x_person=2, level=3)
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo,
|
||||
last_successful_reply_action: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
observation_info: 决策信息
|
||||
conversation_info: 对话信息
|
||||
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
# (这部分逻辑不变)
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
||||
msg = observation_info.chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
|
||||
msg_time = msg.get("time")
|
||||
if sender_id == bot_id and msg_time:
|
||||
time_diff = time.time() - msg_time
|
||||
if time_diff < 60.0:
|
||||
time_since_last_bot_message_info = (
|
||||
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Observation info chat history is empty or not available for bot time check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might not have chat_history attribute yet for bot time check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
last_goal_dict = conversation_info.goal_list[-1]
|
||||
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
|
||||
last_goal_text = last_goal_dict["goal"]
|
||||
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
|
||||
try:
|
||||
timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "")
|
||||
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
except Exception:
|
||||
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
|
||||
)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
|
||||
if not goals_str:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
|
||||
)
|
||||
goals_str = "- 获取对话目标时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
|
||||
goals_str = "- 构建对话目标时出错。\n"
|
||||
|
||||
# --- 知识信息字符串构建开始 ---
|
||||
knowledge_info_str = "【已获取的相关知识和记忆】\n"
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识,防止 Prompt 过长
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字,避免太长
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
|
||||
)
|
||||
else:
|
||||
# 处理列表里不是字典的异常情况
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge: # 如果 knowledge_list 存在但为空
|
||||
knowledge_info_str += "- 暂无相关知识和记忆。\n"
|
||||
|
||||
else:
|
||||
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
|
||||
knowledge_info_str += "- 暂无相关知识记忆。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
# --- 知识信息字符串构建结束 ---
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
|
||||
)
|
||||
chat_history_text = "获取聊天记录时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
|
||||
chat_history_text = "处理聊天记录时出错。\n"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
|
||||
# (这部分逻辑不变)
|
||||
action_history_summary = "你最近执行的行动历史:\n"
|
||||
last_action_context = "关于你【上一次尝试】的行动:\n"
|
||||
action_history_list = []
|
||||
try:
|
||||
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
|
||||
action_history_list = conversation_info.done_action[-5:]
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
|
||||
|
||||
if not action_history_list:
|
||||
action_history_summary += "- 还没有执行过行动。\n"
|
||||
last_action_context += "- 这是你规划的第一个行动。\n"
|
||||
else:
|
||||
for i, action_data in enumerate(action_history_list):
|
||||
action_type = "未知"
|
||||
plan_reason = "未知"
|
||||
status = "未知"
|
||||
final_reason = ""
|
||||
action_time = ""
|
||||
|
||||
if isinstance(action_data, dict):
|
||||
action_type = action_data.get("action", "未知")
|
||||
plan_reason = action_data.get("plan_reason", "未知规划原因")
|
||||
status = action_data.get("status", "未知")
|
||||
final_reason = action_data.get("final_reason", "")
|
||||
action_time = action_data.get("time", "")
|
||||
elif isinstance(action_data, tuple):
|
||||
# 假设旧格式兼容
|
||||
if len(action_data) > 0:
|
||||
action_type = action_data[0]
|
||||
if len(action_data) > 1:
|
||||
plan_reason = action_data[1] # 可能是规划原因或最终原因
|
||||
if len(action_data) > 2:
|
||||
status = action_data[2]
|
||||
if status == "recall" and len(action_data) > 3:
|
||||
final_reason = action_data[3]
|
||||
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
|
||||
plan_reason = "成功发送" # 简化显示
|
||||
|
||||
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
|
||||
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
|
||||
action_history_summary += summary_line + "\n"
|
||||
|
||||
if i == len(action_history_list) - 1:
|
||||
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
|
||||
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
|
||||
if status == "done":
|
||||
last_action_context += "- 该行动已【成功执行】。\n"
|
||||
# 记录这次成功的行动类型,供下次决策
|
||||
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
|
||||
elif status == "recall":
|
||||
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
|
||||
if final_reason:
|
||||
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n"
|
||||
else:
|
||||
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
|
||||
# self.last_successful_action_type = None # 行动失败,清除记录
|
||||
else:
|
||||
last_action_context += f"- 该行动当前状态: {status}\n"
|
||||
# self.last_successful_action_type = None # 非完成状态,清除记录
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
|
||||
prompt_template = PROMPT_FOLLOW_UP
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
|
||||
else:
|
||||
prompt_template = PROMPT_INITIAL_REPLY
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
|
||||
action_history_summary=action_history_summary,
|
||||
last_action_context=last_action_context,
|
||||
time_since_last_bot_message_info=time_since_last_bot_message_info,
|
||||
timeout_context=timeout_context,
|
||||
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
success, initial_result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"action",
|
||||
"reason",
|
||||
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"},
|
||||
)
|
||||
|
||||
initial_action = initial_result.get("action", "wait")
|
||||
initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待")
|
||||
|
||||
# 检查是否需要进行结束对话决策 ---
|
||||
if initial_action == "end_conversation":
|
||||
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
|
||||
|
||||
# 使用新的 PROMPT_END_DECISION
|
||||
end_decision_prompt = PROMPT_END_DECISION.format(
|
||||
persona_text=persona_text, # 复用之前的 persona_text
|
||||
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
end_success, end_result = get_items_from_json(
|
||||
end_content,
|
||||
self.private_name,
|
||||
"say_bye",
|
||||
"reason",
|
||||
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"},
|
||||
required_types={"say_bye": str, "reason": str}, # 明确类型
|
||||
)
|
||||
|
||||
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
|
||||
end_decision_reason = end_result.get("reason", "未提供原因")
|
||||
|
||||
if end_success and say_bye_decision == "yes":
|
||||
# 决定要告别,返回新的 'say_goodbye' 动作
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
|
||||
final_action = "say_goodbye"
|
||||
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
|
||||
return final_action, final_reason
|
||||
else:
|
||||
# 决定不告别 (包括解析失败或明确说no)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 返回原始的 'end_conversation' 动作
|
||||
final_action = "end_conversation"
|
||||
final_reason = initial_reason # 保持原始的结束理由
|
||||
return final_action, final_reason
|
||||
|
||||
except Exception as end_e:
|
||||
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
|
||||
# 出错时,默认执行原始的结束对话
|
||||
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
|
||||
return "end_conversation", initial_reason # 返回原始动作和原因
|
||||
|
||||
else:
|
||||
action = initial_action
|
||||
reason = initial_reason
|
||||
|
||||
# 验证action类型 (保持不变)
|
||||
valid_actions = [
|
||||
"direct_reply",
|
||||
"send_new_message",
|
||||
"fetch_knowledge",
|
||||
"wait",
|
||||
"listening",
|
||||
"rethink_goal",
|
||||
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
|
||||
"block_and_ignore",
|
||||
"say_goodbye", # 也要验证这个新动作
|
||||
]
|
||||
if action not in valid_actions:
|
||||
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
|
||||
reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}"
|
||||
action = "wait"
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
|
||||
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
|
||||
return action, reason
|
||||
|
||||
except Exception as e:
|
||||
# 外层异常处理保持不变
|
||||
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
|
||||
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"
|
||||
@@ -1,383 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_logger
|
||||
from maim_message import UserInfo
|
||||
from src.config.config import global_config
|
||||
from src.experimental.PFC.chat_states import (
|
||||
NotificationManager,
|
||||
create_new_message_notification,
|
||||
create_cold_chat_notification,
|
||||
)
|
||||
from src.experimental.PFC.message_storage import PeeweeMessageStorage
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_observer")
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
ChatObserver: 观察器实例
|
||||
"""
|
||||
if stream_id not in cls._instances:
|
||||
cls._instances[stream_id] = cls(stream_id, private_name)
|
||||
return cls._instances[stream_id]
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化观察器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.last_check_time = None
|
||||
self.last_bot_speak_time = None
|
||||
self.last_user_speak_time = None
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.message_storage = PeeweeMessageStorage()
|
||||
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._update_event = asyncio.Event() # 触发更新的事件
|
||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||
|
||||
# 通知管理器
|
||||
self.notification_manager = NotificationManager()
|
||||
|
||||
# 冷场检查配置
|
||||
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||
self.last_cold_chat_check: float = time.time()
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
"""添加消息到历史记录并发送通知
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
try:
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer", target="observation_info", message=message
|
||||
)
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
|
||||
async def _check_cold_chat(self):
|
||||
"""检查是否处于冷场状态并发送通知"""
|
||||
current_time = time.time()
|
||||
|
||||
# 每10秒检查一次冷场状态
|
||||
if current_time - self.last_cold_chat_check < 10:
|
||||
return
|
||||
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = (
|
||||
True
|
||||
if self.last_message_time is None
|
||||
else (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||
)
|
||||
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
def new_message_after(self, time_point: float) -> bool:
|
||||
"""判断是否在指定时间点后有新消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
|
||||
return False
|
||||
|
||||
has_new = self.last_message_time > time_point
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
|
||||
)
|
||||
return has_new
|
||||
|
||||
def get_message_history(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回消息数量
|
||||
user_id: 指定用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
filtered_messages = self.message_history
|
||||
|
||||
if start_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||
|
||||
if end_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||
|
||||
if user_id is not None:
|
||||
filtered_messages = [
|
||||
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
]
|
||||
|
||||
if limit is not None:
|
||||
filtered_messages = filtered_messages[-limit:]
|
||||
|
||||
return filtered_messages
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
self.last_message_time = new_messages[-1]["time"]
|
||||
|
||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
await self._add_message_to_history(message)
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
self._update_event.set() # 设置事件以解除等待
|
||||
self._update_complete.set() # 设置完成事件以解除等待
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.bot.qq_account:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
return time_info
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
|
||||
Args:
|
||||
limit: 获取的最大消息数量,默认50
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 缓存的消息历史列表
|
||||
"""
|
||||
return self.message_cache[-limit:]
|
||||
|
||||
def get_last_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一条消息
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None
|
||||
"""
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[-1]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
@@ -1,290 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, List, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
ACTIVE_CHAT = auto() # 活跃状态
|
||||
BOT_SPEAKING = auto() # 机器人正在说话
|
||||
USER_SPEAKING = auto() # 用户正在说话
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
BOT_SPEAKING = auto() # 机器人说话通知
|
||||
USER_SPEAKING = auto() # 用户说话通知
|
||||
MESSAGE_DELETED = auto() # 消息删除通知
|
||||
USER_JOINED = auto() # 用户加入通知
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
last_speaker: Optional[str] = None
|
||||
message_count: int = 0
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
target: str # 接收者标识
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 按接收者和通知类型存储处理器
|
||||
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||
self._active_states: Set[NotificationType] = set()
|
||||
self._notification_history: List[Notification] = []
|
||||
|
||||
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注册通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识(例如:"pfc")
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
if target not in self._handlers:
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
self._handlers[target][notification_type] = []
|
||||
# print(self._handlers[target][notification_type])
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
# print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识
|
||||
notification_type: 通知类型
|
||||
handler: 要注销的处理器实例
|
||||
"""
|
||||
if target in self._handlers and notification_type in self._handlers[target]:
|
||||
handlers = self._handlers[target][notification_type]
|
||||
if handler in handlers:
|
||||
handlers.remove(handler)
|
||||
# 如果该类型的处理器列表为空,删除该类型
|
||||
if not handlers:
|
||||
del self._handlers[target][notification_type]
|
||||
# 如果该目标没有任何处理器,删除该目标
|
||||
if not self._handlers[target]:
|
||||
del self._handlers[target]
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
if notification.is_active:
|
||||
self._active_states.add(notification.type)
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
"""获取当前活跃的状态"""
|
||||
return self._active_states.copy()
|
||||
|
||||
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
sender: 过滤特定发送者的通知
|
||||
target: 过滤特定接收者的通知
|
||||
limit: 限制返回数量
|
||||
"""
|
||||
history = self._notification_history
|
||||
|
||||
if sender:
|
||||
history = [n for n in history if n.sender == sender]
|
||||
if target:
|
||||
history = [n for n in history if n.target == target]
|
||||
|
||||
if limit is not None:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
return Notification(
|
||||
type=NotificationType.NEW_MESSAGE,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.COLD_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.ACTIVE_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_state = ChatState.NORMAL
|
||||
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||
self.state_history: list[ChatStateInfo] = []
|
||||
|
||||
def update_state(self, new_state: ChatState, **kwargs):
|
||||
"""更新聊天状态
|
||||
|
||||
Args:
|
||||
new_state: 新的状态
|
||||
**kwargs: 其他状态信息
|
||||
"""
|
||||
self.current_state = new_state
|
||||
self.state_info.state = new_state
|
||||
|
||||
# 更新其他状态信息
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.state_info, key):
|
||||
setattr(self.state_info, key, value)
|
||||
|
||||
# 记录状态历史
|
||||
self.state_history.append(self.state_info)
|
||||
|
||||
def get_current_state_info(self) -> ChatStateInfo:
|
||||
"""获取当前状态信息"""
|
||||
return self.state_info
|
||||
|
||||
def get_state_history(self) -> list[ChatStateInfo]:
|
||||
"""获取状态历史"""
|
||||
return self.state_history
|
||||
|
||||
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||
"""判断是否处于冷场状态
|
||||
|
||||
Args:
|
||||
threshold: 冷场阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否冷场
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return True
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) > threshold
|
||||
|
||||
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||
"""判断是否处于活跃状态
|
||||
|
||||
Args:
|
||||
threshold: 活跃阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否活跃
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return False
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
@@ -1,701 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# from ...config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from src.chat.message_receive.message import Message
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
from src.common.logger import get_logger
|
||||
from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""对话类,负责管理单个对话的状态和行为"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.state = ConversationState.INIT
|
||||
self.should_continue = False
|
||||
self.ignore_until_timestamp: Optional[float] = None
|
||||
|
||||
# 回复相关
|
||||
self.generated_reply = ""
|
||||
|
||||
async def _initialize(self):
|
||||
"""初始化实例,注册所有组件"""
|
||||
|
||||
try:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo(self.private_name)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg = initial_messages[-1]
|
||||
self.observation_info.last_message_time = last_msg.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
except Exception as load_err:
|
||||
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
|
||||
# 出错也要继续,只是没有历史记录而已
|
||||
# 组件准备完成,启动该论对话
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
|
||||
asyncio.create_task(self._plan_and_action_loop())
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
while self.should_continue:
|
||||
# 忽略逻辑
|
||||
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
|
||||
await asyncio.sleep(30)
|
||||
continue
|
||||
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
|
||||
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
|
||||
self.ignore_until_timestamp = None
|
||||
self.should_continue = False
|
||||
continue
|
||||
try:
|
||||
# --- 在规划前记录当前新消息数量 ---
|
||||
initial_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
|
||||
)
|
||||
|
||||
# --- 调用 Action Planner ---
|
||||
# 传递 self.conversation_info.last_successful_reply_action
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
|
||||
)
|
||||
|
||||
# --- 规划后检查是否有 *更多* 新消息到达 ---
|
||||
current_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
current_new_message_count = self.observation_info.new_messages_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
|
||||
)
|
||||
|
||||
if current_new_message_count > initial_new_message_count + 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
|
||||
)
|
||||
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 包含 send_new_message
|
||||
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
|
||||
)
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
else:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
|
||||
)
|
||||
|
||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||
|
||||
# 检查是否需要结束对话 (逻辑不变)
|
||||
goal_ended = False
|
||||
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
|
||||
for goal_item in self.conversation_info.goal_list:
|
||||
if isinstance(goal_item, dict):
|
||||
current_goal = goal_item.get("goal")
|
||||
|
||||
if current_goal == "结束对话":
|
||||
goal_ended = True
|
||||
break
|
||||
|
||||
if goal_ended:
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
|
||||
|
||||
except Exception as loop_err:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.should_continue:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
|
||||
|
||||
def _check_new_messages_after_planning(self):
|
||||
"""检查在规划后是否有新消息"""
|
||||
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
|
||||
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
|
||||
)
|
||||
return False # 或者根据需要抛出错误
|
||||
|
||||
if self.observation_info.new_messages_count > 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
|
||||
)
|
||||
# 如果有新消息,也应该重置上次回复状态
|
||||
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
try:
|
||||
# 尝试从 msg_dict 直接获取 chat_stream,如果失败则从全局 get_chat_manager 获取
|
||||
chat_info = msg_dict.get("chat_info")
|
||||
if chat_info and isinstance(chat_info, dict):
|
||||
chat_stream = ChatStream.from_dict(chat_info)
|
||||
elif self.chat_stream: # 使用实例变量中的 chat_stream
|
||||
chat_stream = self.chat_stream
|
||||
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
|
||||
chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
|
||||
|
||||
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
|
||||
|
||||
return Message(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
|
||||
chat_stream=chat_stream, # 使用确定的 chat_stream
|
||||
time=msg_dict.get("time", time.time()), # 提供默认时间
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
|
||||
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史 (逻辑不变)
|
||||
current_action_record = {
|
||||
"action": action,
|
||||
"plan_reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
# 确保 done_action 列表存在
|
||||
if not hasattr(conversation_info, "done_action"):
|
||||
conversation_info.done_action = []
|
||||
conversation_info.done_action.append(current_action_record)
|
||||
action_index = len(conversation_info.done_action) - 1
|
||||
|
||||
action_successful = False # 用于标记动作是否成功完成
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
|
||||
# send_new_message 失败后执行 wait
|
||||
if action == "send_new_message":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复 (调用 generate 时传入 action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="send_new_message"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复 (逻辑不变)
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 send_new_message
|
||||
self.conversation_info.last_successful_reply_action = "send_new_message"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 追问失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 追问失败,下次用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "direct_reply":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="direct_reply"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 direct_reply
|
||||
self.conversation_info.last_successful_reply_action = "direct_reply"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 首次回复失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 首次回复失败,下次还是用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作 (保持原有逻辑)
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
knowledge_query = reason
|
||||
try:
|
||||
# 检查 knowledge_fetcher 是否存在
|
||||
if not hasattr(self, "knowledge_fetcher"):
|
||||
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
|
||||
raise AttributeError("KnowledgeFetcher not initialized")
|
||||
|
||||
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
|
||||
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
|
||||
if knowledge:
|
||||
# 确保 knowledge_list 存在
|
||||
if not hasattr(conversation_info, "knowledge_list"):
|
||||
conversation_info.knowledge_list = []
|
||||
conversation_info.knowledge_list.append(
|
||||
{"query": knowledge_query, "knowledge": knowledge, "source": source}
|
||||
)
|
||||
action_successful = True
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "rethink_goal":
|
||||
self.state = ConversationState.RETHINKING
|
||||
try:
|
||||
# 检查 goal_analyzer 是否存在
|
||||
if not hasattr(self, "goal_analyzer"):
|
||||
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
|
||||
raise AttributeError("GoalAnalyzer not initialized")
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
action_successful = True
|
||||
except Exception as rethink_err:
|
||||
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
await self.waiter.wait_listening(conversation_info)
|
||||
action_successful = True # Listening 完成就算成功
|
||||
except Exception as listen_err:
|
||||
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "say_goodbye":
|
||||
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
|
||||
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
|
||||
try:
|
||||
# 1. 生成告别语 (使用 'say_goodbye' action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="say_goodbye"
|
||||
)
|
||||
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
|
||||
|
||||
# 2. 直接发送告别语 (不经过检查)
|
||||
if self.generated_reply: # 确保生成了内容
|
||||
await self._send_reply() # 调用发送方法
|
||||
# 发送成功后,标记动作成功
|
||||
action_successful = True
|
||||
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": "未能生成告别语内容"}
|
||||
)
|
||||
|
||||
# 3. 无论是否发送成功,都准备结束对话
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
|
||||
|
||||
except Exception as goodbye_err:
|
||||
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
# 即使出错,也结束对话
|
||||
self.should_continue = False
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
|
||||
)
|
||||
|
||||
elif action == "end_conversation":
|
||||
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
|
||||
action_successful = True # 标记这个指令本身是成功的
|
||||
|
||||
elif action == "block_and_ignore":
|
||||
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
|
||||
ignore_duration_seconds = 10 * 60
|
||||
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
|
||||
)
|
||||
self.state = ConversationState.IGNORED
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
else: # 对应 'wait' 动作
|
||||
self.state = ConversationState.WAITING
|
||||
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
_timeout_occurred = await self.waiter.wait(self.conversation_info)
|
||||
action_successful = True # Wait 完成就算成功
|
||||
except Exception as wait_err:
|
||||
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
# --- 更新 Action History 状态 ---
|
||||
# 只有当动作本身成功时,才更新状态为 done
|
||||
if action_successful:
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
|
||||
if action not in ["direct_reply", "send_new_message"]:
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
|
||||
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
|
||||
|
||||
async def _send_reply(self):
|
||||
"""发送回复"""
|
||||
if not self.generated_reply:
|
||||
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
|
||||
return
|
||||
|
||||
try:
|
||||
_current_time = time.time()
|
||||
reply_content = self.generated_reply
|
||||
|
||||
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
|
||||
if not hasattr(self, "direct_sender") or not self.direct_sender:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
|
||||
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
|
||||
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
|
||||
# self.chat_observer.trigger_update()
|
||||
# if not await self.chat_observer.wait_for_update():
|
||||
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
|
||||
|
||||
self.state = ConversationState.ANALYZING # 更新状态
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self.state = ConversationState.ANALYZING
|
||||
|
||||
async def _send_timeout_message(self):
|
||||
"""发送超时结束消息"""
|
||||
try:
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")
|
||||
@@ -1,10 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
self.goal_list = []
|
||||
self.knowledge_list = []
|
||||
self.memory_list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
@@ -1,81 +0,0 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import Message
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.chat.message_receive.message import MessageSending, MessageSet
|
||||
from src.chat.message_receive.normal_message_sender import message_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.private_name = private_name
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
content: str,
|
||||
reply_to_message: Optional[Message] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
"""
|
||||
try:
|
||||
# 创建消息内容
|
||||
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
is_emoji=False,
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
# 不知道有什么用,先留下来了,和之前那套sender一样
|
||||
_message_json = message.to_dict()
|
||||
|
||||
# 发送消息
|
||||
message_set = MessageSet(chat_stream, message_id)
|
||||
message_set.add_message(message)
|
||||
await message_manager.add_message(message_set)
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
raise
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user