This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -2,6 +2,7 @@
聊天内容概括器
用于累积、打包和压缩聊天记录
"""
import asyncio
import json
import time
@@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer")
@dataclass
class MessageBatch:
"""消息批次"""
messages: List[DatabaseMessages]
start_time: float
end_time: float
@@ -31,11 +33,11 @@ class MessageBatch:
class ChatHistorySummarizer:
"""聊天内容概括器"""
def __init__(self, chat_id: str, check_interval: int = 60):
"""
初始化聊天内容概括器
Args:
chat_id: 聊天ID
check_interval: 定期检查间隔默认60秒
@@ -43,24 +45,23 @@ class ChatHistorySummarizer:
self.chat_id = chat_id
self._chat_display_name = self._get_chat_display_name()
self.log_prefix = f"[{self._chat_display_name}]"
# 记录时间点,用于计算新消息
self.last_check_time = time.time()
# 当前累积的消息批次
self.current_batch: Optional[MessageBatch] = None
# LLM请求器用于压缩聊天内容
self.summarizer_llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="chat_history_summarizer"
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
)
# 后台循环相关
self.check_interval = check_interval # 检查间隔(秒)
self._periodic_task: Optional[asyncio.Task] = None
self._running = False
def _get_chat_display_name(self) -> str:
"""获取聊天显示名称"""
try:
@@ -76,17 +77,17 @@ class ChatHistorySummarizer:
if len(self.chat_id) > 20:
return f"{self.chat_id[:8]}..."
return self.chat_id
async def process(self, current_time: Optional[float] = None):
"""
处理聊天内容概括
Args:
current_time: 当前时间戳如果为None则使用time.time()
"""
if current_time is None:
current_time = time.time()
try:
logger.info(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
@@ -101,25 +102,23 @@ class ChatHistorySummarizer:
filter_mai=False, # 不过滤bot消息因为需要检查bot是否发言
filter_command=False,
)
if not new_messages:
# 没有新消息,检查是否需要打包
if self.current_batch and self.current_batch.messages:
await self._check_and_package(current_time)
self.last_check_time = current_time
return
# 有新消息,更新最后检查时间
self.last_check_time = current_time
# 如果有当前批次,添加新消息
if self.current_batch:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(
f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
else:
# 创建新批次
self.current_batch = MessageBatch(
@@ -127,23 +126,22 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time,
)
logger.info(
f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息"
)
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
# 检查是否需要打包
await self._check_and_package(current_time)
except Exception as e:
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
import traceback
traceback.print_exc()
async def _check_and_package(self, current_time: float):
"""检查是否需要打包"""
if not self.current_batch or not self.current_batch.messages:
return
messages = self.current_batch.messages
message_count = len(messages)
last_message_time = messages[-1].time if messages else current_time
@@ -153,48 +151,48 @@ class ChatHistorySummarizer:
if time_since_last_message < 60:
time_str = f"{time_since_last_message:.1f}"
elif time_since_last_message < 3600:
time_str = f"{time_since_last_message/60:.1f}分钟"
time_str = f"{time_since_last_message / 60:.1f}分钟"
else:
time_str = f"{time_since_last_message/3600:.1f}小时"
time_str = f"{time_since_last_message / 3600:.1f}小时"
preparing_status = "" if self.current_batch.is_preparing else ""
logger.info(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}"
)
# 检查打包条件
should_package = False
# 条件1: 消息长度超过120直接打包
if message_count >= 120:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条")
# 条件2: 最后一条消息的时间和当前时间差>600秒直接打包
elif time_since_last_message > 600:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟")
# 条件3: 消息长度超过100进入准备结束模式
elif message_count > 100:
if not self.current_batch.is_preparing:
self.current_batch.is_preparing = True
logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值100条进入准备结束模式")
# 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒就打包
if time_since_last_message > 10:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒")
if should_package:
await self._package_and_store()
async def _package_and_store(self):
"""打包并存储聊天记录"""
if not self.current_batch or not self.current_batch.messages:
return
messages = self.current_batch.messages
start_time = self.current_batch.start_time
end_time = self.current_batch.end_time
@@ -202,12 +200,12 @@ class ChatHistorySummarizer:
logger.info(
f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
)
# 检查是否有bot发言
# 第一条消息前推600s到最后一条消息的时间内
check_start_time = max(start_time - 600, 0)
check_end_time = end_time
# 使用包含边界的时间范围查询
bot_messages = message_api.get_messages_by_time_in_chat_inclusive(
chat_id=self.chat_id,
@@ -218,7 +216,7 @@ class ChatHistorySummarizer:
filter_mai=False,
filter_command=False,
)
# 检查是否有bot的发言
has_bot_message = False
bot_user_id = str(global_config.bot.qq_account)
@@ -226,14 +224,14 @@ class ChatHistorySummarizer:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
break
if not has_bot_message:
logger.info(
f"{self.log_prefix} 批次内无Bot发言丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}"
)
self.current_batch = None
return
# 有bot发言进行压缩和存储
try:
# 构建对话原文
@@ -245,39 +243,36 @@ class ChatHistorySummarizer:
truncate=False,
show_actions=False,
)
# 获取参与的所有人的昵称
participants_set: Set[str] = set()
for msg in messages:
# 使用 msg.user_platform扁平化字段或 msg.user_info.platform
platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform
person = Person(
platform=platform,
user_id=msg.user_info.user_id
platform = (
getattr(msg, "user_platform", None)
or (msg.user_info.platform if msg.user_info else None)
or msg.chat_info.platform
)
person = Person(platform=platform, user_id=msg.user_info.user_id)
person_name = person.person_name
if person_name:
participants_set.add(person_name)
participants = list(participants_set)
logger.info(
f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}"
)
logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
# 使用LLM压缩聊天内容
success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success:
logger.warning(
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
logger.warning(f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}")
# 清空当前批次,避免重复处理
self.current_batch = None
return
logger.info(
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}"
)
# 存储到数据库
await self._store_to_database(
start_time=start_time,
@@ -288,23 +283,24 @@ class ChatHistorySummarizer:
keywords=keywords,
summary=summary,
)
logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}")
# 清空当前批次
self.current_batch = None
except Exception as e:
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
import traceback
traceback.print_exc()
# 出错时也清空批次,避免重复处理
self.current_batch = None
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
"""
使用LLM压缩聊天内容
Returns:
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
"""
@@ -325,37 +321,37 @@ class ChatHistorySummarizer:
{original_text}
请直接返回JSON不要包含其他内容。"""
try:
response, _ = await self.summarizer_llm.generate_response_async(
prompt=prompt,
temperature=0.3,
max_tokens=500,
)
# 解析JSON响应
import re
# 移除可能的markdown代码块标记
json_str = response.strip()
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE)
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
json_str = json_str.strip()
# 尝试找到JSON对象的开始和结束位置
# 查找第一个 { 和最后一个匹配的 }
start_idx = json_str.find('{')
start_idx = json_str.find("{")
if start_idx == -1:
raise ValueError("未找到JSON对象开始标记")
# 从后往前查找最后一个 }
end_idx = json_str.rfind('}')
end_idx = json_str.rfind("}")
if end_idx == -1 or end_idx <= start_idx:
raise ValueError("未找到JSON对象结束标记")
# 提取JSON字符串
json_str = json_str[start_idx:end_idx + 1]
json_str = json_str[start_idx : end_idx + 1]
# 尝试解析JSON
try:
result = json.loads(json_str)
@@ -372,7 +368,7 @@ class ChatHistorySummarizer:
if escape_next:
fixed_chars.append(char)
escape_next = False
elif char == '\\':
elif char == "\\":
fixed_chars.append(char)
escape_next = True
elif char == '"' and not escape_next:
@@ -384,27 +380,27 @@ class ChatHistorySummarizer:
else:
fixed_chars.append(char)
i += 1
json_str = ''.join(fixed_chars)
json_str = "".join(fixed_chars)
# 再次尝试解析
result = json.loads(json_str)
theme = result.get("theme", "未命名对话")
keywords = result.get("keywords", [])
summary = result.get("summary", "无概括")
# 确保keywords是列表
if isinstance(keywords, str):
keywords = [keywords]
return True, theme, keywords, summary
except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回失败标志和默认值
return False, "未命名对话", [], "压缩失败,无法生成概括"
async def _store_to_database(
self,
start_time: float,
@@ -419,7 +415,7 @@ class ChatHistorySummarizer:
try:
from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api
# 准备数据
data = {
"chat_id": self.chat_id,
@@ -432,7 +428,7 @@ class ChatHistorySummarizer:
"summary": summary,
"count": 0,
}
# 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time
@@ -441,28 +437,29 @@ class ChatHistorySummarizer:
ChatHistory,
data=data,
)
if saved_record:
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
else:
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
except Exception as e:
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
import traceback
traceback.print_exc()
raise
async def start(self):
"""启动后台定期检查循环"""
if self._running:
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
return
self._running = True
self._periodic_task = asyncio.create_task(self._periodic_check_loop())
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}")
async def stop(self):
"""停止后台定期检查循环"""
self._running = False
@@ -474,14 +471,14 @@ class ChatHistorySummarizer:
pass
self._periodic_task = None
logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
async def _periodic_check_loop(self):
"""后台定期检查循环"""
try:
while self._running:
# 执行一次检查
await self.process()
# 等待指定间隔后再次检查
await asyncio.sleep(self.check_interval)
except asyncio.CancelledError:
@@ -490,6 +487,6 @@ class ChatHistorySummarizer:
except Exception as e:
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
import traceback
traceback.print_exc()
self._running = False

View File

@@ -2,7 +2,7 @@ import time
import random
import re
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from src.config.config import global_config
@@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = []
current_time = time.time()
for action in actions:
action_time = action.time or current_time
action_name = action.action_name or "未知动作"
@@ -595,7 +594,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line)
return "\n".join(output_lines)
@@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。

View File

@@ -2,6 +2,7 @@
记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
"""
import time
import random
from typing import List
@@ -15,27 +16,27 @@ logger = get_logger("memory_forget_task")
class MemoryForgetTask(AsyncTask):
"""记忆遗忘任务每5分钟执行一次"""
def __init__(self):
# 每5分钟执行一次300秒
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
async def run(self):
"""执行遗忘检查"""
try:
current_time = time.time()
logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查
await self._forget_stage_1(current_time)
await self._forget_stage_2(current_time)
await self._forget_stage_3(current_time)
await self._forget_stage_4(current_time)
logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e:
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
async def _forget_stage_1(self, current_time: float):
"""
第一次遗忘检查:
@@ -45,38 +46,34 @@ class MemoryForgetTask(AsyncTask):
try:
# 30分钟 = 1800秒
time_threshold = current_time - 1800
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 0) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高25%和最低25%
total_count = len(candidates)
delete_count = int(total_count * 0.25) # 25%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过")
return
# 选择要删除的记录处理count相同的情况随机选择
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重避免重复删除使用id去重
seen_ids = set()
unique_to_delete = []
@@ -85,7 +82,7 @@ class MemoryForgetTask(AsyncTask):
seen_ids.add(record.id)
unique_to_delete.append(record)
to_delete = unique_to_delete
# 删除记录并更新forget_times
deleted_count = 0
for record in to_delete:
@@ -94,22 +91,22 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
# 更新剩余记录的forget_times为1
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
# 批量更新
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1")
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
async def _forget_stage_2(self, current_time: float):
"""
第二次遗忘检查:
@@ -119,41 +116,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 8小时 = 28800秒
time_threshold = current_time - 28800
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 1) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高7%和最低7%
total_count = len(candidates)
delete_count = int(total_count * 0.07) # 7%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@@ -162,21 +155,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
# 更新剩余记录的forget_times为2
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2")
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
async def _forget_stage_3(self, current_time: float):
"""
第三次遗忘检查:
@@ -186,41 +179,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 48小时 = 172800秒
time_threshold = current_time - 172800
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 2) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高5%和最低5%
total_count = len(candidates)
delete_count = int(total_count * 0.05) # 5%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@@ -229,21 +218,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
# 更新剩余记录的forget_times为3
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3")
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
async def _forget_stage_4(self, current_time: float):
"""
第四次遗忘检查:
@@ -253,41 +242,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 7天 = 604800秒
time_threshold = current_time - 604800
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 3) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高2%和最低2%
total_count = len(candidates)
delete_count = int(total_count * 0.02) # 2%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@@ -296,38 +281,40 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
# 更新剩余记录的forget_times为4
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4")
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]:
def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
"""
处理count相同的情况随机选择要删除的记录
Args:
candidates: 候选记录列表已按count排序
delete_count: 要删除的数量
mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录
Returns:
要删除的记录列表
"""
if not candidates or delete_count == 0:
return []
to_delete = []
if mode == "high":
# 从最高count开始选择
start_idx = 0
@@ -339,7 +326,7 @@ class MemoryForgetTask(AsyncTask):
while idx < len(candidates) and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx += 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
@@ -347,9 +334,9 @@ class MemoryForgetTask(AsyncTask):
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
else: # mode == "low"
# 从最低count开始选择
start_idx = len(candidates) - 1
@@ -361,7 +348,7 @@ class MemoryForgetTask(AsyncTask):
while idx >= 0 and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx -= 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
@@ -369,8 +356,7 @@ class MemoryForgetTask(AsyncTask):
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
return to_delete
start_idx = idx
return to_delete

View File

@@ -153,7 +153,7 @@ def _format_large_number(num: float | int, html: bool = False) -> str:
else:
number_part = f"{value:.1f}"
k_suffix = "K"
if html:
# HTML输出K着色为主题色并加粗大写
return f"{number_part}<span style='color: #8b5cf6; font-weight: bold;'>K</span>"
@@ -502,9 +502,13 @@ class StatisticOutputTask(AsyncTask):
}
for period_key, _ in collect_period
}
# 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
@@ -547,7 +551,7 @@ class StatisticOutputTask(AsyncTask):
is_bot_reply = False
if bot_qq_account and message.user_id == bot_qq_account:
is_bot_reply = True
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
@@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
continue
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period = [
item for item in self.stat_period if item[0] != "all_time"
] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
except Exception as e:
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
@@ -640,12 +646,12 @@ class StatisticOutputTask(AsyncTask):
# 更新上次完整统计数据的时间戳
# 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
# 将 name_mapping 中的元组转换为列表因为JSON不支持元组
json_safe_name_mapping = {}
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
local_storage["last_full_statistics"] = {
"name_mapping": json_safe_name_mapping,
"stat_data": clean_stat_data,
@@ -682,24 +688,28 @@ class StatisticOutputTask(AsyncTask):
"""
# 计算总token数从所有模型的token数中累加
total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0
# 计算花费/消息数量指标每100条
cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0
# 计算花费/时间指标(花费/小时)
online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0
cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0
# 计算token/时间指标token/小时)
tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0
# 计算花费/回复数量指标每100条
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0
# 计算花费/消息数量排除自己回复指标每100条
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0
cost_per_100_messages_excluding_replies = (
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
if total_messages_excluding_replies > 0
else 0.0
)
output = [
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
@@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask):
f"总Token数: {_format_large_number(total_tokens)}",
f"总花费: {stats[TOTAL_COST]:.2f}¥",
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
if total_messages_excluding_replies > 0
else "花费/消息数量(排除回复): N/A",
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
@@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask):
formatted_out_tokens = _format_large_number(out_tokens)
formatted_tokens = _format_large_number(tokens)
output.append(
data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost)
data_fmt.format(
name,
formatted_count,
formatted_in_tokens,
formatted_out_tokens,
formatted_tokens,
cost,
avg_time_cost,
std_time_cost,
)
)
output.append("")
@@ -891,8 +912,12 @@ class StatisticOutputTask(AsyncTask):
except (IndexError, TypeError) as e:
logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
chat_rows_html = (
"\n".join(chat_rows)
if chat_rows
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
)
# 生成HTML
return f"""
<div id=\"{div_id}\" class=\"tab-content\">
@@ -1197,7 +1222,7 @@ class StatisticOutputTask(AsyncTask):
# 添加图表内容
chart_data = self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data))
# 添加指标趋势图表
metrics_data = self._generate_metrics_data(now)
tab_content_list.append(self._generate_metrics_tab(metrics_data))
@@ -1772,121 +1797,125 @@ class StatisticOutputTask(AsyncTask):
def _generate_metrics_data(self, now: datetime) -> dict:
"""生成指标趋势数据"""
metrics_data = {}
# 24小时尺度1小时为单位
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
# 7天尺度1天为单位
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24)
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
# 30天尺度1天为单位
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24)
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
return metrics_data
def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict:
"""收集指定时间范围内每个间隔的指标数据"""
start_time = now - timedelta(hours=hours)
time_points = []
current_time = start_time
# 生成时间点
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(hours=interval_hours)
# 初始化数据结构
cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量每100条
cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时)
tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时)
cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量每100条
# 每个时间点的累计数据
total_costs = [0.0] * len(time_points)
total_tokens = [0] * len(time_points)
total_messages = [0] * len(time_points)
total_replies = [0] * len(time_points)
total_online_hours = [0.0] * len(time_points)
# 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
interval_seconds = interval_hours * 3600
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
cost = record.cost or 0.0
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_token = prompt_tokens + completion_tokens
total_costs[interval_index] += cost
total_tokens[interval_index] += total_token
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
time_diff = message_time_ts - query_start_timestamp
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
total_messages[interval_index] += 1
# 检查是否是bot发送的消息回复
if bot_qq_account and message.user_id == bot_qq_account:
total_replies[interval_index] += 1
# 查询在线时间记录
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore
record_start = record.start_timestamp
record_end = record.end_timestamp
# 找到记录覆盖的所有时间间隔
for idx, time_point in enumerate(time_points):
interval_start = time_point
interval_end = time_point + timedelta(hours=interval_hours)
# 计算重叠部分
overlap_start = max(record_start, interval_start)
overlap_end = min(record_end, interval_end)
if overlap_end > overlap_start:
overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0
total_online_hours[idx] += overlap_hours
# 计算指标
for idx in range(len(time_points)):
# 花费/消息数量每100条
if total_messages[idx] > 0:
cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100)
cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
# 花费/时间(每小时)
if total_online_hours[idx] > 0:
cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx])
cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
# Token/时间(每小时)
if total_online_hours[idx] > 0:
tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx])
tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
# 花费/回复数量每100条
if total_replies[idx] > 0:
cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100)
cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
# 生成时间标签
if interval_hours == 1:
time_labels = [t.strftime("%H:%M") for t in time_points]
else:
time_labels = [t.strftime("%m-%d") for t in time_points]
return {
"time_labels": time_labels,
"cost_per_100_messages": cost_per_100_messages,
@@ -1894,7 +1923,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": tokens_per_hour,
"cost_per_100_replies": cost_per_100_replies,
}
def _generate_metrics_tab(self, metrics_data: dict) -> str:
"""生成指标趋势图表选项卡HTML内容"""
colors = {
@@ -1903,7 +1932,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": "#c7bbff",
"cost_per_100_replies": "#d9ceff",
}
return f"""
<div id="metrics" class="tab-content">
<h2>指标趋势图表</h2>

View File

@@ -4,14 +4,11 @@ import time
import jieba
import json
import ast
import numpy as np
from collections import Counter
from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool:
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射
Args:
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
Returns:
字典,键为平台名,值为账号
"""
@@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号
Args:
platform: 当前消息的平台
platform_accounts: 从 platforms 列表解析的平台账号映射
qq_account: QQ 账号(兼容旧配置)
Returns:
当前平台对应的账号
"""
@@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
@@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text):
elif re.search(
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text
):
is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
@@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
1. 识别分割点(, 。 ; 空格),但如果分割点左右都是英文字母则不分割。
@@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
prev_char = text[i - 1]
next_char = text[i + 1]
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ':
if char == " ":
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
@@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
"不知道",
"不晓得",
"懒得说",
"()"
"()",
]
return random.choice(default_replies)
@@ -469,7 +467,6 @@ def calculate_typing_time(
return total_time # 加上回车时间
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
return f"{message[:max_length]}..." if len(message) > max_length else message
@@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars)
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式

View File

@@ -103,14 +103,16 @@ class ImageManager:
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = Images.delete().where(
(Images.description >> None) | (Images.description << invalid_values)
).execute()
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
# 清理 ImageDescriptions 表
deleted_descriptions = ImageDescriptions.delete().where(
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
).execute()
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")