Ruff format

This commit is contained in:
墨梓柒
2025-12-13 17:14:09 +08:00
parent ef377bb0cd
commit e680a4d1f5
60 changed files with 1546 additions and 1532 deletions

View File

@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
@@ -77,8 +82,6 @@ def init_prompt() -> None:
Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
@@ -95,12 +98,12 @@ class ExpressionLearner:
self._learning_lock = asyncio.Lock()
async def learn_and_store(
self,
self,
messages: List[Any],
) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
Args:
messages: 外部传入的消息列表(必需)
num: 学习数量
@@ -108,7 +111,7 @@ class ExpressionLearner:
"""
if not messages:
return None
random_msg = messages
# 学习用(开启行编号,便于溯源)
@@ -134,26 +137,26 @@ class ExpressionLearner:
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = []
# 检查黑话数量如果超过30个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过30个实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
if jargon_entries:
await self._process_jargon_entries(jargon_entries, random_msg)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return []
logger.info(f"学习的prompt: {prompt}")
logger.info(f"学习的expressions: {expressions}")
logger.info(f"学习的jargon_entries: {jargon_entries}")
@@ -175,18 +178,17 @@ class ExpressionLearner:
# 当前行的原始内容
current_msg = random_msg[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
continue
context = filter_message_content(current_msg.processed_plain_text or "")
if not context:
continue
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
@@ -270,37 +272,38 @@ class ExpressionLearner:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == '\\':
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
@@ -312,13 +315,13 @@ class ExpressionLearner:
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return ''.join(result)
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
@@ -346,12 +349,12 @@ class ExpressionLearner:
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
@@ -503,59 +506,59 @@ class ExpressionLearner:
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
"""
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
messages: 消息列表,用于构建上下文
"""
if not jargon_entries or not messages:
return
# 获取 jargon_miner 实例
jargon_miner = miner_manager.get_miner(self.chat_id)
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
entries: List[Dict[str, List[str]]] = []
for content, source_id in jargon_entries:
content = content.strip()
if not content:
continue
# 检查是否包含机器人名称
if contains_bot_self_name(content):
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
continue
# 解析 source_id
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
continue
# build_anonymous_messages 的编号从 1 开始
line_index = int(source_id_str) - 1
if line_index < 0 or line_index >= len(messages):
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
continue
# 检查是否是机器人自己的消息
target_msg = messages[line_index]
if is_bot_message(target_msg):
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
continue
# 构建上下文段落
context_paragraph = build_context_paragraph(messages, line_index)
if not context_paragraph:
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
continue
entries.append({"content": content, "raw_content": [context_paragraph]})
if not entries:
return
# 调用 jargon_miner 处理这些条目
await jargon_miner.process_extracted_entries(entries)

View File

@@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达
try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")

View File

@@ -128,9 +128,7 @@ class ExpressionSelector:
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
& (~Expression.rejected)
& (Expression.count > 1)
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
style_exprs = [
@@ -150,12 +148,15 @@ class ExpressionSelector:
# 要求至少有10个 count > 1 的表达方式才进行选择
min_required = 10
if len(style_exprs) < min_required:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
)
return [], []
# 固定选择5个
select_count = 5
import random
selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time
@@ -163,7 +164,9 @@ class ExpressionSelector:
self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style]
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}")
logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids
except Exception as e:
@@ -186,9 +189,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
style_exprs = [
{
@@ -246,7 +247,9 @@ class ExpressionSelector:
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic(
self,
@@ -275,14 +278,12 @@ class ExpressionSelector:
# think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择
if think_level == 0:
return self._select_expressions_simple(chat_id, max_num)
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
all_style_exprs = [
{
"id": expr.id,
@@ -299,29 +300,33 @@ class ExpressionSelector:
# 分离 count > 1 和 count <= 1 的表达方式
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
# 根据 think_level 设置要求(仅支持 0/10 已在上方返回)
min_high_count = 10
min_total_count = 10
select_high_count = 5
select_random_count = 5
# 检查数量要求
if len(high_count_exprs) < min_high_count:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
)
return [], []
if len(all_style_exprs) < min_total_count:
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], []
# 先选取高count的表达方式
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
# 然后从所有表达方式中随机抽样(使用加权抽样)
remaining_num = select_random_count
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
# 合并候选池(去重,避免重复)
candidate_exprs = selected_high.copy()
candidate_ids = {expr["id"] for expr in candidate_exprs}
@@ -329,9 +334,10 @@ class ExpressionSelector:
if expr["id"] not in candidate_ids:
candidate_exprs.append(expr)
candidate_ids.add(expr["id"])
# 打乱顺序避免高count的都在前面
import random
random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表
@@ -351,7 +357,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""

View File

@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon")
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return ""
return ""

View File

@@ -1,4 +1,3 @@
import time
import json
import asyncio
import random
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import (
@@ -33,23 +31,23 @@ logger = get_logger("jargon")
def _is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话(单个汉字、英文或数字)
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
'\u4e00' <= char <= '\u9fff' or # 汉字
'a' <= char <= 'z' or # 小写字母
'A' <= char <= 'Z' or # 大写字母
'0' <= char <= '9' # 数字
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
)
@@ -195,7 +193,7 @@ class JargonMiner:
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
self.llm_inference = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.inference",
@@ -207,7 +205,7 @@ class JargonMiner:
self.stream_name = stream_name if stream_name else self.chat_id
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@@ -299,17 +297,19 @@ class JargonMiner:
# 获取当前count和上一次的meaning
current_count = jargon_obj.count or 0
previous_meaning = jargon_obj.meaning or ""
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
@@ -318,8 +318,10 @@ class JargonMiner:
**上一次推断的含义(仅供参考)**
{previous_meaning}
"""
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt",
content=content,
@@ -481,7 +483,7 @@ class JargonMiner:
async def run_once(self, messages: List[Any]) -> None:
"""
运行一次黑话提取
Args:
messages: 外部传入的消息列表(必需)
"""
@@ -650,7 +652,9 @@ class JargonMiner:
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
@@ -726,13 +730,13 @@ class JargonMiner:
async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None:
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
Args:
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
"""
if not entries:
return
try:
# 去重并合并raw_content按 content 聚合)
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
@@ -876,8 +880,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager()
def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:

View File

@@ -15,25 +15,25 @@ class MessageRecorder:
"""
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
"""
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()
# 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
# 获取 expression 和 jargon 的配置参数
self._init_parameters()
# 获取 expression_learner 和 jargon_miner 实例
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
self.jargon_miner = miner_manager.get_miner(chat_id)
def _init_parameters(self) -> None:
"""初始化提取参数"""
# 获取 expression 配置
@@ -42,17 +42,17 @@ class MessageRecorder:
)
self.min_messages_for_extraction = 30
self.min_extraction_interval = 60
logger.debug(
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
f"min_messages={self.min_messages_for_extraction}, "
f"min_interval={self.min_extraction_interval}"
)
def should_trigger_extraction(self) -> bool:
"""
检查是否应该触发消息提取
Returns:
bool: 是否应该触发提取
"""
@@ -60,19 +60,19 @@ class MessageRecorder:
time_diff = time.time() - self.last_extraction_time
if time_diff < self.min_extraction_interval:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_extraction_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
return False
return True
async def extract_and_distribute(self) -> None:
"""
提取消息并分发给 expression_learner 和 jargon_miner
@@ -82,41 +82,40 @@ class MessageRecorder:
# 在锁内检查,避免并发触发
if not self.should_trigger_extraction():
return
# 检查 chat_stream 是否存在
if not self.chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_extraction_time
extraction_end_time = time.time()
# 立即更新提取时间,防止并发触发
self.last_extraction_time = extraction_end_time
try:
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
# 拉取提取窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
)
if not messages:
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
return
# 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
@@ -124,28 +123,26 @@ class MessageRecorder:
asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
)
# 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning:
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 expression 学习,使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
@@ -154,7 +151,7 @@ class MessageRecorder:
try:
# 传递消息给 ExpressionLearner必需参数
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
@@ -162,17 +159,15 @@ class MessageRecorder:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback
traceback.print_exc()
async def _trigger_jargon_extraction(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 jargon 提取,使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
@@ -181,19 +176,20 @@ class MessageRecorder:
try:
# 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
class MessageRecorderManager:
"""MessageRecorder 管理器"""
def __init__(self) -> None:
self._recorders: dict[str, MessageRecorder] = {}
def get_recorder(self, chat_id: str) -> MessageRecorder:
"""获取或创建指定 chat_id 的 MessageRecorder"""
if chat_id not in self._recorders:
@@ -208,10 +204,9 @@ recorder_manager = MessageRecorderManager()
async def extract_and_distribute_messages(chat_id: str) -> None:
"""
统一的消息提取和分发入口函数
Args:
chat_id: 聊天流ID
"""
recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute()

View File

@@ -176,19 +176,19 @@ class BrainChatting:
# 如果有新消息,更新 last_read_time
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
# 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理
should_continue = await self._observe(recent_messages_list=recent_messages_list)
if not should_continue:
# 选择了 complete_talk返回 False 表示需要等待新消息
return False
# 继续下一次迭代(除非选择了 complete_talk
# 短暂等待后再继续,避免过于频繁的循环
await asyncio.sleep(0.1)
return True
async def _send_and_store_reply(
@@ -328,9 +328,7 @@ class BrainChatting:
)
# 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any(
action.action_type == "complete_talk" for action in action_to_use_info
)
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
# 并行执行所有动作
action_tasks = [
@@ -430,12 +428,12 @@ class BrainChatting:
await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _wait_for_new_message(self):
"""等待新消息到达"""
last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次
while self.running:
# 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat(
@@ -448,13 +446,13 @@ class BrainChatting:
filter_command=False,
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time 并返回
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return
# 等待一段时间后再次检查
await asyncio.sleep(check_interval)
@@ -660,9 +658,9 @@ class BrainChatting:
except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -673,12 +671,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {
@@ -693,9 +691,9 @@ class BrainChatting:
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换")
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -706,12 +704,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {

View File

@@ -147,7 +147,7 @@ class BrainPlanner:
) # 用于动作规划
self.last_obs_time_mark = 0.0
# 计划日志记录
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
@@ -203,9 +203,11 @@ class BrainPlanner:
# 内部保留动作(不依赖插件系统)
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容
if action == "listening":
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换")
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk"""
return [
ActionPlannerInfo(
@@ -564,7 +568,7 @@ class BrainPlanner:
available_actions=available_actions,
)
]
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
"""添加计划日志"""
self.plan_log.append((reasoning, time.time(), actions))

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",")
return emoji_record.emotion.replace("", ",").split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()]
emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:

View File

@@ -619,13 +619,13 @@ class HeartFChatting:
think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp(
chat_id=self.stream_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = f"[小程序分享"
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"

View File

@@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool:
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
Args:
segment: Seg 消息段对象
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
"""
from maim_message import Seg
result = []
if segment is None:
return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
@@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list:
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append({
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
})
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
@@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time
@@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段

View File

@@ -77,8 +77,7 @@ target_message_id为必填表示触发消息的id
```""",
"planner_prompt",
)
Prompt(
"""
{action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
)
if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else:
reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else:
reply_target_block = ""
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,9 +18,9 @@ def init_replyer_private_prompt():
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt",
)
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -37,4 +38,4 @@ def init_replyer_private_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。
""",
"private_replyer_self_prompt",
)
)

View File

@@ -23,7 +23,7 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt_0",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt",
)

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)

View File

@@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,
@@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODULE][module_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,

View File

@@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
"""
临时记录replyer动作被选择的信息仅群聊
Args:
chat_id: 聊天ID
reason: 选择理由
@@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
# 确保data/temp目录存在
temp_dir = "data/temp"
os.makedirs(temp_dir, exist_ok=True)
# 创建记录数据
record_data = {
"chat_id": chat_id,
@@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
"think_level": think_level,
"timestamp": datetime.now().isoformat(),
}
# 生成文件名(使用时间戳避免冲突)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"replyer_action_{timestamp_str}.json"
filepath = os.path.join(temp_dir, filename)
# 写入文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(record_data, f, ensure_ascii=False, indent=2)
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
except Exception as e:
logger.warning(f"记录replyer动作选择失败: {e}")

View File

@@ -130,12 +130,10 @@ class ImageManager:
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = (
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
logger.info(
@@ -166,7 +164,7 @@ class ImageManager:
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
"""如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录
Args:
image_base64: 图片的base64编码
image_hash: 图片的MD5哈希值
@@ -174,7 +172,7 @@ class ImageManager:
"""
if not global_config.emoji.steal_emoji:
return
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -236,12 +234,16 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)

View File

@@ -609,23 +609,23 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
fields = list(model._meta.fields.keys())
# Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = 'id' # 默认值
primary_key_name = "id" # 默认值
try:
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
if hasattr(model._meta.primary_key, 'name'):
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key
except Exception:
pass # 如果获取失败,使用默认值 'id'
# 如果字段列表包含主键,则排除它
if primary_key_name in fields:
fields_without_pk = [f for f in fields if f != primary_key_name]
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
else:
fields_without_pk = fields
fields_str = ", ".join(fields_without_pk)
# 检查是否有字段需要从 NULL 改为 NOT NULL

View File

@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj
# 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold)
should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array):
@@ -46,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
# 普通 list转换为 tomlkit 数组
arr = tomlkit.array()
arr.multiline(should_multiline)
for item in obj:
arr.append(_format_toml_value(item, threshold, depth + 1))
return arr
@@ -112,7 +112,7 @@ def save_toml_with_format(
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output)
output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f:
f.write(output)
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output)
return re.sub(r"\n{3,}", "\n\n", output)

View File

@@ -778,9 +778,9 @@ class DreamConfig(ConfigBase):
"""
if not self.dream_time_ranges:
return True
now_min = self._now_minutes()
for time_range in self.dream_time_ranges:
if not isinstance(time_range, str):
continue
@@ -790,7 +790,7 @@ class DreamConfig(ConfigBase):
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
return True
return False
def __post_init__(self):
@@ -800,4 +800,4 @@ class DreamConfig(ConfigBase):
if self.max_iterations < 1:
raise ValueError(f"max_iterations 必须至少为1当前值: {self.max_iterations}")
if self.first_delay_seconds < 0:
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")

View File

@@ -1,14 +1,13 @@
import asyncio
import random
import time
import json
from typing import Any, Dict, List, Optional, Tuple
from peewee import fn
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory, Jargon
from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
)
class DreamTool:
"""dream 模块内部使用的简易工具封装"""
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
"search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
(
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
],
search_chat_history,
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
(
"keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
],
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
"finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'", False, None),
(
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
],
finish_maintenance,
)
@@ -246,7 +268,7 @@ async def run_dream_agent_once(
"""
if max_iterations is None:
max_iterations = global_config.dream.max_iterations
start_ts = time.time()
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations}")
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
detail_text = (
f"ID={record.id}\n"
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
+ detail_text
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
)
conversation_messages.append(start_detail_builder.build())
else:
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具
success, response, reasoning_content, model_name, tool_calls = (
await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
if not success:
@@ -522,7 +545,7 @@ async def start_dream_scheduler(
if interval_seconds is None:
interval_seconds = global_config.dream.interval_minutes * 60
logger.info(
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
)
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
# 初始化提示词
init_dream_prompts()

View File

@@ -86,7 +86,7 @@ async def generate_dream_summary(
try:
import json
from src.chat.utils.prompt_builder import global_prompt_manager
# 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {}
for msg in conversation_messages:
@@ -98,11 +98,11 @@ async def generate_dream_summary(
else:
content = str(msg.content)
tool_results_map[msg.tool_call_id] = content
# 第二步:详细记录所有工具调用操作和结果到日志
tool_call_count = 0
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
for msg in conversation_messages:
if msg.role == RoleType.Assistant and msg.tool_calls:
tool_call_count += 1
@@ -110,34 +110,38 @@ async def generate_dream_summary(
thought_content = ""
if msg.content:
if isinstance(msg.content, list) and msg.content:
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else:
thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content:
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1):
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
tool_call_id = tool_call.call_id
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
# 格式化参数
try:
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
except Exception:
args_str = str(tool_args)
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
logger.info(f"[dream][工具调用详情] {'-' * 60}")
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
# 第三步:构建对话历史摘要(用于生成梦境)
conversation_summary = []
for msg in conversation_messages:
@@ -145,11 +149,11 @@ async def generate_dream_summary(
content = ""
if msg.content:
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
if role == "user" and "轮次信息" in content:
# 跳过轮次信息消息
continue
if role == "assistant":
# 只保留思考内容,简化工具调用信息
if content:
@@ -162,13 +166,13 @@ async def generate_dream_summary(
# 截取前300字符
content_preview = content[:300] + ("..." if len(content) > 300 else "")
conversation_summary.append(f"[工具执行] {content_preview}")
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
# 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt(
"dream_summary_prompt",
@@ -186,13 +190,14 @@ async def generate_dream_summary(
max_tokens=512,
temperature=0.8,
)
if dream_content:
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
else:
logger.warning("[dream][梦境总结] 未能生成梦境总结")
except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt()
init_dream_summary_prompt()

View File

@@ -4,8 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
"""

View File

@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg
return finish_maintenance

View File

@@ -1,5 +1,4 @@
import time
from typing import Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
# 将时间戳转换为可读时间格式
start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
if record.start_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
result = (
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
)
logger.debug(
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result
except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(keywords_list)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(sorted(all_keywords_set))
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'有关"{search_label}"的关键词信息为空'
)
logger.info(
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info(
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
)
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
# 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon)
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
return f"search_jargon 执行失败: {e}"
return search_jargon

View File

@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化
self._persist_topic_cache()
else:
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
else:
time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
# 检查“话题检查”触发条件
should_check = False
@@ -414,7 +414,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False
for msg in messages:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
return
# 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys())
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
)
if not success or not topic_to_indices:
logger.error(
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
if not numbered_lines:
return False, {}
history_topics_block = (
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt(
@@ -635,17 +633,17 @@ class ChatHistorySummarizer:
json_str = None
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
# 找到JSON代码块使用第一个匹配
json_str = matches[0].strip()
else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ]
start_idx = response.find('[')
end_idx = response.rfind(']')
start_idx = response.find("[")
end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip()
json_str = response[start_idx : end_idx + 1].strip()
else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip()
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
init_prompt()

View File

@@ -49,7 +49,7 @@ class LLMRequest:
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志
Args:
time_cost: 请求耗时(秒)
model_name: 使用的模型名称
@@ -323,7 +323,7 @@ class LLMRequest:
effective_temperature = (model_info.extra_params or {}).get("temperature")
if effective_temperature is None:
effective_temperature = self.model_for_task.temperature
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
effective_max_tokens = max_tokens
if effective_max_tokens is None:
@@ -332,7 +332,7 @@ class LLMRequest:
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
if effective_max_tokens is None:
effective_max_tokens = self.model_for_task.max_tokens
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
@@ -366,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
@@ -394,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
@@ -540,7 +544,5 @@ class LLMRequest:
if e.__cause__:
original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__)
return (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
return ""

View File

@@ -113,7 +113,6 @@ class MainSystem:
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
)
def _log_conversation_messages(
conversation_messages: List[Message],
head_prompt: Optional[str] = None,
@@ -172,7 +170,9 @@ def _log_conversation_messages(
# 构建单条消息的日志信息
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
# if full_content:
# msg_info += f"\n{full_content}"
@@ -185,8 +185,7 @@ def _log_conversation_messages(
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
# if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
@@ -330,7 +329,7 @@ async def _react_agent_solve_question(
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
# 后续迭代都复用第一次构建的head_prompt
head_prompt = first_head_prompt
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
)
# logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# )
if not success:
@@ -409,20 +408,20 @@ async def _react_agent_solve_question(
"""从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
if not text:
return None, None
# 查找finish_search函数调用位置不区分大小写
func_pattern = "finish_search"
text_lower = text.lower()
func_pos = text_lower.find(func_pattern)
if func_pos == -1:
return None, None
# 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos)
if start_pos == -1:
return None, None
# 查找匹配的右括号(考虑嵌套)
paren_count = 0
end_pos = start_pos
@@ -437,10 +436,10 @@ async def _react_agent_solve_question(
else:
# 没有找到匹配的右括号
return None, None
# 提取函数参数部分
params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None
found_answer_patterns = [
@@ -454,49 +453,60 @@ async def _react_agent_solve_question(
if match:
found_answer = "true" in match.group(0).lower()
break
# 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer")
return found_answer, answer
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
if parsed_found_answer is not None:
# 检测到finish_search函数调用格式
if parsed_found_answer:
# 找到了答案
if parsed_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案",
)
return False, "", thinking_steps, False
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
@@ -514,44 +524,53 @@ async def _react_agent_solve_question(
for tool_call in tool_calls:
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
if tool_name == "finish_search":
finish_search_found = tool_args.get("found_answer", False)
finish_search_answer = tool_args.get("answer", "")
if finish_search_found:
# 找到了答案
if finish_search_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}",
)
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search工具判断未找到答案",
)
return False, "", thinking_steps, False
# 如果没有finish_search工具调用继续处理其他工具
tool_tasks = []
for i, tool_call in enumerate(tool_calls):
@@ -627,7 +646,7 @@ async def _react_agent_solve_question(
observation_text += f"\n\n{jargon_info}"
collected_info += f"\n{jargon_info}\n"
logger.info(f"工具输出触发黑话解析: {new_concepts}")
tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(observation_text)
@@ -645,7 +664,7 @@ async def _react_agent_solve_question(
elif iteration + 1 >= max_iterations:
should_do_final_evaluation = True
logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估")
if should_do_final_evaluation:
# 获取必要变量用于最终评估
tool_registry = get_tool_registry()
@@ -653,7 +672,7 @@ async def _react_agent_solve_question(
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
current_iteration = iteration + 1
remaining_iterations = 0
# 提取函数调用中参数的值,支持单引号和双引号
def extract_quoted_content(text, func_name, param_name):
"""从文本中提取函数调用中参数的值,支持单引号和双引号
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
max_iterations=max_iterations,
)
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
(
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
@@ -739,7 +764,7 @@ async def _react_agent_solve_question(
final_status="未找到答案最终评估阶段LLM调用失败",
)
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
@@ -759,17 +784,17 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"]
"observations": ["最终评估阶段检测到found_answer"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_content}",
)
return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息
@@ -778,35 +803,37 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"]
"observations": ["最终评估阶段检测到not_enough_info"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}",
)
return False, "", thinking_steps, is_timeout
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
"observations": ["已到达最大迭代次数,无法找到答案"]
"actions": [
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
}
thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
)
return False, "", thinking_steps, is_timeout
# 如果正常迭代过程中提前找到答案返回,不会到达这里
@@ -817,7 +844,7 @@ async def _react_agent_solve_question(
head_prompt=first_head_prompt,
final_status="未找到答案:正常迭代结束",
)
return False, "", thinking_steps, is_timeout
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
else:
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}")
logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
@@ -1157,10 +1186,10 @@ async def build_memory_retrieval_prompt(
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = []
# 先添加当前查询的结果
current_questions = set()
for result in question_results:
@@ -1170,7 +1199,7 @@ async def build_memory_retrieval_prompt(
if question_end != -1:
current_questions.add(result[4:question_end])
all_results.append(result)
# 添加缓存答案(排除当前查询中已存在的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
return ""

View File

@@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)

View File

@@ -47,4 +47,3 @@ def register_tool():
],
execute_func=finish_search,
)

View File

@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def search_chat_history(
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
) -> str:
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
@@ -117,7 +115,7 @@ async def search_chat_history(
)
if kw_matched:
matched_count += 1
# 计算需要匹配的关键词数量
total_keywords = len(keywords_lower)
if total_keywords > 2:
@@ -126,7 +124,7 @@ async def search_chat_history(
else:
# 关键词数量<=2必须全部匹配
required_matches = total_keywords
keyword_matched = matched_count >= required_matches
# 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配
@@ -144,7 +142,9 @@ async def search_chat_history(
keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -160,9 +160,7 @@ async def search_chat_history(
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -179,13 +177,12 @@ async def search_chat_history(
keywords_str = "".join(sorted(all_keywords_set))
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
)
# 构建结果文本返回id、theme和keywords最多20条

View File

@@ -22,42 +22,42 @@ def get_current_token(
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
def set_auth_cookie(response: Response, token: str) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
@@ -77,7 +77,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
def clear_auth_cookie(response: Response) -> None:
"""
清除认证 Cookie
Args:
response: FastAPI Response 对象
"""
@@ -96,32 +96,32 @@ def verify_auth_token_from_cookie_or_header(
) -> bool:
"""
验证认证 Token支持从 Cookie 或 Header 获取
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@@ -63,14 +63,14 @@ class ChatHistoryManager:
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
@@ -414,7 +414,9 @@ async def websocket_chat(
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由"""
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
"""
后台生成缩略图(在线程池中执行)
生成完成后自动从 generating 集合中移除
"""
try:
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
生成缩略图并保存到缓存目录
Args:
source_path: 原图路径
file_hash: 文件哈希值,用作缓存文件名
Returns:
缩略图路径
Features:
- GIF: 提取第一帧作为缩略图
- 所有格式统一转为 WebP
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
_ensure_thumbnail_cache_dir()
cache_path = _get_thumbnail_cache_path(file_hash)
# 使用锁防止并发生成同一缩略图
lock = _get_thumbnail_lock(file_hash)
with lock:
# 双重检查,可能在等待锁时已被其他线程生成
if cache_path.exists():
return cache_path
try:
with Image.open(source_path) as img:
# GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1:
if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'):
if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA')
elif img.mode == 'LA':
img = img.convert('RGBA')
elif img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
img = img.convert("RGBA")
elif img.mode == "LA":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
# 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
except Exception as e:
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
# 生成失败时不创建缓存文件,下次会重试
raise
return cache_path
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
"""
清理孤立的缩略图缓存(原图已不存在的缩略图)
Returns:
(清理数量, 保留数量)
"""
if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0
# 获取所有表情包的哈希值
valid_hashes = set()
for emoji in Emoji.select(Emoji.emoji_hash):
valid_hashes.add(emoji.emoji_hash)
cleaned = 0
kept = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
file_hash = cache_file.stem
if file_hash not in valid_hashes:
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
else:
kept += 1
if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
禁用表情包(快捷操作)
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
Returns:
表情包缩略图WebP 格式)或原图
Features:
- 懒加载:首次请求时生成缩略图
- 缓存:后续请求直接返回缓存
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
try:
token_manager = get_token_manager()
is_valid = False
# 1. 优先使用 Cookie
if maibot_session and token_manager.verify_token(maibot_session):
is_valid = True
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
auth_token = authorization.replace("Bearer ", "")
if token_manager.verify_token(auth_token):
is_valid = True
if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期")
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
)
# 尝试获取或生成缩略图
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 检查缓存是否存在
if cache_path.exists():
# 缓存命中,直接返回
return FileResponse(
path=str(cache_path),
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
)
# 缓存未命中,触发后台生成并返回 202
with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
status_code=202,
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
}
},
)
except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表情包
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
):
"""
获取缩略图缓存统计信息
Returns:
缓存目录、缓存数量、总大小、覆盖率等统计信息
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 统计缓存文件
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
total_count = len(cache_files)
total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数
emoji_count = Emoji.select().count()
# 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
return ThumbnailCacheStatsResponse(
success=True,
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
emoji_count=emoji_count,
coverage_percent=coverage_percent,
)
except HTTPException:
raise
except Exception as e:
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
):
"""
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
cleaned, kept = cleanup_orphaned_thumbnails()
return ThumbnailCleanupResponse(
success=True,
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
cleaned_count=cleaned,
kept_count=kept,
)
except HTTPException:
raise
except Exception as e:
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
):
"""
预热缩略图缓存(提前生成未缓存的缩略图)
优先处理使用次数高的表情包
Args:
limit: 最多预热数量 (1-1000)
Returns:
预热结果
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先)
emojis = (
Emoji.select()
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
.order_by(Emoji.usage_count.desc())
.limit(limit * 2) # 多查一些,因为有些可能已缓存
)
generated = 0
skipped = 0
failed = 0
for emoji in emojis:
if generated >= limit:
break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 已缓存,跳过
if cache_path.exists():
skipped += 1
continue
# 原文件不存在,跳过
if not os.path.exists(emoji.full_path):
failed += 1
continue
try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop()
await loop.run_in_executor(
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
failed += 1
return ThumbnailPreheatResponse(
success=True,
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed}",
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
skipped_count=skipped,
failed_count=failed,
)
except HTTPException:
raise
except Exception as e:
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
):
"""
清空所有缩略图缓存(下次访问时会重新生成)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not THUMBNAIL_CACHE_DIR.exists():
return ThumbnailCleanupResponse(
success=True,
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
cleaned_count=0,
kept_count=0,
)
cleaned = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
try:
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
cleaned += 1
except Exception as e:
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
return ThumbnailCleanupResponse(
success=True,
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
cleaned_count=cleaned,
kept_count=0,
)
except HTTPException:
raise
except Exception as e:

View File

@@ -256,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
@@ -376,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据

View File

@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
@@ -277,17 +277,13 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none(
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = (
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息

View File

@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
"""
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str
"""
def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
async def get_available_mirrors(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
"""
获取所有可用的镜像源配置
"""
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
@router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
async def add_mirror(
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
"""
添加新的镜像源
"""
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror(
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> MirrorConfigResponse:
"""
更新镜像源配置
@@ -426,7 +434,9 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def delete_mirror(
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
删除镜像源
"""
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file(
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> FetchRawFileResponse:
"""
获取 GitHub 仓库的 Raw 文件内容
@@ -534,7 +546,9 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository(
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> CloneRepositoryResponse:
"""
克隆 GitHub 仓库到本地
@@ -574,7 +588,11 @@ async def clone_repository(
@router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def install_plugin(
request: InstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
安装插件
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
@router.post("/uninstall")
async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: UninstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
卸载插件
@@ -913,7 +933,11 @@ async def uninstall_plugin(
@router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def update_plugin(
request: UpdatePluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
@router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_installed_plugins(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取已安装的插件列表
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config_schema(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件配置 Schema
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件当前配置值
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
@router.put("/config/{plugin_id}")
async def update_plugin_config(
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件配置
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def reset_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
重置插件配置为默认值
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def toggle_plugin(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
切换插件启用状态

View File

@@ -139,10 +139,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
async def logout(response: Response):
"""
登出并清除认证 Cookie
Args:
response: FastAPI Response 对象
Returns:
登出结果
"""
@@ -158,23 +158,23 @@ async def check_auth_status(
):
"""
检查当前认证状态(用于前端判断是否已登录)
Returns:
认证状态
"""
try:
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
return {"authenticated": False}
token_manager = get_token_manager()
if token_manager.verify_token(token):
return {"authenticated": True}
@@ -211,7 +211,7 @@ async def update_token(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -222,7 +222,7 @@ async def update_token(
# 更新 token
success, message = token_manager.update_token(request.new_token)
# 如果更新成功,清除 Cookie要求用户重新登录
if success:
clear_auth_cookie(response)
@@ -263,7 +263,7 @@ async def regenerate_token(
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
@@ -271,7 +271,7 @@ async def regenerate_token(
# 重新生成 token
new_token = token_manager.regenerate_token()
# 清除 Cookie要求用户重新登录
clear_auth_cookie(response)
@@ -306,7 +306,7 @@ async def get_setup_status(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -349,7 +349,7 @@ async def complete_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -392,7 +392,7 @@ async def reset_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@@ -166,22 +166,22 @@ class TokenManager:
str: 新生成的 token
"""
logger.info("正在重新生成 WebUI Token...")
# 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态
config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token
def _validate_token_format(self, token: str) -> bool:

View File

@@ -110,8 +110,10 @@ class WebUIServer:
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
@@ -166,6 +168,7 @@ class WebUIServer:
def _check_port_available(self) -> bool:
"""检查端口是否可用"""
import socket
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)