Merge branch 'dev' into r-dev

This commit is contained in:
UnCLAS-Prommer
2026-01-15 17:21:56 +08:00
234 changed files with 60773 additions and 1506 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import List, Union
from typing import List, Union, Dict, Any
from .global_logger import logger
from . import prompt_template
@@ -173,3 +173,50 @@ def info_extract_from_str(
return None, None
return entity_extract_result, rdf_triple_extract_result
class IEProcess:
"""
信息抽取处理器类,提供更方便的批次处理接口。
"""
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
self.llm_ner = llm_ner
self.llm_rdf = llm_rdf or llm_ner
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
"""
异步处理多个段落。
"""
from .utils.hash import get_sha256
results = []
total = len(paragraphs)
for i, pg in enumerate(paragraphs, start=1):
# 打印进度日志,让用户知道没有卡死
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try:
entities, triples = await asyncio.to_thread(
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
if entities is not None:
results.append(
{
"idx": get_sha256(pg),
"passage": pg,
"extracted_entities": entities,
"extracted_triples": triples,
}
)
logger.info(f"[IEProcess] 第 {i}/{total} 段处理完成,提取到 {len(entities)} 个实体")
else:
logger.warning(f"[IEProcess] 第 {i}/{total} 段提取失败(返回为空)")
except Exception as e:
logger.error(f"[IEProcess] 处理第 {i}/{total} 段时发生异常: {e}")
return results

View File

@@ -0,0 +1,388 @@
import asyncio
import os
from functools import partial
from typing import List, Callable, Any
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.qa_manager import QAManager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API")
class LPMMOperations:
"""
LPMM 内部操作接口。
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
"""
def __init__(self):
self._initialized = False
async def _run_cancellable_executor(
self, func: Callable, *args, **kwargs
) -> Any:
"""
在线程池中执行可取消的同步操作。
当任务被取消时(如 Ctrl+C会立即响应并抛出 CancelledError。
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
Args:
func: 要执行的同步函数
*args: 函数的位置参数
**kwargs: 函数的关键字参数
Returns:
函数的返回值
Raises:
asyncio.CancelledError: 当任务被取消时
"""
loop = asyncio.get_event_loop()
# 在线程池中执行,当协程被取消时会立即响应
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
return await loop.run_in_executor(None, func, *args, **kwargs)
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
"""获取并确保 LPMM 管理器已初始化"""
qa_mgr = get_qa_manager()
if qa_mgr is None:
# 如果全局没初始化,尝试初始化
if not global_config.lpmm_knowledge.enable:
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
lpmm_start_up()
qa_mgr = get_qa_manager()
if qa_mgr is None:
raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。")
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
async def add_content(self, text: str, auto_split: bool = True) -> dict:
"""
向知识库添加新内容。
Args:
text: 原始文本。
auto_split: 是否自动按双换行符分割段落。
- True: 自动分割(默认),支持多段文本(用双换行分隔)
- False: 不分割,将整个文本作为完整一段处理
Returns:
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 分段处理
if auto_split:
# 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
else:
# 不分割,作为完整一段
text_stripped = text.strip()
if not text_stripped:
return {"status": "error", "message": "文本内容为空"}
paragraphs = [text_stripped]
if not paragraphs:
return {"status": "error", "message": "文本内容为空"}
# 2. 实体与三元组抽取 (内部调用大模型)
from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
extracted_docs = await ie_process.process_paragraphs(paragraphs)
# 3. 构造并导入数据
# 这里我们手动实现导入逻辑,不依赖外部脚本
# a. 准备段落
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
# b. 准备三元组
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
# 向量化并入库
# 注意:此处模仿 import_openie.py 的核心逻辑
# 1. 先进行去重检查,只处理新段落
# store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本
new_raw_paragraphs = {}
new_triple_list_data = {}
for pg_hash, passage in raw_paragraphs.items():
key = f"paragraph-{pg_hash}"
if key not in embed_mgr.stored_pg_hashes:
new_raw_paragraphs[pg_hash] = passage
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
if not new_raw_paragraphs:
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor(
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
# 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor(
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
# 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断")
return {"status": "cancelled", "message": "导入操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def search(self, query: str, top_k: int = 3) -> List[str]:
"""
检索知识库。
Args:
query: 查询问题。
top_k: 返回最相关的条目数。
Returns:
List[str]: 相关文段列表。
"""
try:
_, _, qa_mgr = await self._get_managers()
# 直接调用 QAManager 的检索接口
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
return [knowledge] if knowledge else []
except Exception as e:
logger.error(f"[Plugin API] 检索知识失败: {e}")
return []
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
"""
根据关键词或完整文段删除知识库内容。
Args:
keyword: 匹配关键词或完整文段。
exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配)。
Returns:
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 查找匹配的段落
to_delete_keys = []
to_delete_hashes = []
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
if exact_match:
# 完整文段匹配
if item.str.strip() == keyword.strip():
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
else:
# 关键词模糊匹配
if keyword in item.str:
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
if not to_delete_keys:
match_type = "完整文段" if exact_match else "关键词"
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
# 2. 执行删除
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
to_delete_keys
)
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
# b. 从知识图谱删除
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 3. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断")
return {"status": "cancelled", "message": "删除操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def clear_all(self) -> dict:
"""
清空整个LPMM知识库删除所有段落、实体、关系和知识图谱数据
Returns:
dict: {"status": "success/error", "message": "描述", "stats": {...}}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 记录清空前的统计信息
before_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# 1. 清空所有向量库
# 获取所有keys
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
# 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
para_keys
)
embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量
if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items,
ent_keys
)
else:
ent_deleted = 0
# 删除所有关系向量
if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items,
rel_keys
)
else:
rel_deleted = 0
# 2. 清空所有 embedding store 的索引和映射
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
def _clear_embedding_indices():
# 清空段落索引
embed_mgr.paragraphs_embedding_store.faiss_index = None
embed_mgr.paragraphs_embedding_store.idx2hash = None
embed_mgr.paragraphs_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
# 清空实体索引
embed_mgr.entities_embedding_store.faiss_index = None
embed_mgr.entities_embedding_store.idx2hash = None
embed_mgr.entities_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
os.remove(embed_mgr.entities_embedding_store.index_file_path)
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
# 清空关系索引
embed_mgr.relation_embedding_store.faiss_index = None
embed_mgr.relation_embedding_store.idx2hash = None
embed_mgr.relation_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
os.remove(embed_mgr.relation_embedding_store.index_file_path)
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
await self._run_cancellable_executor(_clear_embedding_indices)
# 3. 清空知识图谱
# 获取所有段落hash
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
if all_pg_hashes:
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear()
# 4. 保存所有数据此时所有store都是空的索引也是None
# 注意即使store为空save_to_file也会保存空的DataFrame这是正确的
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
after_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
return {
"status": "success",
"message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
"stats": {
"before": before_stats,
"after": after_stats,
}
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 清空操作被用户中断")
return {"status": "cancelled", "message": "清空操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
# 内部使用的单例
lpmm_ops = LPMMOperations()

View File

@@ -4,6 +4,7 @@ import traceback
import random
import re
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from collections import OrderedDict
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
@@ -110,6 +111,10 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
# 黑话缓存:使用 OrderedDict 实现 LRU最多缓存10个
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
self.unknown_words_cache_limit = 10
def find_message_by_id(
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
@@ -299,6 +304,136 @@ class ActionPlanner:
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
return False
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
"""
更新黑话缓存,将新的黑话加入缓存
Args:
new_words: 新提取的黑话列表
"""
for word in new_words:
if not isinstance(word, str):
continue
word = word.strip()
if not word:
continue
# 如果已存在移到末尾LRU
if word in self.unknown_words_cache:
self.unknown_words_cache.move_to_end(word)
else:
# 添加新词
self.unknown_words_cache[word] = None
# 如果超过限制,移除最老的
if len(self.unknown_words_cache) > self.unknown_words_cache_limit:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}黑话缓存已满,移除最老的黑话")
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
"""
合并新提取的黑话和缓存中的黑话
Args:
new_words: 新提取的黑话列表可能为None
Returns:
合并后的黑话列表(去重)
"""
# 清理新提取的黑话
cleaned_new_words: List[str] = []
if new_words:
for word in new_words:
if isinstance(word, str):
word = word.strip()
if word:
cleaned_new_words.append(word)
# 获取缓存中的黑话列表
cached_words = list(self.unknown_words_cache.keys())
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
merged_words: List[str] = []
seen = set()
# 先添加新提取的
for word in cleaned_new_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
# 再添加缓存的(如果不在新提取的列表中)
for word in cached_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
return merged_words
def _process_unknown_words_cache(
self, actions: List[ActionPlannerInfo]
) -> None:
"""
处理黑话缓存逻辑:
1. 检查是否有 reply action 提取了 unknown_words
2. 如果没有提取移除最老的1个
3. 如果缓存数量大于5移除最老的2个
4. 对于每个 reply action合并缓存和新提取的黑话
5. 更新缓存
Args:
actions: 解析后的动作列表
"""
# 先检查缓存数量如果大于5移除最老的2个
if len(self.unknown_words_cache) > 5:
# 移除最老的2个
removed_count = 0
for _ in range(2):
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
removed_count += 1
if removed_count > 0:
logger.debug(f"{self.log_prefix}缓存数量大于5移除最老的{removed_count}个缓存")
# 检查是否有 reply action 提取了 unknown_words
has_extracted_unknown_words = False
for action in actions:
if action.action_type == "reply":
action_data = action.action_data or {}
unknown_words = action_data.get("unknown_words")
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
has_extracted_unknown_words = True
break
# 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words:
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话
for action in actions:
if action.action_type == "reply":
action_data = action.action_data or {}
new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words)
# 更新 action_data
if merged_words:
action_data["unknown_words"] = merged_words
logger.debug(
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
f"缓存 {len(self.unknown_words_cache)} 个,合并后 {len(merged_words)}"
)
else:
# 如果没有合并后的黑话,移除 unknown_words 字段
action_data.pop("unknown_words", None)
# 更新缓存(将新提取的黑话加入缓存)
if new_words:
self._update_unknown_words_cache(new_words)
async def plan(
self,
available_actions: Dict[str, ActionInfo],
@@ -722,6 +857,9 @@ class ActionPlanner:
random.shuffle(shuffled)
actions = list({a.action_type: a for a in shuffled}.values())
# 处理黑话缓存逻辑
self._process_unknown_words_cache(actions)
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms

View File

@@ -368,7 +368,7 @@ class ChatHistory(BaseModel):
theme = TextField() # 主题:这段对话的主要内容,一个简短的标题
keywords = TextField() # 关键词这段对话的关键词JSON格式存储
summary = TextField() # 概括:对这段话的平文本概括
key_point = TextField(null=True) # 关键信息话题中的关键信息点JSON格式存储
# key_point = TextField(null=True) # 关键信息话题中的关键信息点JSON格式存储
count = IntegerField(default=0) # 被检索次数
forget_times = IntegerField(default=0) # 被遗忘检查的次数

View File

@@ -192,7 +192,6 @@ def init_dream_tools(chat_id: str) -> None:
("theme", ToolParamType.STRING, "新的主题标题,如果不需要修改可不填。", False, None),
("summary", ToolParamType.STRING, "新的概括内容,如果不需要修改可不填。", False, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2']。", False, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2']。", False, None),
],
update_chat_history,
)
@@ -201,7 +200,7 @@ def init_dream_tools(chat_id: str) -> None:
_dream_tool_registry.register_tool(
DreamTool(
"create_chat_history",
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)。",
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词等)。",
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
@@ -212,10 +211,11 @@ def init_dream_tools(chat_id: str) -> None:
True,
None,
),
("original_text", ToolParamType.STRING, "对话原文内容(必填)。", True, None),
(
"key_point",
"participants",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
"参与人的 JSON 字符串,如 ['用户1','用户2'](必填)。",
True,
None,
),
@@ -313,8 +313,7 @@ async def run_dream_agent_once(
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
f"概括={record.summary or ''}"
)
logger.debug(

View File

@@ -11,7 +11,8 @@ def make_create_chat_history(chat_id: str):
theme: str,
summary: str,
keywords: str,
key_point: str,
original_text: str,
participants: str,
start_time: float,
end_time: float,
) -> str:
@@ -20,7 +21,8 @@ def make_create_chat_history(chat_id: str):
logger.info(
f"[dream][tool] 调用 create_chat_history("
f"theme={bool(theme)}, summary={bool(summary)}, "
f"keywords={bool(keywords)}, key_point={bool(key_point)}, "
f"keywords={bool(keywords)}, original_text={bool(original_text)}, "
f"participants={bool(participants)}, "
f"start_time={start_time}, end_time={end_time}) (chat_id={chat_id})"
)
@@ -43,7 +45,8 @@ def make_create_chat_history(chat_id: str):
theme=theme,
summary=summary,
keywords=keywords,
key_point=key_point,
original_text=original_text,
participants=participants,
# 对于由 dream 整理产生的新概括,时间范围优先使用工具提供的时间,否则使用当前时间占位
start_time=start_ts,
end_time=end_ts,

View File

@@ -32,8 +32,7 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
f"概括={record.summary or ''}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result

View File

@@ -13,13 +13,12 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
theme: Optional[str] = None,
summary: Optional[str] = None,
keywords: Optional[str] = None,
key_point: Optional[str] = None,
) -> str:
"""按字段更新 chat_history字符串字段要求 JSON 的字段须传入已序列化的字符串)"""
try:
logger.info(
f"[dream][tool] 调用 update_chat_history(memory_id={memory_id}, "
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)}, key_point={bool(key_point)})"
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)})"
)
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
if not record:
@@ -34,8 +33,6 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
data["summary"] = summary
if keywords is not None:
data["keywords"] = keywords
if key_point is not None:
data["key_point"] = key_point
if not data:
return "未提供任何需要更新的字段。"

View File

@@ -71,16 +71,14 @@ def init_prompt():
1. 关键词提取与话题相关的关键词用列表形式返回3-10个关键词
2. 概括对这段话的平文本概括50-200字要求
- 仔细地转述发生的事件和聊天内容;
- 可以适当摘取聊天记录中的原文;
- 重点突出事件的发展过程和结果;
- 围绕话题这个中心进行概括。
3. 关键信息:提取话题中的关键信息点,用列表形式返回3-8个关键信息点每个关键信息点应该简洁明了。
- 提取话题中的关键信息点,关键信息点应该简洁明了。
请以JSON格式返回格式如下
{{
"keywords": ["关键词1", "关键词2", ...],
"summary": "概括内容",
"key_point": ["关键信息1", "关键信息2", ...]
"summary": "概括内容"
}}
聊天记录:
@@ -815,12 +813,38 @@ class ChatHistorySummarizer:
original_text = "\n".join(item.messages)
logger.info(
f"{self.log_prefix} 开始打包话题[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
)
# 使用 LLM 进行总结(基于话题名)
success, keywords, summary, key_point = await self._compress_with_llm(original_text, topic)
if not success:
# 使用 LLM 进行总结(基于话题名),带重试机制
max_retries = 3
attempt = 0
success = False
keywords = []
summary = ""
while attempt < max_retries:
attempt += 1
success, keywords, summary = await self._compress_with_llm(original_text, topic)
if success and keywords and summary:
# 成功获取到有效的 keywords 和 summary
if attempt > 1:
logger.info(
f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功"
)
break
if attempt < max_retries:
logger.warning(
f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试"
)
else:
logger.error(
f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储"
)
if not success or not keywords or not summary:
logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库")
return
@@ -834,14 +858,13 @@ class ChatHistorySummarizer:
theme=topic, # 主题直接使用话题名
keywords=keywords,
summary=summary,
key_point=key_point,
)
logger.info(
f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}"
)
async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str, List[str]]:
async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]:
"""
使用LLM压缩聊天内容用于单个话题的最终总结
@@ -850,7 +873,7 @@ class ChatHistorySummarizer:
topic: 话题名称
Returns:
tuple[bool, List[str], str, List[str]]: (是否成功, 关键词列表, 概括, 关键信息列表)
tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括)
"""
prompt = await global_prompt_manager.format_prompt(
"hippo_topic_summary_prompt",
@@ -920,24 +943,24 @@ class ChatHistorySummarizer:
keywords = result.get("keywords", [])
summary = result.get("summary", "")
key_point = result.get("key_point", [])
if not (keywords and summary) and key_point:
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少字段原文\n{response}")
# 检查必需字段是否为空
if not keywords or not summary:
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段原文\n{response}")
# 返回失败,和模型出错一样,让上层进行重试
return False, [], ""
# 确保keywords和key_point是列表
# 确保keywords是列表
if isinstance(keywords, str):
keywords = [keywords]
if isinstance(key_point, str):
key_point = [key_point]
return True, keywords, summary, key_point
return True, keywords, summary
except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回失败标志和默认值
return False, [], "压缩失败,无法生成概括", []
return False, [], "压缩失败,无法生成概括"
async def _store_to_database(
self,
@@ -948,7 +971,6 @@ class ChatHistorySummarizer:
theme: str,
keywords: List[str],
summary: str,
key_point: Optional[List[str]] = None,
):
"""存储到数据库"""
try:
@@ -968,10 +990,6 @@ class ChatHistorySummarizer:
"count": 0,
}
# 存储 key_point如果存在
if key_point is not None:
data["key_point"] = json.dumps(key_point, ensure_ascii=False)
# 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time
@@ -991,7 +1009,6 @@ class ChatHistorySummarizer:
await self._import_to_lpmm_knowledge(
theme=theme,
summary=summary,
key_point=key_point,
participants=participants,
original_text=original_text,
)
@@ -1007,7 +1024,6 @@ class ChatHistorySummarizer:
self,
theme: str,
summary: str,
key_point: Optional[List[str]],
participants: List[str],
original_text: str,
):
@@ -1017,7 +1033,6 @@ class ChatHistorySummarizer:
Args:
theme: 话题主题
summary: 概括内容
key_point: 关键信息点列表
participants: 参与者列表
original_text: 原始文本(可能很长,需要截断)
"""
@@ -1025,46 +1040,43 @@ class ChatHistorySummarizer:
from src.chat.knowledge.lpmm_ops import lpmm_ops
# 构造要导入的文本内容
# 格式:主题 + 概括 + 关键信息 + 参与者信息
# 格式:主题 + 概括 + 参与者信息 + 原始内容摘要
# 注意使用单换行符连接确保整个内容作为一段导入不被LPMM分段
content_parts = []
# 1. 话题主题
if theme:
content_parts.append(f"话题:{theme}")
# if theme:
# content_parts.append(f"话题:{theme}")
# 2. 概括内容
if summary:
content_parts.append(f"概括:{summary}")
# 3. 关键信息
if key_point:
key_points_text = "".join(key_point)
content_parts.append(f"关键信息:{key_points_text}")
# 4. 参与者信息
# 3. 参与者信息
if participants:
participants_text = "".join(participants)
content_parts.append(f"参与者:{participants_text}")
# 5. 原始文本摘要如果原始文本太长只取前500字
if original_text:
# 截断原始文本,避免过长
max_original_length = 500
if len(original_text) > max_original_length:
truncated_text = original_text[:max_original_length] + "..."
content_parts.append(f"原始内容摘要:{truncated_text}")
else:
content_parts.append(f"原始内容:{original_text}")
# 4. 原始文本摘要如果原始文本太长只取前500字
# if original_text:
# # 截断原始文本,避免过长
# max_original_length = 500
# if len(original_text) > max_original_length:
# truncated_text = original_text[:max_original_length] + "..."
# content_parts.append(f"原始内容摘要:{truncated_text}")
# else:
# content_parts.append(f"原始内容:{original_text}")
# 将所有部分合并为一个段落(用双换行分隔符合lpmm_ops.add_content的格式要求
content_to_import = "\n\n".join(content_parts)
# 将所有部分合并为一个完整段落(使用单换行符避免被LPMM分段
# LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段
content_to_import = "\n".join(content_parts)
if not content_to_import.strip():
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库")
return
# 调用lpmm_ops导入
result = await lpmm_ops.add_content(content_to_import)
result = await lpmm_ops.add_content(text=content_to_import, auto_split=False)
if result["status"] == "success":
logger.info(

View File

@@ -463,18 +463,6 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
if record.summary:
result_parts.append(f"概括:{record.summary}")
# 添加关键信息点
if record.key_point:
try:
key_point_data = (
json.loads(record.key_point) if isinstance(record.key_point, str) else record.key_point
)
if isinstance(key_point_data, list) and key_point_data:
key_point_str = "\n".join([f" - {str(kp)}" for kp in key_point_data])
result_parts.append(f"关键信息点:\n{key_point_str}")
except (json.JSONDecodeError, TypeError, ValueError):
pass
results.append("\n".join(result_parts))
if not results:

161
src/webui/app.py Normal file
View File

@@ -0,0 +1,161 @@
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
import mimetypes
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from src.common.logger import get_logger
logger = get_logger("webui.app")
def create_app(
host: str = "0.0.0.0",
port: int = 8001,
enable_static: bool = True,
) -> FastAPI:
"""
创建 WebUI FastAPI 应用实例
Args:
host: 服务器主机地址
port: 服务器端口
enable_static: 是否启用静态文件服务
"""
app = FastAPI(title="MaiBot WebUI")
_setup_anti_crawler(app)
_setup_cors(app, port)
_register_api_routes(app)
_setup_robots_txt(app)
if enable_static:
_setup_static_files(app)
return app
def _setup_cors(app: FastAPI, port: int):
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:7999",
"http://127.0.0.1:7999",
f"http://localhost:{port}",
f"http://127.0.0.1:{port}",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
],
expose_headers=["Content-Length", "Content-Type"],
)
logger.debug("✅ CORS 中间件已配置")
def _setup_anti_crawler(app: FastAPI):
try:
from src.webui.middleware import AntiCrawlerMiddleware
from src.config.config import global_config
anti_crawler_mode = global_config.webui.anti_crawler_mode
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {
"false": "已禁用",
"strict": "严格模式",
"loose": "宽松模式",
"basic": "基础模式",
}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(app: FastAPI):
try:
from src.webui.middleware import create_robots_txt_response
@app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(app: FastAPI):
try:
from src.webui.routers import get_all_routers
for router in get_all_routers():
app.include_router(router)
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
def _setup_static_files(app: FastAPI):
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return
@app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
if not full_path or full_path == "/":
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type)
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def show_access_token():
"""显示 WebUI Access Token供启动时调用"""
try:
from src.webui.core import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
except Exception as e:
logger.error(f"❌ 获取 Access Token 失败: {e}")

View File

@@ -0,0 +1,30 @@
from .security import TokenManager, get_token_manager
from .rate_limiter import (
RateLimiter,
get_rate_limiter,
check_auth_rate_limit,
check_api_rate_limit,
)
from .auth import (
COOKIE_NAME,
COOKIE_MAX_AGE,
get_current_token,
set_auth_cookie,
clear_auth_cookie,
verify_auth_token_from_cookie_or_header,
)
__all__ = [
"TokenManager",
"get_token_manager",
"RateLimiter",
"get_rate_limiter",
"check_auth_rate_limit",
"check_api_rate_limit",
"COOKIE_NAME",
"COOKIE_MAX_AGE",
"get_current_token",
"set_auth_cookie",
"clear_auth_cookie",
"verify_auth_token_from_cookie_or_header",
]

185
src/webui/core/auth.py Normal file
View File

@@ -0,0 +1,185 @@
"""
WebUI 认证模块
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
"""
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
from src.config.config import global_config
from .security import get_token_manager
logger = get_logger("webui.auth")
# Cookie 配置
COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 从配置读取
if global_config.webui.secure_cookie:
logger.info("配置中启用了 secure_cookie")
return True
# 检查是否是生产环境
if global_config.webui.mode == "production":
logger.info("WebUI运行在生产模式启用 secure cookie")
return True
# 默认:开发环境不启用(因为通常是 HTTP
logger.debug("WebUI运行在开发模式禁用 secure cookie")
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
request: FastAPI Request 对象(可选,用于检测协议)
"""
# 根据环境和实际请求协议决定安全设置
is_secure = _is_secure_environment()
# 如果提供了 request检测实际使用的协议
if request:
# 检查 X-Forwarded-Proto header代理/负载均衡器)
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
if forwarded_proto:
is_https = forwarded_proto == "https"
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
else:
# 检查 request.url.scheme
is_https = request.url.scheme == "https"
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
# 如果是 HTTP 连接,强制禁用 secure 标志
if not is_https and is_secure:
logger.warning("=" * 80)
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
logger.warning("=" * 80)
is_secure = False
# 设置 Cookie
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
secure=is_secure, # 根据实际协议决定
path="/", # 确保 Cookie 在所有路径下可用
)
logger.info(
f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})"
)
logger.debug(f"完整 token 前缀: {token[:20]}...")
def clear_auth_cookie(response: Response) -> None:
"""
清除认证 Cookie
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")
def verify_auth_token_from_cookie_or_header(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""
验证认证 Token支持从 Cookie 或 Header 获取
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@@ -0,0 +1,245 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间(秒)
key_suffix: 键后缀,用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长(秒)
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试(如登录失败)
如果在窗口期内失败次数过多,自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口(秒)
block_duration: 封禁时长(秒)
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数(认证成功后调用)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则:
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth",
)
if not allowed:
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则:每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api",
)
if not allowed:
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})

309
src/webui/core/security.py Normal file
View File

@@ -0,0 +1,309 @@
"""
WebUI Token 管理模块
负责生成、保存、验证和更新访问令牌
"""
import json
import secrets
from pathlib import Path
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("webui")
class TokenManager:
"""Token 管理器"""
def __init__(self, config_path: Optional[Path] = None):
"""
初始化 Token 管理器
Args:
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
"""
if config_path is None:
# 获取项目根目录 (src/webui/core -> src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent.parent
config_path = project_root / "data" / "webui.json"
self.config_path = config_path
self.config_path.parent.mkdir(parents=True, exist_ok=True)
# 确保配置文件存在并包含有效的 token
self._ensure_config()
def _ensure_config(self):
"""确保配置文件存在且包含有效的 token"""
if not self.config_path.exists():
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
self._create_new_token()
else:
# 验证配置文件格式
try:
config = self._load_config()
if not config.get("access_token"):
logger.warning("WebUI 配置文件中缺少 access_token正在重新生成")
self._create_new_token()
else:
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
except Exception as e:
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
self._create_new_token()
def _load_config(self) -> dict:
"""加载配置文件"""
try:
with open(self.config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 WebUI 配置失败: {e}")
return {}
def _save_config(self, config: dict):
"""保存配置文件"""
try:
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
logger.info(f"WebUI 配置已保存到: {self.config_path}")
except Exception as e:
logger.error(f"保存 WebUI 配置失败: {e}")
raise
def _create_new_token(self) -> str:
"""生成新的 64 位随机 token"""
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
token = secrets.token_hex(32)
config = {
"access_token": token,
"created_at": self._get_current_timestamp(),
"updated_at": self._get_current_timestamp(),
"first_setup_completed": False, # 标记首次配置未完成
}
self._save_config(config)
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
return token
def _get_current_timestamp(self) -> str:
"""获取当前时间戳字符串"""
from datetime import datetime
return datetime.now().isoformat()
def get_token(self) -> str:
"""获取当前有效的 token"""
config = self._load_config()
return config.get("access_token", "")
def verify_token(self, token: str) -> bool:
"""
验证 token 是否有效
Args:
token: 待验证的 token
Returns:
bool: token 是否有效
"""
if not token:
return False
current_token = self.get_token()
if not current_token:
logger.error("系统中没有有效的 token")
return False
# 使用 secrets.compare_digest 防止时序攻击
is_valid = secrets.compare_digest(token, current_token)
if is_valid:
logger.debug("Token 验证成功")
else:
logger.warning("Token 验证失败")
return is_valid
def update_token(self, new_token: str) -> tuple[bool, str]:
"""
更新 token
Args:
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
Returns:
tuple[bool, str]: (是否更新成功, 错误消息)
"""
# 验证新 token 格式
is_valid, error_msg = self._validate_custom_token(new_token)
if not is_valid:
logger.error(f"Token 格式无效: {error_msg}")
return False, error_msg
try:
config = self._load_config()
old_token = config.get("access_token", "")[:8]
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
return True, "Token 更新成功"
except Exception as e:
logger.error(f"更新 Token 失败: {e}")
return False, f"更新失败: {str(e)}"
def regenerate_token(self) -> str:
"""
重新生成 token保留 first_setup_completed 状态)
Returns:
str: 新生成的 token
"""
logger.info("正在重新生成 WebUI Token...")
# 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态
config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token
def _validate_token_format(self, token: str) -> bool:
"""
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token
Args:
token: 待验证的 token
Returns:
bool: 格式是否正确
"""
if not token or not isinstance(token, str):
return False
# 必须是 64 位十六进制字符串
if len(token) != 64:
return False
# 验证是否为有效的十六进制字符串
try:
int(token, 16)
return True
except ValueError:
return False
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
"""
验证自定义 token 格式
要求:
- 最少 10 位
- 包含大写字母
- 包含小写字母
- 包含特殊符号
Args:
token: 待验证的 token
Returns:
tuple[bool, str]: (是否有效, 错误消息)
"""
if not token or not isinstance(token, str):
return False, "Token 不能为空"
# 检查长度
if len(token) < 10:
return False, "Token 长度至少为 10 位"
# 检查是否包含大写字母
has_upper = any(c.isupper() for c in token)
if not has_upper:
return False, "Token 必须包含大写字母"
# 检查是否包含小写字母
has_lower = any(c.islower() for c in token)
if not has_lower:
return False, "Token 必须包含小写字母"
# 检查是否包含特殊符号
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
has_special = any(c in special_chars for c in token)
if not has_special:
return False, f"Token 必须包含特殊符号 ({special_chars})"
return True, "Token 格式正确"
def is_first_setup(self) -> bool:
"""
检查是否为首次配置
Returns:
bool: 是否为首次配置
"""
config = self._load_config()
return not config.get("first_setup_completed", False)
def mark_setup_completed(self) -> bool:
"""
标记首次配置已完成
Returns:
bool: 是否标记成功
"""
try:
config = self._load_config()
config["first_setup_completed"] = True
config["setup_completed_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info("首次配置已标记为完成")
return True
except Exception as e:
logger.error(f"标记首次配置完成失败: {e}")
return False
def reset_setup_status(self) -> bool:
"""
重置首次配置状态,允许重新进入配置向导
Returns:
bool: 是否重置成功
"""
try:
config = self._load_config()
config["first_setup_completed"] = False
if "setup_completed_at" in config:
del config["setup_completed_at"]
self._save_config(config)
logger.info("首次配置状态已重置")
return True
except Exception as e:
logger.error(f"重置首次配置状态失败: {e}")
return False
# 全局单例
_token_manager_instance: Optional[TokenManager] = None
def get_token_manager() -> TokenManager:
"""获取 TokenManager 单例"""
global _token_manager_instance
if _token_manager_instance is None:
_token_manager_instance = TokenManager()
return _token_manager_instance

87
src/webui/dependencies.py Normal file
View File

@@ -0,0 +1,87 @@
from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
async def require_auth(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
FastAPI 依赖:要求有效认证
用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token
Returns:
验证通过的 token
Raises:
HTTPException 401: 认证失败
"""
return get_current_token(request, maibot_session, authorization)
async def require_auth_with_rate_limit(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
_rate_limit: None = Depends(check_auth_rate_limit),
) -> str:
"""
FastAPI 依赖:要求有效认证 + 频率限制
组合了认证检查和频率限制,适用于敏感操作
Returns:
验证通过的 token
Raises:
HTTPException 401: 认证失败
HTTPException 429: 请求过于频繁
"""
return get_current_token(request, maibot_session, authorization)
def get_optional_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Optional[str]:
"""
FastAPI 依赖:可选获取 token不验证
用于某些需要知道是否有 token 但不强制验证的场景
Returns:
token 字符串或 None
"""
if maibot_session:
return maibot_session
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
async def verify_token_optional(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""
FastAPI 依赖:可选验证 token
返回 token 是否有效,不抛出异常
Returns:
True 如果 token 有效,否则 False
"""
token = None
if maibot_session:
token = maibot_session
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
return False
token_manager = get_token_manager()
return token_manager.verify_token(token)

View File

@@ -0,0 +1,17 @@
from .anti_crawler import (
AntiCrawlerMiddleware,
create_robots_txt_response,
ANTI_CRAWLER_MODE,
ALLOWED_IPS,
TRUSTED_PROXIES,
TRUST_XFF,
)
__all__ = [
"AntiCrawlerMiddleware",
"create_robots_txt_response",
"ANTI_CRAWLER_MODE",
"ALLOWED_IPS",
"TRUSTED_PROXIES",
"TRUST_XFF",
]

View File

@@ -0,0 +1,795 @@
"""
WebUI 防爬虫模块
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import time
import ipaddress
import re
from collections import deque
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from src.common.logger import get_logger
logger = get_logger("webui.anti_crawler")
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
CRAWLER_USER_AGENTS = {
# 搜索引擎爬虫(精确匹配)
"googlebot",
"bingbot",
"baiduspider",
"yandexbot",
"slurp", # Yahoo
"duckduckbot",
"sogou",
"exabot",
"facebot",
"ia_archiver", # Internet Archive
# 通用爬虫(移除过于宽泛的关键词)
"crawler",
"spider",
"scraper",
"wget", # 保留wget因为通常用于自动化脚本
"scrapy", # 保留scrapy因为这是爬虫框架
# 安全扫描工具(这些是明确的扫描工具)
"masscan",
"nmap",
"nikto",
"sqlmap",
# 注意:移除了以下过于宽泛的关键词以避免误报:
# - "bot" (会误匹配GitHub-Robot等)
# - "curl" (正常工具)
# - "python-requests" (正常库)
# - "httpx" (正常库)
# - "aiohttp" (正常库)
}
# 资产测绘工具 User-Agent 标识
ASSET_SCANNER_USER_AGENTS = {
# 知名资产测绘平台
"shodan",
"censys",
"zoomeye",
"fofa",
"quake",
"hunter",
"binaryedge",
"onyphe",
"securitytrails",
"virustotal",
"passivetotal",
# 安全扫描工具
"acunetix",
"appscan",
"burpsuite",
"nessus",
"openvas",
"qualys",
"rapid7",
"tenable",
"veracode",
"zap",
"awvs", # Acunetix Web Vulnerability Scanner
"netsparker",
"skipfish",
"w3af",
"arachni",
# 其他扫描工具
"masscan",
"zmap",
"nmap",
"whatweb",
"wpscan",
"joomscan",
"dnsenum",
"subfinder",
"amass",
"sublist3r",
"theharvester",
}
# 资产测绘工具常用的HTTP头标识
ASSET_SCANNER_HEADERS = {
# 常见的扫描工具自定义头
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
"x-scanner": {"nmap", "masscan", "zmap"},
"x-probe": {"masscan", "zmap"},
# 其他可疑头(移除反向代理标准头)
"x-originating-ip": set(),
"x-remote-ip": set(),
"x-remote-addr": set(),
# 注意:移除了以下反向代理标准头以避免误报:
# - "x-forwarded-proto" (反向代理标准头)
# - "x-real-ip" (反向代理标准头已在_get_client_ip中使用)
}
# 仅检查特定HTTP头中的可疑模式收紧匹配范围
# 只检查这些特定头,不检查所有头
SCANNER_SPECIFIC_HEADERS = {
"x-scan",
"x-scanner",
"x-probe",
"x-originating-ip",
"x-remote-ip",
"x-remote-addr",
}
# 防爬虫模式配置
# false: 禁用
# strict: 严格模式(更严格的检测,更低的频率限制)
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
# - CIDR格式192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
# - 通配符192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
# - IPv6::1, 2001:db8::/32
def _parse_allowed_ips(ip_string: str) -> list:
"""
解析IP白名单字符串支持精确IP、CIDR格式和通配符
Args:
ip_string: 逗号分隔的IP字符串
Returns:
IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式
- ipaddress.IPv4Address/IPv6Address对象精确IP
- str通配符模式已转换为正则表达式
"""
allowed = []
if not ip_string:
return allowed
for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 跳过注释行(以#开头)
if ip_entry.startswith("#"):
continue
# 检查通配符格式(包含*
if "*" in ip_entry:
# 处理通配符
pattern = _convert_wildcard_to_regex(ip_entry)
if pattern:
allowed.append(pattern)
else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue
try:
# 尝试解析为CIDR格式包含/
if "/" in ip_entry:
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
else:
# 精确IP地址
allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
"""
将通配符IP模式转换为正则表达式
支持的格式:
- 192.168.*.* 或 192.168.*
- 10.*.*.* 或 10.*
- *.*.*.* 或 *
Args:
wildcard_pattern: 通配符模式字符串
Returns:
正则表达式字符串如果格式无效则返回None
"""
# 去除空格
pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有)
if pattern == "*":
return r".*"
# 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".")
if len(parts) > 4:
return None # IPv4最多4段
# 构建正则表达式
regex_parts = []
for part in parts:
part = part.strip()
if part == "*":
regex_parts.append(r"\d+") # 匹配任意数字
elif part.isdigit():
# 验证数字范围0-255
num = int(part)
if 0 <= num <= 255:
regex_parts.append(re.escape(part))
else:
return None # 无效的数字
else:
return None # 无效的格式
# 如果部分少于4段补充.*
while len(regex_parts) < 4:
regex_parts.append(r"\d+")
# 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
def _get_anti_crawler_config():
"""获取防爬虫配置"""
from src.config.config import global_config
return {
'mode': global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff
}
# 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode']
ALLOWED_IPS = _config['allowed_ips']
TRUSTED_PROXIES = _config['trusted_proxies']
TRUST_XFF = _config['trust_xff']
def _get_mode_config(mode: str) -> dict:
"""
根据模式获取配置参数
Args:
mode: 防爬虫模式 (false/strict/loose/basic)
Returns:
配置字典,包含所有相关参数
"""
mode = mode.lower()
if mode == "false":
return {
"enabled": False,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
"max_tracked_ips": 0,
"check_user_agent": False,
"check_asset_scanner": False,
"check_rate_limit": False,
"block_on_detect": False, # 不阻止
}
elif mode == "strict":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
"max_tracked_ips": 20000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
elif mode == "loose":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
"max_tracked_ips": 5000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
else: # basic (默认模式)
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 不限制请求数
"max_tracked_ips": 0, # 不跟踪IP
"check_user_agent": True, # 检测但不阻止
"check_asset_scanner": True, # 检测但不阻止
"check_rate_limit": False, # 不限制请求频率
"block_on_detect": False, # 只记录,不阻止
}
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""防爬虫中间件"""
def __init__(self, app, mode: str = "standard"):
"""
初始化防爬虫中间件
Args:
app: FastAPI 应用实例
mode: 防爬虫模式 (false/strict/loose/standard)
"""
super().__init__(app)
self.mode = mode.lower()
# 根据模式获取配置
config = _get_mode_config(self.mode)
self.enabled = config["enabled"]
self.rate_limit_window = config["rate_limit_window"]
self.rate_limit_max_requests = config["rate_limit_max_requests"]
self.max_tracked_ips = config["max_tracked_ips"]
self.check_user_agent = config["check_user_agent"]
self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
"""
检测是否为爬虫 User-Agent
Args:
user_agent: User-Agent 字符串
Returns:
如果是爬虫则返回 True
"""
if not user_agent:
# 没有 User-Agent 的请求记录日志但不直接阻止
# 改为只记录,让频率限制来处理
logger.debug("请求缺少User-Agent")
return False # 不再直接阻止无User-Agent的请求
user_agent_lower = user_agent.lower()
# 使用集合查找提高性能(检查是否包含爬虫关键词)
for crawler_keyword in self.crawler_keywords_set:
if crawler_keyword in user_agent_lower:
return True
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
Args:
request: 请求对象
Returns:
如果检测到资产测绘工具头则返回 True
"""
# 只检查特定的扫描工具头,不检查所有头
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
header_value_lower = header_value.lower() if header_value else ""
# 检查已知的扫描工具头
if header_name_lower in ASSET_SCANNER_HEADERS:
# 如果该头有特定的工具集合,检查值是否匹配
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
if expected_tools:
for tool in expected_tools:
if tool in header_value_lower:
return True
else:
# 如果没有特定工具集合,只要存在该头就视为可疑
if header_value_lower:
return True
# 只检查特定头中的可疑模式(收紧匹配)
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
# 检查头值中是否包含已知扫描工具名称
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
return True
return False
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
"""
检测资产测绘工具
Args:
request: 请求对象
Returns:
(是否检测到, 检测到的工具名称)
"""
user_agent = request.headers.get("User-Agent")
# 检查 User-Agent使用集合查找提高性能
if user_agent:
user_agent_lower = user_agent.lower()
for scanner_keyword in self.scanner_keywords_set:
if scanner_keyword in user_agent_lower:
return True, scanner_keyword
# 检查HTTP头
if self._is_asset_scanner_header(request):
# 尝试从User-Agent或头中提取工具名称
detected_tool = None
if user_agent:
user_agent_lower = user_agent.lower()
for tool in self.scanner_keywords_set:
if tool in user_agent_lower:
detected_tool = tool
break
# 检查HTTP头中的工具标识只检查特定头
if not detected_tool:
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
break
if detected_tool:
break
return True, detected_tool or "unknown_scanner"
return False, None
def _check_rate_limit(self, client_ip: str) -> bool:
"""
检查请求频率限制
Args:
client_ip: 客户端IP地址
Returns:
如果超过限制则返回 True需要阻止
"""
# 检查IP白名单
if self._is_ip_allowed(client_ip):
return False
current_time = time.time()
# 定期清理过期的请求记录每5分钟清理一次
if current_time - self.last_cleanup > 300:
self._cleanup_old_requests(current_time)
self.last_cleanup = current_time
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录删除最久未访问的IP
self._cleanup_oldest_ips()
# 获取或创建该IP的请求时间deque不使用maxlen避免限流变松
if client_ip not in self.request_times:
self.request_times[client_ip] = deque()
request_times = self.request_times[client_ip]
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
while request_times and current_time - request_times[0] >= self.rate_limit_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= self.rate_limit_max_requests:
return True
# 记录当前请求时间
request_times.append(current_time)
return False
def _cleanup_old_requests(self, current_time: float):
"""清理过期的请求记录只清理当前需要检查的IP不全量遍历"""
# 这个方法现在主要用于定期清理实际清理在_check_rate_limit中按需进行
# 清理最久未访问的IP记录
if len(self.request_times) > self.max_tracked_ips * 0.8:
self._cleanup_oldest_ips()
def _cleanup_oldest_ips(self):
"""清理最久未访问的IP记录全量遍历找真正的oldest"""
if not self.request_times:
return
# 先收集空deque的IP优先删除
empty_ips = []
# 找到最久未访问的IP最旧时间戳
oldest_ip = None
oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items():
if not times:
# 空deque记录待删除
empty_ips.append(ip)
else:
# 找到最旧的时间戳
if times[0] < oldest_time:
oldest_time = times[0]
oldest_ip = ip
# 先删除空deque的IP
for ip in empty_ips:
del self.request_times[ip]
# 如果没有空deque可删除且仍需要清理删除最旧的一个IP
if not empty_ips and oldest_ip:
del self.request_times[oldest_ip]
def _is_trusted_proxy(self, ip: str) -> bool:
"""
检查IP是否在信任的代理列表中
Args:
ip: IP地址字符串
Returns:
如果是信任的代理则返回 True
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
if isinstance(trusted_entry, str):
try:
if re.match(trusted_entry, ip):
return True
except re.error:
continue
# CIDR格式网络对象
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in trusted_entry:
return True
except (ValueError, AttributeError):
continue
# 精确IP地址对象
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == trusted_entry:
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端真实IP地址带基本验证和代理信任检查
Args:
request: 请求对象
Returns:
客户端IP地址
"""
# 获取直接连接的客户端IP用于验证代理
direct_client_ip = None
if request.client:
direct_client_ip = request.client.host
# 检查是否信任X-Forwarded-For头
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
use_xff = False
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
use_xff = self._is_trusted_proxy(direct_client_ip)
# 如果信任代理,优先从 X-Forwarded-For 获取
if use_xff:
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取(如果信任代理)
if use_xff:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用直接连接的客户端IP
if direct_client_ip and self._validate_ip(direct_client_ip):
return direct_client_ip
return "unknown"
def _validate_ip(self, ip: str) -> bool:
"""
验证IP地址格式
Args:
ip: IP地址字符串
Returns:
如果格式有效则返回 True
"""
try:
ipaddress.ip_address(ip)
return True
except (ValueError, AttributeError):
return False
def _is_ip_allowed(self, ip: str) -> bool:
"""
检查IP是否在白名单中支持精确IP、CIDR格式和通配符
Args:
ip: 客户端IP地址
Returns:
如果IP在白名单中则返回 True
"""
if not ALLOWED_IPS or ip == "unknown":
return False
# 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式)
if isinstance(allowed_entry, str):
try:
if re.match(allowed_entry, ip):
return True
except re.error:
# 正则表达式错误,跳过
continue
# CIDR格式网络对象
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
# 精确IP地址对象
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
return False
async def dispatch(self, request: Request, call_next):
"""
处理请求
Args:
request: 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
响应对象
"""
# 如果未启用,直接通过
if not self.enabled:
return await call_next(request)
# 允许访问 robots.txt由专门的路由处理
if request.url.path == "/robots.txt":
return await call_next(request)
# 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static:
return await call_next(request)
# 获取客户端IP只获取一次避免重复调用
client_ip = self._get_client_ip(request)
# 检查IP白名单优先检查白名单IP直接通过
if self._is_ip_allowed(client_ip):
return await call_next(request)
# 获取 User-Agent
user_agent = request.headers.get("User-Agent")
# 检测资产测绘工具(优先检测,因为更危险)
if self.check_asset_scanner:
is_scanner, scanner_name = self._detect_asset_scanner(request)
if is_scanner:
logger.warning(
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
f"User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Asset scanning tools are not allowed",
status_code=403,
)
# 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Crawlers are not allowed",
status_code=403,
)
# 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
return PlainTextResponse(
"Too Many Requests: Rate limit exceeded",
status_code=429,
)
# 正常请求,继续处理
return await call_next(request)
def create_robots_txt_response() -> PlainTextResponse:
"""
创建 robots.txt 响应
Returns:
robots.txt 响应对象
"""
robots_content = """User-agent: *
Disallow: /
# 禁止所有爬虫访问
"""
return PlainTextResponse(
content=robots_content,
media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
)

View File

@@ -0,0 +1,35 @@
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
from fastapi import APIRouter
def get_api_router() -> APIRouter:
"""获取主 API 路由器(包含所有子路由)"""
from src.webui.routes import router as main_router
return main_router
def get_all_routers() -> list[APIRouter]:
"""获取所有需要独立注册的路由器列表"""
from src.webui.routes import router as main_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.chat import router as chat_router
from src.webui.api.planner import router as planner_router
from src.webui.api.replier import router as replier_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,
replier_router,
]
__all__ = [
"get_api_router",
"get_all_routers",
]

View File

@@ -0,0 +1,938 @@
"""麦麦 2025 年度总结 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import (
LLMUsage,
OnlineTime,
Messages,
ChatStreams,
PersonInfo,
Emoji,
Expression,
ActionRecords,
Jargon,
)
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ==================== Pydantic 模型定义 ====================
class TimeFootprintData(BaseModel):
"""时光足迹数据"""
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
first_message_time: Optional[str] = Field(None, description="初次消息时间")
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
is_night_owl: bool = Field(False, description="是否是夜猫子")
class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
longest_companion_days: int = Field(0, description="陪伴天数")
class BrainPowerData(BaseModel):
"""最强大脑数据"""
total_tokens: int = Field(0, description="年度消耗Token总量")
total_cost: float = Field(0.0, description="年度总花费")
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
total_actions: int = Field(0, description="总动作数")
no_reply_count: int = Field(0, description="选择沉默的次数")
avg_interest_value: float = Field(0.0, description="平均兴趣值")
max_interest_value: float = Field(0.0, description="最高兴趣值")
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
max_reasoning_length: int = Field(0, description="最长思考长度")
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
"""趣味成就数据"""
new_jargon_count: int = Field(0, description="新学到的黑话数量")
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
class AnnualReportData(BaseModel):
"""年度报告完整数据"""
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
achievements: AchievementData = Field(default_factory=AchievementData)
# ==================== 辅助函数 ====================
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
"""获取指定年份的时间戳范围"""
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
return start, end
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
"""获取指定年份的 datetime 范围"""
start = datetime(year, 1, 1, 0, 0, 0)
end = datetime(year, 12, 31, 23, 59, 59)
return start, end
# ==================== 维度一:时光足迹 ====================
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
"""获取时光足迹数据"""
data = TimeFootprintData()
start_ts, end_ts = get_year_time_range(year)
start_dt, end_dt = get_year_datetime_range(year)
try:
# 1. 年度在线时长
online_records = list(
OnlineTime.select().where(
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
)
)
total_seconds = 0
for record in online_records:
try:
start = max(record.start_timestamp, start_dt)
end = min(record.end_timestamp, end_dt)
if end > start:
total_seconds += (end - start).total_seconds()
except Exception:
continue
data.total_online_hours = round(total_seconds / 3600, 2)
# 2. 初次相遇 - 年度第一条消息
first_msg = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.order_by(Messages.time.asc())
.first()
)
if first_msg:
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
content = first_msg.processed_plain_text or first_msg.display_message or ""
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
# 3. 最忙碌的一天
# 使用 SQLite 的 date 函数按日期分组
busiest_query = (
Messages.select(
fn.date(Messages.time, "unixepoch").alias("day"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.date(Messages.time, "unixepoch"))
.order_by(fn.COUNT(Messages.id).desc())
.limit(1)
)
busiest_result = list(busiest_query.dicts())
if busiest_result:
data.busiest_day = busiest_result[0].get("day")
data.busiest_day_count = busiest_result[0].get("count", 0)
# 4. 昼夜节律 - 24小时活跃分布
hourly_query = (
Messages.select(
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
)
hourly_distribution = [0] * 24
for row in hourly_query.dicts():
try:
hour = int(row.get("hour", 0))
if 0 <= hour < 24:
hourly_distribution[hour] = row.get("count", 0)
except (ValueError, TypeError):
continue
data.hourly_distribution = hourly_distribution
# 5. 深夜食堂 (0-4点)
data.midnight_chat_count = sum(hourly_distribution[0:5])
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
morning_activity = sum(hourly_distribution[6:13])
data.is_night_owl = night_activity > morning_activity
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
return data
# ==================== 维度二:社交网络 ====================
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 加入的群组总数
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
# 2. 话痨群组 TOP3
top_groups_query = (
Messages.select(
Messages.chat_info_group_id,
Messages.chat_info_group_name,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.chat_info_group_id.is_null(False))
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_groups = [
{
"group_id": row["chat_info_group_id"],
"group_name": row["chat_info_group_name"] or "未知群组",
"message_count": row["count"],
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
}
for row in top_groups_query.dicts()
]
# 3. 互动最多的用户 TOP5过滤 bot 自身)
top_users_query = (
Messages.select(
Messages.user_id,
Messages.user_nickname,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id.is_null(False))
& (Messages.user_id != bot_qq) # 过滤 bot 自身
)
.group_by(Messages.user_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_users = [
{
"user_id": row["user_id"],
"user_nickname": row["user_nickname"] or "未知用户",
"message_count": row["count"],
"is_webui": str(row["user_id"]).startswith("webui_"),
}
for row in top_users_query.dicts()
]
# 4. 被@次数
data.at_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_at == True)
)
.count()
)
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_mentioned == True)
)
.count()
)
# 6. 最长情陪伴的用户(过滤 bot 自身)
companion_query = (
ChatStreams.select(
ChatStreams.user_id,
ChatStreams.user_nickname,
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
)
.where(
(ChatStreams.user_id.is_null(False))
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
)
companion_result = list(companion_query.dicts())
if companion_result:
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
duration = companion_result[0].get("duration", 0) or 0
data.longest_companion_days = int(duration / 86400) # 转换为天
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
return data
# ==================== 维度三:最强大脑 ====================
async def get_brain_power(year: int = 2025) -> BrainPowerData:
"""获取最强大脑数据"""
data = BrainPowerData()
start_dt, end_dt = get_year_datetime_range(year)
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 年度消耗 Token 总量和总花费
token_query = LLMUsage.select(
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
result = token_query.dicts().get()
data.total_tokens = int(result.get("total_tokens", 0) or 0)
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
# 2. 最爱用的模型
model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10)
)
model_results = list(model_query.dicts())
if model_results:
data.favorite_model = model_results[0].get("model")
data.favorite_model_count = model_results[0].get("count", 0)
data.model_distribution = [
{
"model": row["model"],
"count": row["count"],
"tokens": row["tokens"],
"cost": round(row["cost"], 4),
}
for row in model_results
]
# 3. 最昂贵的一次思考
expensive_query = (
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.order_by(LLMUsage.cost.desc())
.limit(1)
)
expensive_result = expensive_query.first()
if expensive_result:
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_query = (
LLMUsage.select(
LLMUsage.user_id,
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (LLMUsage.user_id != "system") # 过滤 system 用户
& (LLMUsage.user_id.is_null(False))
)
.group_by(LLMUsage.user_id)
.order_by(fn.SUM(LLMUsage.cost).desc())
.limit(3)
)
data.top_token_consumers = [
{
"user_id": row["user_id"],
"cost": round(row["cost"], 4),
"tokens": row["tokens"],
}
for row in consumer_query.dicts()
]
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (
LLMUsage.model_assign_name.contains("replyer")
| LLMUsage.model_assign_name.contains("回复")
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
)
)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where(
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
).count()
no_reply_count = ActionRecords.select().where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
).count()
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
# 6. 情绪波动 (兴趣值)
interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"),
).where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value.is_null(False))
)
interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
# 找到最高兴趣值的时间
if data.max_interest_value > 0:
max_interest_msg = (
Messages.select(Messages.time)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value == data.max_interest_value)
)
.first()
)
if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
"%Y-%m-%d %H:%M:%S"
)
# 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = (
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
)
reasoning_lengths = []
max_len = 0
max_len_time = None
for record in reasoning_records:
if record.action_reasoning:
length = len(record.action_reasoning)
reasoning_lengths.append(length)
if length > max_len:
max_len = length
max_len_time = record.time
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
if max_len_time:
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
return data
# ==================== 维度四:个性与表达 ====================
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 表情包之王 - 使用次数最多的表情包
top_emoji_query = (
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
.where(Emoji.is_registered == True)
.order_by(Emoji.usage_count.desc())
.limit(5)
)
top_emojis = list(top_emoji_query.dicts())
if top_emojis:
data.top_emoji = {
"id": top_emojis[0].get("id"),
"path": top_emojis[0].get("full_path"),
"description": top_emojis[0].get("description"),
"usage_count": top_emojis[0].get("usage_count", 0),
"hash": top_emojis[0].get("emoji_hash"),
}
data.top_emojis = [
{
"id": e.get("id"),
"path": e.get("full_path"),
"description": e.get("description"),
"usage_count": e.get("usage_count", 0),
"hash": e.get("emoji_hash"),
}
for e in top_emojis
]
# 2. 百变麦麦 - 最常用的表达风格
expression_query = (
Expression.select(
Expression.style,
fn.SUM(Expression.count).alias("total_count"),
)
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc())
.limit(5)
)
data.top_expressions = [
{"style": row["style"], "count": row["total_count"]}
for row in expression_query.dicts()
]
# 3. 被拒绝的表达
data.rejected_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.rejected == True)
)
.count()
)
# 4. 已检查的表达
data.checked_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.checked == True)
)
.count()
)
# 5. 表达总数
data.total_expressions = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.count()
)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
]
action_query = (
ActionRecords.select(
ActionRecords.action_name,
fn.COUNT(ActionRecords.id).alias("count"),
)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name.not_in(excluded_actions))
)
.group_by(ActionRecords.action_name)
.order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10)
)
data.action_types = [
{"action": row["action_name"], "count": row["count"]}
for row in action_query.dicts()
]
# 7. 处理的图片数量
data.image_processed_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_picid == True)
)
.count()
)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content)
# 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
# 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
Messages.time,
Messages.processed_plain_text,
Messages.display_message,
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = datetime.fromtimestamp(msg.time)
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
})
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
data.late_night_reply = {
"time": selected["datetime_str"],
"content": content,
}
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
reply_records = (
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
)
reply_contents = []
for record in reply_records:
try:
action_data = record.action_data
if action_data:
content = None
# 尝试解析 JSON 格式
try:
parsed = json_lib.loads(action_data)
if isinstance(parsed, dict):
# 优先使用 reply_text其次使用 content
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
if most_common:
fav_content, fav_count = most_common[0]
# 截断过长的内容
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
data.favorite_reply = {
"content": display_content,
"count": fav_count,
}
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
return data
# ==================== 维度五:趣味成就 ====================
async def get_achievements(year: int = 2025) -> AchievementData:
"""获取趣味成就数据"""
data = AchievementData()
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
# 2. 代表性黑话示例
jargon_samples = (
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
.where(Jargon.is_jargon == True)
.order_by(Jargon.count.desc())
.limit(5)
)
data.sample_jargons = [
{
"content": j.content,
"meaning": j.meaning,
"count": j.count,
}
for j in jargon_samples
]
# 3. 总消息数
data.total_messages = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.count()
)
# 4. 总回复数 (有 reply_to 的消息)
data.total_replies = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.reply_to.is_null(False))
)
.count()
)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
return data
# ==================== API 路由 ====================
@router.get("/full", response_model=AnnualReportData)
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
"""
获取完整年度报告数据
Args:
year: 报告年份默认2025
Returns:
完整的年度报告数据
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"
# 并行获取各维度数据
time_footprint = await get_time_footprint(year)
social_network = await get_social_network(year)
brain_power = await get_brain_power(year)
expression_vibe = await get_expression_vibe(year)
achievements = await get_achievements(year)
report = AnnualReportData(
year=year,
bot_name=bot_name,
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
time_footprint=time_footprint,
social_network=social_network,
brain_power=brain_power,
expression_vibe=expression_vibe,
achievements=achievements,
)
logger.info(f"{year} 年度报告生成完成")
return report
except Exception as e:
logger.error(f"生成年度报告失败: {e}")
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
@router.get("/time-footprint", response_model=TimeFootprintData)
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取时光足迹数据"""
try:
return await get_time_footprint(year)
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/social-network", response_model=SocialNetworkData)
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取社交网络数据"""
try:
return await get_social_network(year)
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/brain-power", response_model=BrainPowerData)
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取最强大脑数据"""
try:
return await get_brain_power(year)
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/expression-vibe", response_model=ExpressionVibeData)
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取个性与表达数据"""
try:
return await get_expression_vibe(year)
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/achievements", response_model=AchievementData)
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取趣味成就数据"""
try:
return await get_achievements(year)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

781
src/webui/routers/chat.py Normal file
View File

@@ -0,0 +1,781 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话
支持两种模式:
1. WebUI 模式:使用 WebUI 平台独立身份聊天
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
"""
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
# 虚拟身份模式的群 ID 前缀
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# 固定的 WebUI 用户 ID 前缀
WEBUI_USER_ID_PREFIX = "webui_user_"
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False # 是否启用虚拟身份模式
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
person_id: Optional[str] = None # PersonInfo 的 person_id
user_id: Optional[str] = None # 原始平台用户 ID
user_nickname: Optional[str] = None # 用户昵称
group_id: Optional[str] = None # 虚拟群 ID自动生成或用户指定
group_name: Optional[str] = None # 虚拟群名(用户自定义)
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False
class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200):
self.max_messages = max_messages
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
# 虚拟群user_id 等于机器人 QQ 账号的是机器人消息
bot_qq = str(global_config.bot.qq_account)
is_bot = user_id == bot_qq
else:
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.time,
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录
Args:
limit: 获取的消息数量
group_id: 群 ID默认为 WEBUI_CHAT_GROUP_ID
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
# 查询指定群的消息,按时间排序
messages = (
Messages.select()
.where(Messages.chat_info_group_id == target_group_id)
.order_by(Messages.time.desc())
.limit(limit)
)
# 转换为列表并反转(使最旧的消息在前)
# 传递 group_id 以便正确判断虚拟群中的机器人消息
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空聊天历史记录
Args:
group_id: 群 ID默认清空 WebUI 默认聊天室
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
return 0
# 全局聊天历史管理器
chat_history = ChatHistoryManager()
# 存储 WebSocket 连接
class ChatConnectionManager:
"""聊天连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
chat_manager = ChatConnectionManager()
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据
Args:
content: 消息内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID可选自动生成
is_at_bot: 是否 @ 机器人
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
"""
if message_id is None:
message_id = str(uuid.uuid4())
# 确定使用的平台、群信息和用户信息
if virtual_config and virtual_config.enabled:
# 虚拟身份模式:使用真实平台身份
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
group_name = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
actual_user_name = virtual_config.user_nickname or user_name
else:
# 标准 WebUI 模式
platform = WEBUI_CHAT_PLATFORM
group_id = WEBUI_CHAT_GROUP_ID
group_name = "WebUI本地聊天室"
actual_user_id = user_id
actual_user_name = user_name
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": actual_user_id,
"user_nickname": actual_user_name,
"user_cardname": actual_user_name,
"platform": platform,
},
"additional_config": {
"at_bot": is_at_bot,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
{
"type": "mention_bot",
"data": "1.0",
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
如果指定了 group_id则获取该虚拟群的历史记录
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
history = chat_history.get_history(limit, target_group_id)
return {
"success": True,
"messages": history,
"total": len(history),
}
@router.get("/platforms")
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
从 PersonInfo 表中获取所有已知的平台
"""
try:
from peewee import fn
# 查询所有不同的平台
platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
.group_by(PersonInfo.platform)
.order_by(fn.COUNT(PersonInfo.id).desc())
)
result = []
for p in platforms:
if p.platform: # 排除空平台
result.append({"platform": p.platform, "count": p.count})
return {"success": True, "platforms": result}
except Exception as e:
logger.error(f"获取平台列表失败: {e}")
return {"success": False, "error": str(e), "platforms": []}
@router.get("/persons")
async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
Args:
platform: 平台名称(如 qq, discord 等)
search: 搜索关键词匹配昵称、用户名、user_id
limit: 返回数量限制
"""
try:
# 构建查询
query = PersonInfo.select().where(PersonInfo.platform == platform)
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 按最后交互时间排序,优先显示活跃用户
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.limit(limit)
result = []
for person in query:
result.append(
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.nickname or person.user_id,
}
)
return {"success": True, "persons": result, "total": len(result)}
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
return {"success": False, "error": str(e), "persons": []}
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
group_id: 可选,指定要清空的群 ID默认清空 WebUI 默认聊天室
"""
deleted = chat_history.clear_history(group_id)
return {
"success": True,
"message": f"已清空 {deleted} 条聊天记录",
}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
Args:
user_id: 用户唯一标识(由前端生成并持久化)
user_name: 用户显示昵称(可修改)
platform: 虚拟身份模式的平台(可选)
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名(可选)
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取)
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的
if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
# 当前会话的虚拟身份配置(可通过消息动态更新)
current_virtual_config: Optional[VirtualIdentityConfig] = None
# 如果 URL 参数中提供了虚拟身份信息,自动配置
if platform and person_id:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
# 使用前端传递的 group_id如果没有则生成一个稳定的
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
await chat_manager.connect(websocket, session_id, user_id)
try:
# 构建会话信息
session_info_data = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
# 如果有虚拟身份配置,添加到会话信息中
if current_virtual_config and current_virtual_config.enabled:
session_info_data["virtual_mode"] = True
session_info_data["group_id"] = current_virtual_config.group_id
session_info_data["virtual_identity"] = {
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_name": current_virtual_config.group_name,
}
# 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, session_info_data)
# 发送历史记录(根据模式选择不同的群)
if current_virtual_config and current_virtual_config.enabled:
history = chat_history.get_history(50, current_virtual_config.group_id)
else:
history = chat_history.get_history(50)
if history:
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史)
if current_virtual_config and current_virtual_config.enabled:
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
else:
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": welcome_msg,
"timestamp": time.time(),
},
)
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
content = data.get("content", "").strip()
if not content:
continue
# 用户可以更新昵称
current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4())
timestamp = time.time()
# 确定发送者信息(根据是否使用虚拟身份)
if current_virtual_config and current_virtual_config.enabled:
sender_name = current_virtual_config.user_nickname or current_user_name
sender_user_id = current_virtual_config.user_id or user_id
else:
sender_name = current_user_name
sender_user_id = user_id
# 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": sender_name,
"user_id": sender_user_id,
"is_bot": False,
},
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
}
)
# 创建麦麦消息格式
message_data = create_message_data(
content=content,
user_id=user_id,
user_name=current_user_name,
message_id=message_id,
is_at_bot=True,
virtual_config=current_virtual_config,
)
try:
# 显示正在输入状态
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": True,
}
)
# 调用麦麦的消息处理
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally:
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": False,
}
)
elif data.get("type") == "ping":
await chat_manager.send_message(
session_id,
{
"type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname":
# 允许用户更新昵称
if new_name := data.get("user_name", "").strip():
current_user_name = new_name
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
},
)
elif data.get("type") == "set_virtual_identity":
# 设置或更新虚拟身份配置
virtual_data = data.get("config", {})
if virtual_data.get("enabled"):
# 验证必要字段
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
"timestamp": time.time(),
},
)
continue
# 获取用户信息
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
if not person:
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"找不到用户: {virtual_data.get('person_id')}",
"timestamp": time.time(),
},
)
continue
# 生成虚拟群 ID
custom_group_id = virtual_data.get("group_id")
if custom_group_id:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
else:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=group_id,
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
)
# 发送虚拟身份已激活的消息
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {
"enabled": True,
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_id": current_virtual_config.group_id,
"group_name": current_virtual_config.group_name,
},
"timestamp": time.time(),
},
)
# 加载虚拟群的历史记录
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": virtual_history,
"group_id": current_virtual_config.group_id,
},
)
# 发送系统消息
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
"timestamp": time.time(),
},
)
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"设置虚拟身份失败: {str(e)}",
"timestamp": time.time(),
},
)
else:
# 禁用虚拟身份模式
current_virtual_config = None
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {"enabled": False},
"timestamp": time.time(),
},
)
# 重新加载默认聊天室历史
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": default_history,
"group_id": WEBUI_CHAT_GROUP_ID,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": "已切换回 WebUI 独立用户模式",
"timestamp": time.time(),
},
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, user_id)
@router.get("/info")
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
def get_webui_chat_broadcaster() -> tuple:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
"""
return (chat_manager, WEBUI_CHAT_PLATFORM)

597
src/webui/routers/config.py Normal file
View File

@@ -0,0 +1,597 @@
"""
配置管理API路由
"""
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
VoiceConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
schema = ConfigSchemaGenerator.generate_config_schema(Config)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model")
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 =====
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
支持的section_name:
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
- api_provider: APIProvider
- model_info: ModelInfo
"""
section_map = {
"bot": BotConfig,
"personality": PersonalityConfig,
"relationship": RelationshipConfig,
"chat": ChatConfig,
"message_receive": MessageReceiveConfig,
"emoji": EmojiConfig,
"expression": ExpressionConfig,
"keyword_reaction": KeywordReactionConfig,
"chinese_typo": ChineseTypoConfig,
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"voice": VoiceConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
}
if section_name not in section_map:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
try:
config_class = section_map[section_name]
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 =====
@router.get("/bot")
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model")
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("麦麦主程序配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model")
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("模型配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组类型(如 platforms, aliases直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 =====
@router.get("/bot/raw")
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
raw_content = f.read()
return {"success": True, "content": raw_content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
f.write(raw_content)
logger.info("麦麦主程序配置已更新(原始模式)")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}")
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 =====
def _normalize_adapter_path(path: str) -> str:
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
if not path:
return path
# 如果已经是绝对路径,直接返回
if os.path.isabs(path):
return path
# 相对路径,转换为相对于项目根目录的绝对路径
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
return path
try:
# 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith(".."):
return rel_path
except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
pass
# 无法转换为相对路径,返回绝对路径
return path
@router.get("/adapter-config/path")
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
webui_data_path = os.path.join("data", "webui.json")
if not os.path.exists(webui_data_path):
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
adapter_config_path = webui_data.get("adapter_config_path")
if not adapter_config_path:
return {"success": True, "path": None}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
import datetime
mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能)
display_path = _to_relative_path(abs_path)
return {"success": True, "path": display_path, "lastModified": last_modified}
else:
# 文件不存在,返回原路径
return {"success": True, "path": adapter_config_path, "lastModified": None}
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
# 保存到 data/webui.json
webui_data_path = os.path.join("data", "webui.json")
import json
# 读取现有数据
if os.path.exists(webui_data_path):
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
else:
webui_data = {}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
# 更新路径
webui_data["adapter_config_path"] = save_path
# 保存
os.makedirs("data", exist_ok=True)
with open(webui_data_path, "w", encoding="utf-8") as f:
json.dump(webui_data, f, ensure_ascii=False, indent=2)
logger.info(f"适配器配置路径已保存: {save_path}(绝对路径: {abs_path}")
return {"success": True, "message": "路径已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config")
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
logger.info(f"已读取适配器配置: {path} (绝对路径: {abs_path})")
return {"success": True, "content": content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
content = data.get("content")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
tomlkit.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在
dir_path = os.path.dirname(abs_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
# 保存文件
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"适配器配置已保存: {path} (绝对路径: {abs_path})")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

1310
src/webui/routers/emoji.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,773 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from src.webui.core import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
last_active_time: float
chat_id: str
create_date: Optional[float]
checked: bool
rejected: bool
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
checked: Optional[bool] = None
rejected: Optional[bool] = None
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
return ExpressionResponse(
id=expression.id,
situation=expression.situation,
style=expression.style,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表(用于下拉选择)
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
)
)
# 按名称排序
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(maibot_session, authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and expression.checked:
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段
update_data.pop("require_unchecked", None)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if "checked" in update_data or "rejected" in update_data:
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
rejected: int
ai_checked: int
user_checked: int
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取审核统计数据
Returns:
审核统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
user_checked = Expression.select().where(Expression.modified_by == "user").count()
return ReviewStatsResponse(
total=total,
unchecked=unchecked,
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
@router.get("/review/list", response_model=ReviewListResponse)
async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取待审核/已审核的表达方式列表
Args:
page: 页码
page_size: 每页数量
filter_type: 筛选类型 (unchecked/passed/rejected/all)
search: 搜索关键词
chat_id: 聊天ID筛选
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
query = Expression.select()
# 根据筛选类型过滤
if filter_type == "unchecked":
query = query.where(Expression.checked == False)
elif filter_type == "passed":
query = query.where((Expression.checked == True) & (Expression.rejected == False))
elif filter_type == "rejected":
query = query.where((Expression.checked == True) & (Expression.rejected == True))
# all 不需要额外过滤
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序:创建时间倒序
from peewee import Case
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
total = query.count()
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
return ReviewListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=[expression_to_response(expr) for expr in expressions],
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
failed: int
results: List[BatchReviewResultItem]
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量审核表达方式
Args:
request: 批量审核请求
Returns:
批量审核结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
results = []
succeeded = 0
failed = 0
for item in request.items:
try:
expression = Expression.get_or_none(Expression.id == item.id)
if not expression:
results.append(
BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式")
)
failed += 1
continue
# 冲突检测
if item.require_unchecked and expression.checked:
results.append(
BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
)
)
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = "user"
expression.last_active_time = time.time()
expression.save()
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
)
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e)))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量审核失败: {e}")
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e

532
src/webui/routers/jargon.py Normal file
View File

@@ -0,0 +1,532 @@
"""黑话(俚语)管理路由"""
import json
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon")
router = APIRouter(prefix="/jargon", tags=["Jargon"])
# ==================== 辅助函数 ====================
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
解析 chat_id 字段,提取所有 stream_id
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
if isinstance(parsed, list):
# 格式: [["stream_id", user_id], ...]
stream_ids = []
for item in parsed:
if isinstance(item, list) and len(item) >= 1:
stream_ids.append(str(item[0]))
return stream_ids
else:
# 其他格式,返回原始字符串
return [chat_id_str]
except (json.JSONDecodeError, TypeError):
# 不是有效的 JSON可能是直接的 stream_id
return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str:
"""
获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
# ==================== 请求/响应模型 ====================
class JargonResponse(BaseModel):
"""黑话信息响应"""
id: int
content: str
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0
is_jargon: Optional[bool] = None
is_complete: bool = False
inference_with_context: Optional[str] = None
inference_content_only: Optional[str] = None
class JargonListResponse(BaseModel):
"""黑话列表响应"""
success: bool = True
total: int
page: int
page_size: int
data: List[JargonResponse]
class JargonDetailResponse(BaseModel):
"""黑话详情响应"""
success: bool = True
data: JargonResponse
class JargonCreateRequest(BaseModel):
"""黑话创建请求"""
content: str = Field(..., description="黑话内容")
raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel):
"""黑话更新请求"""
content: Optional[str] = None
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None
class JargonCreateResponse(BaseModel):
"""黑话创建响应"""
success: bool = True
message: str
data: JargonResponse
class JargonUpdateResponse(BaseModel):
"""黑话更新响应"""
success: bool = True
message: str
data: Optional[JargonResponse] = None
class JargonDeleteResponse(BaseModel):
"""黑话删除响应"""
success: bool = True
message: str
deleted_count: int = 0
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int] = Field(..., description="要删除的黑话ID列表")
class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict
class ChatInfoResponse(BaseModel):
"""聊天信息响应"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool = True
data: List[ChatInfoResponse]
# ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict:
"""将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": jargon.chat_id,
"stream_id": stream_id,
"chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only,
}
# ==================== API 端点 ====================
@router.get("/list", response_model=JargonListResponse)
async def get_jargon_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
):
"""获取黑话列表"""
try:
# 构建查询
query = Jargon.select()
# 搜索过滤
if search:
query = query.where(
(Jargon.content.contains(search))
| (Jargon.meaning.contains(search))
| (Jargon.raw_content.contains(search))
)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
# 使用第一个 stream_id 进行模糊匹配
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
else:
# 如果无法解析,使用精确匹配
query = query.where(Jargon.chat_id == chat_id)
# 按是否是黑话筛选
if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon)
# 按是否全局筛选
if is_global is not None:
query = query.where(Jargon.is_global == is_global)
# 获取总数
total = query.count()
# 分页和排序(按使用次数降序)
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
return JargonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data,
)
except Exception as e:
logger.error(f"获取黑话列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
seen_stream_ids.add(stream_ids[0])
result = []
for stream_id in seen_stream_ids:
# 尝试从 ChatStreams 表获取聊天名称
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id方便筛选匹配
chat_name=chat_stream.group_name or stream_id,
platform=chat_stream.platform,
is_group=True,
)
)
else:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
platform=None,
is_group=False,
)
)
return ChatListResponse(success=True, data=result)
except Exception as e:
logger.error(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/stats/summary", response_model=JargonStatsResponse)
async def get_jargon_stats():
"""获取黑话统计数据"""
try:
# 总数量
total = Jargon.select().count()
# 已确认是黑话的数量
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
# 已确认不是黑话的数量
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
# 未判定的数量
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
# 全局黑话数量
global_count = Jargon.select().where(Jargon.is_global).count()
# 已完成推断的数量
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
return JargonStatsResponse(
success=True,
data={
"total": total,
"confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon,
"pending": pending,
"global_count": global_count,
"complete_count": complete_count,
"chat_count": chat_count,
"top_chats": top_chats_dict,
},
)
except Exception as e:
logger.error(f"获取黑话统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
except HTTPException:
raise
except Exception as e:
logger.error(f"获取黑话详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
@router.post("/", response_model=JargonCreateResponse)
async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话
jargon = Jargon.create(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning,
chat_id=request.chat_id,
is_global=request.is_global,
count=0,
is_jargon=None,
is_complete=False,
)
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
return JargonCreateResponse(
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"创建黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
jargon.save()
logger.info(f"更新黑话成功: id={jargon_id}")
return JargonUpdateResponse(
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"更新黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
async def delete_jargon(jargon_id: int):
"""删除黑话"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content
jargon.delete_instance()
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse(
success=True,
message="删除成功",
deleted_count=1,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
@router.post("/batch/delete", response_model=JargonDeleteResponse)
async def batch_delete_jargons(request: BatchDeleteRequest):
"""批量删除黑话"""
try:
if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse(
success=True,
message=f"成功删除 {deleted_count} 条黑话",
deleted_count=deleted_count,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
async def batch_set_jargon_status(
ids: Annotated[List[int], Query(description="黑话ID列表")],
is_jargon: Annotated[bool, Query(description="是否是黑话")],
):
"""批量设置黑话状态"""
try:
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse(
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量更新黑话状态失败: {e}")
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e

View File

@@ -0,0 +1,390 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.config.config import global_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
# 延迟初始化的轻量级 embedding store只读仅用于获取段落完整文本
_paragraph_store_cache = None
def _get_paragraph_store():
"""延迟加载段落 embedding store只读模式轻量级
Returns:
EmbeddingStore | None: 如果配置启用则返回store否则返回None
"""
# 检查配置是否启用
if not global_config.webui.enable_paragraph_content:
return None
global _paragraph_store_cache
if _paragraph_store_cache is not None:
return _paragraph_store_cache
try:
from src.chat.knowledge.embedding_store import EmbeddingStore
import os
# 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore(
namespace="paragraph",
dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程
chunk_size=100
)
paragraph_store.load_from_file()
_paragraph_store_cache = paragraph_store
logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落")
return paragraph_store
except Exception as e:
logger.warning(f"加载段落 embedding store 失败: {e}")
return None
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容
Args:
node_id: 段落节点ID格式为 'paragraph-{hash}'
Returns:
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
"""
try:
paragraph_store = _get_paragraph_store()
if paragraph_store is None:
# 功能未启用
return None, False
# 从 store 中获取完整内容
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '')
if content:
return content, True
return None, True
except Exception as e:
logger.debug(f"获取段落内容失败: {e}")
return None, True
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
id: str
type: str # 'entity' or 'paragraph'
content: str
create_time: Optional[float] = None
class KnowledgeEdge(BaseModel):
"""知识边"""
source: str
target: str
weight: float
create_time: Optional[float] = None
update_time: Optional[float] = None
class KnowledgeGraph(BaseModel):
"""知识图谱"""
nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel):
"""知识库统计信息"""
total_nodes: int
total_edges: int
entity_nodes: int
paragraph_nodes: int
avg_connections: float
def _load_kg_manager():
"""延迟加载 KGManager"""
try:
from src.chat.knowledge.kg_manager import KGManager
kg_manager = KGManager()
kg_manager.load_from_file()
return kg_manager
except Exception as e:
logger.error(f"加载 KGManager 失败: {e}")
return None
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
"""将 DiGraph 转换为 JSON 格式"""
if kg_manager is None or kg_manager.graph is None:
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
nodes = []
edges = []
# 转换节点
node_list = graph.get_node_list()
for node_id in node_list:
try:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 转换边
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
# edge_tuple 是 (source, target) 元组
source, target = edge_tuple[0], edge_tuple[1]
# 通过 graph[source, target] 获取边的属性数据
edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
return KnowledgeGraph(nodes=nodes, edges=edges)
@router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
Args:
limit: 返回的最大节点数,默认 100,最大 10000
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
Returns:
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None:
logger.warning("KGManager 未初始化,返回空图谱")
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
all_node_list = graph.get_node_list()
# 按类型过滤节点
if node_type == "entity":
all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量
total_nodes = len(all_node_list)
if len(all_node_list) > limit:
node_list = all_node_list[:limit]
else:
node_list = all_node_list
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
# 转换节点
nodes = []
node_ids = set()
for node_id in node_list:
try:
node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
node_ids.add(node_id)
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 只获取涉及当前节点集的边(保证图的完整性)
edges = []
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
source, target = edge_tuple[0], edge_tuple[1]
# 只包含两端都在当前节点集中的边
if source not in node_ids or target not in node_ids:
continue
edge_data = graph[source, target]
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
return graph_data
except Exception as e:
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
return KnowledgeGraph(nodes=[], edges=[])
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
KnowledgeStats: 统计信息
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
graph = kg_manager.graph
node_list = graph.get_node_list()
edge_list = graph.get_edge_list()
total_nodes = len(node_list)
total_edges = len(edge_list)
# 统计节点类型
entity_nodes = 0
paragraph_nodes = 0
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == "ent":
entity_nodes += 1
elif node_type == "pg":
paragraph_nodes += 1
except Exception:
continue
# 计算平均连接数
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
return KnowledgeStats(
total_nodes=total_nodes,
total_edges=total_edges,
entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2),
)
except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:
query: 搜索关键词
Returns:
List[KnowledgeNode]: 匹配的节点列表
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return []
graph = kg_manager.graph
node_list = graph.get_node_list()
results = []
query_lower = query.lower()
# 在节点内容中搜索
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception:
continue
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
return results[:50] # 限制返回数量
except Exception as e:
logger.error(f"搜索节点失败: {e}", exc_info=True)
return []

View File

383
src/webui/routers/model.py Normal file
View File

@@ -0,0 +1,383 @@
"""
模型列表获取API路由
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
"""
import os
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
"openai": {
"endpoint": "/models",
"parser": "openai",
},
# Gemini 格式
"gemini": {
"endpoint": "/models",
"parser": "gemini",
},
}
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
if not url:
return ""
return url.rstrip("/")
def _parse_openai_response(data: dict) -> list[dict]:
"""
解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append(
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
)
return models
def _parse_gemini_response(data: dict) -> list[dict]:
"""
解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
"""
models = []
if "models" in data and isinstance(data["models"], list):
for model in data["models"]:
if isinstance(model, dict) and "name" in model:
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
model_id = model["name"]
if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀
models.append(
{
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
}
)
return models
async def _fetch_models_from_provider(
base_url: str,
api_key: str,
endpoint: str,
parser: str,
client_type: str = "openai",
) -> list[dict]:
"""
从提供商 API 获取模型列表
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else:
raise HTTPException(
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
) from e
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应
if parser == "openai":
return _parse_openai_response(data)
elif parser == "gemini":
return _parse_gemini_response(data)
else:
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
def _get_provider_config(provider_name: str) -> Optional[dict]:
"""
从 model_config.toml 获取指定提供商的配置
Args:
provider_name: 提供商名称
Returns:
提供商配置,如果未找到则返回 None
"""
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
return None
@router.get("/list")
async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
通过提供商名称查找配置,然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"provider": provider_name,
"count": len(models),
}
@router.get("/list-by-url")
async def get_models_by_url(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: str = Query(..., description="API Key"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
"""
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"count": len(models),
}
@router.get("/test-connection")
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
分两步测试:
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
2. API Key 验证(可选):如果提供了 api_key尝试获取模型列表验证 Key 是否有效
返回:
- network_ok: 网络是否连通
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
- latency_ms: 响应延迟(毫秒)
- error: 错误信息(如果有)
"""
import time
base_url = _normalize_url(base_url)
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
result = {
"network_ok": False,
"api_key_valid": None,
"latency_ms": None,
"error": None,
"http_status": None,
}
# 第一步:测试网络连通性
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
# 尝试 GET 请求 base_url不需要 API Key
response = await client.get(base_url)
latency = (time.time() - start_time) * 1000
result["network_ok"] = True
result["latency_ms"] = round(latency, 2)
result["http_status"] = response.status_code
except httpx.ConnectError as e:
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
return result
except httpx.TimeoutException:
result["error"] = "连接超时:服务器响应时间过长"
return result
except httpx.RequestError as e:
result["error"] = f"请求错误:{str(e)}"
return result
except Exception as e:
result["error"] = f"未知错误:{str(e)}"
return result
# 第二步:如果提供了 API Key验证其有效性
if api_key:
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# 尝试获取模型列表
models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers)
if response.status_code == 200:
result["api_key_valid"] = True
elif response.status_code in (401, 403):
result["api_key_valid"] = False
result["error"] = "API Key 无效或已过期"
else:
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
result["api_key_valid"] = None
except Exception as e:
# API Key 验证失败不影响网络连通性结果
logger.warning(f"API Key 验证失败: {e}")
result["api_key_valid"] = None
return result
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接(从配置文件读取信息)
"""
# 读取配置文件
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(model_config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(model_config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 查找提供商
providers = config.get("api_providers", [])
provider = None
for p in providers:
if p.get("name") == provider_name:
provider = p
break
if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)

416
src/webui/routers/person.py Normal file
View File

@@ -0,0 +1,416 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header
import json
import time
logger = get_logger("webui.person")
# 创建路由器
router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
person_name: Optional[str]
name_reason: Optional[str]
platform: str
user_id: str
nickname: Optional[str]
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
memory_points: Optional[str]
know_times: Optional[float]
know_since: Optional[float]
last_know: Optional[float]
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[PersonInfoResponse]
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
memory_points: Optional[str] = None
is_known: Optional[bool] = None
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
person_ids: List[str]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[str] = []
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
"""解析群昵称 JSON 字符串"""
if not group_nick_name_str:
return None
try:
return json.loads(group_nick_name_str)
except (json.JSONDecodeError, TypeError):
return None
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
"""将 PersonInfo 模型转换为响应对象"""
return PersonInfoResponse(
id=person.id,
is_known=person.is_known,
person_id=person.person_id,
person_name=person.person_name,
name_reason=person.name_reason,
platform=person.platform,
user_id=person.user_id,
nickname=person.nickname,
group_nick_name=parse_group_nick_name(person.group_nick_name),
memory_points=person.memory_points,
know_times=person.know_times,
know_since=person.know_since,
last_know=person.last_know,
)
@router.get("/list", response_model=PersonListResponse)
async def get_person_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 person_name, nickname, user_id)
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
Args:
person_id: 人物唯一 ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息
Args:
request: 包含person_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.person_ids:
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for person_id in request.person_ids:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
person.delete_instance()
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
except Exception as e:
logger.error(f"删除 {person_id} 失败: {e}")
failed_count += 1
failed_ids.append(person_id)
message = f"成功删除 {deleted_count} 个人物"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

2059
src/webui/routers/plugin.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,319 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据(优化:使用数据库聚合)"""
summary = StatisticsSummary()
# 使用聚合查询替代全量加载
query = LLMUsage.select(
fn.COUNT(LLMUsage.id).alias("total_requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
result = query.dicts().get()
summary.total_requests = result["total_requests"]
summary.total_cost = result["total_cost"]
summary.total_tokens = result["total_tokens"]
summary.avg_response_time = result["avg_response_time"] or 0.0
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
online_records = list(
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量 - 使用聚合优化
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
)
summary.total_messages = messages_query.scalar() or 0
# 统计回复数量
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp())
& (Messages.time <= end_time.timestamp())
& (Messages.reply_to.is_null(False))
)
summary.total_replies = replies_query.scalar() or 0
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
fn.COUNT(LLMUsage.id).alias("request_count"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
)
.where(LLMUsage.timestamp >= start_time)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10) # 只取前10个
)
result = []
for row in query.dicts():
result.append(
ModelStatistics(
model_name=row["model_name"],
request_count=row["request_count"],
total_cost=row["total_cost"],
total_tokens=row["total_tokens"],
avg_response_time=row["avg_response_time"] or 0.0,
)
)
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
)
# 转换为字典以快速查找
data_dict = {row["hour"]: row for row in query.dicts()}
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
if hour_str in data_dict:
row = data_dict[hour_str]
result.append(
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
)
# 转换为字典
data_dict = {row["day"]: row for row in query.dicts()}
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.strftime("%Y-%m-%dT00:00:00")
if day_str in data_dict:
row = data_dict[day_str]
result.append(
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities
@router.get("/summary")
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
Args:
hours: 统计时间范围(小时)
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
summary = await _get_summary_statistics(start_time, now)
return summary
except Exception as e:
logger.error(f"获取统计摘要失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models")
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计
Args:
hours: 统计时间范围(小时)
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
stats = await _get_model_statistics(start_time)
return stats
except Exception as e:
logger.error(f"获取模型统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

113
src/webui/routers/system.py Normal file
View File

@@ -0,0 +1,113 @@
"""
系统控制路由
提供系统重启、状态查询等功能
"""
import os
import time
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
# 记录启动时间
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
success: bool
message: str
class StatusResponse(BaseModel):
"""状态响应"""
running: bool
uptime: float
version: str
start_time: str
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
请求重启当前进程,配置更改将在重启后生效。
注意:此操作会使麦麦暂时离线。
"""
import asyncio
try:
# 记录重启操作
logger.info("WebUI 触发重启操作")
# 定义延迟重启的异步任务
async def delayed_restart():
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
# 42 是约定的重启状态码
logger.info("WebUI 请求重启,退出代码 42")
os._exit(42)
# 创建后台任务执行重启
asyncio.create_task(delayed_restart())
# 立即返回成功响应
return RestartResponse(success=True, message="麦麦正在重启中...")
except Exception as e:
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
返回麦麦的运行状态、运行时长和版本信息。
"""
try:
uptime = time.time() - _start_time
# 尝试获取版本信息(需要根据实际情况调整)
version = MMC_VERSION # 可以从配置或常量中读取
return StatusResponse(
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
# 可选:添加更多系统控制功能
@router.post("/reload-config")
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置(不重启进程)
仅重新加载配置文件,某些配置可能需要重启才能生效。
此功能需要在主程序中实现配置热重载逻辑。
"""
# 这里需要调用主程序的配置重载函数
# 示例await app_instance.reload_config()
return {"success": True, "message": "配置重载功能待实现"}

View File

@@ -0,0 +1,9 @@
from .logs import router as logs_router
from .plugin_progress import get_progress_router
from .auth import router as ws_auth_router
__all__ = [
"logs_router",
"get_progress_router",
"ws_auth_router",
]

View File

@@ -0,0 +1,114 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.core import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie 或 Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证。
临时 token 有效期 60 秒,且只能使用一次。
注意:在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面。
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}

View File

@@ -0,0 +1,177 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@@ -0,0 +1,164 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
# 创建路由器
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
# 当前加载进度状态
current_progress: Dict[str, Any] = {
"operation": "idle", # idle, fetch, install, uninstall, update
"stage": "idle", # idle, loading, success, error
"progress": 0, # 0-100
"message": "",
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0,
}
async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
async def update_progress(
stage: str,
progress: int,
message: str,
operation: str = "fetch",
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
message: 当前消息
operation: 操作类型 (fetch, install, uninstall, update)
error: 错误信息(可选)
plugin_id: 当前操作的插件 ID
total_plugins: 总插件数
loaded_plugins: 已加载插件数
"""
progress_data = {
"operation": operation,
"stage": stage,
"progress": progress,
"message": message,
"error": error,
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取插件进度 WebSocket 路由器"""
return router

462
src/webui/routes.py Normal file
View File

@@ -0,0 +1,462 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from src.webui.core import (
get_token_manager,
set_auth_cookie,
clear_auth_cookie,
get_rate_limiter,
check_auth_rate_limit,
)
from src.webui.routers.config import router as config_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.person import router as person_router
from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.websocket.plugin_progress import get_progress_router
from src.webui.routers.system import router as system_router
from src.webui.routers.model import router as model_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.annual_report import router as annual_report_router
logger = get_logger("webui.api")
# 创建路由器
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
# 注册配置管理路由
router.include_router(config_router)
# 注册统计数据路由
router.include_router(statistics_router)
# 注册人物信息管理路由
router.include_router(person_router)
# 注册表达方式管理路由
router.include_router(expression_router)
# 注册黑话管理路由
router.include_router(jargon_router)
# 注册表情包管理路由
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册年度报告路由
router.include_router(annual_report_router)
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
"""Token 验证响应"""
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
is_first_setup: bool = Field(False, description="是否为首次设置")
class TokenUpdateRequest(BaseModel):
"""Token 更新请求"""
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
"""Token 更新响应"""
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
"""Token 重新生成响应"""
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
class FirstSetupStatusResponse(BaseModel):
"""首次配置状态响应"""
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
"""完成配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
"""重置配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
@router.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "service": "MaiBot WebUI"}
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌,验证成功后设置 HttpOnly Cookie
Args:
request_body: 包含 token 的验证请求
request: FastAPI Request 对象(用于获取客户端 IP
response: FastAPI Response 对象
Returns:
验证结果(包含首次配置状态)
"""
try:
token_manager = get_token_manager()
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie传入 request 以检测协议)
set_auth_cookie(response, request_body.token, request)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600, # 封禁 10 分钟
)
if blocked:
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@router.post("/auth/logout")
async def logout(response: Response):
"""
登出并清除认证 Cookie
Args:
response: FastAPI Response 对象
Returns:
登出结果
"""
clear_auth_cookie(response)
return {"success": True, "message": "已成功登出"}
@router.get("/auth/check")
async def check_auth_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
检查当前认证状态(用于前端判断是否已登录)
Returns:
认证状态
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(
f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}"
)
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
logger.debug("使用 Cookie 中的 token")
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
logger.debug("使用 Header 中的 token")
if not token:
logger.debug("未找到 token返回未认证")
return {"authenticated": False}
token_manager = get_token_manager()
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:
return {"authenticated": False}
except Exception as e:
logger.error(f"认证检查失败: {e}", exc_info=True)
return {"authenticated": False}
@router.post("/auth/update", response_model=TokenUpdateResponse)
async def update_token(
request: TokenUpdateRequest,
response: Response,
req: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
更新访问令牌(需要当前有效的 token
Args:
request: 包含新 token 的更新请求
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
更新结果
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 更新 token
success, message = token_manager.update_token(request.new_token)
# 如果更新成功,清除 Cookie要求用户重新登录
if success:
clear_auth_cookie(response)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 更新失败: {e}")
raise HTTPException(status_code=500, detail="Token 更新失败") from e
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
async def regenerate_token(
response: Response,
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
重新生成访问令牌(需要当前有效的 token
Args:
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
新生成的 token
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 重新生成 token
new_token = token_manager.regenerate_token()
# 清除 Cookie要求用户重新登录
clear_auth_cookie(response)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 重新生成失败: {e}")
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
async def get_setup_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取首次配置状态
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
首次配置状态
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
except HTTPException:
raise
except Exception as e:
logger.error(f"获取配置状态失败: {e}")
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
@router.post("/setup/complete", response_model=CompleteSetupResponse)
async def complete_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
标记首次配置完成
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
完成结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 标记配置完成
success = token_manager.mark_setup_completed()
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"标记配置完成失败: {e}")
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
@router.post("/setup/reset", response_model=ResetSetupResponse)
async def reset_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
重置首次配置状态,允许重新进入配置向导
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
重置结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 重置配置状态
success = token_manager.reset_setup_status()
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"重置配置状态失败: {e}")
raise HTTPException(status_code=500, detail="重置配置状态失败") from e

View File

@@ -0,0 +1,109 @@
"""WebUI Schemas - Pydantic models for API requests and responses."""
# Auth schemas
from .auth import (
TokenVerifyRequest,
TokenVerifyResponse,
TokenUpdateRequest,
TokenUpdateResponse,
TokenRegenerateResponse,
FirstSetupStatusResponse,
CompleteSetupResponse,
ResetSetupResponse,
)
# Statistics schemas
from .statistics import (
StatisticsSummary,
ModelStatistics,
TimeSeriesData,
DashboardData,
)
# Emoji schemas
from .emoji import (
EmojiResponse,
EmojiListResponse,
EmojiDetailResponse,
EmojiUpdateRequest,
EmojiUpdateResponse,
EmojiDeleteResponse,
BatchDeleteRequest,
BatchDeleteResponse,
EmojiUploadResponse,
ThumbnailCacheStatsResponse,
ThumbnailCleanupResponse,
ThumbnailPreheatResponse,
)
# Chat schemas
from .chat import (
VirtualIdentityConfig,
ChatHistoryMessage,
)
# Plugin schemas
from .plugin import (
FetchRawFileRequest,
FetchRawFileResponse,
CloneRepositoryRequest,
CloneRepositoryResponse,
MirrorConfigResponse,
AvailableMirrorsResponse,
AddMirrorRequest,
UpdateMirrorRequest,
GitStatusResponse,
InstallPluginRequest,
VersionResponse,
UninstallPluginRequest,
UpdatePluginRequest,
UpdatePluginConfigRequest,
)
__all__ = [
# Auth
"TokenVerifyRequest",
"TokenVerifyResponse",
"TokenUpdateRequest",
"TokenUpdateResponse",
"TokenRegenerateResponse",
"FirstSetupStatusResponse",
"CompleteSetupResponse",
"ResetSetupResponse",
# Statistics
"StatisticsSummary",
"ModelStatistics",
"TimeSeriesData",
"DashboardData",
# Emoji
"EmojiResponse",
"EmojiListResponse",
"EmojiDetailResponse",
"EmojiUpdateRequest",
"EmojiUpdateResponse",
"EmojiDeleteResponse",
"BatchDeleteRequest",
"BatchDeleteResponse",
"EmojiUploadResponse",
"ThumbnailCacheStatsResponse",
"ThumbnailCleanupResponse",
"ThumbnailPreheatResponse",
# Chat
"VirtualIdentityConfig",
"ChatHistoryMessage",
# Plugin
"FetchRawFileRequest",
"FetchRawFileResponse",
"CloneRepositoryRequest",
"CloneRepositoryResponse",
"MirrorConfigResponse",
"AvailableMirrorsResponse",
"AddMirrorRequest",
"UpdateMirrorRequest",
"GitStatusResponse",
"InstallPluginRequest",
"VersionResponse",
"UninstallPluginRequest",
"UpdatePluginRequest",
"UpdatePluginConfigRequest",
]

41
src/webui/schemas/auth.py Normal file
View File

@@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
class TokenVerifyRequest(BaseModel):
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
is_first_setup: bool = Field(False, description="是否为首次设置")
class TokenUpdateRequest(BaseModel):
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
class FirstSetupStatusResponse(BaseModel):
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")

26
src/webui/schemas/chat.py Normal file
View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel
from typing import Optional
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False
platform: Optional[str] = None
person_id: Optional[str] = None
user_id: Optional[str] = None
user_nickname: Optional[str] = None
group_id: Optional[str] = None
group_name: Optional[str] = None
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False

115
src/webui/schemas/emoji.py Normal file
View File

@@ -0,0 +1,115 @@
from pydantic import BaseModel
from typing import Optional, List
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[str]
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[EmojiResponse]
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[str] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
emoji_ids: List[int]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[int] = []
class EmojiUploadResponse(BaseModel):
"""表情包上传响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
total_size_mb: float
emoji_count: int
coverage_percent: float
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
kept_count: int
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
skipped_count: int
failed_count: int

135
src/webui/schemas/plugin.py Normal file
View File

@@ -0,0 +1,135 @@
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
class FetchRawFileRequest(BaseModel):
"""获取 Raw 文件请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
branch: str = Field(..., description="分支名称", example="main")
file_path: str = Field(..., description="文件路径", example="plugin_details.json")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义完整 URL")
class FetchRawFileResponse(BaseModel):
"""获取 Raw 文件响应"""
success: bool = Field(..., description="是否成功")
data: Optional[str] = Field(None, description="文件内容")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际请求的 URL")
class CloneRepositoryRequest(BaseModel):
"""克隆仓库请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
target_path: str = Field(..., description="目标路径(相对于插件目录)")
branch: Optional[str] = Field(None, description="分支名称", example="main")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义克隆 URL")
depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1)
class CloneRepositoryResponse(BaseModel):
"""克隆仓库响应"""
success: bool = Field(..., description="是否成功")
path: Optional[str] = Field(None, description="克隆路径")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际克隆的 URL")
message: Optional[str] = Field(None, description="附加信息")
class MirrorConfigResponse(BaseModel):
"""镜像源配置响应"""
id: str = Field(..., description="镜像源 ID")
name: str = Field(..., description="镜像源名称")
raw_prefix: str = Field(..., description="Raw 文件前缀")
clone_prefix: str = Field(..., description="克隆前缀")
enabled: bool = Field(..., description="是否启用")
priority: int = Field(..., description="优先级(数字越小优先级越高)")
class AvailableMirrorsResponse(BaseModel):
"""可用镜像源列表响应"""
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: List[str] = Field(..., description="默认优先级顺序ID 列表)")
class AddMirrorRequest(BaseModel):
"""添加镜像源请求"""
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
clone_prefix: str = Field(..., description="克隆前缀", example="https://example.com/clone")
enabled: bool = Field(True, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class UpdateMirrorRequest(BaseModel):
"""更新镜像源请求"""
name: Optional[str] = Field(None, description="镜像源名称")
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
enabled: Optional[bool] = Field(None, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class GitStatusResponse(BaseModel):
"""Git 安装状态响应"""
installed: bool = Field(..., description="是否已安装 Git")
version: Optional[str] = Field(None, description="Git 版本号")
path: Optional[str] = Field(None, description="Git 可执行文件路径")
error: Optional[str] = Field(None, description="错误信息")
class InstallPluginRequest(BaseModel):
"""安装插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class VersionResponse(BaseModel):
"""麦麦版本响应"""
version: str = Field(..., description="麦麦版本号")
version_major: int = Field(..., description="主版本号")
version_minor: int = Field(..., description="次版本号")
version_patch: int = Field(..., description="补丁版本号")
class UninstallPluginRequest(BaseModel):
"""卸载插件请求"""
plugin_id: str = Field(..., description="插件 ID")
class UpdatePluginRequest(BaseModel):
"""更新插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class UpdatePluginConfigRequest(BaseModel):
"""更新插件配置请求"""
config: Dict[str, Any] = Field(..., description="配置数据")

View File

@@ -0,0 +1,45 @@
from pydantic import BaseModel, Field
from typing import Dict, Any, List
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

128
src/webui/webui_server.py Normal file
View File

@@ -0,0 +1,128 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
import asyncio
from uvicorn import Config, Server as UvicornServer
from src.common.logger import get_logger
from src.webui.app import create_app, show_access_token
logger = get_logger("webui_server")
class WebUIServer:
"""独立的 WebUI 服务器"""
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
self.host = host
self.port = port
self.app = create_app(host=host, port=port, enable_static=True)
self._server = None
show_access_token()
async def start(self):
"""启动服务器"""
# 预先检查端口是否可用
if not self._check_port_available():
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
logger.error(error_msg)
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
config = Config(
app=self.app,
host=self.host,
port=self.port,
log_config=None,
access_log=False,
)
self._server = UvicornServer(config=config)
logger.info("🌐 WebUI 服务器启动中...")
# 根据地址类型显示正确的访问地址
if ":" in self.host:
# IPv6 地址需要用方括号包裹
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
if self.host == "::":
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
elif self.host == "::1":
logger.info("💡 仅支持 IPv6 本地访问")
else:
# IPv4 地址
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
if self.host == "0.0.0.0":
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
try:
await self._server.serve()
except OSError as e:
# 处理端口绑定相关的错误
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
else:
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
raise
except Exception as e:
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
raise
def _check_port_available(self) -> bool:
"""检查端口是否可用(支持 IPv4 和 IPv6"""
import socket
# 判断使用 IPv4 还是 IPv6
if ":" in self.host:
# IPv6 地址
family = socket.AF_INET6
test_host = self.host if self.host != "::" else "::1"
else:
# IPv4 地址
family = socket.AF_INET
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
try:
with socket.socket(family, socket.SOCK_STREAM) as s:
s.settimeout(1)
# 尝试绑定端口
s.bind((test_host, self.port))
return True
except OSError:
return False
async def shutdown(self):
"""关闭服务器"""
if self._server:
logger.info("正在关闭 WebUI 服务器...")
self._server.should_exit = True
try:
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
logger.info("✅ WebUI 服务器已关闭")
except asyncio.TimeoutError:
logger.warning("⚠️ WebUI 服务器关闭超时")
except Exception as e:
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
finally:
self._server = None
# 全局 WebUI 服务器实例
_webui_server = None
def get_webui_server() -> WebUIServer:
"""获取全局 WebUI 服务器实例"""
global _webui_server
if _webui_server is None:
# 从环境变量读取
import os
host = os.getenv("WEBUI_HOST", "127.0.0.1")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)
return _webui_server