This commit is contained in:
SengokuCola
2025-12-15 00:07:31 +08:00
85 changed files with 3462 additions and 1901 deletions

View File

@@ -3,7 +3,7 @@ import json
import os
import re
import asyncio
from typing import List, Optional, Tuple, Any, Dict, Callable
from typing import List, Optional, Tuple, Any, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
@@ -77,8 +82,6 @@ def init_prompt() -> None:
Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
@@ -95,20 +98,20 @@ class ExpressionLearner:
self._learning_lock = asyncio.Lock()
async def learn_and_store(
self,
self,
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
Args:
messages: 外部传入的消息列表(必需)
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
"""
if not messages:
return None
random_msg = messages
# 学习用(开启行编号,便于溯源)
@@ -134,37 +137,26 @@ class ExpressionLearner:
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
# 过滤掉包含人物名称的表达方式
if person_name_filter:
filtered_expressions = []
for situation, style, source_id in expressions:
# 检查 situation 和 style 是否包含人物名称
if person_name_filter(situation) or person_name_filter(style):
logger.info(f"跳过包含人物名称的表达方式: situation={situation}, style={style}")
continue
filtered_expressions.append((situation, style, source_id))
expressions = filtered_expressions
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = []
# 检查黑话数量如果超过30个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过30个实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
if jargon_entries:
await self._process_jargon_entries(jargon_entries, random_msg, person_name_filter)
await self._process_jargon_entries(jargon_entries, random_msg)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return []
logger.info(f"学习的prompt: {prompt}")
logger.info(f"学习的expressions: {expressions}")
logger.info(f"学习的jargon_entries: {jargon_entries}")
@@ -186,18 +178,17 @@ class ExpressionLearner:
# 当前行的原始内容
current_msg = random_msg[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
continue
context = filter_message_content(current_msg.processed_plain_text or "")
if not context:
continue
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
@@ -281,37 +272,38 @@ class ExpressionLearner:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == '\\':
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
@@ -323,13 +315,13 @@ class ExpressionLearner:
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return ''.join(result)
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
@@ -357,12 +349,12 @@ class ExpressionLearner:
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
@@ -511,75 +503,64 @@ class ExpressionLearner:
logger.error(f"概括表达情境失败: {e}")
return None
async def _process_jargon_entries(
self,
jargon_entries: List[Tuple[str, str]],
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
"""
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
messages: 消息列表,用于构建上下文
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
"""
if not jargon_entries or not messages:
return
# 获取 jargon_miner 实例
jargon_miner = miner_manager.get_miner(self.chat_id)
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
entries: List[Dict[str, List[str]]] = []
for content, source_id in jargon_entries:
content = content.strip()
if not content:
continue
# 检查是否包含机器人名称
if contains_bot_self_name(content):
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
continue
# 检查是否包含人物名称
if person_name_filter and person_name_filter(content):
logger.info(f"跳过包含人物名称的黑话: {content}")
continue
# 解析 source_id
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
continue
# build_anonymous_messages 的编号从 1 开始
line_index = int(source_id_str) - 1
if line_index < 0 or line_index >= len(messages):
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
continue
# 检查是否是机器人自己的消息
target_msg = messages[line_index]
if is_bot_message(target_msg):
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
continue
# 构建上下文段落
context_paragraph = build_context_paragraph(messages, line_index)
if not context_paragraph:
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
continue
entries.append({"content": content, "raw_content": [context_paragraph]})
if not entries:
return
# 调用 jargon_miner 处理这些条目
await jargon_miner.process_extracted_entries(entries, person_name_filter)
await jargon_miner.process_extracted_entries(entries)
init_prompt()

View File

@@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达
try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
@@ -147,7 +145,7 @@ expression_reflector_manager = ExpressionReflectorManager()
async def _check_tracker_exists(operator_config: str) -> bool:
"""检查指定 Operator 是否已有活跃的 Tracker"""
from src.express.reflect_tracker import reflect_tracker_manager
from src.bw_learner.reflect_tracker import reflect_tracker_manager
chat_manager = get_chat_manager()
chat_stream = None
@@ -242,7 +240,7 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
stream_id = chat_stream.stream_id
# 注册 Tracker
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
reflect_tracker_manager.add_tracker(stream_id, tracker)

View File

@@ -128,9 +128,7 @@ class ExpressionSelector:
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
& (~Expression.rejected)
& (Expression.count > 1)
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
style_exprs = [
@@ -150,12 +148,15 @@ class ExpressionSelector:
# 要求至少有10个 count > 1 的表达方式才进行选择
min_required = 10
if len(style_exprs) < min_required:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
)
return [], []
# 固定选择5个
select_count = 5
import random
selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time
@@ -163,7 +164,9 @@ class ExpressionSelector:
self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style]
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}")
logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids
except Exception as e:
@@ -186,9 +189,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
style_exprs = [
{
@@ -246,7 +247,9 @@ class ExpressionSelector:
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic(
self,
@@ -275,14 +278,12 @@ class ExpressionSelector:
# think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择
if think_level == 0:
return self._select_expressions_simple(chat_id, max_num)
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
all_style_exprs = [
{
"id": expr.id,
@@ -299,29 +300,33 @@ class ExpressionSelector:
# 分离 count > 1 和 count <= 1 的表达方式
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
# 根据 think_level 设置要求(仅支持 0/10 已在上方返回)
min_high_count = 10
min_total_count = 10
select_high_count = 5
select_random_count = 5
# 检查数量要求
if len(high_count_exprs) < min_high_count:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
)
return [], []
if len(all_style_exprs) < min_total_count:
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], []
# 先选取高count的表达方式
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
# 然后从所有表达方式中随机抽样(使用加权抽样)
remaining_num = select_random_count
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
# 合并候选池(去重,避免重复)
candidate_exprs = selected_high.copy()
candidate_ids = {expr["id"] for expr in candidate_exprs}
@@ -329,9 +334,10 @@ class ExpressionSelector:
if expr["id"] not in candidate_ids:
candidate_exprs.append(expr)
candidate_ids.add(expr["id"])
# 打乱顺序避免高count的都在前面
import random
random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表
@@ -351,7 +357,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""

View File

@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon")
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return ""
return ""

View File

@@ -1,4 +1,3 @@
import time
import json
import asyncio
import random
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import (
@@ -33,23 +31,23 @@ logger = get_logger("jargon")
def _is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话(单个汉字、英文或数字)
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
'\u4e00' <= char <= '\u9fff' or # 汉字
'a' <= char <= 'z' or # 小写字母
'A' <= char <= 'Z' or # 大写字母
'0' <= char <= '9' # 数字
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
)
@@ -195,7 +193,7 @@ class JargonMiner:
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
self.llm_inference = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.inference",
@@ -207,7 +205,7 @@ class JargonMiner:
self.stream_name = stream_name if stream_name else self.chat_id
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@@ -299,17 +297,19 @@ class JargonMiner:
# 获取当前count和上一次的meaning
current_count = jargon_obj.count or 0
previous_meaning = jargon_obj.meaning or ""
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
@@ -318,8 +318,10 @@ class JargonMiner:
**上一次推断的含义(仅供参考)**
{previous_meaning}
"""
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt",
content=content,
@@ -485,7 +487,7 @@ class JargonMiner:
) -> None:
"""
运行一次黑话提取
Args:
messages: 外部传入的消息列表(必需)
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
@@ -660,7 +662,9 @@ class JargonMiner:
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
@@ -740,14 +744,14 @@ class JargonMiner:
) -> None:
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
Args:
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
"""
if not entries:
return
try:
# 去重并合并raw_content按 content 聚合)
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
@@ -899,8 +903,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager()
def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:

View File

@@ -1,62 +1,39 @@
import time
import asyncio
from typing import List, Any, Optional
from collections import OrderedDict
from dataclasses import dataclass
from typing import List, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.jargon_miner import miner_manager
from src.person_info.person_info import Person
logger = get_logger("bw_learner")
@dataclass
class PersonInfo:
"""参与聊天的人物信息"""
user_id: str
user_platform: str
user_nickname: str
user_cardname: Optional[str]
person_name: str
last_seen_time: float # 最后发言时间
def get_unique_key(self) -> str:
"""获取唯一标识(用于去重)"""
return f"{self.user_platform}:{self.user_id}"
class MessageRecorder:
"""
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
"""
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()
# 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
# 维护参与该chat_id的人物列表最多30个使用OrderedDict保持插入顺序
# key: f"{platform}:{user_id}", value: PersonInfo
self._person_list: OrderedDict[str, PersonInfo] = OrderedDict()
self._max_person_count = 30
# 获取 expression 和 jargon 的配置参数
self._init_parameters()
# 获取 expression_learner 和 jargon_miner 实例
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
self.jargon_miner = miner_manager.get_miner(chat_id)
def _init_parameters(self) -> None:
"""初始化提取参数"""
# 获取 expression 配置
@@ -65,17 +42,17 @@ class MessageRecorder:
)
self.min_messages_for_extraction = 30
self.min_extraction_interval = 60
logger.debug(
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
f"min_messages={self.min_messages_for_extraction}, "
f"min_interval={self.min_extraction_interval}"
)
def should_trigger_extraction(self) -> bool:
"""
检查是否应该触发消息提取
Returns:
bool: 是否应该触发提取
"""
@@ -83,19 +60,19 @@ class MessageRecorder:
time_diff = time.time() - self.last_extraction_time
if time_diff < self.min_extraction_interval:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_extraction_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
return False
return True
async def extract_and_distribute(self) -> None:
"""
提取消息并分发给 expression_learner 和 jargon_miner
@@ -105,46 +82,40 @@ class MessageRecorder:
# 在锁内检查,避免并发触发
if not self.should_trigger_extraction():
return
# 检查 chat_stream 是否存在
if not self.chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_extraction_time
extraction_end_time = time.time()
# 立即更新提取时间,防止并发触发
self.last_extraction_time = extraction_end_time
try:
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
# 拉取提取窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
)
if not messages:
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
return
# 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
# 更新参与聊天的人物列表
self._update_person_list(messages)
logger.info(f"聊天流 {self.chat_name} 的人物列表: {self._person_list}")
logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
@@ -152,40 +123,35 @@ class MessageRecorder:
asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
)
# 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning:
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 expression 学习,使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息和过滤函数给 ExpressionLearner
learnt_style = await self.expression_learner.learn_and_store(
messages=messages,
person_name_filter=self.contains_person_name
)
# 传递消息给 ExpressionLearner(必需参数)
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
@@ -193,148 +159,37 @@ class MessageRecorder:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback
traceback.print_exc()
async def _trigger_jargon_extraction(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 jargon 提取,使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息和过滤函数给 JargonMiner
await self.jargon_miner.run_once(
messages=messages,
person_name_filter=self.contains_person_name
)
# 传递消息给 JargonMiner,避免它重复获取
await self.jargon_miner.run_once(messages=messages)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
def _update_person_list(self, messages: List[Any]) -> None:
"""
从消息中提取人物信息并更新人物列表
Args:
messages: 消息列表
"""
for msg in messages:
# 获取消息发送者信息
# 消息对象可能是 DatabaseMessages它有 user_info 属性
if hasattr(msg, 'user_info'):
# DatabaseMessages 类型
user_info = msg.user_info
user_id = getattr(user_info, 'user_id', None) or ''
user_platform = getattr(user_info, 'platform', None) or ''
user_nickname = getattr(user_info, 'user_nickname', None) or ''
user_cardname = getattr(user_info, 'user_cardname', None)
else:
# 直接属性访问
user_id = getattr(msg, 'user_id', None) or ''
user_platform = getattr(msg, 'user_platform', None) or ''
user_nickname = getattr(msg, 'user_nickname', None) or ''
user_cardname = getattr(msg, 'user_cardname', None)
msg_time = getattr(msg, 'time', time.time())
# 检查必要信息
if not user_id or not user_platform:
continue
# 获取 person_name
try:
person = Person(platform=user_platform, user_id=str(user_id))
person_name = person.person_name or user_nickname or (user_cardname if user_cardname else "未知用户")
except Exception as e:
logger.info(f"获取person_name失败: {e}, 使用nickname")
person_name = user_nickname or (user_cardname if user_cardname else "未知用户")
# 生成唯一key
unique_key = f"{user_platform}:{user_id}"
# 如果已存在,更新最后发言时间
if unique_key in self._person_list:
self._person_list[unique_key].last_seen_time = msg_time
# 移动到末尾(表示最近活跃)
self._person_list.move_to_end(unique_key)
else:
# 如果超过最大数量,移除最早的(最前面的)
if len(self._person_list) >= self._max_person_count:
oldest_key = next(iter(self._person_list))
del self._person_list[oldest_key]
logger.info(f"人物列表已满,移除最早的人物: {oldest_key}")
# 添加新人物
person_info = PersonInfo(
user_id=str(user_id),
user_platform=user_platform,
user_nickname=user_nickname or "",
user_cardname=user_cardname,
person_name=person_name,
last_seen_time=msg_time
)
self._person_list[unique_key] = person_info
logger.info(f"添加新人物到列表: {unique_key}, person_name={person_name}")
def contains_person_name(self, content: str) -> bool:
"""
检查内容是否包含任何参与聊天的人物的名称或昵称
Args:
content: 要检查的内容
Returns:
bool: 如果包含任何人物名称或昵称返回True
"""
if not content or not self._person_list:
return False
content_lower = content.strip().lower()
if not content_lower:
return False
# 检查所有人物
for person_info in self._person_list.values():
# 检查 person_name
if person_info.person_name:
person_name_lower = person_info.person_name.strip().lower()
if person_name_lower and person_name_lower in content_lower:
logger.debug(f"内容包含person_name: {person_info.person_name} in {content}")
return True
# 检查 user_nickname
if person_info.user_nickname:
nickname_lower = person_info.user_nickname.strip().lower()
if nickname_lower and nickname_lower in content_lower:
logger.debug(f"内容包含nickname: {person_info.user_nickname} in {content}")
return True
# 检查 user_cardname群昵称
if person_info.user_cardname:
cardname_lower = person_info.user_cardname.strip().lower()
if cardname_lower and cardname_lower in content_lower:
logger.debug(f"内容包含cardname: {person_info.user_cardname} in {content}")
return True
return False
class MessageRecorderManager:
"""MessageRecorder 管理器"""
def __init__(self) -> None:
self._recorders: dict[str, MessageRecorder] = {}
def get_recorder(self, chat_id: str) -> MessageRecorder:
"""获取或创建指定 chat_id 的 MessageRecorder"""
if chat_id not in self._recorders:
@@ -349,10 +204,9 @@ recorder_manager = MessageRecorderManager()
async def extract_and_distribute_messages(chat_id: str) -> None:
"""
统一的消息提取和分发入口函数
Args:
chat_id: 聊天流ID
"""
recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute()

View File

@@ -176,19 +176,19 @@ class BrainChatting:
# 如果有新消息,更新 last_read_time
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
# 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理
should_continue = await self._observe(recent_messages_list=recent_messages_list)
if not should_continue:
# 选择了 complete_talk返回 False 表示需要等待新消息
return False
# 继续下一次迭代(除非选择了 complete_talk
# 短暂等待后再继续,避免过于频繁的循环
await asyncio.sleep(0.1)
return True
async def _send_and_store_reply(
@@ -328,9 +328,7 @@ class BrainChatting:
)
# 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any(
action.action_type == "complete_talk" for action in action_to_use_info
)
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
# 并行执行所有动作
action_tasks = [
@@ -430,12 +428,12 @@ class BrainChatting:
await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _wait_for_new_message(self):
"""等待新消息到达"""
last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次
while self.running:
# 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat(
@@ -448,13 +446,13 @@ class BrainChatting:
filter_command=False,
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time 并返回
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return
# 等待一段时间后再次检查
await asyncio.sleep(check_interval)
@@ -660,9 +658,9 @@ class BrainChatting:
except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -673,12 +671,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {
@@ -693,9 +691,9 @@ class BrainChatting:
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换")
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -706,12 +704,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {

View File

@@ -147,7 +147,7 @@ class BrainPlanner:
) # 用于动作规划
self.last_obs_time_mark = 0.0
# 计划日志记录
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
@@ -203,9 +203,11 @@ class BrainPlanner:
# 内部保留动作(不依赖插件系统)
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容
if action == "listening":
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换")
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk"""
return [
ActionPlannerInfo(
@@ -564,7 +568,7 @@ class BrainPlanner:
available_actions=available_actions,
)
]
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
"""添加计划日志"""
self.plan_log.append((reasoning, time.time(), actions))

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",")
return emoji_record.emotion.replace("", ",").split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()]
emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:

View File

@@ -619,13 +619,13 @@ class HeartFChatting:
think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp(
chat_id=self.stream_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -213,6 +213,68 @@ class MessageRecv(Message):
}
"""
return ""
elif segment.type == "video_card":
# 处理视频卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
file_name = segment.data.get("file", "未知视频")
file_size = segment.data.get("file_size", "")
url = segment.data.get("url", "")
text = f"[视频: {file_name}"
if file_size:
text += f", 大小: {file_size}字节"
text += "]"
if url:
text += f" 链接: {url}"
return text
return "[视频]"
elif segment.type == "music_card":
# 处理音乐卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "未知歌曲")
singer = segment.data.get("singer", "")
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
jump_url = segment.data.get("jump_url", "")
music_url = segment.data.get("music_url", "")
text = f"[音乐: {title}"
if singer:
text += f" - {singer}"
if tag:
text += f" ({tag})"
text += "]"
if jump_url:
text += f" 跳转链接: {jump_url}"
if music_url:
text += f" 音乐链接: {music_url}"
return text
return "[音乐]"
elif segment.type == "miniapp_card":
# 处理小程序分享卡片如B站视频分享
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "") # 小程序名称
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"
if desc:
text += f" {desc}"
if source_url:
text += f" 链接: {source_url}"
elif url:
text += f" 链接: {url}"
return text
return "[小程序分享]"
else:
return ""
except Exception as e:

View File

@@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool:
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
Args:
segment: Seg 消息段对象
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
"""
from maim_message import Seg
result = []
if segment is None:
return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
@@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list:
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append({
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
})
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
@@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time
@@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段

View File

@@ -77,8 +77,7 @@ target_message_id为必填表示触发消息的id
```""",
"planner_prompt",
)
Prompt(
"""
{action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
)
if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else:
reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else:
reply_target_block = ""
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,9 +18,9 @@ def init_replyer_private_prompt():
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt",
)
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -37,4 +38,4 @@ def init_replyer_private_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。
""",
"private_replyer_self_prompt",
)
)

View File

@@ -23,7 +23,7 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt_0",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt",
)

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)

View File

@@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,
@@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODULE][module_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,

View File

@@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
"""
临时记录replyer动作被选择的信息仅群聊
Args:
chat_id: 聊天ID
reason: 选择理由
@@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
# 确保data/temp目录存在
temp_dir = "data/temp"
os.makedirs(temp_dir, exist_ok=True)
# 创建记录数据
record_data = {
"chat_id": chat_id,
@@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
"think_level": think_level,
"timestamp": datetime.now().isoformat(),
}
# 生成文件名(使用时间戳避免冲突)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"replyer_action_{timestamp_str}.json"
filepath = os.path.join(temp_dir, filename)
# 写入文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(record_data, f, ensure_ascii=False, indent=2)
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
except Exception as e:
logger.warning(f"记录replyer动作选择失败: {e}")

View File

@@ -130,12 +130,10 @@ class ImageManager:
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = (
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
logger.info(
@@ -166,7 +164,7 @@ class ImageManager:
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
"""如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录
Args:
image_base64: 图片的base64编码
image_hash: 图片的MD5哈希值
@@ -174,7 +172,7 @@ class ImageManager:
"""
if not global_config.emoji.steal_emoji:
return
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -236,12 +234,16 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)

View File

@@ -609,23 +609,23 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
fields = list(model._meta.fields.keys())
# Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = 'id' # 默认值
primary_key_name = "id" # 默认值
try:
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
if hasattr(model._meta.primary_key, 'name'):
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key
except Exception:
pass # 如果获取失败,使用默认值 'id'
# 如果字段列表包含主键,则排除它
if primary_key_name in fields:
fields_without_pk = [f for f in fields if f != primary_key_name]
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
else:
fields_without_pk = fields
fields_str = ", ".join(fields_without_pk)
# 检查是否有字段需要从 NULL 改为 NOT NULL

View File

@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj
# 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold)
should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array):
@@ -46,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
# 普通 list转换为 tomlkit 数组
arr = tomlkit.array()
arr.multiline(should_multiline)
for item in obj:
arr.append(_format_toml_value(item, threshold, depth + 1))
return arr
@@ -112,7 +112,7 @@ def save_toml_with_format(
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output)
output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f:
f.write(output)
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output)
return re.sub(r"\n{3,}", "\n\n", output)

View File

@@ -778,9 +778,9 @@ class DreamConfig(ConfigBase):
"""
if not self.dream_time_ranges:
return True
now_min = self._now_minutes()
for time_range in self.dream_time_ranges:
if not isinstance(time_range, str):
continue
@@ -790,7 +790,7 @@ class DreamConfig(ConfigBase):
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
return True
return False
def __post_init__(self):
@@ -800,4 +800,4 @@ class DreamConfig(ConfigBase):
if self.max_iterations < 1:
raise ValueError(f"max_iterations 必须至少为1当前值: {self.max_iterations}")
if self.first_delay_seconds < 0:
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")

View File

@@ -1,14 +1,13 @@
import asyncio
import random
import time
import json
from typing import Any, Dict, List, Optional, Tuple
from peewee import fn
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory, Jargon
from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
)
class DreamTool:
"""dream 模块内部使用的简易工具封装"""
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
"search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
(
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
],
search_chat_history,
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
(
"keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
],
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
"finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'", False, None),
(
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
],
finish_maintenance,
)
@@ -246,7 +268,7 @@ async def run_dream_agent_once(
"""
if max_iterations is None:
max_iterations = global_config.dream.max_iterations
start_ts = time.time()
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations}")
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
detail_text = (
f"ID={record.id}\n"
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
+ detail_text
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
)
conversation_messages.append(start_detail_builder.build())
else:
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具
success, response, reasoning_content, model_name, tool_calls = (
await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
if not success:
@@ -522,7 +545,7 @@ async def start_dream_scheduler(
if interval_seconds is None:
interval_seconds = global_config.dream.interval_minutes * 60
logger.info(
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
)
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
# 初始化提示词
init_dream_prompts()

View File

@@ -86,7 +86,7 @@ async def generate_dream_summary(
try:
import json
from src.chat.utils.prompt_builder import global_prompt_manager
# 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {}
for msg in conversation_messages:
@@ -98,11 +98,11 @@ async def generate_dream_summary(
else:
content = str(msg.content)
tool_results_map[msg.tool_call_id] = content
# 第二步:详细记录所有工具调用操作和结果到日志
tool_call_count = 0
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
for msg in conversation_messages:
if msg.role == RoleType.Assistant and msg.tool_calls:
tool_call_count += 1
@@ -110,34 +110,38 @@ async def generate_dream_summary(
thought_content = ""
if msg.content:
if isinstance(msg.content, list) and msg.content:
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else:
thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content:
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1):
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
tool_call_id = tool_call.call_id
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
# 格式化参数
try:
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
except Exception:
args_str = str(tool_args)
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
logger.info(f"[dream][工具调用详情] {'-' * 60}")
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
# 第三步:构建对话历史摘要(用于生成梦境)
conversation_summary = []
for msg in conversation_messages:
@@ -145,11 +149,11 @@ async def generate_dream_summary(
content = ""
if msg.content:
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
if role == "user" and "轮次信息" in content:
# 跳过轮次信息消息
continue
if role == "assistant":
# 只保留思考内容,简化工具调用信息
if content:
@@ -162,13 +166,13 @@ async def generate_dream_summary(
# 截取前300字符
content_preview = content[:300] + ("..." if len(content) > 300 else "")
conversation_summary.append(f"[工具执行] {content_preview}")
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
# 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt(
"dream_summary_prompt",
@@ -186,13 +190,14 @@ async def generate_dream_summary(
max_tokens=512,
temperature=0.8,
)
if dream_content:
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
else:
logger.warning("[dream][梦境总结] 未能生成梦境总结")
except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt()
init_dream_summary_prompt()

View File

@@ -4,8 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
"""

View File

@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg
return finish_maintenance

View File

@@ -1,5 +1,4 @@
import time
from typing import Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
# 将时间戳转换为可读时间格式
start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
if record.start_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
result = (
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
)
logger.debug(
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result
except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(keywords_list)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(sorted(all_keywords_set))
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'有关"{search_label}"的关键词信息为空'
)
logger.info(
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info(
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
)
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
# 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon)
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
return f"search_jargon 执行失败: {e}"
return search_jargon

View File

@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化
self._persist_topic_cache()
else:
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
else:
time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
# 检查“话题检查”触发条件
should_check = False
@@ -414,7 +414,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False
for msg in messages:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
return
# 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys())
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
)
if not success or not topic_to_indices:
logger.error(
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
if not numbered_lines:
return False, {}
history_topics_block = (
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt(
@@ -635,17 +633,17 @@ class ChatHistorySummarizer:
json_str = None
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
# 找到JSON代码块使用第一个匹配
json_str = matches[0].strip()
else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ]
start_idx = response.find('[')
end_idx = response.rfind(']')
start_idx = response.find("[")
end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip()
json_str = response[start_idx : end_idx + 1].strip()
else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip()
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
init_prompt()

View File

@@ -98,7 +98,10 @@ def _convert_messages(
content: List[Part] = []
for item in message.content:
if isinstance(item, tuple):
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
image_format = item[0].lower()
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
if image_format in ("jpg", "jpeg"):
image_format = "jpeg"
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
elif isinstance(item, str):
content.append(Part.from_text(text=item))

View File

@@ -61,10 +61,16 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
content = []
for item in message.content:
if isinstance(item, tuple):
image_format = item[0].lower()
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
if image_format in ("jpg", "jpeg"):
mime_suffix = "jpeg"
else:
mime_suffix = image_format
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
"image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"},
}
)
elif isinstance(item, str):

View File

@@ -49,7 +49,7 @@ class LLMRequest:
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志
Args:
time_cost: 请求耗时(秒)
model_name: 使用的模型名称
@@ -323,7 +323,7 @@ class LLMRequest:
effective_temperature = (model_info.extra_params or {}).get("temperature")
if effective_temperature is None:
effective_temperature = self.model_for_task.temperature
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
effective_max_tokens = max_tokens
if effective_max_tokens is None:
@@ -332,7 +332,7 @@ class LLMRequest:
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
if effective_max_tokens is None:
effective_max_tokens = self.model_for_task.max_tokens
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
@@ -366,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
@@ -394,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
@@ -540,7 +544,5 @@ class LLMRequest:
if e.__cause__:
original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__)
return (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
return ""

View File

@@ -113,7 +113,6 @@ class MainSystem:
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
)
def _log_conversation_messages(
conversation_messages: List[Message],
head_prompt: Optional[str] = None,
@@ -172,7 +170,9 @@ def _log_conversation_messages(
# 构建单条消息的日志信息
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
# if full_content:
# msg_info += f"\n{full_content}"
@@ -185,8 +185,7 @@ def _log_conversation_messages(
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
# if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
@@ -330,7 +329,7 @@ async def _react_agent_solve_question(
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
# 后续迭代都复用第一次构建的head_prompt
head_prompt = first_head_prompt
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
)
# logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# )
if not success:
@@ -409,20 +408,20 @@ async def _react_agent_solve_question(
"""从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
if not text:
return None, None
# 查找finish_search函数调用位置不区分大小写
func_pattern = "finish_search"
text_lower = text.lower()
func_pos = text_lower.find(func_pattern)
if func_pos == -1:
return None, None
# 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos)
if start_pos == -1:
return None, None
# 查找匹配的右括号(考虑嵌套)
paren_count = 0
end_pos = start_pos
@@ -437,10 +436,10 @@ async def _react_agent_solve_question(
else:
# 没有找到匹配的右括号
return None, None
# 提取函数参数部分
params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None
found_answer_patterns = [
@@ -454,49 +453,60 @@ async def _react_agent_solve_question(
if match:
found_answer = "true" in match.group(0).lower()
break
# 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer")
return found_answer, answer
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
if parsed_found_answer is not None:
# 检测到finish_search函数调用格式
if parsed_found_answer:
# 找到了答案
if parsed_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案",
)
return False, "", thinking_steps, False
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
@@ -514,44 +524,53 @@ async def _react_agent_solve_question(
for tool_call in tool_calls:
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
if tool_name == "finish_search":
finish_search_found = tool_args.get("found_answer", False)
finish_search_answer = tool_args.get("answer", "")
if finish_search_found:
# 找到了答案
if finish_search_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}",
)
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search工具判断未找到答案",
)
return False, "", thinking_steps, False
# 如果没有finish_search工具调用继续处理其他工具
tool_tasks = []
for i, tool_call in enumerate(tool_calls):
@@ -627,7 +646,7 @@ async def _react_agent_solve_question(
observation_text += f"\n\n{jargon_info}"
collected_info += f"\n{jargon_info}\n"
logger.info(f"工具输出触发黑话解析: {new_concepts}")
tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(observation_text)
@@ -645,7 +664,7 @@ async def _react_agent_solve_question(
elif iteration + 1 >= max_iterations:
should_do_final_evaluation = True
logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估")
if should_do_final_evaluation:
# 获取必要变量用于最终评估
tool_registry = get_tool_registry()
@@ -653,7 +672,7 @@ async def _react_agent_solve_question(
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
current_iteration = iteration + 1
remaining_iterations = 0
# 提取函数调用中参数的值,支持单引号和双引号
def extract_quoted_content(text, func_name, param_name):
"""从文本中提取函数调用中参数的值,支持单引号和双引号
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
max_iterations=max_iterations,
)
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
(
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
@@ -739,7 +764,7 @@ async def _react_agent_solve_question(
final_status="未找到答案最终评估阶段LLM调用失败",
)
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
@@ -759,17 +784,17 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"]
"observations": ["最终评估阶段检测到found_answer"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_content}",
)
return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息
@@ -778,35 +803,37 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"]
"observations": ["最终评估阶段检测到not_enough_info"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}",
)
return False, "", thinking_steps, is_timeout
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
"observations": ["已到达最大迭代次数,无法找到答案"]
"actions": [
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
}
thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
)
return False, "", thinking_steps, is_timeout
# 如果正常迭代过程中提前找到答案返回,不会到达这里
@@ -817,7 +844,7 @@ async def _react_agent_solve_question(
head_prompt=first_head_prompt,
final_status="未找到答案:正常迭代结束",
)
return False, "", thinking_steps, is_timeout
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
else:
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}")
logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
@@ -1157,10 +1186,10 @@ async def build_memory_retrieval_prompt(
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = []
# 先添加当前查询的结果
current_questions = set()
for result in question_results:
@@ -1170,7 +1199,7 @@ async def build_memory_retrieval_prompt(
if question_end != -1:
current_questions.add(result[4:question_end])
all_results.append(result)
# 添加缓存答案(排除当前查询中已存在的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
return ""

View File

@@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)

View File

@@ -47,4 +47,3 @@ def register_tool():
],
execute_func=finish_search,
)

View File

@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def search_chat_history(
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
) -> str:
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
@@ -117,7 +115,7 @@ async def search_chat_history(
)
if kw_matched:
matched_count += 1
# 计算需要匹配的关键词数量
total_keywords = len(keywords_lower)
if total_keywords > 2:
@@ -126,7 +124,7 @@ async def search_chat_history(
else:
# 关键词数量<=2必须全部匹配
required_matches = total_keywords
keyword_matched = matched_count >= required_matches
# 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配
@@ -144,7 +142,9 @@ async def search_chat_history(
keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -160,9 +160,7 @@ async def search_chat_history(
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -179,13 +177,12 @@ async def search_chat_history(
keywords_str = "".join(sorted(all_keywords_set))
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
)
# 构建结果文本返回id、theme和keywords最多20条

View File

@@ -11,6 +11,9 @@ from .base import (
BaseCommand,
BaseTool,
ConfigField,
ConfigSection,
ConfigLayout,
ConfigTab,
ComponentType,
ActionActivationType,
ChatMode,
@@ -116,6 +119,9 @@ __all__ = [
# 装饰器
"register_plugin",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
# 工具函数
"ManifestValidator",
"get_logger",

View File

@@ -29,7 +29,7 @@ from .component_types import (
ForwardNode,
ReplySetModel,
)
from .config_types import ConfigField
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
__all__ = [
"BasePlugin",
@@ -46,6 +46,9 @@ __all__ = [
"PluginInfo",
"PythonDependency",
"ConfigField",
"ConfigSection",
"ConfigLayout",
"ConfigTab",
"EventHandlerInfo",
"EventType",
"BaseEventHandler",

View File

@@ -70,6 +70,12 @@ class ConfigField:
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
# === 列表类型专用 ===
item_type: Optional[str] = None # 数组元素类型: "string", "number", "object"
item_fields: Optional[Dict[str, Any]] = None # 当 item_type="object" 时,定义对象的字段结构
min_items: Optional[int] = None # 数组最小元素数量
max_items: Optional[int] = None # 数组最大元素数量
def get_ui_type(self) -> str:
"""
获取 UI 控件类型
@@ -132,6 +138,10 @@ class ConfigField:
"group": self.group,
"depends_on": self.depends_on,
"depends_value": self.depends_value,
"item_type": self.item_type,
"item_fields": self.item_fields,
"min_items": self.min_items,
"max_items": self.max_items,
}

784
src/webui/anti_crawler.py Normal file
View File

@@ -0,0 +1,784 @@
"""
WebUI 防爬虫模块
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import os
import time
import ipaddress
import re
from collections import deque
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from src.common.logger import get_logger
logger = get_logger("webui.anti_crawler")
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
CRAWLER_USER_AGENTS = {
# 搜索引擎爬虫(精确匹配)
"googlebot",
"bingbot",
"baiduspider",
"yandexbot",
"slurp", # Yahoo
"duckduckbot",
"sogou",
"exabot",
"facebot",
"ia_archiver", # Internet Archive
# 通用爬虫(移除过于宽泛的关键词)
"crawler",
"spider",
"scraper",
"wget", # 保留wget因为通常用于自动化脚本
"scrapy", # 保留scrapy因为这是爬虫框架
# 安全扫描工具(这些是明确的扫描工具)
"masscan",
"nmap",
"nikto",
"sqlmap",
# 注意:移除了以下过于宽泛的关键词以避免误报:
# - "bot" (会误匹配GitHub-Robot等)
# - "curl" (正常工具)
# - "python-requests" (正常库)
# - "httpx" (正常库)
# - "aiohttp" (正常库)
}
# 资产测绘工具 User-Agent 标识
ASSET_SCANNER_USER_AGENTS = {
# 知名资产测绘平台
"shodan",
"censys",
"zoomeye",
"fofa",
"quake",
"hunter",
"binaryedge",
"onyphe",
"securitytrails",
"virustotal",
"passivetotal",
# 安全扫描工具
"acunetix",
"appscan",
"burpsuite",
"nessus",
"openvas",
"qualys",
"rapid7",
"tenable",
"veracode",
"zap",
"awvs", # Acunetix Web Vulnerability Scanner
"netsparker",
"skipfish",
"w3af",
"arachni",
# 其他扫描工具
"masscan",
"zmap",
"nmap",
"whatweb",
"wpscan",
"joomscan",
"dnsenum",
"subfinder",
"amass",
"sublist3r",
"theharvester",
}
# 资产测绘工具常用的HTTP头标识
ASSET_SCANNER_HEADERS = {
# 常见的扫描工具自定义头
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
"x-scanner": {"nmap", "masscan", "zmap"},
"x-probe": {"masscan", "zmap"},
# 其他可疑头(移除反向代理标准头)
"x-originating-ip": set(),
"x-remote-ip": set(),
"x-remote-addr": set(),
# 注意:移除了以下反向代理标准头以避免误报:
# - "x-forwarded-proto" (反向代理标准头)
# - "x-real-ip" (反向代理标准头已在_get_client_ip中使用)
}
# 仅检查特定HTTP头中的可疑模式收紧匹配范围
# 只检查这些特定头,不检查所有头
SCANNER_SPECIFIC_HEADERS = {
"x-scan",
"x-scanner",
"x-probe",
"x-originating-ip",
"x-remote-ip",
"x-remote-addr",
}
# 防爬虫模式配置
# false: 禁用
# strict: 严格模式(更严格的检测,更低的频率限制)
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# IP白名单配置从环境变量读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
# - CIDR格式192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
# - 通配符192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
# - IPv6::1, 2001:db8::/32
def _parse_allowed_ips(ip_string: str) -> list:
"""
解析IP白名单字符串支持精确IP、CIDR格式和通配符
Args:
ip_string: 逗号分隔的IP字符串
Returns:
IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式
- ipaddress.IPv4Address/IPv6Address对象精确IP
- str通配符模式已转换为正则表达式
"""
allowed = []
if not ip_string:
return allowed
for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 检查通配符格式(包含*
if "*" in ip_entry:
# 处理通配符
pattern = _convert_wildcard_to_regex(ip_entry)
if pattern:
allowed.append(pattern)
else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue
try:
# 尝试解析为CIDR格式包含/
if "/" in ip_entry:
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
else:
# 精确IP地址
allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
"""
将通配符IP模式转换为正则表达式
支持的格式:
- 192.168.*.* 或 192.168.*
- 10.*.*.* 或 10.*
- *.*.*.* 或 *
Args:
wildcard_pattern: 通配符模式字符串
Returns:
正则表达式字符串如果格式无效则返回None
"""
# 去除空格
pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有)
if pattern == "*":
return r".*"
# 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".")
if len(parts) > 4:
return None # IPv4最多4段
# 构建正则表达式
regex_parts = []
for part in parts:
part = part.strip()
if part == "*":
regex_parts.append(r"\d+") # 匹配任意数字
elif part.isdigit():
# 验证数字范围0-255
num = int(part)
if 0 <= num <= 255:
regex_parts.append(re.escape(part))
else:
return None # 无效的数字
else:
return None # 无效的格式
# 如果部分少于4段补充.*
while len(regex_parts) < 4:
regex_parts.append(r"\d+")
# 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
# 信任的代理IP配置从环境变量读取逗号分隔
# 只有在信任的代理IP下才使用X-Forwarded-For头
# 默认关闭(空),不信任任何代理
TRUSTED_PROXIES = _parse_allowed_ips(os.getenv("WEBUI_TRUSTED_PROXIES", ""))
TRUST_XFF = os.getenv("WEBUI_TRUST_XFF", "false").lower() == "true"
def _get_mode_config(mode: str) -> dict:
"""
根据模式获取配置参数
Args:
mode: 防爬虫模式 (false/strict/loose/basic)
Returns:
配置字典,包含所有相关参数
"""
mode = mode.lower()
if mode == "false":
return {
"enabled": False,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
"max_tracked_ips": 0,
"check_user_agent": False,
"check_asset_scanner": False,
"check_rate_limit": False,
"block_on_detect": False, # 不阻止
}
elif mode == "strict":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
"max_tracked_ips": 20000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
elif mode == "loose":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
"max_tracked_ips": 5000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
else: # basic (默认模式)
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 不限制请求数
"max_tracked_ips": 0, # 不跟踪IP
"check_user_agent": True, # 检测但不阻止
"check_asset_scanner": True, # 检测但不阻止
"check_rate_limit": False, # 不限制请求频率
"block_on_detect": False, # 只记录,不阻止
}
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""防爬虫中间件"""
def __init__(self, app, mode: str = "standard"):
"""
初始化防爬虫中间件
Args:
app: FastAPI 应用实例
mode: 防爬虫模式 (false/strict/loose/standard)
"""
super().__init__(app)
self.mode = mode.lower()
# 根据模式获取配置
config = _get_mode_config(self.mode)
self.enabled = config["enabled"]
self.rate_limit_window = config["rate_limit_window"]
self.rate_limit_max_requests = config["rate_limit_max_requests"]
self.max_tracked_ips = config["max_tracked_ips"]
self.check_user_agent = config["check_user_agent"]
self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
"""
检测是否为爬虫 User-Agent
Args:
user_agent: User-Agent 字符串
Returns:
如果是爬虫则返回 True
"""
if not user_agent:
# 没有 User-Agent 的请求记录日志但不直接阻止
# 改为只记录,让频率限制来处理
logger.debug("请求缺少User-Agent")
return False # 不再直接阻止无User-Agent的请求
user_agent_lower = user_agent.lower()
# 使用集合查找提高性能(检查是否包含爬虫关键词)
for crawler_keyword in self.crawler_keywords_set:
if crawler_keyword in user_agent_lower:
return True
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
Args:
request: 请求对象
Returns:
如果检测到资产测绘工具头则返回 True
"""
# 只检查特定的扫描工具头,不检查所有头
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
header_value_lower = header_value.lower() if header_value else ""
# 检查已知的扫描工具头
if header_name_lower in ASSET_SCANNER_HEADERS:
# 如果该头有特定的工具集合,检查值是否匹配
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
if expected_tools:
for tool in expected_tools:
if tool in header_value_lower:
return True
else:
# 如果没有特定工具集合,只要存在该头就视为可疑
if header_value_lower:
return True
# 只检查特定头中的可疑模式(收紧匹配)
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
# 检查头值中是否包含已知扫描工具名称
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
return True
return False
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
"""
检测资产测绘工具
Args:
request: 请求对象
Returns:
(是否检测到, 检测到的工具名称)
"""
user_agent = request.headers.get("User-Agent")
# 检查 User-Agent使用集合查找提高性能
if user_agent:
user_agent_lower = user_agent.lower()
for scanner_keyword in self.scanner_keywords_set:
if scanner_keyword in user_agent_lower:
return True, scanner_keyword
# 检查HTTP头
if self._is_asset_scanner_header(request):
# 尝试从User-Agent或头中提取工具名称
detected_tool = None
if user_agent:
user_agent_lower = user_agent.lower()
for tool in self.scanner_keywords_set:
if tool in user_agent_lower:
detected_tool = tool
break
# 检查HTTP头中的工具标识只检查特定头
if not detected_tool:
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
break
if detected_tool:
break
return True, detected_tool or "unknown_scanner"
return False, None
def _check_rate_limit(self, client_ip: str) -> bool:
"""
检查请求频率限制
Args:
client_ip: 客户端IP地址
Returns:
如果超过限制则返回 True需要阻止
"""
# 检查IP白名单
if self._is_ip_allowed(client_ip):
return False
current_time = time.time()
# 定期清理过期的请求记录每5分钟清理一次
if current_time - self.last_cleanup > 300:
self._cleanup_old_requests(current_time)
self.last_cleanup = current_time
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录删除最久未访问的IP
self._cleanup_oldest_ips()
# 获取或创建该IP的请求时间deque不使用maxlen避免限流变松
if client_ip not in self.request_times:
self.request_times[client_ip] = deque()
request_times = self.request_times[client_ip]
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
while request_times and current_time - request_times[0] >= self.rate_limit_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= self.rate_limit_max_requests:
return True
# 记录当前请求时间
request_times.append(current_time)
return False
def _cleanup_old_requests(self, current_time: float):
"""清理过期的请求记录只清理当前需要检查的IP不全量遍历"""
# 这个方法现在主要用于定期清理实际清理在_check_rate_limit中按需进行
# 清理最久未访问的IP记录
if len(self.request_times) > self.max_tracked_ips * 0.8:
self._cleanup_oldest_ips()
def _cleanup_oldest_ips(self):
"""清理最久未访问的IP记录全量遍历找真正的oldest"""
if not self.request_times:
return
# 先收集空deque的IP优先删除
empty_ips = []
# 找到最久未访问的IP最旧时间戳
oldest_ip = None
oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items():
if not times:
# 空deque记录待删除
empty_ips.append(ip)
else:
# 找到最旧的时间戳
if times[0] < oldest_time:
oldest_time = times[0]
oldest_ip = ip
# 先删除空deque的IP
for ip in empty_ips:
del self.request_times[ip]
# 如果没有空deque可删除且仍需要清理删除最旧的一个IP
if not empty_ips and oldest_ip:
del self.request_times[oldest_ip]
def _is_trusted_proxy(self, ip: str) -> bool:
"""
检查IP是否在信任的代理列表中
Args:
ip: IP地址字符串
Returns:
如果是信任的代理则返回 True
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
if isinstance(trusted_entry, str):
try:
if re.match(trusted_entry, ip):
return True
except re.error:
continue
# CIDR格式网络对象
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in trusted_entry:
return True
except (ValueError, AttributeError):
continue
# 精确IP地址对象
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == trusted_entry:
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端真实IP地址带基本验证和代理信任检查
Args:
request: 请求对象
Returns:
客户端IP地址
"""
# 获取直接连接的客户端IP用于验证代理
direct_client_ip = None
if request.client:
direct_client_ip = request.client.host
# 检查是否信任X-Forwarded-For头
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
use_xff = False
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
use_xff = self._is_trusted_proxy(direct_client_ip)
# 如果信任代理,优先从 X-Forwarded-For 获取
if use_xff:
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取(如果信任代理)
if use_xff:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用直接连接的客户端IP
if direct_client_ip and self._validate_ip(direct_client_ip):
return direct_client_ip
return "unknown"
def _validate_ip(self, ip: str) -> bool:
"""
验证IP地址格式
Args:
ip: IP地址字符串
Returns:
如果格式有效则返回 True
"""
try:
ipaddress.ip_address(ip)
return True
except (ValueError, AttributeError):
return False
def _is_ip_allowed(self, ip: str) -> bool:
"""
检查IP是否在白名单中支持精确IP、CIDR格式和通配符
Args:
ip: 客户端IP地址
Returns:
如果IP在白名单中则返回 True
"""
if not ALLOWED_IPS or ip == "unknown":
return False
# 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式)
if isinstance(allowed_entry, str):
try:
if re.match(allowed_entry, ip):
return True
except re.error:
# 正则表达式错误,跳过
continue
# CIDR格式网络对象
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
# 精确IP地址对象
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
return False
async def dispatch(self, request: Request, call_next):
"""
处理请求
Args:
request: 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
响应对象
"""
# 如果未启用,直接通过
if not self.enabled:
return await call_next(request)
# 允许访问 robots.txt由专门的路由处理
if request.url.path == "/robots.txt":
return await call_next(request)
# 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static:
return await call_next(request)
# 获取客户端IP只获取一次避免重复调用
client_ip = self._get_client_ip(request)
# 检查IP白名单优先检查白名单IP直接通过
if self._is_ip_allowed(client_ip):
return await call_next(request)
# 获取 User-Agent
user_agent = request.headers.get("User-Agent")
# 检测资产测绘工具(优先检测,因为更危险)
if self.check_asset_scanner:
is_scanner, scanner_name = self._detect_asset_scanner(request)
if is_scanner:
logger.warning(
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
f"User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Asset scanning tools are not allowed",
status_code=403,
)
# 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Crawlers are not allowed",
status_code=403,
)
# 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
return PlainTextResponse(
"Too Many Requests: Rate limit exceeded",
status_code=429,
)
# 正常请求,继续处理
return await call_next(request)
def create_robots_txt_response() -> PlainTextResponse:
"""
创建 robots.txt 响应
Returns:
robots.txt 响应对象
"""
robots_content = """User-agent: *
Disallow: /
# 禁止所有爬虫访问
"""
return PlainTextResponse(
content=robots_content,
media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
)

View File

@@ -3,6 +3,7 @@ WebUI 认证模块
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
"""
import os
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
@@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 检查环境变量
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
return True
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
return False
# 检查是否是生产环境
env = os.environ.get("WEBUI_MODE", "").lower()
if env in ("production", "prod"):
return True
# 默认:开发环境不启用(因为通常是 HTTP
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
@@ -22,69 +45,76 @@ def get_current_token(
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
def set_auth_cookie(response: Response, token: str) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
"""
# 根据环境决定安全设置
is_secure = _is_secure_environment()
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取
samesite="lax", # 允许同站导航时发送 Cookie兼容开发环境代理
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
secure=is_secure, # 生产环境强制 HTTPS
path="/", # 确保 Cookie 在所有路径下可用
)
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
def clear_auth_cookie(response: Response) -> None:
"""
清除认证 Cookie
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="lax",
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")
@@ -96,32 +126,32 @@ def verify_auth_token_from_cookie_or_header(
) -> bool:
"""
验证认证 Token支持从 Cookie 或 Header 获取
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@@ -8,18 +8,30 @@
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
@@ -63,14 +75,14 @@ class ChatHistoryManager:
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
@@ -256,6 +268,7 @@ async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
@@ -272,7 +285,7 @@ async def get_chat_history(
@router.get("/platforms")
async def get_available_platforms():
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
从 PersonInfo 表中获取所有已知的平台
@@ -303,6 +316,7 @@ async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
@@ -350,7 +364,7 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
@@ -372,6 +386,7 @@ async def websocket_chat(
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
@@ -382,9 +397,45 @@ async def websocket_chat(
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名(可选)
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取)
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
@@ -414,7 +465,9 @@ async def websocket_chat(
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
@@ -710,7 +763,7 @@ async def websocket_chat(
@router.get("/info")
async def get_chat_info():
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,

View File

@@ -4,10 +4,11 @@
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
@@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
@@ -65,7 +74,7 @@ async def get_bot_config_schema():
@router.get("/schema/model")
async def get_model_config_schema():
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
@@ -79,7 +88,7 @@ async def get_model_config_schema():
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
@@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
@router.get("/bot")
async def get_bot_config():
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -168,7 +177,7 @@ async def get_bot_config():
@router.get("/model")
async def get_model_config():
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@@ -190,7 +199,7 @@ async def get_model_config():
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody):
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
@@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
@router.post("/model")
async def update_model_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
@@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody):
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
@router.get("/bot/raw")
async def get_bot_config_raw():
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -307,7 +316,7 @@ async def get_bot_config_raw():
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody):
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
@@ -337,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody):
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody):
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -368,6 +379,17 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
@@ -418,7 +440,7 @@ def _to_relative_path(path: str) -> str:
@router.get("/adapter-config/path")
async def get_adapter_config_path():
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
@@ -457,7 +479,7 @@ async def get_adapter_config_path():
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody):
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@@ -500,7 +522,7 @@ async def save_adapter_config_path(data: PathBody):
@router.get("/adapter-config")
async def get_adapter_config(path: str):
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
@@ -532,7 +554,7 @@ async def get_adapter_config(path: str):
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody):
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由"""
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
"""
后台生成缩略图(在线程池中执行)
生成完成后自动从 generating 集合中移除
"""
try:
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
生成缩略图并保存到缓存目录
Args:
source_path: 原图路径
file_hash: 文件哈希值,用作缓存文件名
Returns:
缩略图路径
Features:
- GIF: 提取第一帧作为缩略图
- 所有格式统一转为 WebP
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
_ensure_thumbnail_cache_dir()
cache_path = _get_thumbnail_cache_path(file_hash)
# 使用锁防止并发生成同一缩略图
lock = _get_thumbnail_lock(file_hash)
with lock:
# 双重检查,可能在等待锁时已被其他线程生成
if cache_path.exists():
return cache_path
try:
with Image.open(source_path) as img:
# GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1:
if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'):
if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA')
elif img.mode == 'LA':
img = img.convert('RGBA')
elif img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
img = img.convert("RGBA")
elif img.mode == "LA":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
# 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
except Exception as e:
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
# 生成失败时不创建缓存文件,下次会重试
raise
return cache_path
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
"""
清理孤立的缩略图缓存(原图已不存在的缩略图)
Returns:
(清理数量, 保留数量)
"""
if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0
# 获取所有表情包的哈希值
valid_hashes = set()
for emoji in Emoji.select(Emoji.emoji_hash):
valid_hashes.add(emoji.emoji_hash)
cleaned = 0
kept = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
file_hash = cache_file.stem
if file_hash not in valid_hashes:
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
else:
kept += 1
if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
禁用表情包(快捷操作)
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
Returns:
表情包缩略图WebP 格式)或原图
Features:
- 懒加载:首次请求时生成缩略图
- 缓存:后续请求直接返回缓存
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
try:
token_manager = get_token_manager()
is_valid = False
# 1. 优先使用 Cookie
if maibot_session and token_manager.verify_token(maibot_session):
is_valid = True
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
auth_token = authorization.replace("Bearer ", "")
if token_manager.verify_token(auth_token):
is_valid = True
if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期")
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
)
# 尝试获取或生成缩略图
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 检查缓存是否存在
if cache_path.exists():
# 缓存命中,直接返回
return FileResponse(
path=str(cache_path),
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
)
# 缓存未命中,触发后台生成并返回 202
with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
status_code=202,
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
}
},
)
except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表情包
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
):
"""
获取缩略图缓存统计信息
Returns:
缓存目录、缓存数量、总大小、覆盖率等统计信息
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 统计缓存文件
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
total_count = len(cache_files)
total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数
emoji_count = Emoji.select().count()
# 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
return ThumbnailCacheStatsResponse(
success=True,
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
emoji_count=emoji_count,
coverage_percent=coverage_percent,
)
except HTTPException:
raise
except Exception as e:
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
):
"""
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
cleaned, kept = cleanup_orphaned_thumbnails()
return ThumbnailCleanupResponse(
success=True,
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
cleaned_count=cleaned,
kept_count=kept,
)
except HTTPException:
raise
except Exception as e:
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
):
"""
预热缩略图缓存(提前生成未缓存的缩略图)
优先处理使用次数高的表情包
Args:
limit: 最多预热数量 (1-1000)
Returns:
预热结果
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先)
emojis = (
Emoji.select()
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
.order_by(Emoji.usage_count.desc())
.limit(limit * 2) # 多查一些,因为有些可能已缓存
)
generated = 0
skipped = 0
failed = 0
for emoji in emojis:
if generated >= limit:
break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 已缓存,跳过
if cache_path.exists():
skipped += 1
continue
# 原文件不存在,跳过
if not os.path.exists(emoji.full_path):
failed += 1
continue
try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop()
await loop.run_in_executor(
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
failed += 1
return ThumbnailPreheatResponse(
success=True,
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed}",
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
skipped_count=skipped,
failed_count=failed,
)
except HTTPException:
raise
except Exception as e:
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
):
"""
清空所有缩略图缓存(下次访问时会重新生成)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not THUMBNAIL_CACHE_DIR.exists():
return ThumbnailCleanupResponse(
success=True,
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
cleaned_count=0,
kept_count=0,
)
cleaned = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
try:
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
cleaned += 1
except Exception as e:
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
return ThumbnailCleanupResponse(
success=True,
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
cleaned_count=cleaned,
kept_count=0,
)
except HTTPException:
raise
except Exception as e:

View File

@@ -256,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
@@ -376,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据

View File

@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
@@ -277,17 +277,13 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none(
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = (
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -1,15 +1,24 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
@@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
@@ -199,7 +209,7 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats():
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
@@ -248,7 +258,7 @@ async def get_knowledge_stats():
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)):
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:

View File

@@ -1,10 +1,12 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:

View File

@@ -6,18 +6,27 @@
import os
import httpx
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
@@ -184,6 +193,7 @@ async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
@@ -228,6 +238,7 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
@@ -251,6 +262,7 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
@@ -337,6 +349,7 @@ async def test_provider_connection(
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接(从配置文件读取信息)

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息

View File

@@ -1,10 +1,12 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
@@ -89,14 +91,48 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态

View File

@@ -34,6 +34,85 @@ def get_token_from_cookie_or_header(
return None
def validate_safe_path(user_path: str, base_path: Path) -> Path:
"""
验证用户提供的路径是否安全,防止路径遍历攻击
Args:
user_path: 用户输入的路径(相对路径)
base_path: 允许的基础目录
Returns:
安全的绝对路径
Raises:
HTTPException: 如果检测到路径遍历攻击
"""
# 规范化基础路径
base_resolved = base_path.resolve()
# 检查用户路径是否包含可疑字符
# 禁止: .., 绝对路径开头, 空字节等
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
# 检查是否为绝对路径Windows 和 Unix
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
# 构建目标路径并解析
target_path = (base_path / user_path).resolve()
# 验证解析后的路径仍在基础目录内
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def validate_plugin_id(plugin_id: str) -> str:
"""
验证插件 ID 格式是否安全
Args:
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
Returns:
验证通过的插件 ID
Raises:
HTTPException: 如果插件 ID 格式不安全
"""
# 禁止空字符串
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
# 禁止危险字符: 路径分隔符、空字节、控制字符等
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
for pattern in dangerous_patterns:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
# 禁止特殊名称
if plugin_id in (".", ".."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
"""
解析版本号字符串
@@ -125,6 +204,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
"""
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str
"""
def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
@@ -313,7 +393,9 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
async def get_available_mirrors(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
"""
获取所有可用的镜像源配置
"""
@@ -343,7 +425,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
@router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
async def add_mirror(
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
"""
添加新的镜像源
"""
@@ -383,7 +467,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror(
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> MirrorConfigResponse:
"""
更新镜像源配置
@@ -426,7 +513,9 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def delete_mirror(
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
删除镜像源
"""
@@ -449,26 +538,24 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file(
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> FetchRawFileResponse:
"""
获取 GitHub 仓库的 Raw 文件内容
支持多镜像源自动切换和错误重试
注意:此接口可公开访问,用于获取插件仓库等公开资源
需要认证才能访问,防止被滥用作为 SSRF 跳板
"""
# Token 验证(可选,用于日志记录
# Token 验证(强制
token = get_token_from_cookie_or_header(maibot_session, authorization)
token_manager = get_token_manager()
is_authenticated = token and token_manager.verify_token(token)
if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
# 对于公开仓库的访问,不强制要求认证
# 只在日志中记录是否认证
logger.info(
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
)
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
# 发送开始加载进度
await update_progress(
@@ -534,7 +621,9 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository(
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> CloneRepositoryResponse:
"""
克隆 GitHub 仓库到本地
@@ -550,10 +639,10 @@ async def clone_repository(
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
try:
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
# TODO: 确定实际的插件目录基路径
base_plugin_path = Path("./plugins") # 临时路径
target_path = base_plugin_path / request.target_path
# 验证 target_path 的安全性,防止路径遍历攻击
base_plugin_path = Path("./plugins").resolve()
base_plugin_path.mkdir(exist_ok=True)
target_path = validate_safe_path(request.target_path, base_plugin_path)
service = get_git_mirror_service()
result = await service.clone_repository(
@@ -574,7 +663,11 @@ async def clone_repository(
@router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def install_plugin(
request: InstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
安装插件
@@ -589,13 +682,16 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
logger.info(f"收到安装插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始安装
await update_progress(
stage="loading",
progress=5,
message=f"开始安装插件: {request.plugin_id}",
message=f"开始安装插件: {plugin_id}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 解析仓库 URL
@@ -616,27 +712,28 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=10,
message=f"解析仓库信息: {owner}/{repo}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 确定插件安装路径
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
plugins_dir.mkdir(exist_ok=True)
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
folder_name = request.plugin_id.replace(".", "_")
target_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证,防止路径遍历
target_path = validate_safe_path(folder_name, plugins_dir)
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
old_format_path = plugins_dir / request.plugin_id
old_format_path = plugins_dir / plugin_id
if target_path.exists() or old_format_path.exists():
await update_progress(
stage="error",
progress=0,
message="插件已存在",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件已安装,请先卸载",
)
raise HTTPException(status_code=400, detail="插件已安装")
@@ -646,7 +743,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=15,
message=f"准备克隆到: {target_path}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
@@ -675,14 +772,14 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=0,
message="克隆仓库失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 4. 验证插件完整性
await update_progress(
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
)
manifest_path = target_path / "_manifest.json"
@@ -697,14 +794,14 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=0,
message="插件缺少 _manifest.json",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
# 5. 读取并验证 manifest
await update_progress(
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
)
try:
@@ -721,7 +818,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
# 将插件 ID 写入 manifest用于后续准确识别
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
manifest["id"] = request.plugin_id
manifest["id"] = plugin_id
with open(manifest_path, "w", encoding="utf-8") as f:
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
@@ -736,7 +833,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=0,
message="_manifest.json 无效",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@@ -747,13 +844,13 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=100,
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件安装成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": manifest["name"],
"version": manifest["version"],
"path": str(target_path),
@@ -769,7 +866,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
progress=0,
message="安装失败",
operation="install",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@@ -778,7 +875,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
@router.post("/uninstall")
async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: UninstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
卸载插件
@@ -794,22 +893,26 @@ async def uninstall_plugin(
logger.info(f"收到卸载插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始卸载
await update_progress(
stage="loading",
progress=10,
message=f"开始卸载插件: {request.plugin_id}",
message=f"开始卸载插件: {plugin_id}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否存在(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@@ -821,7 +924,7 @@ async def uninstall_plugin(
progress=0,
message="插件不存在",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装或已被删除",
)
raise HTTPException(status_code=404, detail="插件未安装")
@@ -831,12 +934,12 @@ async def uninstall_plugin(
progress=30,
message=f"正在删除插件文件: {plugin_path}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 2. 读取插件信息(用于日志)
manifest_path = plugin_path / "_manifest.json"
plugin_name = request.plugin_id
plugin_name = plugin_id
if manifest_path.exists():
try:
@@ -844,7 +947,7 @@ async def uninstall_plugin(
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
plugin_name = manifest.get("name", request.plugin_id)
plugin_name = manifest.get("name", plugin_id)
except Exception:
pass # 如果读取失败,使用插件 ID 作为名称
@@ -853,7 +956,7 @@ async def uninstall_plugin(
progress=50,
message=f"正在删除 {plugin_name}...",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除插件目录
@@ -869,7 +972,7 @@ async def uninstall_plugin(
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
# 4. 推送成功状态
await update_progress(
@@ -877,10 +980,10 @@ async def uninstall_plugin(
progress=100,
message=f"成功卸载插件: {plugin_name}",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
@@ -892,7 +995,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="权限不足,无法删除插件文件",
)
@@ -905,7 +1008,7 @@ async def uninstall_plugin(
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
@@ -913,7 +1016,11 @@ async def uninstall_plugin(
@router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def update_plugin(
request: UpdatePluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件
@@ -928,22 +1035,26 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
logger.info(f"收到更新插件请求: {request.plugin_id}")
try:
# 验证插件 ID 格式安全性
plugin_id = validate_plugin_id(request.plugin_id)
# 推送进度:开始更新
await update_progress(
stage="loading",
progress=5,
message=f"开始更新插件: {request.plugin_id}",
message=f"开始更新插件: {plugin_id}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 1. 检查插件是否已安装(支持新旧两种格式)
plugins_dir = Path("plugins")
plugins_dir = Path("plugins").resolve()
# 新格式:下划线
folder_name = request.plugin_id.replace(".", "_")
plugin_path = plugins_dir / folder_name
folder_name = plugin_id.replace(".", "_")
# 使用安全路径验证
plugin_path = validate_safe_path(folder_name, plugins_dir)
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
old_format_path = validate_safe_path(plugin_id, plugins_dir)
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
@@ -955,7 +1066,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=0,
message="插件不存在",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="插件未安装,请先安装",
)
raise HTTPException(status_code=404, detail="插件未安装")
@@ -979,12 +1090,12 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=10,
message=f"当前版本: {old_version},准备更新...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
# 3. 删除旧版本
await update_progress(
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
)
import shutil
@@ -999,7 +1110,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
shutil.rmtree(plugin_path, onerror=remove_readonly)
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
# 4. 解析仓库 URL
await update_progress(
@@ -1007,7 +1118,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=30,
message="正在准备下载新版本...",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
repo_url = request.repository_url.rstrip("/")
@@ -1045,14 +1156,14 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=0,
message="下载新版本失败",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 6. 验证新版本
await update_progress(
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
)
new_manifest_path = plugin_path / "_manifest.json"
@@ -1072,7 +1183,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=0,
message="新版本缺少 _manifest.json",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
@@ -1083,9 +1194,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
new_manifest = json_module.load(f)
new_version = new_manifest.get("version", "unknown")
new_name = new_manifest.get("name", request.plugin_id)
new_name = new_manifest.get("name", plugin_id)
logger.info(f"成功更新插件: {request.plugin_id} {old_version}{new_version}")
logger.info(f"成功更新插件: {plugin_id} {old_version}{new_version}")
# 8. 推送成功状态
await update_progress(
@@ -1093,13 +1204,13 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=100,
message=f"成功更新 {new_name}: {old_version}{new_version}",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件更新成功",
"plugin_id": request.plugin_id,
"plugin_id": plugin_id,
"plugin_name": new_name,
"old_version": old_version,
"new_version": new_version,
@@ -1114,7 +1225,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
progress=0,
message="_manifest.json 无效",
operation="update",
plugin_id=request.plugin_id,
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@@ -1125,14 +1236,16 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_installed_plugins(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取已安装的插件列表
@@ -1272,7 +1385,9 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config_schema(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件配置 Schema
@@ -1373,12 +1488,34 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
# 推断字段类型
field_type = type(field_value).__name__
ui_type = "text"
item_type = None
item_fields = None
if isinstance(field_value, bool):
ui_type = "switch"
elif isinstance(field_value, (int, float)):
ui_type = "number"
elif isinstance(field_value, list):
ui_type = "list"
# 推断数组元素类型
if field_value:
first_item = field_value[0]
if isinstance(first_item, dict):
item_type = "object"
# 从第一个元素推断字段结构
item_fields = {}
for k, v in first_item.items():
item_fields[k] = {
"type": "number" if isinstance(v, (int, float)) else "string",
"label": k,
"default": "" if isinstance(v, str) else 0,
}
elif isinstance(first_item, (int, float)):
item_type = "number"
else:
item_type = "string"
else:
item_type = "string"
elif isinstance(field_value, dict):
ui_type = "json"
@@ -1393,6 +1530,26 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
"hidden": False,
"disabled": False,
"order": 0,
"item_type": item_type,
"item_fields": item_fields,
"min_items": None,
"max_items": None,
# 补充缺失的字段
"placeholder": None,
"hint": None,
"icon": None,
"example": None,
"choices": None,
"min": None,
"max": None,
"step": None,
"pattern": None,
"max_length": None,
"input_type": None,
"rows": 3,
"group": None,
"depends_on": None,
"depends_value": None,
}
return {"success": True, "schema": schema}
@@ -1405,7 +1562,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件当前配置值
@@ -1461,7 +1620,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
@router.put("/config/{plugin_id}")
async def update_plugin_config(
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件配置
@@ -1532,7 +1694,9 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def reset_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
重置插件配置为默认值
@@ -1592,7 +1756,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def toggle_plugin(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
切换插件启用状态

245
src/webui/rate_limiter.py Normal file
View File

@@ -0,0 +1,245 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间(秒)
key_suffix: 键后缀,用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长(秒)
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试(如登录失败)
如果在窗口期内失败次数过多,自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口(秒)
block_duration: 封禁时长(秒)
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数(认证成功后调用)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则:
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth",
)
if not allowed:
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则:每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api",
)
if not allowed:
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})

View File

@@ -7,10 +7,12 @@
import os
import time
from datetime import datetime
from fastapi import APIRouter, HTTPException
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
@@ -19,6 +21,14 @@ logger = get_logger("webui_system")
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
@@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot():
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
@@ -67,7 +77,7 @@ async def restart_maibot():
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
@@ -90,7 +100,7 @@ async def get_maibot_status():
@router.post("/reload-config")
async def reload_config():
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置(不重启进程)

View File

@@ -1,11 +1,12 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
@@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
logger = get_logger("webui.api")
@@ -42,6 +44,8 @@ router.include_router(get_progress_router())
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
class TokenVerifyRequest(BaseModel):
@@ -107,12 +111,18 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(request: TokenVerifyRequest, response: Response):
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌,验证成功后设置 HttpOnly Cookie
Args:
request: 包含 token 的验证请求
request_body: 包含 token 的验证请求
request: FastAPI Request 对象(用于获取客户端 IP
response: FastAPI Response 对象
Returns:
@@ -120,16 +130,37 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie
set_auth_cookie(response, request.token)
set_auth_cookie(response, request_body.token)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600, # 封禁 10 分钟
)
if blocked:
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@@ -139,10 +170,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
async def logout(response: Response):
"""
登出并清除认证 Cookie
Args:
response: FastAPI Response 对象
Returns:
登出结果
"""
@@ -158,23 +189,23 @@ async def check_auth_status(
):
"""
检查当前认证状态(用于前端判断是否已登录)
Returns:
认证状态
"""
try:
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
return {"authenticated": False}
token_manager = get_token_manager()
if token_manager.verify_token(token):
return {"authenticated": True}
@@ -211,7 +242,7 @@ async def update_token(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -222,7 +253,7 @@ async def update_token(
# 更新 token
success, message = token_manager.update_token(request.new_token)
# 如果更新成功,清除 Cookie要求用户重新登录
if success:
clear_auth_cookie(response)
@@ -263,7 +294,7 @@ async def regenerate_token(
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
@@ -271,7 +302,7 @@ async def regenerate_token(
# 重新生成 token
new_token = token_manager.regenerate_token()
# 清除 Cookie要求用户重新登录
clear_auth_cookie(response)
@@ -306,7 +337,7 @@ async def get_setup_status(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -349,7 +380,7 @@ async def complete_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -392,7 +423,7 @@ async def reset_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@@ -1,19 +1,28 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
@@ -58,7 +67,7 @@ class DashboardData(BaseModel):
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24):
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
@@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
@router.get("/summary")
async def get_summary(hours: int = 24):
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
@@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
@router.get("/models")
async def get_model_stats(hours: int = 24):
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计

View File

@@ -166,22 +166,22 @@ class TokenManager:
str: 新生成的 token
"""
logger.info("正在重新生成 WebUI Token...")
# 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态
config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token
def _validate_token_format(self, token: str) -> bool:

View File

@@ -22,6 +22,9 @@ class WebUIServer:
self.app = FastAPI(title="MaiBot WebUI")
self._server = None
# 配置防爬虫中间件需要在CORS之前注册
self._setup_anti_crawler()
# 配置 CORS支持开发环境跨域请求
self._setup_cors()
@@ -32,6 +35,9 @@ class WebUIServer:
self._register_api_routes()
self._setup_static_files()
# 注册robots.txt路由
self._setup_robots_txt()
def _setup_cors(self):
"""配置 CORS 中间件"""
# 开发环境需要允许前端开发服务器的跨域请求
@@ -40,12 +46,21 @@ class WebUIServer:
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")
@@ -89,20 +104,60 @@ class WebUIServer:
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 其他路径返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def _setup_anti_crawler(self):
"""配置防爬虫中间件"""
try:
from src.webui.anti_crawler import AntiCrawlerMiddleware
# 从环境变量读取防爬虫模式false/strict/loose/basic
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(self):
"""设置robots.txt路由"""
try:
from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回robots.txt禁止所有爬虫"""
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(self):
"""注册所有 WebUI API 路由"""
try:
@@ -110,8 +165,10 @@ class WebUIServer:
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
@@ -166,6 +223,7 @@ class WebUIServer:
def _check_port_available(self) -> bool:
"""检查端口是否可用"""
import socket
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)

114
src/webui/ws_auth.py Normal file
View File

@@ -0,0 +1,114 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie 或 Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证。
临时 token 有效期 60 秒,且只能使用一次。
注意:在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面。
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}