Ruff Format
This commit is contained in:
@@ -455,6 +455,7 @@ class ExpressionSelector:
|
||||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
|
||||
@@ -17,6 +17,7 @@ from src.bw_learner.learner_utils import (
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
class JargonExplainer:
|
||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||
|
||||
|
||||
@@ -60,31 +60,31 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
|
||||
"""
|
||||
计算两个 style 的相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py)
|
||||
|
||||
|
||||
Args:
|
||||
style1: 第一个 style
|
||||
style2: 第二个 style
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
if not style1 or not style2:
|
||||
return 0.0
|
||||
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
|
||||
cleaned_style1 = remove_ignored_words(style1)
|
||||
cleaned_style2 = remove_ignored_words(style2)
|
||||
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_style1 or not cleaned_style2:
|
||||
return 0.0
|
||||
|
||||
|
||||
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
|
||||
|
||||
|
||||
@@ -495,4 +495,4 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
return expressions, jargon_entries
|
||||
|
||||
@@ -2,7 +2,6 @@ import time
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
@@ -119,9 +118,7 @@ class MessageRecorder:
|
||||
|
||||
# 触发 expression_learner 和 jargon_miner 的处理
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(messages)
|
||||
)
|
||||
asyncio.create_task(self._trigger_expression_learning(messages))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
@@ -130,9 +127,7 @@ class MessageRecorder:
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, messages: List[Any]
|
||||
) -> None:
|
||||
async def _trigger_expression_learning(self, messages: List[Any]) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -120,7 +120,7 @@ class ActionPlanner:
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
@@ -128,7 +128,7 @@ class ActionPlanner:
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
@@ -170,13 +170,10 @@ class ActionPlanner:
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
|
||||
)
|
||||
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
|
||||
@@ -112,10 +112,10 @@ class Conversation:
|
||||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
}
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages_dict
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
|
||||
@@ -66,9 +66,9 @@ class DirectMessageSender:
|
||||
|
||||
# 发送消息(直接调用底层 API)
|
||||
from src.chat.message_receive.uni_message_sender import _send_message
|
||||
|
||||
|
||||
sent = await _send_message(message, show_log=True)
|
||||
|
||||
|
||||
if sent:
|
||||
# 存储消息
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
|
||||
@@ -5,7 +5,7 @@ from src.common.logger import get_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
@@ -13,15 +13,15 @@ logger = get_logger("observation_info")
|
||||
|
||||
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
|
||||
"""Convert PFC dict format to DatabaseMessages object
|
||||
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
|
||||
Returns:
|
||||
DatabaseMessages object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
|
||||
|
||||
return DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", 0.0),
|
||||
|
||||
@@ -42,9 +42,7 @@ class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
@@ -60,7 +58,7 @@ class GoalAnalyzer:
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
@@ -68,7 +66,7 @@ class GoalAnalyzer:
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import Message
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
||||
@@ -16,9 +16,7 @@ class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
@@ -64,7 +62,7 @@ class KnowledgeFetcher:
|
||||
# TODO: Integrate with new memory system if needed
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
|
||||
|
||||
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# text=f"{query}\n{chat_history_text}",
|
||||
|
||||
@@ -14,10 +14,7 @@ class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="reply_check"
|
||||
)
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
@@ -27,7 +24,7 @@ class ReplyChecker:
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
@@ -35,7 +32,7 @@ class ReplyChecker:
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ class ReplyGenerator:
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
@@ -107,7 +107,7 @@ class ReplyGenerator:
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
|
||||
@@ -704,10 +704,7 @@ class BrainChatting:
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
@@ -731,7 +728,9 @@ class BrainChatting:
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||
)
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
@@ -749,10 +748,7 @@ class BrainChatting:
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._new_message_event.wait(),
|
||||
timeout=wait_seconds
|
||||
)
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -431,15 +431,21 @@ class BrainPlanner:
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return extracted_reasoning, [
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
extracted_reasoning,
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
|
||||
@@ -105,7 +105,7 @@ class EmbeddingStore:
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = f"{dir_path}/{namespace}_i2h.json"
|
||||
|
||||
|
||||
self.dirty = False # 标记是否有新增数据需要重建索引
|
||||
|
||||
# 多线程配置参数验证和设置
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union, Dict, Any
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
@@ -192,17 +192,15 @@ class IEProcess:
|
||||
|
||||
results = []
|
||||
total = len(paragraphs)
|
||||
|
||||
|
||||
for i, pg in enumerate(paragraphs, start=1):
|
||||
# 打印进度日志,让用户知道没有卡死
|
||||
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
|
||||
|
||||
|
||||
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
|
||||
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
|
||||
try:
|
||||
entities, triples = await asyncio.to_thread(
|
||||
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
|
||||
)
|
||||
entities, triples = await asyncio.to_thread(info_extract_from_str, self.llm_ner, self.llm_rdf, pg)
|
||||
|
||||
if entities is not None:
|
||||
results.append(
|
||||
|
||||
@@ -395,8 +395,7 @@ class KGManager:
|
||||
appear_cnt = self.ent_appear_cnt.get(ent_hash)
|
||||
if not appear_cnt or appear_cnt <= 0:
|
||||
logger.debug(
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,"
|
||||
f"将使用 1.0 作为默认出现次数参与权重计算"
|
||||
f"实体 {ent_hash} 在 ent_appear_cnt 中不存在或计数为 0,将使用 1.0 作为默认出现次数参与权重计算"
|
||||
)
|
||||
appear_cnt = 1.0
|
||||
ent_weights[ent_hash] = float(np.sum(scores)) / float(appear_cnt)
|
||||
|
||||
@@ -11,31 +11,30 @@ from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
||||
|
||||
logger = get_logger("LPMM-Plugin-API")
|
||||
|
||||
|
||||
class LPMMOperations:
|
||||
"""
|
||||
LPMM 内部操作接口。
|
||||
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
async def _run_cancellable_executor(
|
||||
self, func: Callable, *args, **kwargs
|
||||
) -> Any:
|
||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
在线程池中执行可取消的同步操作。
|
||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
|
||||
|
||||
|
||||
Args:
|
||||
func: 要执行的同步函数
|
||||
*args: 函数的位置参数
|
||||
**kwargs: 函数的关键字参数
|
||||
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: 当任务被取消时
|
||||
"""
|
||||
@@ -51,42 +50,42 @@ class LPMMOperations:
|
||||
# 如果全局没初始化,尝试初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
|
||||
|
||||
|
||||
lpmm_start_up()
|
||||
qa_mgr = get_qa_manager()
|
||||
|
||||
|
||||
if qa_mgr is None:
|
||||
raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。")
|
||||
|
||||
|
||||
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
|
||||
|
||||
async def add_content(self, text: str, auto_split: bool = True) -> dict:
|
||||
"""
|
||||
向知识库添加新内容。
|
||||
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
auto_split: 是否自动按双换行符分割段落。
|
||||
- True: 自动分割(默认),支持多段文本(用双换行分隔)
|
||||
- False: 不分割,将整个文本作为完整一段处理
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 1. 分段处理
|
||||
if auto_split:
|
||||
# 自动按双换行符分割
|
||||
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
else:
|
||||
# 不分割,作为完整一段
|
||||
text_stripped = text.strip()
|
||||
if not text_stripped:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
paragraphs = [text_stripped]
|
||||
|
||||
|
||||
if not paragraphs:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
|
||||
@@ -94,14 +93,16 @@ class LPMMOperations:
|
||||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
|
||||
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
extracted_docs = await ie_process.process_paragraphs(paragraphs)
|
||||
|
||||
|
||||
# 3. 构造并导入数据
|
||||
# 这里我们手动实现导入逻辑,不依赖外部脚本
|
||||
# a. 准备段落
|
||||
@@ -115,7 +116,7 @@ class LPMMOperations:
|
||||
# store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本
|
||||
new_raw_paragraphs = {}
|
||||
new_triple_list_data = {}
|
||||
|
||||
|
||||
for pg_hash, passage in raw_paragraphs.items():
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_mgr.stored_pg_hashes:
|
||||
@@ -128,26 +129,22 @@ class LPMMOperations:
|
||||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||
# store_new_data_set 会自动处理嵌入生成和存储
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
await self._run_cancellable_executor(
|
||||
embed_mgr.store_new_data_set,
|
||||
new_raw_paragraphs,
|
||||
new_triple_list_data
|
||||
)
|
||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||
|
||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||
await self._run_cancellable_executor(
|
||||
kg_mgr.build_kg,
|
||||
new_triple_list_data,
|
||||
embed_mgr
|
||||
)
|
||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||
|
||||
# 4. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(new_raw_paragraphs),
|
||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||
return {"status": "cancelled", "message": "导入操作已被用户中断"}
|
||||
@@ -158,11 +155,11 @@ class LPMMOperations:
|
||||
async def search(self, query: str, top_k: int = 3) -> List[str]:
|
||||
"""
|
||||
检索知识库。
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询问题。
|
||||
top_k: 返回最相关的条目数。
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 相关文段列表。
|
||||
"""
|
||||
@@ -179,21 +176,21 @@ class LPMMOperations:
|
||||
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
|
||||
"""
|
||||
根据关键词或完整文段删除知识库内容。
|
||||
|
||||
|
||||
Args:
|
||||
keyword: 匹配关键词或完整文段。
|
||||
exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 1. 查找匹配的段落
|
||||
to_delete_keys = []
|
||||
to_delete_hashes = []
|
||||
|
||||
|
||||
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
|
||||
if exact_match:
|
||||
# 完整文段匹配
|
||||
@@ -205,29 +202,25 @@ class LPMMOperations:
|
||||
if keyword in item.str:
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
|
||||
|
||||
if not to_delete_keys:
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
|
||||
|
||||
# 2. 执行删除
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
|
||||
# a. 从向量库删除
|
||||
deleted_count, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
to_delete_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
|
||||
|
||||
# b. 从知识图谱删除
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
to_delete_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
@@ -235,9 +228,13 @@ class LPMMOperations:
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||
@@ -249,13 +246,13 @@ class LPMMOperations:
|
||||
async def clear_all(self) -> dict:
|
||||
"""
|
||||
清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "message": "描述", "stats": {...}}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
|
||||
# 记录清空前的统计信息
|
||||
before_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
@@ -264,40 +261,37 @@ class LPMMOperations:
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
|
||||
# 1. 清空所有向量库
|
||||
# 获取所有keys
|
||||
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
|
||||
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
|
||||
|
||||
|
||||
# 删除所有段落向量
|
||||
para_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items,
|
||||
para_keys
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes.clear()
|
||||
|
||||
|
||||
# 删除所有实体向量
|
||||
if ent_keys:
|
||||
ent_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.entities_embedding_store.delete_items,
|
||||
ent_keys
|
||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||
)
|
||||
else:
|
||||
ent_deleted = 0
|
||||
|
||||
|
||||
# 删除所有关系向量
|
||||
if rel_keys:
|
||||
rel_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.relation_embedding_store.delete_items,
|
||||
rel_keys
|
||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||
)
|
||||
else:
|
||||
rel_deleted = 0
|
||||
|
||||
|
||||
# 2. 清空所有 embedding store 的索引和映射
|
||||
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
|
||||
def _clear_embedding_indices():
|
||||
@@ -310,7 +304,7 @@ class LPMMOperations:
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
# 清空实体索引
|
||||
embed_mgr.entities_embedding_store.faiss_index = None
|
||||
embed_mgr.entities_embedding_store.idx2hash = None
|
||||
@@ -320,7 +314,7 @@ class LPMMOperations:
|
||||
os.remove(embed_mgr.entities_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
# 清空关系索引
|
||||
embed_mgr.relation_embedding_store.faiss_index = None
|
||||
embed_mgr.relation_embedding_store.idx2hash = None
|
||||
@@ -330,9 +324,9 @@ class LPMMOperations:
|
||||
os.remove(embed_mgr.relation_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
|
||||
|
||||
|
||||
await self._run_cancellable_executor(_clear_embedding_indices)
|
||||
|
||||
|
||||
# 3. 清空知识图谱
|
||||
# 获取所有段落hash
|
||||
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
|
||||
@@ -341,24 +335,22 @@ class LPMMOperations:
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs,
|
||||
all_pg_hashes,
|
||||
ent_hashes=None,
|
||||
remove_orphan_entities=True
|
||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
|
||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||
from quick_algo import di_graph
|
||||
|
||||
kg_mgr.graph = di_graph.DiGraph()
|
||||
kg_mgr.stored_paragraph_hashes.clear()
|
||||
kg_mgr.ent_appear_cnt.clear()
|
||||
|
||||
|
||||
# 4. 保存所有数据(此时所有store都是空的,索引也是None)
|
||||
# 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
|
||||
after_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
@@ -366,14 +358,14 @@ class LPMMOperations:
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
|
||||
"stats": {
|
||||
"before": before_stats,
|
||||
"after": after_stats,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -383,6 +375,6 @@ class LPMMOperations:
|
||||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# 内部使用的单例
|
||||
lpmm_ops = LPMMOperations()
|
||||
|
||||
|
||||
@@ -136,4 +136,3 @@ class PlanReplyLogger:
|
||||
return str(value)
|
||||
# Fallback to string for other complex types
|
||||
return str(value)
|
||||
|
||||
|
||||
@@ -85,17 +85,17 @@ class ChatBot:
|
||||
|
||||
async def _create_pfc_chat(self, message: MessageRecv):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = str(message.chat_stream.stream_id)
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -96,7 +96,7 @@ class Message(MessageBase):
|
||||
if processed_text:
|
||||
return f"{global_config.bot.nickname}: {processed_text}"
|
||||
return None
|
||||
|
||||
|
||||
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
segments_text = []
|
||||
|
||||
@@ -189,7 +189,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
|
||||
# 如果未开启 API Server,直接跳过 Fallback
|
||||
if not global_config.maim_message.enable_api_server:
|
||||
logger.debug(f"[API Server Fallback] API Server未开启,跳过fallback")
|
||||
logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
@@ -198,13 +198,13 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
extra_server = getattr(global_api, "extra_server", None)
|
||||
|
||||
if not extra_server:
|
||||
logger.warning(f"[API Server Fallback] extra_server不存在")
|
||||
logger.warning("[API Server Fallback] extra_server不存在")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
if not extra_server.is_running():
|
||||
logger.warning(f"[API Server Fallback] extra_server未运行")
|
||||
logger.warning("[API Server Fallback] extra_server未运行")
|
||||
if legacy_exception:
|
||||
raise legacy_exception
|
||||
return False
|
||||
@@ -253,7 +253,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
)
|
||||
|
||||
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||
logger.debug(f"[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
||||
results = await extra_server.send_message(api_message)
|
||||
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
@@ -48,7 +49,7 @@ class ActionPlanner:
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
|
||||
# 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个
|
||||
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
|
||||
self.unknown_words_cache_limit = 10
|
||||
@@ -111,20 +112,29 @@ class ActionPlanner:
|
||||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
|
||||
platform = (
|
||||
getattr(message, "user_info", None)
|
||||
and message.user_info.platform
|
||||
or getattr(message, "chat_info", None)
|
||||
and message.chat_info.platform
|
||||
or "qq"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
@@ -137,6 +147,7 @@ class ActionPlanner:
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
@@ -165,7 +176,7 @@ class ActionPlanner:
|
||||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
@@ -244,7 +255,7 @@ class ActionPlanner:
|
||||
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
|
||||
"""
|
||||
更新黑话缓存,将新的黑话加入缓存
|
||||
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表
|
||||
"""
|
||||
@@ -254,7 +265,7 @@ class ActionPlanner:
|
||||
word = word.strip()
|
||||
if not word:
|
||||
continue
|
||||
|
||||
|
||||
# 如果已存在,移到末尾(LRU)
|
||||
if word in self.unknown_words_cache:
|
||||
self.unknown_words_cache.move_to_end(word)
|
||||
@@ -269,10 +280,10 @@ class ActionPlanner:
|
||||
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
|
||||
"""
|
||||
合并新提取的黑话和缓存中的黑话
|
||||
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表(可能为None)
|
||||
|
||||
|
||||
Returns:
|
||||
合并后的黑话列表(去重)
|
||||
"""
|
||||
@@ -284,31 +295,29 @@ class ActionPlanner:
|
||||
word = word.strip()
|
||||
if word:
|
||||
cleaned_new_words.append(word)
|
||||
|
||||
|
||||
# 获取缓存中的黑话列表
|
||||
cached_words = list(self.unknown_words_cache.keys())
|
||||
|
||||
|
||||
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
|
||||
merged_words: List[str] = []
|
||||
seen = set()
|
||||
|
||||
|
||||
# 先添加新提取的
|
||||
for word in cleaned_new_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
|
||||
# 再添加缓存的(如果不在新提取的列表中)
|
||||
for word in cached_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
|
||||
return merged_words
|
||||
|
||||
def _process_unknown_words_cache(
|
||||
self, actions: List[ActionPlannerInfo]
|
||||
) -> None:
|
||||
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||
"""
|
||||
处理黑话缓存逻辑:
|
||||
1. 检查是否有 reply action 提取了 unknown_words
|
||||
@@ -316,7 +325,7 @@ class ActionPlanner:
|
||||
3. 如果缓存数量大于5,移除最老的2个
|
||||
4. 对于每个 reply action,合并缓存和新提取的黑话
|
||||
5. 更新缓存
|
||||
|
||||
|
||||
Args:
|
||||
actions: 解析后的动作列表
|
||||
"""
|
||||
@@ -330,7 +339,7 @@ class ActionPlanner:
|
||||
removed_count += 1
|
||||
if removed_count > 0:
|
||||
logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存")
|
||||
|
||||
|
||||
# 检查是否有 reply action 提取了 unknown_words
|
||||
has_extracted_unknown_words = False
|
||||
for action in actions:
|
||||
@@ -340,22 +349,22 @@ class ActionPlanner:
|
||||
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
|
||||
has_extracted_unknown_words = True
|
||||
break
|
||||
|
||||
|
||||
# 如果当前 plan 的 reply 没有提取,移除最老的1个
|
||||
if not has_extracted_unknown_words:
|
||||
if len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
|
||||
|
||||
# 对于每个 reply action,合并缓存和新提取的黑话
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
new_words = action_data.get("unknown_words")
|
||||
|
||||
|
||||
# 合并新提取的和缓存的黑话列表
|
||||
merged_words = self._merge_unknown_words_with_cache(new_words)
|
||||
|
||||
|
||||
# 更新 action_data
|
||||
if merged_words:
|
||||
action_data["unknown_words"] = merged_words
|
||||
@@ -366,7 +375,7 @@ class ActionPlanner:
|
||||
else:
|
||||
# 如果没有合并后的黑话,移除 unknown_words 字段
|
||||
action_data.pop("unknown_words", None)
|
||||
|
||||
|
||||
# 更新缓存(将新提取的黑话加入缓存)
|
||||
if new_words:
|
||||
self._update_unknown_words_cache(new_words)
|
||||
@@ -442,15 +451,19 @@ class ActionPlanner:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
|
||||
if (
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
):
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
@@ -577,10 +590,11 @@ class ActionPlanner:
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]'
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
@@ -590,7 +604,9 @@ class ActionPlanner:
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
@@ -741,15 +757,21 @@ class ActionPlanner:
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
return (
|
||||
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
|
||||
@@ -1071,7 +1071,6 @@ class DefaultReplyer:
|
||||
chat_target_2_prompt = prompt_manager.get_prompt("chat_target_group2")
|
||||
chat_target_2 = await prompt_manager.render_prompt(chat_target_2_prompt)
|
||||
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = global_config.personality.multiple_reply_style
|
||||
|
||||
@@ -26,6 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
@@ -807,7 +808,7 @@ class PrivateReplyer:
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
|
||||
|
||||
if is_bot_self(platform, user_id):
|
||||
prompt_template = prompt_manager.get_prompt("private_replyer_self")
|
||||
prompt_template.add_context("target", target)
|
||||
|
||||
@@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
|
||||
class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from src.common.database.database_model import ChatSession
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
self.session_id = session_id
|
||||
@@ -33,4 +34,4 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -93,11 +93,15 @@ class Images(SQLModel, table=True):
|
||||
query_count: int = Field(default=0) # 被查询次数
|
||||
is_registered: bool = Field(default=False) # 是否已经注册
|
||||
is_banned: bool = Field(default=False) # 被手动禁用
|
||||
|
||||
|
||||
no_file_flag: bool = Field(default=False) # 文件不存在标记,如果为True表示文件已经不存在,仅保留描述字段
|
||||
|
||||
record_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间(数据库记录被创建的时间)
|
||||
register_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 注册时间(被注册为可用表情包的时间)
|
||||
record_time: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 记录时间(数据库记录被创建的时间)
|
||||
register_time: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(DateTime, nullable=True)
|
||||
) # 注册时间(被注册为可用表情包的时间)
|
||||
last_used_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 上次使用时间
|
||||
|
||||
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
||||
@@ -171,7 +175,9 @@ class Expression(SQLModel, table=True):
|
||||
|
||||
content_list: str # 内容列表,JSON格式存储
|
||||
count: int = Field(default=0) # 使用次数
|
||||
last_active_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 上次使用时间
|
||||
last_active_time: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 上次使用时间
|
||||
create_time: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime)) # 创建时间
|
||||
session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID,区分是否为全局表达方式
|
||||
|
||||
@@ -232,8 +238,12 @@ class ThinkingQuestion(SQLModel, table=True):
|
||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||
|
||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
||||
updated_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后更新时间
|
||||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
updated_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后更新时间
|
||||
|
||||
|
||||
class BinaryData(SQLModel, table=True):
|
||||
@@ -272,7 +282,9 @@ class PersonInfo(SQLModel, table=True):
|
||||
|
||||
# 认识次数和时间
|
||||
know_counts: int = Field(default=0) # 认识次数
|
||||
first_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 首次认识时间
|
||||
first_known_time: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(DateTime, nullable=True)
|
||||
) # 首次认识时间
|
||||
last_known_time: Optional[datetime] = Field(default=None, sa_column=Column(DateTime, nullable=True)) # 最后认识时间
|
||||
|
||||
|
||||
@@ -285,8 +297,12 @@ class ChatSession(SQLModel, table=True):
|
||||
|
||||
session_id: str = Field(unique=True, index=True, max_length=255) # 聊天会话ID
|
||||
|
||||
created_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 创建时间
|
||||
last_active_timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 最后活跃时间
|
||||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
last_active_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后活跃时间
|
||||
|
||||
# 身份元数据
|
||||
user_id: Optional[str] = Field(index=True, max_length=255, nullable=True) # 用户ID
|
||||
|
||||
@@ -221,5 +221,7 @@ if not supports_truecolor():
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
else:
|
||||
for name, (hex_fore_color, hex_back_color, bold) in MODULE_COLORS.items():
|
||||
escape_str = rgb_pair_to_ansi_truecolor(hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold)
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
escape_str = rgb_pair_to_ansi_truecolor(
|
||||
hex_to_rgb(hex_fore_color), hex_to_rgb(hex_back_color) if hex_back_color else None, bold
|
||||
)
|
||||
CONVERTED_MODULE_COLORS[name] = escape_str
|
||||
|
||||
@@ -9,6 +9,7 @@ from .server import get_global_server
|
||||
|
||||
global_api = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
@@ -80,12 +81,12 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
|
||||
return False
|
||||
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_auth = auth_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3. Setup Message Bridge
|
||||
# Initialize refined route map if not exists
|
||||
if not hasattr(global_api, "platform_map"):
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
global_api.platform_map = {} # type: ignore # 不知道这是什么神奇写法
|
||||
|
||||
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
|
||||
# 使用 MessageConverter 转换 APIMessageBase 到 Legacy MessageBase
|
||||
@@ -108,7 +109,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
api_logger.debug(f"Bridge received: api_key='{api_key}', platform='{platform}'")
|
||||
|
||||
if platform:
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
global_api.platform_map[platform] = api_key # type: ignore
|
||||
api_logger.info(f"Updated platform_map: {platform} -> {api_key}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to update platform map: {e}")
|
||||
@@ -117,21 +118,21 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
if "raw_message" not in msg_dict:
|
||||
msg_dict["raw_message"] = None
|
||||
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
await global_api.process_message(msg_dict) # type: ignore
|
||||
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
server_config.on_message = bridge_message_handler # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 3.5. Register custom message handlers (bridge to Legacy handlers)
|
||||
# message_id_echo: handles message ID echo from adapters
|
||||
# 兼容新旧两个版本的 maim_message:
|
||||
# - 旧版: handler(payload)
|
||||
# - 新版: handler(payload, metadata)
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
async def custom_message_id_echo_handler(payload: dict, metadata: dict = None): # type: ignore
|
||||
# Bridge to the Legacy custom handler registered in main.py
|
||||
try:
|
||||
# The Legacy handler expects the payload format directly
|
||||
if hasattr(global_api, "_custom_message_handlers"):
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
handler = global_api._custom_message_handlers.get("message_id_echo") # type: ignore # 已经不知道这是什么了
|
||||
if handler:
|
||||
await handler(payload)
|
||||
api_logger.debug(f"Processed message_id_echo: {payload}")
|
||||
@@ -140,7 +141,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to process message_id_echo: {e}")
|
||||
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
server_config.register_custom_handler("message_id_echo", custom_message_id_echo_handler) # type: ignore # maim_message库写错类型了
|
||||
|
||||
# 4. Initialize Server
|
||||
extra_server = WebSocketServer(config=server_config)
|
||||
@@ -167,7 +168,7 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
global_api.stop = patched_stop
|
||||
|
||||
# Attach for reference
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
global_api.extra_server = extra_server # type: ignore # 这是什么
|
||||
|
||||
except ImportError:
|
||||
get_logger("maim_message").error(
|
||||
|
||||
@@ -9,6 +9,7 @@ from src.common.database.database import get_db_session
|
||||
|
||||
logger = get_logger("file_utils")
|
||||
|
||||
|
||||
class FileUtils:
|
||||
@staticmethod
|
||||
def save_binary_to_file(file_path: Path, data: bytes):
|
||||
@@ -35,7 +36,7 @@ class FileUtils:
|
||||
except Exception as e:
|
||||
logger.error(f"保存文件 {file_path} 失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_file_path_by_hash(data_hash: str) -> Path:
|
||||
"""
|
||||
@@ -52,4 +53,4 @@ class FileUtils:
|
||||
if binary_data := session.exec(statement).first():
|
||||
return Path(binary_data.full_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
|
||||
raise FileNotFoundError(f"未找到哈希值为 {data_hash} 的数据文件记录")
|
||||
|
||||
@@ -278,4 +278,3 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
|
||||
@@ -54,8 +54,6 @@ async def generate_dream_summary(
|
||||
) -> None:
|
||||
"""生成梦境总结,输出到日志,并根据配置可选地推送给指定用户"""
|
||||
try:
|
||||
|
||||
|
||||
# 第一步:建立工具调用结果映射 (call_id -> result)
|
||||
tool_results_map: dict[str, str] = {}
|
||||
for msg in conversation_messages:
|
||||
|
||||
@@ -4,4 +4,3 @@ dream agent 工具实现模块。
|
||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||
"""
|
||||
|
||||
|
||||
@@ -63,4 +63,3 @@ def make_create_chat_history(chat_id: str):
|
||||
return f"create_chat_history 执行失败: {e}"
|
||||
|
||||
return create_chat_history
|
||||
|
||||
|
||||
@@ -23,4 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||
return f"delete_chat_history 执行失败: {e}"
|
||||
|
||||
return delete_chat_history
|
||||
|
||||
|
||||
@@ -23,4 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||
return f"delete_jargon 执行失败: {e}"
|
||||
|
||||
return delete_jargon
|
||||
|
||||
|
||||
@@ -14,4 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
||||
return msg
|
||||
|
||||
return finish_maintenance
|
||||
|
||||
|
||||
@@ -41,4 +41,3 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
||||
return f"get_chat_history_detail 执行失败: {e}"
|
||||
|
||||
return get_chat_history_detail
|
||||
|
||||
|
||||
@@ -212,4 +212,3 @@ def make_search_chat_history(chat_id: str):
|
||||
return f"search_chat_history 执行失败: {e}"
|
||||
|
||||
return search_chat_history
|
||||
|
||||
|
||||
@@ -46,4 +46,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
||||
return f"update_chat_history 执行失败: {e}"
|
||||
|
||||
return update_chat_history
|
||||
|
||||
|
||||
@@ -49,4 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
||||
return f"update_jargon 执行失败: {e}"
|
||||
|
||||
return update_jargon
|
||||
|
||||
|
||||
@@ -458,8 +458,8 @@ def _default_normal_response_parser(
|
||||
if not isinstance(arguments, dict):
|
||||
# 此时为了调试方便,建议打印出 arguments 的类型
|
||||
raise RespParseException(
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}",
|
||||
)
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
@@ -105,25 +105,27 @@ async def search_chat_history(
|
||||
# 检查参数
|
||||
if not keyword and not participant and not start_time and not end_time:
|
||||
return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)"
|
||||
|
||||
|
||||
# 解析时间参数
|
||||
start_timestamp = None
|
||||
end_timestamp = None
|
||||
|
||||
|
||||
if start_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||
except ValueError as e:
|
||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
|
||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||
except ValueError as e:
|
||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
|
||||
# 验证时间范围
|
||||
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
|
||||
return "开始时间不能晚于结束时间"
|
||||
@@ -158,23 +160,20 @@ async def search_chat_history(
|
||||
f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
|
||||
# 添加时间过滤条件
|
||||
if start_timestamp is not None and end_timestamp is not None:
|
||||
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
|
||||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||
query = query.where(
|
||||
(
|
||||
(ChatHistory.start_time >= start_timestamp)
|
||||
& (ChatHistory.start_time <= end_timestamp)
|
||||
(ChatHistory.start_time >= start_timestamp) & (ChatHistory.start_time <= end_timestamp)
|
||||
) # 记录开始时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.end_time >= start_timestamp)
|
||||
& (ChatHistory.end_time <= end_timestamp)
|
||||
(ChatHistory.end_time >= start_timestamp) & (ChatHistory.end_time <= end_timestamp)
|
||||
) # 记录结束时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.start_time <= start_timestamp)
|
||||
& (ChatHistory.end_time >= end_timestamp)
|
||||
(ChatHistory.start_time <= start_timestamp) & (ChatHistory.end_time >= end_timestamp)
|
||||
) # 记录完全包含查询时间段
|
||||
)
|
||||
logger.debug(
|
||||
@@ -302,7 +301,7 @@ async def search_chat_history(
|
||||
time_desc = f"时间<='{end_str}'"
|
||||
if time_desc:
|
||||
conditions.append(time_desc)
|
||||
|
||||
|
||||
if conditions:
|
||||
conditions_str = "且".join(conditions)
|
||||
return f"未找到满足条件({conditions_str})的聊天记录"
|
||||
|
||||
@@ -30,7 +30,7 @@ async def query_words(chat_id: str, words: str) -> str:
|
||||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
@@ -76,4 +76,3 @@ def register_tool():
|
||||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ async def generate_reply(
|
||||
# 如果 reply_time_point 未传入,设置为当前时间戳
|
||||
if reply_time_point is None:
|
||||
reply_time_point = time.time()
|
||||
|
||||
|
||||
# 获取回复器
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
|
||||
@@ -558,7 +558,9 @@ class PluginBase(ABC):
|
||||
if version_spec:
|
||||
is_ok, msg = self._is_version_spec_satisfied(dep_version, version_spec)
|
||||
if not is_ok:
|
||||
logger.error(f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})")
|
||||
logger.error(
|
||||
f"{self.log_prefix} 依赖插件版本不满足: {dep_name} {version_spec}, 当前版本={dep_version} ({msg})"
|
||||
)
|
||||
return False
|
||||
|
||||
if min_version or max_version:
|
||||
|
||||
@@ -751,9 +751,7 @@ class ComponentRegistry:
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
"workflow_steps": workflow_step_count,
|
||||
"enabled_workflow_steps": enabled_workflow_step_count,
|
||||
"workflow_steps_by_stage": {
|
||||
stage.value: len(steps) for stage, steps in self._workflow_steps.items()
|
||||
},
|
||||
"workflow_steps_by_stage": {stage.value: len(steps) for stage, steps in self._workflow_steps.items()},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -429,7 +429,9 @@ class PluginManager:
|
||||
|
||||
def _resolve_plugin_load_order(self, dependency_graph: Dict[str, Set[str]]) -> Tuple[List[str], Set[str]]:
|
||||
"""根据依赖图计算加载顺序,并检测循环依赖。"""
|
||||
indegree: Dict[str, int] = {plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()}
|
||||
indegree: Dict[str, int] = {
|
||||
plugin_name: len(dependencies) for plugin_name, dependencies in dependency_graph.items()
|
||||
}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_name: set() for plugin_name in dependency_graph}
|
||||
|
||||
for plugin_name, dependencies in dependency_graph.items():
|
||||
|
||||
@@ -55,7 +55,9 @@ class PluginServiceRegistry:
|
||||
full_name = self._resolve_full_name(service_name, plugin_name)
|
||||
return self._service_handlers.get(full_name) if full_name else None
|
||||
|
||||
def list_services(self, plugin_name: Optional[str] = None, enabled_only: bool = False) -> Dict[str, PluginServiceInfo]:
|
||||
def list_services(
|
||||
self, plugin_name: Optional[str] = None, enabled_only: bool = False
|
||||
) -> Dict[str, PluginServiceInfo]:
|
||||
"""列出插件服务。"""
|
||||
services = self._services.copy()
|
||||
if plugin_name:
|
||||
@@ -120,7 +122,12 @@ class PluginServiceRegistry:
|
||||
target_name = f"{plugin_name}.{service_name}" if plugin_name and "." not in service_name else service_name
|
||||
raise ValueError(f"插件服务未注册: {target_name}")
|
||||
|
||||
if "." not in service_name and plugin_name is None and caller_plugin and service_info.plugin_name != caller_plugin:
|
||||
if (
|
||||
"." not in service_name
|
||||
and plugin_name is None
|
||||
and caller_plugin
|
||||
and service_info.plugin_name != caller_plugin
|
||||
):
|
||||
raise PermissionError("跨插件服务调用必须使用完整服务名或显式指定plugin_name")
|
||||
|
||||
if not self._is_call_authorized(service_info, caller_plugin):
|
||||
@@ -153,7 +160,9 @@ class PluginServiceRegistry:
|
||||
allowed_callers = {caller.strip() for caller in service_info.allowed_callers if caller.strip()}
|
||||
return "*" in allowed_callers or caller_plugin in allowed_callers
|
||||
|
||||
def _validate_input_contract(self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
|
||||
def _validate_input_contract(
|
||||
self, service_info: PluginServiceInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
"""校验服务入参契约。"""
|
||||
schema = service_info.params_schema
|
||||
if not schema:
|
||||
|
||||
@@ -96,7 +96,9 @@ class WorkflowEngine:
|
||||
except Exception as e:
|
||||
workflow_context.timings[stage_key] = time.perf_counter() - stage_start
|
||||
workflow_context.errors.append(f"{stage_key}: {e}")
|
||||
logger.error(f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={workflow_context.trace_id}] Workflow阶段 {stage_key} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
self._execution_history[workflow_context.trace_id]["status"] = "failed"
|
||||
self._execution_history[workflow_context.trace_id]["errors"] = workflow_context.errors.copy()
|
||||
return (
|
||||
@@ -195,7 +197,9 @@ class WorkflowEngine:
|
||||
except Exception as e:
|
||||
context.timings[step_timing_key] = time.perf_counter() - step_start
|
||||
context.errors.append(f"{step_info.full_name}: {e}")
|
||||
logger.error(f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[trace_id={context.trace_id}] Workflow step {step_info.full_name} 执行异常: {e}", exc_info=True
|
||||
)
|
||||
return WorkflowStepResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
|
||||
@@ -117,7 +117,7 @@ class PromptManager:
|
||||
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
"""
|
||||
添加一个上下文构造函数
|
||||
|
||||
|
||||
Args:
|
||||
name (str): 上下文名称
|
||||
func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数,接受 Prompt 名称作为参数,返回字符串或返回字符串的协程
|
||||
@@ -144,7 +144,7 @@ class PromptManager:
|
||||
def get_prompt(self, prompt_name: str) -> Prompt:
|
||||
"""
|
||||
获取指定名称的 Prompt 实例的克隆
|
||||
|
||||
|
||||
Args:
|
||||
prompt_name (str): 要获取的 Prompt 名称
|
||||
Returns:
|
||||
@@ -161,7 +161,7 @@ class PromptManager:
|
||||
async def render_prompt(self, prompt: Prompt) -> str:
|
||||
"""
|
||||
渲染一个 Prompt 实例
|
||||
|
||||
|
||||
Args:
|
||||
prompt (Prompt): 要渲染的 Prompt 实例
|
||||
Returns:
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
@@ -21,6 +22,7 @@ PLAN_LOG_DIR = Path("logs/plan")
|
||||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
@@ -29,6 +31,7 @@ class ChatSummary(BaseModel):
|
||||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
@@ -41,6 +44,7 @@ class PlanLogSummary(BaseModel):
|
||||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
@@ -54,6 +58,7 @@ class PlanLogDetail(BaseModel):
|
||||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
@@ -61,6 +66,7 @@ class PlannerOverview(BaseModel):
|
||||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
@@ -71,7 +77,7 @@ class PaginatedChatLogs(BaseModel):
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
@@ -86,41 +92,39 @@ async def get_planner_overview():
|
||||
"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
|
||||
|
||||
|
||||
chats = []
|
||||
total_plans = 0
|
||||
|
||||
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
plan_count = len(json_files)
|
||||
total_plans += plan_count
|
||||
|
||||
|
||||
if plan_count == 0:
|
||||
continue
|
||||
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
|
||||
chats.append(
|
||||
ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
return PlannerOverview(total_chats=len(chats), total_plans=total_plans, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
@@ -128,7 +132,7 @@ async def get_chat_plan_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
@@ -137,73 +141,69 @@ async def get_chat_plan_logs(
|
||||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
return PaginatedChatLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
reasoning = data.get("reasoning", "")
|
||||
actions = data.get("actions", [])
|
||||
action_types = [a.get("action_type", "") for a in actions if a.get("action_type")]
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get("timing", {}).get("total_plan_ms", 0),
|
||||
llm_duration_ms=data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logs.append(
|
||||
PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedChatLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
@@ -212,9 +212,9 @@ async def get_log_detail(chat_id: str, filename: str):
|
||||
log_file = PLAN_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
@@ -223,11 +223,12 @@ async def get_log_detail(chat_id: str, filename: str):
|
||||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
|
||||
|
||||
# 获取最近10条计划的摘要
|
||||
recent_plans = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
@@ -236,17 +237,17 @@ async def get_planner_stats():
|
||||
recent_plans.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_plans = recent_plans[:10]
|
||||
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
"recent_plans": recent_plans,
|
||||
}
|
||||
|
||||
|
||||
@@ -258,44 +259,43 @@ async def get_chat_list():
|
||||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100)):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
|
||||
# 收集所有文件
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
|
||||
page_files = all_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
reasoning = data.get("reasoning", "")
|
||||
logs.append(
|
||||
{
|
||||
"chat_id": data.get("chat_id", chat_id),
|
||||
"timestamp": data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get("actions", [])),
|
||||
"total_plan_ms": data.get("timing", {}).get("total_plan_ms", 0),
|
||||
"llm_duration_ms": data.get("timing", {}).get("llm_duration_ms", 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else "",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
@@ -21,6 +22,7 @@ REPLY_LOG_DIR = Path("logs/reply")
|
||||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
@@ -29,6 +31,7 @@ class ReplierChatSummary(BaseModel):
|
||||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
@@ -41,6 +44,7 @@ class ReplyLogSummary(BaseModel):
|
||||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
@@ -57,6 +61,7 @@ class ReplyLogDetail(BaseModel):
|
||||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
@@ -64,6 +69,7 @@ class ReplierOverview(BaseModel):
|
||||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
@@ -74,7 +80,7 @@ class PaginatedReplyLogs(BaseModel):
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
timestamp_str = filename.split("_")[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
@@ -89,41 +95,39 @@ async def get_replier_overview():
|
||||
"""
|
||||
if not REPLY_LOG_DIR.exists():
|
||||
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
|
||||
|
||||
|
||||
chats = []
|
||||
total_replies = 0
|
||||
|
||||
|
||||
for chat_dir in REPLY_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
reply_count = len(json_files)
|
||||
total_replies += reply_count
|
||||
|
||||
|
||||
if reply_count == 0:
|
||||
continue
|
||||
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
|
||||
chats.append(
|
||||
ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name,
|
||||
)
|
||||
)
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
return ReplierOverview(total_chats=len(chats), total_replies=total_replies, chats=chats)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
@@ -131,7 +135,7 @@ async def get_chat_reply_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容"),
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
@@ -140,71 +144,67 @@ async def get_chat_reply_logs(
|
||||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
return PaginatedReplyLogs(data=[], total=0, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
prompt = data.get("prompt", "")
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
page_files = json_files[offset : offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
output = data.get("output", "")
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get("model", ""),
|
||||
success=data.get("success", True),
|
||||
llm_ms=data.get("timing", {}).get("llm_ms", 0),
|
||||
overall_ms=data.get("timing", {}).get("overall_ms", 0),
|
||||
output_preview=output[:100] if output else "",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logs.append(
|
||||
ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model="",
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview="[读取失败]",
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedReplyLogs(data=logs, total=total, page=page, page_size=page_size, chat_id=chat_id)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
@@ -213,23 +213,23 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
||||
log_file = REPLY_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
type=data.get("type", "reply"),
|
||||
chat_id=data.get("chat_id", chat_id),
|
||||
timestamp=data.get("timestamp", 0),
|
||||
prompt=data.get("prompt", ""),
|
||||
output=data.get("output", ""),
|
||||
processed_output=data.get("processed_output", []),
|
||||
model=data.get("model", ""),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
think_level=data.get("think_level", 0),
|
||||
timing=data.get("timing", {}),
|
||||
error=data.get("error"),
|
||||
success=data.get("success", True),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
@@ -237,11 +237,12 @@ async def get_reply_log_detail(chat_id: str, filename: str):
|
||||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
overview = await get_replier_overview()
|
||||
|
||||
|
||||
# 获取最近10条回复的摘要
|
||||
recent_replies = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
@@ -250,15 +251,15 @@ async def get_replier_stats():
|
||||
recent_replies.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_replies = recent_replies[:10]
|
||||
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
"recent_replies": recent_replies,
|
||||
}
|
||||
|
||||
|
||||
@@ -266,4 +267,4 @@ async def get_replier_stats():
|
||||
async def get_replier_chat_list():
|
||||
"""获取所有聊天ID列表"""
|
||||
overview = await get_replier_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Cookie, Header, Request, HTTPException
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
|
||||
from fastapi import Depends, Cookie, Header, Request
|
||||
from .core import get_current_token, get_token_manager, check_auth_rate_limit
|
||||
|
||||
|
||||
async def require_auth(
|
||||
|
||||
@@ -124,6 +124,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
@@ -151,7 +152,7 @@ def _parse_allowed_ips(ip_string: str) -> list:
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
@@ -237,19 +238,21 @@ def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
"mode": global_config.webui.anti_crawler_mode,
|
||||
"allowed_ips": _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
"trusted_proxies": _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
"trust_xff": global_config.webui.trust_xff,
|
||||
}
|
||||
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
ANTI_CRAWLER_MODE = _config["mode"]
|
||||
ALLOWED_IPS = _config["allowed_ips"]
|
||||
TRUSTED_PROXIES = _config["trusted_proxies"]
|
||||
TRUST_XFF = _config["trust_xff"]
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
|
||||
@@ -17,36 +17,36 @@ _paragraph_store_cache = None
|
||||
|
||||
def _get_paragraph_store():
|
||||
"""延迟加载段落 embedding store(只读模式,轻量级)
|
||||
|
||||
|
||||
Returns:
|
||||
EmbeddingStore | None: 如果配置启用则返回store,否则返回None
|
||||
"""
|
||||
# 检查配置是否启用
|
||||
if not global_config.webui.enable_paragraph_content:
|
||||
return None
|
||||
|
||||
|
||||
global _paragraph_store_cache
|
||||
if _paragraph_store_cache is not None:
|
||||
return _paragraph_store_cache
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.embedding_store import EmbeddingStore
|
||||
import os
|
||||
|
||||
|
||||
# 获取数据路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
embedding_dir = os.path.join(root_path, "data/embedding")
|
||||
|
||||
|
||||
# 只加载段落 embedding store(轻量级)
|
||||
paragraph_store = EmbeddingStore(
|
||||
namespace="paragraph",
|
||||
dir_path=embedding_dir,
|
||||
max_workers=1, # 只读不需要多线程
|
||||
chunk_size=100
|
||||
chunk_size=100,
|
||||
)
|
||||
paragraph_store.load_from_file()
|
||||
|
||||
|
||||
_paragraph_store_cache = paragraph_store
|
||||
logger.info(f"成功加载段落 embedding store,包含 {len(paragraph_store.store)} 个段落")
|
||||
return paragraph_store
|
||||
@@ -57,10 +57,10 @@ def _get_paragraph_store():
|
||||
|
||||
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
"""从 embedding store 获取段落完整内容
|
||||
|
||||
|
||||
Args:
|
||||
node_id: 段落节点ID,格式为 'paragraph-{hash}'
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
|
||||
"""
|
||||
@@ -69,12 +69,12 @@ def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
if paragraph_store is None:
|
||||
# 功能未启用
|
||||
return None, False
|
||||
|
||||
|
||||
# 从 store 中获取完整内容
|
||||
paragraph_item = paragraph_store.store.get(node_id)
|
||||
if paragraph_item is not None:
|
||||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||||
content: str = getattr(paragraph_item, 'str', '')
|
||||
content: str = getattr(paragraph_item, "str", "")
|
||||
if content:
|
||||
return content, True
|
||||
return None, True
|
||||
@@ -156,14 +156,18 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
@@ -245,14 +249,18 @@ async def get_knowledge_graph(
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type_val == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
@@ -368,11 +376,15 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
|
||||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||||
if node_type == "paragraph":
|
||||
full_content, _ = _get_paragraph_content(node_id)
|
||||
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
|
||||
content = (
|
||||
full_content
|
||||
if full_content is not None
|
||||
else (node_data["content"] if "content" in node_data else node_id)
|
||||
)
|
||||
else:
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
|
||||
|
||||
@@ -659,4 +659,4 @@ def get_git_mirror_service() -> GitMirrorService:
|
||||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
return _git_mirror_service
|
||||
|
||||
Reference in New Issue
Block a user