Ruff format
This commit is contained in:
@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
|
||||
from src.bw_learner.learner_utils import (
|
||||
filter_message_content,
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
)
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
@@ -77,8 +82,6 @@ def init_prompt() -> None:
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
@@ -95,12 +98,12 @@ class ExpressionLearner:
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
async def learn_and_store(
|
||||
self,
|
||||
self,
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
num: 学习数量
|
||||
@@ -108,7 +111,7 @@ class ExpressionLearner:
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
|
||||
random_msg = messages
|
||||
|
||||
# 学习用(开启行编号,便于溯源)
|
||||
@@ -134,26 +137,26 @@ class ExpressionLearner:
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
|
||||
|
||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||
if len(expressions) > 10:
|
||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
|
||||
# 检查黑话数量,如果超过30个则放弃本次黑话学习
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
if jargon_entries:
|
||||
await self._process_jargon_entries(jargon_entries, random_msg)
|
||||
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"学习的prompt: {prompt}")
|
||||
logger.info(f"学习的expressions: {expressions}")
|
||||
logger.info(f"学习的jargon_entries: {jargon_entries}")
|
||||
@@ -175,18 +178,17 @@ class ExpressionLearner:
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
continue
|
||||
|
||||
|
||||
context = filter_message_content(current_msg.processed_plain_text or "")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
@@ -270,37 +272,38 @@ class ExpressionLearner:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '\\':
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
@@ -312,13 +315,13 @@ class ExpressionLearner:
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
@@ -346,12 +349,12 @@ class ExpressionLearner:
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
@@ -503,59 +506,59 @@ class ExpressionLearner:
|
||||
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||
"""
|
||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||
|
||||
|
||||
Args:
|
||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||
messages: 消息列表,用于构建上下文
|
||||
"""
|
||||
if not jargon_entries or not messages:
|
||||
return
|
||||
|
||||
|
||||
# 获取 jargon_miner 实例
|
||||
jargon_miner = miner_manager.get_miner(self.chat_id)
|
||||
|
||||
|
||||
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
|
||||
entries: List[Dict[str, List[str]]] = []
|
||||
|
||||
|
||||
for content, source_id in jargon_entries:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否包含机器人名称
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
|
||||
continue
|
||||
|
||||
|
||||
# 解析 source_id
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
|
||||
# build_anonymous_messages 的编号从 1 开始
|
||||
line_index = int(source_id_str) - 1
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = messages[line_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
|
||||
# 构建上下文段落
|
||||
context_paragraph = build_context_paragraph(messages, line_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
|
||||
continue
|
||||
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
|
||||
# 调用 jargon_miner 处理这些条目
|
||||
await jargon_miner.process_extracted_entries(entries)
|
||||
|
||||
|
||||
@@ -82,9 +82,7 @@ class ExpressionReflector:
|
||||
# 获取未检查的表达
|
||||
try:
|
||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (
|
||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
)
|
||||
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
|
||||
expr_list = list(expressions)
|
||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||
|
||||
@@ -128,9 +128,7 @@ class ExpressionSelector:
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
& (~Expression.rejected)
|
||||
& (Expression.count > 1)
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
@@ -150,12 +148,15 @@ class ExpressionSelector:
|
||||
# 要求至少有10个 count > 1 的表达方式才进行选择
|
||||
min_required = 10
|
||||
if len(style_exprs) < min_required:
|
||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
# 固定选择5个
|
||||
select_count = 5
|
||||
import random
|
||||
|
||||
selected_style = random.sample(style_exprs, select_count)
|
||||
|
||||
# 更新last_active_time
|
||||
@@ -163,7 +164,9 @@ class ExpressionSelector:
|
||||
self.update_expressions_last_active_time(selected_style)
|
||||
|
||||
selected_ids = [expr["id"] for expr in selected_style]
|
||||
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个")
|
||||
logger.debug(
|
||||
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||
)
|
||||
return selected_style, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
@@ -186,9 +189,7 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -246,7 +247,9 @@ class ExpressionSelector:
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
|
||||
return await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
@@ -275,14 +278,12 @@ class ExpressionSelector:
|
||||
# think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择
|
||||
if think_level == 0:
|
||||
return self._select_expressions_simple(chat_id, max_num)
|
||||
|
||||
|
||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
@@ -299,29 +300,33 @@ class ExpressionSelector:
|
||||
|
||||
# 分离 count > 1 和 count <= 1 的表达方式
|
||||
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
|
||||
|
||||
|
||||
# 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回)
|
||||
min_high_count = 10
|
||||
min_total_count = 10
|
||||
select_high_count = 5
|
||||
select_random_count = 5
|
||||
|
||||
|
||||
# 检查数量要求
|
||||
if len(high_count_exprs) < min_high_count:
|
||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
|
||||
if len(all_style_exprs) < min_total_count:
|
||||
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
|
||||
# 先选取高count的表达方式
|
||||
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
|
||||
|
||||
|
||||
# 然后从所有表达方式中随机抽样(使用加权抽样)
|
||||
remaining_num = select_random_count
|
||||
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
|
||||
|
||||
|
||||
# 合并候选池(去重,避免重复)
|
||||
candidate_exprs = selected_high.copy()
|
||||
candidate_ids = {expr["id"] for expr in candidate_exprs}
|
||||
@@ -329,9 +334,10 @@ class ExpressionSelector:
|
||||
if expr["id"] not in candidate_ids:
|
||||
candidate_exprs.append(expr)
|
||||
candidate_ids.add(expr["id"])
|
||||
|
||||
|
||||
# 打乱顺序,避免高count的都在前面
|
||||
import random
|
||||
|
||||
random.shuffle(candidate_exprs)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
@@ -351,7 +357,7 @@ class ExpressionSelector:
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
|
||||
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
|
||||
@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.jargon_miner import search_jargon
|
||||
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
)
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
@@ -33,23 +31,23 @@ logger = get_logger("jargon")
|
||||
def _is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
'\u4e00' <= char <= '\u9fff' or # 汉字
|
||||
'a' <= char <= 'z' or # 小写字母
|
||||
'A' <= char <= 'Z' or # 大写字母
|
||||
'0' <= char <= '9' # 数字
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
)
|
||||
|
||||
|
||||
@@ -195,7 +193,7 @@ class JargonMiner:
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
|
||||
self.llm_inference = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.inference",
|
||||
@@ -207,7 +205,7 @@ class JargonMiner:
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@@ -299,17 +297,19 @@ class JargonMiner:
|
||||
# 获取当前count和上一次的meaning
|
||||
current_count = jargon_obj.count or 0
|
||||
previous_meaning = jargon_obj.meaning or ""
|
||||
|
||||
|
||||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||||
# 计算要保留的数量(至少保留1个)
|
||||
keep_count = max(1, len(raw_content_list) // 2)
|
||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
|
||||
logger.info(
|
||||
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||
)
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
|
||||
|
||||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||||
previous_meaning_section = ""
|
||||
previous_meaning_instruction = ""
|
||||
@@ -318,8 +318,10 @@ class JargonMiner:
|
||||
**上一次推断的含义(仅供参考)**
|
||||
{previous_meaning}
|
||||
"""
|
||||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
|
||||
previous_meaning_instruction = (
|
||||
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
)
|
||||
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
@@ -481,7 +483,7 @@ class JargonMiner:
|
||||
async def run_once(self, messages: List[Any]) -> None:
|
||||
"""
|
||||
运行一次黑话提取
|
||||
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
"""
|
||||
@@ -650,7 +652,9 @@ class JargonMiner:
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
json.loads(obj.raw_content)
|
||||
if isinstance(obj.raw_content, str)
|
||||
else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
@@ -726,13 +730,13 @@ class JargonMiner:
|
||||
async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
|
||||
Args:
|
||||
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
|
||||
"""
|
||||
if not entries:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
@@ -876,8 +880,6 @@ class JargonMinerManager:
|
||||
miner_manager = JargonMinerManager()
|
||||
|
||||
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
|
||||
@@ -15,25 +15,25 @@ class MessageRecorder:
|
||||
"""
|
||||
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
|
||||
# 维护每个chat的上次提取时间
|
||||
self.last_extraction_time: float = time.time()
|
||||
|
||||
|
||||
# 提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 获取 expression 和 jargon 的配置参数
|
||||
self._init_parameters()
|
||||
|
||||
|
||||
# 获取 expression_learner 和 jargon_miner 实例
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
|
||||
self.jargon_miner = miner_manager.get_miner(chat_id)
|
||||
|
||||
|
||||
def _init_parameters(self) -> None:
|
||||
"""初始化提取参数"""
|
||||
# 获取 expression 配置
|
||||
@@ -42,17 +42,17 @@ class MessageRecorder:
|
||||
)
|
||||
self.min_messages_for_extraction = 30
|
||||
self.min_extraction_interval = 60
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
|
||||
f"min_messages={self.min_messages_for_extraction}, "
|
||||
f"min_interval={self.min_extraction_interval}"
|
||||
)
|
||||
|
||||
|
||||
def should_trigger_extraction(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发消息提取
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发提取
|
||||
"""
|
||||
@@ -60,19 +60,19 @@ class MessageRecorder:
|
||||
time_diff = time.time() - self.last_extraction_time
|
||||
if time_diff < self.min_extraction_interval:
|
||||
return False
|
||||
|
||||
|
||||
# 检查消息数量
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_extraction_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def extract_and_distribute(self) -> None:
|
||||
"""
|
||||
提取消息并分发给 expression_learner 和 jargon_miner
|
||||
@@ -82,41 +82,40 @@ class MessageRecorder:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger_extraction():
|
||||
return
|
||||
|
||||
|
||||
# 检查 chat_stream 是否存在
|
||||
if not self.chat_stream:
|
||||
return
|
||||
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_extraction_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
|
||||
# 立即更新提取时间,防止并发触发
|
||||
self.last_extraction_time = extraction_end_time
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
|
||||
|
||||
|
||||
# 拉取提取窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
)
|
||||
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
|
||||
return
|
||||
|
||||
|
||||
# 按时间排序,确保顺序一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||
# 传递提取的消息,避免它们重复获取
|
||||
# 触发 expression 学习(如果启用)
|
||||
@@ -124,28 +123,26 @@ class MessageRecorder:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
|
||||
)
|
||||
|
||||
|
||||
# 触发 jargon 提取(如果启用),传递消息
|
||||
# if self.enable_jargon_learning:
|
||||
# asyncio.create_task(
|
||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
# )
|
||||
|
||||
# asyncio.create_task(
|
||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
messages: List[Any]
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
@@ -154,7 +151,7 @@ class MessageRecorder:
|
||||
try:
|
||||
# 传递消息给 ExpressionLearner(必需参数)
|
||||
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
|
||||
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
@@ -162,17 +159,15 @@ class MessageRecorder:
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def _trigger_jargon_extraction(
|
||||
self,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
messages: List[Any]
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 jargon 提取,使用指定的消息列表
|
||||
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
@@ -181,19 +176,20 @@ class MessageRecorder:
|
||||
try:
|
||||
# 传递消息给 JargonMiner,避免它重复获取
|
||||
await self.jargon_miner.run_once(messages=messages)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class MessageRecorderManager:
|
||||
"""MessageRecorder 管理器"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._recorders: dict[str, MessageRecorder] = {}
|
||||
|
||||
|
||||
def get_recorder(self, chat_id: str) -> MessageRecorder:
|
||||
"""获取或创建指定 chat_id 的 MessageRecorder"""
|
||||
if chat_id not in self._recorders:
|
||||
@@ -208,10 +204,9 @@ recorder_manager = MessageRecorderManager()
|
||||
async def extract_and_distribute_messages(chat_id: str) -> None:
|
||||
"""
|
||||
统一的消息提取和分发入口函数
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
"""
|
||||
recorder = recorder_manager.get_recorder(chat_id)
|
||||
await recorder.extract_and_distribute()
|
||||
|
||||
|
||||
@@ -176,19 +176,19 @@ class BrainChatting:
|
||||
# 如果有新消息,更新 last_read_time
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
|
||||
|
||||
# 总是执行一次思考迭代(不管有没有新消息)
|
||||
# wait 动作会在其内部等待,不需要在这里处理
|
||||
should_continue = await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
|
||||
if not should_continue:
|
||||
# 选择了 complete_talk,返回 False 表示需要等待新消息
|
||||
return False
|
||||
|
||||
|
||||
# 继续下一次迭代(除非选择了 complete_talk)
|
||||
# 短暂等待后再继续,避免过于频繁的循环
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
@@ -328,9 +328,7 @@ class BrainChatting:
|
||||
)
|
||||
|
||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||
has_complete_talk = any(
|
||||
action.action_type == "complete_talk" for action in action_to_use_info
|
||||
)
|
||||
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||
|
||||
# 并行执行所有动作
|
||||
action_tasks = [
|
||||
@@ -430,12 +428,12 @@ class BrainChatting:
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
|
||||
async def _wait_for_new_message(self):
|
||||
"""等待新消息到达"""
|
||||
last_check_time = self.last_read_time
|
||||
check_interval = 1.0 # 每秒检查一次
|
||||
|
||||
|
||||
while self.running:
|
||||
# 检查是否有新消息
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
@@ -448,13 +446,13 @@ class BrainChatting:
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并返回
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
|
||||
return
|
||||
|
||||
|
||||
# 等待一段时间后再次检查
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
@@ -660,9 +658,9 @@ class BrainChatting:
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒")
|
||||
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -673,12 +671,12 @@ class BrainChatting:
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="wait",
|
||||
)
|
||||
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
|
||||
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
@@ -693,9 +691,9 @@ class BrainChatting:
|
||||
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换")
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒")
|
||||
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -706,12 +704,12 @@ class BrainChatting:
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="listening",
|
||||
)
|
||||
|
||||
|
||||
# 等待指定时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
|
||||
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
|
||||
@@ -147,7 +147,7 @@ class BrainPlanner:
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
|
||||
# 计划日志记录
|
||||
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
|
||||
|
||||
@@ -203,9 +203,11 @@ class BrainPlanner:
|
||||
# 内部保留动作(不依赖插件系统)
|
||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||
|
||||
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||
)
|
||||
|
||||
# 将 listening 转换为 wait(向后兼容)
|
||||
if action == "listening":
|
||||
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换")
|
||||
@@ -521,7 +523,7 @@ class BrainPlanner:
|
||||
if json_objects:
|
||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
for i, json_obj in enumerate(json_objects):
|
||||
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
|
||||
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||
@@ -553,7 +555,9 @@ class BrainPlanner:
|
||||
|
||||
return extracted_reasoning, actions
|
||||
|
||||
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""创建complete_talk"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
@@ -564,7 +568,7 @@ class BrainPlanner:
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
"""添加计划日志"""
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
|
||||
@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
|
||||
emoji.description = emoji_data.description
|
||||
# Deserialize emotion string from DB to list
|
||||
emoji.emotion = emoji_data.emotion.replace(",",",").split(",") if emoji_data.emotion else []
|
||||
emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else []
|
||||
emoji.usage_count = emoji_data.usage_count
|
||||
|
||||
db_last_used_time = emoji_data.last_used_time
|
||||
@@ -732,7 +732,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.replace(",",",").split(",")
|
||||
return emoji_record.emotion.replace(",", ",").split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
@@ -993,7 +993,7 @@ class EmojiManager:
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.replace(",",",").split(",") if e.strip()]
|
||||
emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
|
||||
@@ -619,13 +619,13 @@ class HeartFChatting:
|
||||
think_level = 0
|
||||
# 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason
|
||||
planner_reasoning = action_planner_info.action_reasoning or reason
|
||||
|
||||
|
||||
record_replyer_action_temp(
|
||||
chat_id=self.stream_id,
|
||||
reason=reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
|
||||
@@ -123,7 +123,11 @@ class ChatBot:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return True, response, not bool(intercept_message_level) # 找到命令,根据intercept_message决定是否继续
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
|
||||
@@ -263,7 +263,7 @@ class MessageRecv(Message):
|
||||
desc = segment.data.get("desc", "") # 内容描述
|
||||
source_url = segment.data.get("source_url", "") # 原始链接
|
||||
url = segment.data.get("url", "") # 小程序链接
|
||||
text = f"[小程序分享"
|
||||
text = "[小程序分享"
|
||||
if title:
|
||||
text += f" - {title}"
|
||||
text += "]"
|
||||
|
||||
@@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool:
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
"""
|
||||
from maim_message import Seg
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
@@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list:
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append({
|
||||
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
|
||||
})
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
||||
import time
|
||||
@@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
# 否则使用 rich 类型,包含完整的消息段
|
||||
|
||||
@@ -77,8 +77,7 @@ target_message_id为必填,表示触发消息的id
|
||||
```""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{action_name}
|
||||
|
||||
@@ -250,7 +250,12 @@ class DefaultReplyer:
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
|
||||
self.chat_stream.stream_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
reply_reason=reply_reason,
|
||||
think_level=think_level,
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -273,7 +278,6 @@ class DefaultReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -788,7 +792,8 @@ class DefaultReplyer:
|
||||
# 并行执行八个构建任务(包括黑话解释)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
|
||||
"expression_habits",
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
@@ -980,7 +985,6 @@ class DefaultReplyer:
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
|
||||
|
||||
@@ -287,7 +287,6 @@ class PrivateReplyer:
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
@@ -907,16 +906,11 @@ class PrivateReplyer:
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
def init_replyer_private_prompt():
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
@@ -17,9 +18,9 @@ def init_replyer_private_prompt():
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
"private_replyer_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
@@ -37,4 +38,4 @@ def init_replyer_private_prompt():
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ def init_replyer_prompt():
|
||||
现在,你说:""",
|
||||
"replyer_prompt_0",
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
@@ -44,4 +44,3 @@ def init_replyer_prompt():
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
@@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
@@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODULE][module_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
|
||||
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
@@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
|
||||
@@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
|
||||
"""
|
||||
临时记录replyer动作被选择的信息(仅群聊)
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reason: 选择理由
|
||||
@@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
|
||||
# 确保data/temp目录存在
|
||||
temp_dir = "data/temp"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 创建记录数据
|
||||
record_data = {
|
||||
"chat_id": chat_id,
|
||||
@@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
|
||||
"think_level": think_level,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# 生成文件名(使用时间戳避免冲突)
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"replyer_action_{timestamp_str}.json"
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
|
||||
|
||||
# 写入文件
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(record_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
|
||||
except Exception as e:
|
||||
logger.warning(f"记录replyer动作选择失败: {e}")
|
||||
|
||||
@@ -130,12 +130,10 @@ class ImageManager:
|
||||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
)
|
||||
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
@@ -166,7 +164,7 @@ class ImageManager:
|
||||
|
||||
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
|
||||
"""如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
|
||||
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_hash: 图片的MD5哈希值
|
||||
@@ -174,7 +172,7 @@ class ImageManager:
|
||||
"""
|
||||
if not global_config.emoji.steal_emoji:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -236,12 +234,16 @@ class ImageManager:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
elif cache_record.description:
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
|
||||
@@ -609,23 +609,23 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
fields = list(model._meta.fields.keys())
|
||||
# Peewee 默认使用 'id' 作为主键字段名
|
||||
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
||||
primary_key_name = 'id' # 默认值
|
||||
primary_key_name = "id" # 默认值
|
||||
try:
|
||||
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
|
||||
if hasattr(model._meta.primary_key, 'name'):
|
||||
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
|
||||
if hasattr(model._meta.primary_key, "name"):
|
||||
primary_key_name = model._meta.primary_key.name
|
||||
elif isinstance(model._meta.primary_key, str):
|
||||
primary_key_name = model._meta.primary_key
|
||||
except Exception:
|
||||
pass # 如果获取失败,使用默认值 'id'
|
||||
|
||||
|
||||
# 如果字段列表包含主键,则排除它
|
||||
if primary_key_name in fields:
|
||||
fields_without_pk = [f for f in fields if f != primary_key_name]
|
||||
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
|
||||
else:
|
||||
fields_without_pk = fields
|
||||
|
||||
|
||||
fields_str = ", ".join(fields_without_pk)
|
||||
|
||||
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||
|
||||
@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
return obj
|
||||
|
||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
||||
should_multiline = depth == 0 and len(obj) > threshold
|
||||
|
||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||
if isinstance(obj, Array):
|
||||
@@ -46,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||
# 普通 list:转换为 tomlkit 数组
|
||||
arr = tomlkit.array()
|
||||
arr.multiline(should_multiline)
|
||||
|
||||
|
||||
for item in obj:
|
||||
arr.append(_format_toml_value(item, threshold, depth + 1))
|
||||
return arr
|
||||
@@ -112,7 +112,7 @@ def save_toml_with_format(
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
output = re.sub(r'\n{3,}', '\n\n', output)
|
||||
output = re.sub(r"\n{3,}", "\n\n", output)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(output)
|
||||
|
||||
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||
output = tomlkit.dumps(formatted)
|
||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||
return re.sub(r'\n{3,}', '\n\n', output)
|
||||
return re.sub(r"\n{3,}", "\n\n", output)
|
||||
|
||||
@@ -778,9 +778,9 @@ class DreamConfig(ConfigBase):
|
||||
"""
|
||||
if not self.dream_time_ranges:
|
||||
return True
|
||||
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
|
||||
for time_range in self.dream_time_ranges:
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
@@ -790,7 +790,7 @@ class DreamConfig(ConfigBase):
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -800,4 +800,4 @@ class DreamConfig(ConfigBase):
|
||||
if self.max_iterations < 1:
|
||||
raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}")
|
||||
if self.first_delay_seconds < 0:
|
||||
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
|
||||
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import ChatHistory, Jargon
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.plugin_system.apis import llm_api
|
||||
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
|
||||
)
|
||||
|
||||
|
||||
|
||||
class DreamTool:
|
||||
"""dream 模块内部使用的简易工具封装"""
|
||||
|
||||
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
|
||||
"search_chat_history",
|
||||
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
||||
[
|
||||
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
|
||||
(
|
||||
"keyword",
|
||||
ToolParamType.STRING,
|
||||
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
||||
],
|
||||
search_chat_history,
|
||||
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
|
||||
[
|
||||
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
||||
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
||||
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
|
||||
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
|
||||
(
|
||||
"keywords",
|
||||
ToolParamType.STRING,
|
||||
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"key_point",
|
||||
ToolParamType.STRING,
|
||||
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
|
||||
True,
|
||||
None,
|
||||
),
|
||||
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
||||
],
|
||||
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
|
||||
"finish_maintenance",
|
||||
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
||||
[
|
||||
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。", False, None),
|
||||
(
|
||||
"reason",
|
||||
ToolParamType.STRING,
|
||||
"结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
],
|
||||
finish_maintenance,
|
||||
)
|
||||
@@ -246,7 +268,7 @@ async def run_dream_agent_once(
|
||||
"""
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.dream.max_iterations
|
||||
|
||||
|
||||
start_ts = time.time()
|
||||
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations} 轮")
|
||||
|
||||
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
|
||||
else "未知"
|
||||
)
|
||||
end_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
||||
if record.end_time
|
||||
else "未知"
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||
)
|
||||
detail_text = (
|
||||
f"ID={record.id}\n"
|
||||
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
|
||||
start_detail_builder = MessageBuilder()
|
||||
start_detail_builder.set_role(RoleType.User)
|
||||
start_detail_builder.add_text_content(
|
||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
|
||||
+ detail_text
|
||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
|
||||
)
|
||||
conversation_messages.append(start_detail_builder.build())
|
||||
else:
|
||||
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
|
||||
conversation_messages.append(round_info_builder.build())
|
||||
|
||||
# 调用 LLM 让其决定是否要使用工具
|
||||
success, response, reasoning_content, model_name, tool_calls = (
|
||||
await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_defs,
|
||||
request_type="dream.react",
|
||||
)
|
||||
(
|
||||
success,
|
||||
response,
|
||||
reasoning_content,
|
||||
model_name,
|
||||
tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_defs,
|
||||
request_type="dream.react",
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -522,7 +545,7 @@ async def start_dream_scheduler(
|
||||
|
||||
if interval_seconds is None:
|
||||
interval_seconds = global_config.dream.interval_minutes * 60
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s,之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
|
||||
)
|
||||
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
|
||||
|
||||
# 初始化提示词
|
||||
init_dream_prompts()
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ async def generate_dream_summary(
|
||||
try:
|
||||
import json
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
@@ -98,11 +98,11 @@ async def generate_dream_summary(
|
||||
else:
|
||||
content = str(msg.content)
|
||||
tool_results_map[msg.tool_call_id] = content
|
||||
|
||||
|
||||
# 第二步:详细记录所有工具调用操作和结果到日志
|
||||
tool_call_count = 0
|
||||
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
|
||||
|
||||
|
||||
for msg in conversation_messages:
|
||||
if msg.role == RoleType.Assistant and msg.tool_calls:
|
||||
tool_call_count += 1
|
||||
@@ -110,34 +110,38 @@ async def generate_dream_summary(
|
||||
thought_content = ""
|
||||
if msg.content:
|
||||
if isinstance(msg.content, list) and msg.content:
|
||||
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
thought_content = (
|
||||
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||
)
|
||||
else:
|
||||
thought_content = str(msg.content)
|
||||
|
||||
|
||||
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
||||
if thought_content:
|
||||
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
|
||||
|
||||
logger.info(
|
||||
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
|
||||
)
|
||||
|
||||
# 记录每个工具调用的详细信息
|
||||
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
tool_call_id = tool_call.call_id
|
||||
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
|
||||
|
||||
|
||||
# 格式化参数
|
||||
try:
|
||||
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
|
||||
except Exception:
|
||||
args_str = str(tool_args)
|
||||
|
||||
|
||||
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
|
||||
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
|
||||
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
|
||||
logger.info(f"[dream][工具调用详情] {'-' * 60}")
|
||||
|
||||
|
||||
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
|
||||
|
||||
|
||||
# 第三步:构建对话历史摘要(用于生成梦境)
|
||||
conversation_summary = []
|
||||
for msg in conversation_messages:
|
||||
@@ -145,11 +149,11 @@ async def generate_dream_summary(
|
||||
content = ""
|
||||
if msg.content:
|
||||
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
|
||||
|
||||
|
||||
if role == "user" and "轮次信息" in content:
|
||||
# 跳过轮次信息消息
|
||||
continue
|
||||
|
||||
|
||||
if role == "assistant":
|
||||
# 只保留思考内容,简化工具调用信息
|
||||
if content:
|
||||
@@ -162,13 +166,13 @@ async def generate_dream_summary(
|
||||
# 截取前300字符
|
||||
content_preview = content[:300] + ("..." if len(content) > 300 else "")
|
||||
conversation_summary.append(f"[工具执行] {content_preview}")
|
||||
|
||||
|
||||
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
|
||||
|
||||
|
||||
# 随机选择2个梦境风格
|
||||
selected_styles = get_random_dream_styles(2)
|
||||
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
|
||||
|
||||
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
|
||||
|
||||
# 使用 Prompt 管理器格式化梦境生成 prompt
|
||||
dream_prompt = await global_prompt_manager.format_prompt(
|
||||
"dream_summary_prompt",
|
||||
@@ -186,13 +190,14 @@ async def generate_dream_summary(
|
||||
max_tokens=512,
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
|
||||
if dream_content:
|
||||
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
|
||||
else:
|
||||
logger.warning("[dream][梦境总结] 未能生成梦境总结")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
||||
|
||||
init_dream_summary_prompt()
|
||||
|
||||
init_dream_summary_prompt()
|
||||
|
||||
@@ -4,8 +4,3 @@ dream agent 工具实现模块。
|
||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
|
||||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||
return f"delete_jargon 执行失败: {e}"
|
||||
|
||||
return delete_jargon
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
||||
return msg
|
||||
|
||||
return finish_maintenance
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
||||
|
||||
# 将时间戳转换为可读时间格式
|
||||
start_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
|
||||
if record.start_time
|
||||
else "未知"
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
|
||||
)
|
||||
end_time_str = (
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
||||
if record.end_time
|
||||
else "未知"
|
||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||
)
|
||||
|
||||
result = (
|
||||
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
||||
f"概括={record.summary or '无'}\n"
|
||||
f"关键信息={record.key_point or '无'}"
|
||||
)
|
||||
logger.debug(
|
||||
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
|
||||
)
|
||||
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_chat_history_detail 失败: {e}")
|
||||
return f"get_chat_history_detail 执行失败: {e}"
|
||||
|
||||
return get_chat_history_detail
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords)
|
||||
if isinstance(record.keywords, str)
|
||||
else record.keywords
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
|
||||
keywords_str = "、".join(keywords_list)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return (
|
||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
)
|
||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords)
|
||||
if isinstance(record.keywords, str)
|
||||
else record.keywords
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
for k in keywords_data:
|
||||
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
|
||||
keywords_str = "、".join(sorted(all_keywords_set))
|
||||
response_text = (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f"有关\"{search_label}\"的关键词:\n"
|
||||
f'有关"{search_label}"的关键词:\n'
|
||||
f"{keywords_str}"
|
||||
)
|
||||
else:
|
||||
response_text = (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f"有关\"{search_label}\"的关键词信息为空"
|
||||
f'有关"{search_label}"的关键词信息为空'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords)
|
||||
if isinstance(record.keywords, str)
|
||||
else record.keywords
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list) and keywords_data:
|
||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
|
||||
return f"search_chat_history 执行失败: {e}"
|
||||
|
||||
return search_chat_history
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
|
||||
if not keyword or not keyword.strip():
|
||||
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
||||
|
||||
logger.info(
|
||||
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
|
||||
)
|
||||
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
|
||||
|
||||
# 基础条件:只查 is_jargon=True 的记录
|
||||
query = Jargon.select().where(Jargon.is_jargon)
|
||||
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
|
||||
return f"search_jargon 执行失败: {e}"
|
||||
|
||||
return search_jargon
|
||||
|
||||
|
||||
|
||||
@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||
return f"update_chat_history 执行失败: {e}"
|
||||
|
||||
return update_chat_history
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||
return f"update_jargon 执行失败: {e}"
|
||||
|
||||
return update_jargon
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
|
||||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||
)
|
||||
# 更新批次后持久化
|
||||
self._persist_topic_cache()
|
||||
else:
|
||||
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
|
||||
else:
|
||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||
|
||||
# 检查“话题检查”触发条件
|
||||
should_check = False
|
||||
@@ -414,7 +414,7 @@ class ChatHistorySummarizer:
|
||||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
has_bot_message = False
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
|
||||
return
|
||||
|
||||
# 2. 构造编号后的消息字符串和参与者信息
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||
self._build_numbered_messages_for_llm(messages)
|
||||
)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
|
||||
)
|
||||
|
||||
if not success or not topic_to_indices:
|
||||
logger.error(
|
||||
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
|
||||
)
|
||||
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
|
||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||
return
|
||||
|
||||
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
|
||||
if not numbered_lines:
|
||||
return False, {}
|
||||
|
||||
history_topics_block = (
|
||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
)
|
||||
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
messages_block = "\n".join(numbered_lines)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
@@ -635,17 +633,17 @@ class ChatHistorySummarizer:
|
||||
json_str = None
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
|
||||
if matches:
|
||||
# 找到JSON代码块,使用第一个匹配
|
||||
json_str = matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||
# 查找第一个 [ 和最后一个 ]
|
||||
start_idx = response.find('[')
|
||||
end_idx = response.rfind(']')
|
||||
start_idx = response.find("[")
|
||||
end_idx = response.rfind("]")
|
||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||
json_str = response[start_idx:end_idx + 1].strip()
|
||||
json_str = response[start_idx : end_idx + 1].strip()
|
||||
else:
|
||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||
json_str = response.strip()
|
||||
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class LLMRequest:
|
||||
|
||||
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
|
||||
"""检查请求是否过慢并输出警告日志
|
||||
|
||||
|
||||
Args:
|
||||
time_cost: 请求耗时(秒)
|
||||
model_name: 使用的模型名称
|
||||
@@ -323,7 +323,7 @@ class LLMRequest:
|
||||
effective_temperature = (model_info.extra_params or {}).get("temperature")
|
||||
if effective_temperature is None:
|
||||
effective_temperature = self.model_for_task.temperature
|
||||
|
||||
|
||||
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
|
||||
effective_max_tokens = max_tokens
|
||||
if effective_max_tokens is None:
|
||||
@@ -332,7 +332,7 @@ class LLMRequest:
|
||||
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
|
||||
if effective_max_tokens is None:
|
||||
effective_max_tokens = self.model_for_task.max_tokens
|
||||
|
||||
|
||||
return await client.get_response(
|
||||
model_info=model_info,
|
||||
message_list=(compressed_messages or message_list),
|
||||
@@ -366,7 +366,9 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
|
||||
)
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
|
||||
except NetworkConnectionError as e:
|
||||
@@ -394,7 +396,9 @@ class LLMRequest:
|
||||
if e.status_code == 429 or e.status_code >= 500:
|
||||
retry_remain -= 1
|
||||
if retry_remain <= 0:
|
||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
|
||||
logger.error(
|
||||
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
|
||||
)
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||
|
||||
logger.warning(
|
||||
@@ -540,7 +544,5 @@ class LLMRequest:
|
||||
if e.__cause__:
|
||||
original_error_type = type(e.__cause__).__name__
|
||||
original_error_msg = str(e.__cause__)
|
||||
return (
|
||||
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
)
|
||||
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
return ""
|
||||
|
||||
@@ -113,7 +113,6 @@ class MainSystem:
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
|
||||
# 初始化聊天管理器
|
||||
await get_chat_manager()._initialize()
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
|
||||
@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _log_conversation_messages(
|
||||
conversation_messages: List[Message],
|
||||
head_prompt: Optional[str] = None,
|
||||
@@ -172,7 +170,9 @@ def _log_conversation_messages(
|
||||
|
||||
# 构建单条消息的日志信息
|
||||
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||
msg_info = (
|
||||
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||
)
|
||||
|
||||
# if full_content:
|
||||
# msg_info += f"\n{full_content}"
|
||||
@@ -185,8 +185,7 @@ def _log_conversation_messages(
|
||||
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
||||
|
||||
# if msg.tool_call_id:
|
||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||
|
||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||
|
||||
log_lines.append(msg_info)
|
||||
|
||||
@@ -330,7 +329,7 @@ async def _react_agent_solve_question(
|
||||
remaining_iterations=remaining_iterations,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
|
||||
# 后续迭代都复用第一次构建的head_prompt
|
||||
head_prompt = first_head_prompt
|
||||
|
||||
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
|
||||
)
|
||||
|
||||
# logger.info(
|
||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
# )
|
||||
|
||||
if not success:
|
||||
@@ -409,20 +408,20 @@ async def _react_agent_solve_question(
|
||||
"""从文本中解析finish_search函数调用,返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
|
||||
if not text:
|
||||
return None, None
|
||||
|
||||
|
||||
# 查找finish_search函数调用位置(不区分大小写)
|
||||
func_pattern = "finish_search"
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None, None
|
||||
|
||||
|
||||
# 查找函数调用的开始和结束位置
|
||||
# 从func_pos开始向后查找左括号
|
||||
start_pos = text.find("(", func_pos)
|
||||
if start_pos == -1:
|
||||
return None, None
|
||||
|
||||
|
||||
# 查找匹配的右括号(考虑嵌套)
|
||||
paren_count = 0
|
||||
end_pos = start_pos
|
||||
@@ -437,10 +436,10 @@ async def _react_agent_solve_question(
|
||||
else:
|
||||
# 没有找到匹配的右括号
|
||||
return None, None
|
||||
|
||||
|
||||
# 提取函数参数部分
|
||||
params_text = text[start_pos + 1 : end_pos]
|
||||
|
||||
|
||||
# 解析found_answer参数(布尔值,可能是true/false/True/False)
|
||||
found_answer = None
|
||||
found_answer_patterns = [
|
||||
@@ -454,49 +453,60 @@ async def _react_agent_solve_question(
|
||||
if match:
|
||||
found_answer = "true" in match.group(0).lower()
|
||||
break
|
||||
|
||||
|
||||
# 解析answer参数(字符串,使用extract_quoted_content)
|
||||
answer = extract_quoted_content(text, "finish_search", "answer")
|
||||
|
||||
|
||||
return found_answer, answer
|
||||
|
||||
|
||||
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
|
||||
|
||||
|
||||
if parsed_found_answer is not None:
|
||||
# 检测到finish_search函数调用格式
|
||||
if parsed_found_answer:
|
||||
# 找到了答案
|
||||
if parsed_answer:
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": parsed_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
|
||||
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{parsed_answer}",
|
||||
)
|
||||
|
||||
|
||||
return True, parsed_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||
step["actions"].append(
|
||||
{"action_type": "finish_search", "action_params": {"found_answer": False}}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
||||
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:通过finish_search文本格式判断未找到答案",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
|
||||
# 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
|
||||
@@ -514,44 +524,53 @@ async def _react_agent_solve_question(
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
|
||||
if tool_name == "finish_search":
|
||||
finish_search_found = tool_args.get("found_answer", False)
|
||||
finish_search_answer = tool_args.get("answer", "")
|
||||
|
||||
|
||||
if finish_search_found:
|
||||
# 找到了答案
|
||||
if finish_search_answer:
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "finish_search",
|
||||
"action_params": {"found_answer": True, "answer": finish_search_answer},
|
||||
}
|
||||
)
|
||||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
|
||||
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{finish_search_answer}",
|
||||
)
|
||||
|
||||
|
||||
return True, finish_search_answer, thinking_steps, False
|
||||
else:
|
||||
# found_answer为True但没有提供answer,视为错误
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||
step["observations"] = ["检测到finish_search工具调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
|
||||
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:通过finish_search工具判断未找到答案",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
|
||||
# 如果没有finish_search工具调用,继续处理其他工具
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
@@ -627,7 +646,7 @@ async def _react_agent_solve_question(
|
||||
observation_text += f"\n\n{jargon_info}"
|
||||
collected_info += f"\n{jargon_info}\n"
|
||||
logger.info(f"工具输出触发黑话解析: {new_concepts}")
|
||||
|
||||
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(observation_text)
|
||||
@@ -645,7 +664,7 @@ async def _react_agent_solve_question(
|
||||
elif iteration + 1 >= max_iterations:
|
||||
should_do_final_evaluation = True
|
||||
logger.info(f"ReAct Agent达到最大迭代次数(已迭代{iteration + 1}次),进入最终评估")
|
||||
|
||||
|
||||
if should_do_final_evaluation:
|
||||
# 获取必要变量用于最终评估
|
||||
tool_registry = get_tool_registry()
|
||||
@@ -653,7 +672,7 @@ async def _react_agent_solve_question(
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
current_iteration = iteration + 1
|
||||
remaining_iterations = 0
|
||||
|
||||
|
||||
# 提取函数调用中参数的值,支持单引号和双引号
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
"""从文本中提取函数调用中参数的值,支持单引号和双引号
|
||||
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
(
|
||||
eval_success,
|
||||
eval_response,
|
||||
eval_reasoning_content,
|
||||
eval_model_name,
|
||||
eval_tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools(
|
||||
evaluation_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[], # 最终评估阶段不提供工具
|
||||
@@ -739,7 +764,7 @@ async def _react_agent_solve_question(
|
||||
final_status="未找到答案:最终评估阶段LLM调用失败",
|
||||
)
|
||||
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
|
||||
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
|
||||
@@ -759,17 +784,17 @@ async def _react_agent_solve_question(
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
||||
"observations": ["最终评估阶段检测到found_answer"]
|
||||
"observations": ["最终评估阶段检测到found_answer"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
||||
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"找到答案:{found_answer_content}",
|
||||
)
|
||||
|
||||
|
||||
return True, found_answer_content, thinking_steps, False
|
||||
|
||||
# 如果评估为not_enough_info,返回空字符串(不返回任何信息)
|
||||
@@ -778,35 +803,37 @@ async def _react_agent_solve_question(
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
||||
"observations": ["最终评估阶段检测到not_enough_info"]
|
||||
"observations": ["最终评估阶段检测到not_enough_info"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
||||
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"未找到答案:{not_enough_info_reason}",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
# 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息)
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
|
||||
"observations": ["已到达最大迭代次数,无法找到答案"]
|
||||
"actions": [
|
||||
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
|
||||
],
|
||||
"observations": ["已到达最大迭代次数,无法找到答案"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
||||
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
# 如果正常迭代过程中提前找到答案返回,不会到达这里
|
||||
@@ -817,7 +844,7 @@ async def _react_agent_solve_question(
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:正常迭代结束",
|
||||
)
|
||||
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
|
||||
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
|
||||
else:
|
||||
max_iterations = base_max_iterations
|
||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||
logger.debug(f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒")
|
||||
logger.debug(
|
||||
f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||
)
|
||||
|
||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||
question_tasks = [
|
||||
@@ -1157,10 +1186,10 @@ async def build_memory_retrieval_prompt(
|
||||
|
||||
# 获取最近10分钟内已找到答案的缓存记录
|
||||
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
|
||||
|
||||
|
||||
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
|
||||
all_results = []
|
||||
|
||||
|
||||
# 先添加当前查询的结果
|
||||
current_questions = set()
|
||||
for result in question_results:
|
||||
@@ -1170,7 +1199,7 @@ async def build_memory_retrieval_prompt(
|
||||
if question_end != -1:
|
||||
current_questions.add(result[4:question_end])
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# 添加缓存答案(排除当前查询中已存在的问题)
|
||||
for cached_answer in cached_answers:
|
||||
if cached_answer.startswith("问题:"):
|
||||
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
|
||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
|
||||
@@ -47,4 +47,3 @@ def register_tool():
|
||||
],
|
||||
execute_func=finish_search,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def search_chat_history(
|
||||
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
|
||||
) -> str:
|
||||
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||
|
||||
Args:
|
||||
@@ -117,7 +115,7 @@ async def search_chat_history(
|
||||
)
|
||||
if kw_matched:
|
||||
matched_count += 1
|
||||
|
||||
|
||||
# 计算需要匹配的关键词数量
|
||||
total_keywords = len(keywords_lower)
|
||||
if total_keywords > 2:
|
||||
@@ -126,7 +124,7 @@ async def search_chat_history(
|
||||
else:
|
||||
# 关键词数量<=2,必须全部匹配
|
||||
required_matches = total_keywords
|
||||
|
||||
|
||||
keyword_matched = matched_count >= required_matches
|
||||
|
||||
# 两者都匹配(如果同时有participant和keyword,需要两者都匹配;如果只有一个条件,只需要该条件匹配)
|
||||
@@ -144,7 +142,9 @@ async def search_chat_history(
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
return (
|
||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
)
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
@@ -160,9 +160,7 @@ async def search_chat_history(
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = (
|
||||
json.loads(record.keywords)
|
||||
if isinstance(record.keywords, str)
|
||||
else record.keywords
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
for k in keywords_data:
|
||||
@@ -179,13 +177,12 @@ async def search_chat_history(
|
||||
keywords_str = "、".join(sorted(all_keywords_set))
|
||||
return (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f"有关\"{search_label}\"的关键词:\n"
|
||||
f'有关"{search_label}"的关键词:\n'
|
||||
f"{keywords_str}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||
f"有关\"{search_label}\"的关键词信息为空"
|
||||
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
|
||||
)
|
||||
|
||||
# 构建结果文本,返回id、theme和keywords(最多20条)
|
||||
|
||||
@@ -22,42 +22,42 @@ def get_current_token(
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
@@ -77,7 +77,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
@@ -96,32 +96,32 @@ def verify_auth_token_from_cookie_or_header(
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -63,14 +63,14 @@ class ChatHistoryManager:
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
@@ -414,7 +414,9 @@ async def websocket_chat(
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" 表情包管理 API 路由"""
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
|
||||
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""
|
||||
后台生成缩略图(在线程池中执行)
|
||||
|
||||
|
||||
生成完成后自动从 generating 集合中移除
|
||||
"""
|
||||
try:
|
||||
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
|
||||
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
生成缩略图并保存到缓存目录
|
||||
|
||||
|
||||
Args:
|
||||
source_path: 原图路径
|
||||
file_hash: 文件哈希值,用作缓存文件名
|
||||
|
||||
|
||||
Returns:
|
||||
缩略图路径
|
||||
|
||||
|
||||
Features:
|
||||
- GIF: 提取第一帧作为缩略图
|
||||
- 所有格式统一转为 WebP
|
||||
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
_ensure_thumbnail_cache_dir()
|
||||
cache_path = _get_thumbnail_cache_path(file_hash)
|
||||
|
||||
|
||||
# 使用锁防止并发生成同一缩略图
|
||||
lock = _get_thumbnail_lock(file_hash)
|
||||
with lock:
|
||||
# 双重检查,可能在等待锁时已被其他线程生成
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
# GIF 处理:提取第一帧
|
||||
if hasattr(img, 'n_frames') and img.n_frames > 1:
|
||||
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||
img.seek(0) # 确保在第一帧
|
||||
|
||||
|
||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||
if img.mode in ('P', 'PA'):
|
||||
if img.mode in ("P", "PA"):
|
||||
# 调色板模式转换为 RGBA 以保留透明度
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode == 'LA':
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode not in ('RGB', 'RGBA'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode == "LA":
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode not in ("RGB", "RGBA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 创建缩略图(保持宽高比)
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
# 保存为 WebP 格式
|
||||
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
|
||||
# 生成失败时不创建缓存文件,下次会重试
|
||||
raise
|
||||
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已不存在的缩略图)
|
||||
|
||||
|
||||
Returns:
|
||||
(清理数量, 保留数量)
|
||||
"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
|
||||
# 获取所有表情包的哈希值
|
||||
valid_hashes = set()
|
||||
for emoji in Emoji.select(Emoji.emoji_hash):
|
||||
valid_hashes.add(emoji.emoji_hash)
|
||||
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
|
||||
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
file_hash = cache_file.stem
|
||||
if file_hash not in valid_hashes:
|
||||
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
|
||||
else:
|
||||
kept += 1
|
||||
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||
|
||||
|
||||
return cleaned, kept
|
||||
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
@@ -365,7 +366,9 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def register_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def ban_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
|
||||
|
||||
Returns:
|
||||
表情包缩略图(WebP 格式)或原图
|
||||
|
||||
|
||||
Features:
|
||||
- 懒加载:首次请求时生成缩略图
|
||||
- 缓存:后续请求直接返回缓存
|
||||
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
|
||||
|
||||
# 1. 优先使用 Cookie
|
||||
if maibot_session and token_manager.verify_token(maibot_session):
|
||||
is_valid = True
|
||||
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
|
||||
auth_token = authorization.replace("Bearer ", "")
|
||||
if token_manager.verify_token(auth_token):
|
||||
is_valid = True
|
||||
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
|
||||
}
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path,
|
||||
media_type=media_type,
|
||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
|
||||
# 尝试获取或生成缩略图
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
|
||||
# 检查缓存是否存在
|
||||
if cache_path.exists():
|
||||
# 缓存命中,直接返回
|
||||
return FileResponse(
|
||||
path=str(cache_path),
|
||||
media_type="image/webp",
|
||||
filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
)
|
||||
|
||||
|
||||
# 缓存未命中,触发后台生成并返回 202
|
||||
with _generating_lock:
|
||||
if emoji.emoji_hash not in _generating_thumbnails:
|
||||
# 标记为正在生成
|
||||
_generating_thumbnails.add(emoji.emoji_hash)
|
||||
# 提交到线程池后台生成
|
||||
_thumbnail_executor.submit(
|
||||
_background_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
|
||||
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
|
||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1", # 建议 1 秒后重试
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_emojis(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
|
||||
|
||||
class ThumbnailCacheStatsResponse(BaseModel):
|
||||
"""缩略图缓存统计响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
cache_dir: str
|
||||
total_count: int
|
||||
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
|
||||
|
||||
class ThumbnailCleanupResponse(BaseModel):
|
||||
"""缩略图清理响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
cleaned_count: int
|
||||
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
|
||||
|
||||
class ThumbnailPreheatResponse(BaseModel):
|
||||
"""缩略图预热响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
generated_count: int
|
||||
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
|
||||
):
|
||||
"""
|
||||
获取缩略图缓存统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
缓存目录、缓存数量、总大小、覆盖率等统计信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
|
||||
# 统计缓存文件
|
||||
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
|
||||
total_count = len(cache_files)
|
||||
total_size = sum(f.stat().st_size for f in cache_files)
|
||||
total_size_mb = round(total_size / (1024 * 1024), 2)
|
||||
|
||||
|
||||
# 统计表情包总数
|
||||
emoji_count = Emoji.select().count()
|
||||
|
||||
|
||||
# 计算覆盖率
|
||||
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
|
||||
|
||||
|
||||
return ThumbnailCacheStatsResponse(
|
||||
success=True,
|
||||
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
|
||||
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
|
||||
emoji_count=emoji_count,
|
||||
coverage_percent=coverage_percent,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
|
||||
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
cleaned, kept = cleanup_orphaned_thumbnails()
|
||||
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=kept,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
预热缩略图缓存(提前生成未缓存的缩略图)
|
||||
|
||||
|
||||
优先处理使用次数高的表情包
|
||||
|
||||
|
||||
Args:
|
||||
limit: 最多预热数量 (1-1000)
|
||||
|
||||
|
||||
Returns:
|
||||
预热结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
|
||||
# 获取使用次数最高的表情包(未缓存的优先)
|
||||
emojis = (
|
||||
Emoji.select()
|
||||
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(limit * 2) # 多查一些,因为有些可能已缓存
|
||||
)
|
||||
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
for emoji in emojis:
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
|
||||
# 已缓存,跳过
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
|
||||
# 原文件不存在,跳过
|
||||
if not os.path.exists(emoji.full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
_thumbnail_executor,
|
||||
_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
success=True,
|
||||
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed} 个",
|
||||
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
|
||||
skipped_count=skipped,
|
||||
failed_count=failed,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
清空所有缩略图缓存(下次访问时会重新生成)
|
||||
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
|
||||
cleaned_count=0,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
|
||||
cleaned = 0
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
try:
|
||||
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
|
||||
|
||||
|
||||
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
|
||||
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -256,7 +256,9 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -376,7 +385,9 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
@@ -277,17 +277,13 @@ async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
)
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
@@ -346,12 +342,7 @@ async def get_jargon_stats():
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
.count()
|
||||
)
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none(
|
||||
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
|
||||
)
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = (
|
||||
Jargon.update(is_jargon=is_jargon)
|
||||
.where(Jargon.id.in_(ids))
|
||||
.execute()
|
||||
)
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
|
||||
@@ -200,7 +200,9 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
|
||||
"""
|
||||
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||
"""
|
||||
|
||||
def _is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
|
||||
|
||||
|
||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||
async def get_available_mirrors(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> AvailableMirrorsResponse:
|
||||
"""
|
||||
获取所有可用的镜像源配置
|
||||
"""
|
||||
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
|
||||
|
||||
|
||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||
async def add_mirror(
|
||||
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
添加新的镜像源
|
||||
"""
|
||||
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
|
||||
|
||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||
async def update_mirror(
|
||||
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
mirror_id: str,
|
||||
request: UpdateMirrorRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
更新镜像源配置
|
||||
@@ -426,7 +434,9 @@ async def update_mirror(
|
||||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def delete_mirror(
|
||||
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
删除镜像源
|
||||
"""
|
||||
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
|
||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||
async def fetch_raw_file(
|
||||
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
request: FetchRawFileRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> FetchRawFileResponse:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
@@ -534,7 +546,9 @@ async def fetch_raw_file(
|
||||
|
||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||
async def clone_repository(
|
||||
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
request: CloneRepositoryRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> CloneRepositoryResponse:
|
||||
"""
|
||||
克隆 GitHub 仓库到本地
|
||||
@@ -574,7 +588,11 @@ async def clone_repository(
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def install_plugin(
|
||||
request: InstallPluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
安装插件
|
||||
|
||||
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
||||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(
|
||||
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
request: UninstallPluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
卸载插件
|
||||
@@ -913,7 +933,11 @@ async def uninstall_plugin(
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def update_plugin(
|
||||
request: UpdatePluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件
|
||||
|
||||
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def get_installed_plugins(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取已安装的插件列表
|
||||
|
||||
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def get_plugin_config_schema(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件配置 Schema
|
||||
|
||||
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def get_plugin_config(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件当前配置值
|
||||
|
||||
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
||||
|
||||
@router.put("/config/{plugin_id}")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
plugin_id: str,
|
||||
request: UpdatePluginConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件配置
|
||||
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def reset_plugin_config(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
重置插件配置为默认值
|
||||
|
||||
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def toggle_plugin(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
切换插件启用状态
|
||||
|
||||
|
||||
@@ -139,10 +139,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
@@ -158,23 +158,23 @@ async def check_auth_status(
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
return {"authenticated": True}
|
||||
@@ -211,7 +211,7 @@ async def update_token(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -222,7 +222,7 @@ async def update_token(
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
|
||||
# 如果更新成功,清除 Cookie,要求用户重新登录
|
||||
if success:
|
||||
clear_auth_cookie(response)
|
||||
@@ -263,7 +263,7 @@ async def regenerate_token(
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -271,7 +271,7 @@ async def regenerate_token(
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
|
||||
# 清除 Cookie,要求用户重新登录
|
||||
clear_auth_cookie(response)
|
||||
|
||||
@@ -306,7 +306,7 @@ async def get_setup_status(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -349,7 +349,7 @@ async def complete_setup(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -392,7 +392,7 @@ async def reset_setup(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
@@ -166,22 +166,22 @@ class TokenManager:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
|
||||
@@ -110,8 +110,10 @@ class WebUIServer:
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
@@ -166,6 +168,7 @@ class WebUIServer:
|
||||
def _check_port_available(self) -> bool:
|
||||
"""检查端口是否可用"""
|
||||
import socket
|
||||
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
|
||||
Reference in New Issue
Block a user