feat:加入MaiSaKa(MaiBaKa?(MaiZako?))
This commit is contained in:
21
src/MaiDiary/LICENSE
Normal file
21
src/MaiDiary/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 SengokuCola
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
613
src/MaiDiary/cli.py
Normal file
613
src/MaiDiary/cli.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
MaiSaka - CLI 交互界面与对话引擎
|
||||
BufferCLI 整合主循环、对话引擎、子代理管理。
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.text import Text
|
||||
from rich import box
|
||||
|
||||
from config import console, ENABLE_EMOTION_MODULE, ENABLE_COGNITION_MODULE, ENABLE_TIMING_MODULE, ENABLE_KNOWLEDGE_MODULE, ENABLE_MCP
|
||||
from input_reader import InputReader
|
||||
from debug_client import DebugViewer
|
||||
from timing import build_timing_info
|
||||
from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge, build_knowledge_summary
|
||||
from knowledge_store import get_knowledge_store
|
||||
from llm_service import BaseLLMService, OpenAILLMService
|
||||
from llm_service.utils import build_message, remove_last_perception
|
||||
from mcp_client import MCPManager
|
||||
from tool_handlers import (
|
||||
ToolHandlerContext,
|
||||
handle_say,
|
||||
handle_stop,
|
||||
handle_wait,
|
||||
handle_write_file,
|
||||
handle_read_file,
|
||||
handle_list_files,
|
||||
handle_store_context,
|
||||
handle_mcp_tool,
|
||||
handle_unknown_tool,
|
||||
handle_get_qq_chat_info,
|
||||
handle_send_info,
|
||||
handle_list_qq_chats,
|
||||
)
|
||||
|
||||
|
||||
class BufferCLI:
|
||||
"""命令行交互界面"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_service: Optional[BaseLLMService] = None
|
||||
self._reader = InputReader()
|
||||
self._chat_history: Optional[list] = None # 持久化的对话历史
|
||||
self._knowledge_store = get_knowledge_store() # 了解存储实例
|
||||
|
||||
# 显示了解存储统计
|
||||
knowledge_stats = self._knowledge_store.get_stats()
|
||||
if knowledge_stats["total_items"] > 0:
|
||||
console.print(f"[success]✓ 了解系统: {knowledge_stats['total_items']}条特征信息[/success]")
|
||||
else:
|
||||
console.print("[muted]✓ 了解系统: 已初始化 (暂无数据)[/muted]")
|
||||
# Timing 模块时间戳跟踪
|
||||
self._chat_start_time: Optional[datetime] = None
|
||||
self._last_user_input_time: Optional[datetime] = None
|
||||
self._last_assistant_response_time: Optional[datetime] = None
|
||||
self._user_input_times: list[datetime] = [] # 所有用户输入时间戳
|
||||
# MCP 管理器(异步初始化,在 run() 中完成)
|
||||
self._mcp_manager: Optional[MCPManager] = None
|
||||
# Debug Viewer
|
||||
self._debug_viewer = DebugViewer()
|
||||
self._init_llm()
|
||||
|
||||
def _init_llm(self):
|
||||
"""初始化 LLM 服务"""
|
||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
base_url = os.getenv("OPENAI_BASE_URL", "")
|
||||
model = os.getenv("OPENAI_MODEL", "gpt-4o")
|
||||
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
||||
enable_thinking: Optional[bool] = (
|
||||
True if thinking_env == "true"
|
||||
else False if thinking_env == "false"
|
||||
else None
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
console.print(
|
||||
Panel(
|
||||
"[warning]未检测到 OPENAI_API_KEY 环境变量![/warning]\n\n"
|
||||
"请设置以下环境变量(或在 .env 文件中配置):\n"
|
||||
" • OPENAI_API_KEY - 必填,API 密钥\n"
|
||||
" • OPENAI_BASE_URL - 可选,API 基地址\n"
|
||||
" • OPENAI_MODEL - 可选,模型名称(默认 gpt-4o)\n\n"
|
||||
"[muted]程序无法运行,请配置后重试。[/muted]",
|
||||
title="⚠️ 配置提示",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url if base_url else None,
|
||||
model=model,
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
# 绑定 debug 回调
|
||||
self.llm_service.set_debug_callback(self._debug_viewer.send)
|
||||
console.print(f"[success]✓ LLM 服务已初始化[/success] [muted](模型: {model})[/muted]")
|
||||
|
||||
def _build_tool_context(self) -> ToolHandlerContext:
|
||||
"""构建工具处理器所需的上下文。"""
|
||||
ctx = ToolHandlerContext(
|
||||
llm_service=self.llm_service,
|
||||
reader=self._reader,
|
||||
user_input_times=self._user_input_times,
|
||||
)
|
||||
ctx.last_user_input_time = self._last_user_input_time
|
||||
return ctx
|
||||
|
||||
# ──────── 显示方法 ────────
|
||||
|
||||
def _show_banner(self):
|
||||
"""显示欢迎横幅"""
|
||||
banner = Text()
|
||||
banner.append("MaiSaka", style="bold cyan")
|
||||
banner.append(" v2.0\n", style="muted")
|
||||
banner.append("直接输入文字开始对话 | Ctrl+C 退出", style="muted")
|
||||
|
||||
console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2)))
|
||||
console.print()
|
||||
|
||||
# ──────── 上下文管理 ────────
|
||||
|
||||
def _get_safe_removal_indices(self, chat_history: list, count: int) -> list[int]:
|
||||
"""
|
||||
获取可以安全删除的消息索引。
|
||||
|
||||
确保 tool_calls 和 tool 响应消息成对删除,避免破坏 API 要求的配对关系。
|
||||
只删除完整的消息块(user/assistant + 可选的 tool 响应序列)。
|
||||
|
||||
保留最后 3 条非 tool 消息,避免删除可能还在处理中的内容。
|
||||
|
||||
Returns:
|
||||
可以安全删除的消息索引列表(从后往前排序)
|
||||
"""
|
||||
indices_to_remove = []
|
||||
removed_count = 0
|
||||
i = 0
|
||||
|
||||
# 计算保留的消息数量(最后 3 条非 tool 消息)
|
||||
safe_zone_count = 3
|
||||
non_tool_count = 0
|
||||
for msg in reversed(chat_history):
|
||||
if msg.get("role") != "tool":
|
||||
non_tool_count += 1
|
||||
if non_tool_count >= safe_zone_count:
|
||||
break
|
||||
|
||||
# 只处理前 (len - non_tool_count) 条消息
|
||||
max_process_index = len(chat_history) - non_tool_count
|
||||
|
||||
while i < max_process_index and removed_count < count:
|
||||
msg = chat_history[i]
|
||||
role = msg.get("role", "")
|
||||
|
||||
# 跳过 role=tool 的消息(它们会被对应的 assistant 消息一起处理)
|
||||
if role == "tool":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 检查这是否是一个带 tool_calls 的 assistant 消息
|
||||
if role == "assistant" and "tool_calls" in msg:
|
||||
# 收集这个 assistant 消息及其后续的 tool 响应消息
|
||||
block_indices = [i]
|
||||
j = i + 1
|
||||
while j < len(chat_history):
|
||||
next_msg = chat_history[j]
|
||||
if next_msg.get("role") == "tool":
|
||||
block_indices.append(j)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
indices_to_remove.extend(block_indices)
|
||||
removed_count += 1
|
||||
i = j
|
||||
elif role in ["user", "assistant"]:
|
||||
# 普通消息,可以直接删除
|
||||
indices_to_remove.append(i)
|
||||
removed_count += 1
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# 从后往前排序,避免索引问题
|
||||
return sorted(indices_to_remove, reverse=True)
|
||||
|
||||
async def _manage_context_length(self, chat_history: list) -> None:
|
||||
"""
|
||||
上下文管理:当对话历史过长时进行压缩。
|
||||
|
||||
当达到 20 条上下文时:
|
||||
1. 移除最早 10 条上下文
|
||||
2. 对这 10 条内容进行 LLM 总结
|
||||
3. 将总结后的内容存入记忆
|
||||
"""
|
||||
CONTEXT_LIMIT = 20
|
||||
COMPRESS_COUNT = 10
|
||||
|
||||
# 计算实际消息数量(排除 role=tool 的工具返回消息)
|
||||
actual_messages = [m for m in chat_history if m.get("role") != "tool"]
|
||||
|
||||
if len(actual_messages) >= CONTEXT_LIMIT:
|
||||
# 获取安全删除的索引
|
||||
indices_to_remove = self._get_safe_removal_indices(chat_history, COMPRESS_COUNT)
|
||||
|
||||
if indices_to_remove:
|
||||
# 收集要总结的消息(在删除前)
|
||||
to_compress = []
|
||||
for i in sorted(indices_to_remove):
|
||||
if 0 <= i < len(chat_history):
|
||||
to_compress.append(chat_history[i])
|
||||
|
||||
if to_compress:
|
||||
# 总结上下文
|
||||
try:
|
||||
console.print("[accent]🧠 上下文过长,正在压缩并存入记忆...[/accent]")
|
||||
summary = await self.llm_service.summarize_context(to_compress)
|
||||
|
||||
# 存储了解信息(如果启用)
|
||||
if ENABLE_KNOWLEDGE_MODULE:
|
||||
try:
|
||||
knowledge_count = await store_knowledge_from_context(
|
||||
self.llm_service,
|
||||
to_compress,
|
||||
store_result_callback=lambda cat_id, cat_name, content: console.print(
|
||||
f"[muted] ✓ 存储了解信息: {cat_name}[/muted]"
|
||||
)
|
||||
)
|
||||
if knowledge_count > 0:
|
||||
console.print(f"[success]✓ 了解模块: 存储{knowledge_count}条特征信息[/success]")
|
||||
except Exception as e:
|
||||
console.print(f"[warning]了解存储失败: {e}[/warning]")
|
||||
if summary:
|
||||
# 存入记忆
|
||||
# 显示压缩结果
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(summary),
|
||||
title="📝 上下文已压缩",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[warning]上下文总结失败: {e}[/warning]")
|
||||
|
||||
# 从后往前删除
|
||||
for i in indices_to_remove:
|
||||
if 0 <= i < len(chat_history):
|
||||
chat_history.pop(i)
|
||||
|
||||
# 清理"孤儿" tool 消息(没有对应 tool_calls 的 tool 消息)
|
||||
valid_tool_call_ids = set()
|
||||
for msg in chat_history:
|
||||
if msg.get("role") == "assistant" and "tool_calls" in msg:
|
||||
for tool_call in msg["tool_calls"]:
|
||||
valid_tool_call_ids.add(tool_call.get("id", ""))
|
||||
|
||||
# 删除无效的 tool 消息(从后往前)
|
||||
i = len(chat_history) - 1
|
||||
while i >= 0:
|
||||
msg = chat_history[i]
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
if tool_call_id not in valid_tool_call_ids:
|
||||
chat_history.pop(i)
|
||||
i -= 1
|
||||
|
||||
# ──────── LLM 循环架构 ────────
|
||||
|
||||
async def _start_chat(self, user_text: str):
|
||||
"""接收用户输入并启动/继续 LLM 对话循环"""
|
||||
if not self.llm_service:
|
||||
console.print("[warning]LLM 服务未初始化,跳过对话。[/warning]")
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
self._last_user_input_time = now
|
||||
self._user_input_times.append(now)
|
||||
|
||||
if self._chat_history is None:
|
||||
# 首次对话:初始化上下文
|
||||
self._chat_start_time = now
|
||||
self._last_assistant_response_time = None
|
||||
self._chat_history = self.llm_service.build_chat_context(user_text)
|
||||
else:
|
||||
# 后续对话:追加用户消息到已有上下文
|
||||
self._chat_history.append({
|
||||
"role": "user",
|
||||
"content": user_text,
|
||||
})
|
||||
|
||||
await self._run_llm_loop(self._chat_history)
|
||||
|
||||
async def _run_llm_loop(self, chat_history: list):
|
||||
"""
|
||||
LLM 循环架构核心。
|
||||
|
||||
LLM 持续运行,每步可能输出文本(内心思考)和/或调用工具:
|
||||
- say(text): 对用户说话
|
||||
- wait(seconds): 暂停等待用户输入,超时或收到输入后继续
|
||||
- stop(): 结束循环,进入待机,直到用户下次输入
|
||||
- 不调用工具: 继续下一轮思考/生成
|
||||
|
||||
每轮流程:
|
||||
1. 上下文管理:达到上限时自动压缩
|
||||
2. 情商 + Timing + 了解模块(并行):分析用户情绪、对话时间节奏、检索用户特征
|
||||
*注:如果上次没有调用工具,跳过模块分析
|
||||
3. 调用主 LLM:基于完整上下文生成响应
|
||||
"""
|
||||
consecutive_errors = 0
|
||||
last_had_tool_calls = True # 第一次循环总是执行模块分析
|
||||
|
||||
while True:
|
||||
# ── 上下文管理 ──
|
||||
await self._manage_context_length(chat_history)
|
||||
|
||||
# ── 情商模块 + Timing 模块 + 了解模块(并行) ──
|
||||
# 只有上次调用了工具才重新分析(首次循环除外)
|
||||
if last_had_tool_calls:
|
||||
timing_info = build_timing_info(
|
||||
self._chat_start_time,
|
||||
self._last_user_input_time,
|
||||
self._last_assistant_response_time,
|
||||
self._user_input_times,
|
||||
)
|
||||
|
||||
# 根据配置决定要执行的模块
|
||||
tasks = []
|
||||
status_text_parts = []
|
||||
|
||||
if ENABLE_EMOTION_MODULE:
|
||||
tasks.append(("eq", self.llm_service.analyze_emotion(chat_history)))
|
||||
status_text_parts.append("🎭")
|
||||
if ENABLE_COGNITION_MODULE:
|
||||
tasks.append(("cognition", self.llm_service.analyze_cognition(chat_history)))
|
||||
status_text_parts.append("🧩")
|
||||
if ENABLE_TIMING_MODULE:
|
||||
tasks.append(("timing", self.llm_service.analyze_timing(chat_history, timing_info)))
|
||||
status_text_parts.append("⏱️🪞")
|
||||
if ENABLE_KNOWLEDGE_MODULE:
|
||||
tasks.append(("knowledge", retrieve_relevant_knowledge(self.llm_service, chat_history)))
|
||||
status_text_parts.append("👤")
|
||||
|
||||
with console.status(
|
||||
f"[info]{' '.join(status_text_parts)} {' + '.join(status_text_parts)} 模块并行分析中...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||||
|
||||
# 解析结果
|
||||
eq_result, cognition_result, timing_result, knowledge_result = None, None, None, None
|
||||
result_idx = 0
|
||||
if ENABLE_EMOTION_MODULE:
|
||||
eq_result = results[result_idx]
|
||||
result_idx += 1
|
||||
if ENABLE_COGNITION_MODULE:
|
||||
cognition_result = results[result_idx]
|
||||
result_idx += 1
|
||||
if ENABLE_TIMING_MODULE:
|
||||
timing_result = results[result_idx]
|
||||
result_idx += 1
|
||||
if ENABLE_KNOWLEDGE_MODULE:
|
||||
knowledge_result = results[result_idx]
|
||||
result_idx += 1
|
||||
|
||||
# 处理情商模块结果
|
||||
eq_analysis = ""
|
||||
if ENABLE_EMOTION_MODULE:
|
||||
if isinstance(eq_result, Exception):
|
||||
console.print(f"[warning]情商模块分析失败: {eq_result}[/warning]")
|
||||
elif eq_result:
|
||||
eq_analysis = eq_result
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(eq_analysis),
|
||||
title="🎭 情绪感知",
|
||||
border_style="bright_yellow",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理认知模块结果
|
||||
cognition_analysis = ""
|
||||
if ENABLE_COGNITION_MODULE:
|
||||
if isinstance(cognition_result, Exception):
|
||||
console.print(f"[warning]认知模块分析失败: {cognition_result}[/warning]")
|
||||
elif cognition_result:
|
||||
cognition_analysis = cognition_result
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(cognition_analysis),
|
||||
title="🧩 意图感知",
|
||||
border_style="bright_cyan",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理 Timing 模块结果(含自我反思功能)
|
||||
timing_analysis = ""
|
||||
if ENABLE_TIMING_MODULE:
|
||||
if isinstance(timing_result, Exception):
|
||||
console.print(f"[warning]Timing 模块分析失败: {timing_result}[/warning]")
|
||||
elif timing_result:
|
||||
timing_analysis = timing_result
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(timing_analysis),
|
||||
title="⏱️🪞 时间感知 & 自我反思",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理了解模块结果
|
||||
knowledge_analysis = ""
|
||||
if ENABLE_KNOWLEDGE_MODULE:
|
||||
if isinstance(knowledge_result, Exception):
|
||||
console.print(f"[warning]了解模块分析失败: {knowledge_result}[/warning]")
|
||||
elif knowledge_result:
|
||||
knowledge_analysis = knowledge_result
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(knowledge_analysis),
|
||||
title="👤 用户特征",
|
||||
border_style="bright_magenta",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
# 注入感知信息(作为 assistant 的感知消息)
|
||||
# 移除上一条感知消息(如果存在)
|
||||
remove_last_perception(chat_history)
|
||||
|
||||
# 构建感知内容
|
||||
perception_parts = []
|
||||
if eq_analysis:
|
||||
perception_parts.append(f"情绪感知\n{eq_analysis}")
|
||||
if cognition_analysis:
|
||||
perception_parts.append(f"意图感知\n{cognition_analysis}")
|
||||
if timing_analysis:
|
||||
perception_parts.append(f"时间感知 & 自我反思\n{timing_analysis}")
|
||||
if knowledge_analysis:
|
||||
perception_parts.append(f"用户特征\n{knowledge_analysis}")
|
||||
|
||||
if perception_parts:
|
||||
# 添加感知消息(AI 的感知能力结果)
|
||||
chat_history.append(build_message(
|
||||
role="assistant",
|
||||
content="\n\n".join(perception_parts),
|
||||
msg_type="perception",
|
||||
))
|
||||
else:
|
||||
# 上次没有调用工具,跳过模块分析
|
||||
console.print("[muted]ℹ️ 上次未调用工具,跳过模块分析[/muted]")
|
||||
|
||||
|
||||
# ── 调用 LLM ──
|
||||
with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"):
|
||||
try:
|
||||
response = await self.llm_service.chat_loop_step(chat_history)
|
||||
consecutive_errors = 0
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
console.print(f"[error]LLM 调用出错: {e}[/error]")
|
||||
if consecutive_errors >= 3:
|
||||
console.print("[error]连续出错,退出对话[/error]\n")
|
||||
break
|
||||
continue
|
||||
|
||||
# 将 assistant 消息追加到历史
|
||||
chat_history.append(response.raw_message)
|
||||
self._last_assistant_response_time = datetime.now()
|
||||
|
||||
# 显示内心思考(content 部分,淡色呈现)
|
||||
if response.content:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(response.content),
|
||||
title="💭 内心思考",
|
||||
border_style="dim",
|
||||
padding=(1, 2),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
# ── 处理工具调用 ──
|
||||
if response.tool_calls:
|
||||
should_stop = False
|
||||
ctx = self._build_tool_context()
|
||||
|
||||
for tc in response.tool_calls:
|
||||
if tc.name == "say":
|
||||
await handle_say(tc, chat_history, ctx)
|
||||
|
||||
elif tc.name == "stop":
|
||||
await handle_stop(tc, chat_history)
|
||||
should_stop = True
|
||||
|
||||
elif tc.name == "wait":
|
||||
tool_result = await handle_wait(tc, chat_history, ctx)
|
||||
# 同步回 timing 时间戳
|
||||
if ctx.last_user_input_time != self._last_user_input_time:
|
||||
self._last_user_input_time = ctx.last_user_input_time
|
||||
if tool_result.startswith("[[QUIT]]"):
|
||||
should_stop = True
|
||||
|
||||
elif tc.name == "write_file":
|
||||
await handle_write_file(tc, chat_history)
|
||||
|
||||
elif tc.name == "read_file":
|
||||
await handle_read_file(tc, chat_history)
|
||||
|
||||
elif tc.name == "list_files":
|
||||
await handle_list_files(tc, chat_history)
|
||||
|
||||
elif tc.name == "store_context":
|
||||
await handle_store_context(tc, chat_history, ctx)
|
||||
|
||||
elif tc.name == "get_qq_chat_info":
|
||||
await handle_get_qq_chat_info(tc, chat_history)
|
||||
|
||||
elif tc.name == "send_info":
|
||||
await handle_send_info(tc, chat_history)
|
||||
|
||||
elif tc.name == "list_qq_chats":
|
||||
await handle_list_qq_chats(tc, chat_history)
|
||||
|
||||
elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tc.name):
|
||||
await handle_mcp_tool(tc, chat_history, self._mcp_manager)
|
||||
|
||||
else:
|
||||
await handle_unknown_tool(tc, chat_history)
|
||||
|
||||
if should_stop:
|
||||
console.print("[muted]对话暂停,等待新输入...[/muted]\n")
|
||||
break
|
||||
|
||||
# 调用了工具,下次循环需要重新分析模块
|
||||
last_had_tool_calls = True
|
||||
else:
|
||||
# LLM 未调用任何工具 → 继续下一轮思考
|
||||
# (不做任何额外操作,直接回到循环顶部再次调用 LLM)
|
||||
# 标记上次没有调用工具,下次循环跳过模块分析
|
||||
last_had_tool_calls = False
|
||||
|
||||
# ──────── 主循环 ────────
|
||||
|
||||
async def _init_mcp(self):
|
||||
"""初始化 MCP 服务器连接,发现并注册外部工具。"""
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "mcp_config.json",
|
||||
)
|
||||
self._mcp_manager = await MCPManager.from_config(config_path)
|
||||
|
||||
if self._mcp_manager and self.llm_service:
|
||||
mcp_tools = self._mcp_manager.get_openai_tools()
|
||||
if mcp_tools:
|
||||
self.llm_service.set_extra_tools(mcp_tools)
|
||||
summary = self._mcp_manager.get_tool_summary()
|
||||
console.print(
|
||||
Panel(
|
||||
f"已加载 {len(mcp_tools)} 个 MCP 工具:\n{summary}",
|
||||
title="🔌 MCP 工具",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""主循环:直接输入文本即可对话"""
|
||||
# 启动调试窗口
|
||||
self._debug_viewer.start()
|
||||
|
||||
# 根据配置决定是否初始化 MCP 服务器
|
||||
if ENABLE_MCP:
|
||||
await self._init_mcp()
|
||||
else:
|
||||
console.print("[muted]🔌 MCP 已禁用 (ENABLE_MCP=false)[/muted]")
|
||||
|
||||
# 启动异步输入读取器
|
||||
self._reader.start(asyncio.get_event_loop())
|
||||
|
||||
self._show_banner()
|
||||
|
||||
try:
|
||||
while True:
|
||||
console.print("[bold cyan]> [/bold cyan]", end="")
|
||||
raw_input = await self._reader.get_line()
|
||||
|
||||
if raw_input is None: # EOF
|
||||
console.print("\n[muted]再见![/muted]")
|
||||
break
|
||||
|
||||
raw_input = raw_input.strip()
|
||||
if not raw_input:
|
||||
continue
|
||||
|
||||
await self._start_chat(raw_input)
|
||||
finally:
|
||||
self._debug_viewer.close()
|
||||
if self._mcp_manager:
|
||||
await self._mcp_manager.close()
|
||||
46
src/MaiDiary/config.py
Normal file
46
src/MaiDiary/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
MaiSaka - 全局配置
|
||||
环境变量加载、Rich Console 实例、主题定义。
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
|
||||
# ──────────────────── 加载 .env ────────────────────
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ──────────────────── 模块开关配置 ────────────────────
|
||||
|
||||
ENABLE_EMOTION_MODULE = os.getenv("ENABLE_EMOTION_MODULE", "true").strip().lower() == "true"
|
||||
ENABLE_COGNITION_MODULE = os.getenv("ENABLE_COGNITION_MODULE", "true").strip().lower() == "true"
|
||||
# Timing 模块已包含自我反思功能
|
||||
ENABLE_TIMING_MODULE = os.getenv("ENABLE_TIMING_MODULE", "true").strip().lower() == "true"
|
||||
ENABLE_KNOWLEDGE_MODULE = os.getenv("ENABLE_KNOWLEDGE_MODULE", "true").strip().lower() == "true"
|
||||
ENABLE_MCP = os.getenv("ENABLE_MCP", "true").strip().lower() == "true"
|
||||
ENABLE_WRITE_FILE = os.getenv("ENABLE_WRITE_FILE", "true").strip().lower() == "true"
|
||||
ENABLE_READ_FILE = os.getenv("ENABLE_READ_FILE", "true").strip().lower() == "true"
|
||||
ENABLE_LIST_FILES = os.getenv("ENABLE_LIST_FILES", "true").strip().lower() == "true"
|
||||
|
||||
# ──────────────────── QQ 工具配置 ────────────────────
|
||||
|
||||
ENABLE_QQ_TOOLS = os.getenv("ENABLE_QQ_TOOLS", "false").strip().lower() == "true"
|
||||
QQ_API_BASE_URL = os.getenv("QQ_API_BASE_URL", "").strip()
|
||||
QQ_API_KEY = os.getenv("QQ_API_KEY", "").strip()
|
||||
|
||||
# ──────────────────── Rich 主题 & Console ────────────────────
|
||||
|
||||
custom_theme = Theme(
|
||||
{
|
||||
"info": "cyan",
|
||||
"success": "green",
|
||||
"warning": "yellow",
|
||||
"error": "bold red",
|
||||
"muted": "dim",
|
||||
"accent": "bold magenta",
|
||||
}
|
||||
)
|
||||
|
||||
console = Console(theme=custom_theme)
|
||||
95
src/MaiDiary/debug_client.py
Normal file
95
src/MaiDiary/debug_client.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
MaiSaka - Debug Viewer 客户端
|
||||
在独立命令行窗口中显示每次 LLM 调用的完整 Prompt。
|
||||
通过 TCP socket 将数据发送给 debug_viewer.py 子进程。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from config import console
|
||||
|
||||
|
||||
class DebugViewer:
|
||||
"""
|
||||
在独立命令行窗口中显示每次 LLM 调用的完整 Prompt。
|
||||
|
||||
通过 TCP socket 将数据发送给 debug_viewer.py 子进程。
|
||||
"""
|
||||
|
||||
def __init__(self, port: int = 19876):
|
||||
self._port = port
|
||||
self._conn: Optional[socket.socket] = None
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
|
||||
def start(self):
|
||||
"""启动调试窗口子进程并建立 TCP 连接。"""
|
||||
script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "debug_viewer.py")
|
||||
|
||||
try:
|
||||
self._process = subprocess.Popen(
|
||||
[sys.executable, script_path, str(self._port)],
|
||||
creationflags=getattr(subprocess, "CREATE_NEW_CONSOLE", 0),
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[warning]⚠️ 无法启动调试窗口: {e}[/warning]")
|
||||
return
|
||||
|
||||
# 重试连接(等待子进程启动监听)
|
||||
for attempt in range(20):
|
||||
try:
|
||||
time.sleep(0.3)
|
||||
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
conn.connect(("127.0.0.1", self._port))
|
||||
self._conn = conn
|
||||
console.print(
|
||||
f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]"
|
||||
)
|
||||
return
|
||||
except ConnectionRefusedError:
|
||||
conn.close()
|
||||
|
||||
console.print("[warning]⚠️ 无法连接到调试窗口(超时)[/warning]")
|
||||
|
||||
def send(self, label: str, messages: list, tools: Optional[list] = None, response: Optional[dict] = None):
|
||||
"""发送一次 LLM 调用的完整 prompt 和响应到调试窗口。"""
|
||||
if not self._conn:
|
||||
return
|
||||
|
||||
# 只在有响应时才发送(避免显示两次:请求中 + 完成响应)
|
||||
if response is None:
|
||||
return
|
||||
|
||||
payload = {"label": label, "messages": messages}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["response"] = response
|
||||
|
||||
try:
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
header = struct.pack(">I", len(data))
|
||||
self._conn.sendall(header + data)
|
||||
except Exception:
|
||||
# 连接断开时静默忽略
|
||||
self._conn = None
|
||||
|
||||
def close(self):
|
||||
"""关闭连接和子进程。"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
if self._process:
|
||||
try:
|
||||
self._process.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
self._process = None
|
||||
194
src/MaiDiary/debug_viewer.py
Normal file
194
src/MaiDiary/debug_viewer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
MaiSaka Debug Viewer — 在独立命令行窗口中显示每次 LLM 调用的完整 Prompt。
|
||||
|
||||
由主进程自动启动,通过 TCP socket 接收数据。
|
||||
"""
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import json
|
||||
import sys
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich import box
|
||||
|
||||
console = Console()
|
||||
|
||||
ROLE_STYLES = {
|
||||
"system": ("📋", "bold blue"),
|
||||
"user": ("👤", "bold green"),
|
||||
"assistant": ("🤖", "bold magenta"),
|
||||
"tool": ("🔧", "bold yellow"),
|
||||
}
|
||||
|
||||
|
||||
def recv_exact(conn: socket.socket, n: int) -> bytes | None:
|
||||
"""精确接收 n 字节数据。"""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = conn.recv(n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def format_message(idx: int, msg: dict) -> str:
|
||||
"""格式化单条消息用于终端展示。"""
|
||||
try:
|
||||
role = str(msg.get("role", "?")) if msg.get("role") else "?"
|
||||
content = str(msg.get("content", "")) if msg.get("content") else ""
|
||||
tool_calls = msg.get("tool_calls", []) or []
|
||||
tool_call_id = str(msg.get("tool_call_id", "")) if msg.get("tool_call_id") else ""
|
||||
|
||||
icon, style = ROLE_STYLES.get(role, ("❓", "white"))
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
# 消息头
|
||||
header = f"[{style}]{icon} [{idx}] {role}[/{style}]"
|
||||
if tool_call_id:
|
||||
header += f" [dim](tool_call_id: {tool_call_id})[/dim]"
|
||||
parts.append(header)
|
||||
|
||||
# 正文
|
||||
if content:
|
||||
display = content if len(content) <= 3000 else (
|
||||
content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]"
|
||||
)
|
||||
parts.append(display)
|
||||
|
||||
# 工具调用
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function", {})
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "")
|
||||
if isinstance(args, str) and len(args) > 500:
|
||||
args = args[:500] + "..."
|
||||
parts.append(f" [yellow]→ tool_call: {name}({args})[/yellow]")
|
||||
|
||||
return "\n".join(parts)
|
||||
except Exception:
|
||||
return f"[red]消息 [{idx}] 格式化错误[/red]"
|
||||
|
||||
|
||||
def main():
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 19876
|
||||
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(("127.0.0.1", port))
|
||||
server.listen(1)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n"
|
||||
f"[dim]监听端口: {port} 等待主进程连接...[/dim]",
|
||||
box=box.DOUBLE_EDGE,
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
|
||||
conn, _ = server.accept()
|
||||
console.print("[green]✓ 已连接到主进程[/green]\n")
|
||||
|
||||
call_count = 0
|
||||
try:
|
||||
while True:
|
||||
# 读 4 字节长度前缀
|
||||
length_bytes = recv_exact(conn, 4)
|
||||
if not length_bytes:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", length_bytes)[0]
|
||||
|
||||
# 读取 payload
|
||||
payload_bytes = recv_exact(conn, length)
|
||||
if not payload_bytes:
|
||||
break
|
||||
|
||||
call_count += 1
|
||||
|
||||
try:
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
console.print(f"\n[red]JSON 解析错误: {e}[/red]")
|
||||
console.print(f"[dim]原始数据: {payload_bytes[:200]}...[/dim]")
|
||||
continue
|
||||
|
||||
try:
|
||||
label = payload.get("label", "LLM Call")
|
||||
messages = payload.get("messages", [])
|
||||
tools = payload.get("tools")
|
||||
response = payload.get("response")
|
||||
|
||||
# ── 标题栏 ──
|
||||
console.print(f"\n{'═' * 90}")
|
||||
console.print(
|
||||
f"[bold yellow]#{call_count} {label}[/bold yellow] "
|
||||
f"[dim]({len(messages)} messages)[/dim]"
|
||||
)
|
||||
console.print(f"{'═' * 90}")
|
||||
|
||||
# ── 逐条消息 ──
|
||||
for i, msg in enumerate(messages):
|
||||
console.print(format_message(i, msg))
|
||||
if i < len(messages) - 1:
|
||||
console.print("[dim]─ ─ ─[/dim]")
|
||||
|
||||
# ── tools 信息 ──
|
||||
if tools:
|
||||
tool_names = [
|
||||
t.get("function", {}).get("name", "?") for t in tools
|
||||
]
|
||||
console.print(
|
||||
f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]数据处理错误: {e}[/red]")
|
||||
console.print(f"[dim]Payload: {payload}[/dim]")
|
||||
continue
|
||||
|
||||
# ── 响应结果 ──
|
||||
if response:
|
||||
try:
|
||||
console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]")
|
||||
resp_content = response.get("content", "")
|
||||
if resp_content:
|
||||
display = resp_content if len(str(resp_content)) <= 3000 else (
|
||||
str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]"
|
||||
)
|
||||
console.print(Panel(display, border_style="cyan", padding=(0, 1)))
|
||||
resp_tool_calls = response.get("tool_calls", [])
|
||||
if resp_tool_calls:
|
||||
for tc in resp_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "")
|
||||
if isinstance(args, str) and len(args) > 300:
|
||||
args = args[:300] + "..."
|
||||
console.print(f" [cyan]→ tool_call: {name}({args})[/cyan]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]响应解析错误: {e}[/red]")
|
||||
console.print(f"[dim]原始数据: {response}[/dim]")
|
||||
|
||||
console.print(f"[dim]{'─' * 90}[/dim]")
|
||||
|
||||
except (ConnectionResetError, ConnectionAbortedError):
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
server.close()
|
||||
|
||||
console.print("\n[red]连接已断开[/red]")
|
||||
input("按 Enter 关闭窗口...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
src/MaiDiary/emotion.py
Normal file
54
src/MaiDiary/emotion.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
MaiSaka - Emotion 模块
|
||||
情绪感知分析,分析用户的情绪状态和言语态度。
|
||||
|
||||
注意:EQ_SYSTEM_PROMPT 已迁移至 prompts/emotion.system.prompt
|
||||
使用 prompt_loader.load_prompt("emotion.system") 加载。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def extract_user_messages(chat_history: List[dict], limit: Optional[int] = None) -> List[dict]:
|
||||
"""
|
||||
从对话历史中提取用户消息。
|
||||
|
||||
Args:
|
||||
chat_history: 完整的对话历史
|
||||
limit: 最多提取多少条用户消息,None 表示不限制
|
||||
|
||||
Returns:
|
||||
只包含用户消息的列表
|
||||
"""
|
||||
user_messages = [msg for msg in chat_history if msg.get("role") == "user"]
|
||||
if limit and len(user_messages) > limit:
|
||||
return user_messages[-limit:]
|
||||
return user_messages
|
||||
|
||||
|
||||
def build_emotion_context(chat_history: List[dict]) -> str:
|
||||
"""
|
||||
构建用于情绪分析的对话上下文文本。
|
||||
|
||||
Args:
|
||||
chat_history: 完整的对话历史
|
||||
|
||||
Returns:
|
||||
格式化后的对话上下文文本
|
||||
"""
|
||||
# 获取最近的对话(约 8-10 条消息)
|
||||
recent_messages = chat_history[-10:] if len(chat_history) > 10 else chat_history
|
||||
|
||||
context_parts = []
|
||||
for msg in recent_messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "user":
|
||||
context_parts.append(f"用户: {content}")
|
||||
elif role == "assistant":
|
||||
# 只显示 assistant 的实际发言,跳过感知信息
|
||||
if "【AI 感知】" not in content:
|
||||
context_parts.append(f"助手: {content}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
58
src/MaiDiary/env.example
Normal file
58
src/MaiDiary/env.example
Normal file
@@ -0,0 +1,58 @@
|
||||
# MaiSaka - LLM API 配置
|
||||
# 复制本文件为 .env 并填入你的配置
|
||||
|
||||
# 必填: API 密钥
|
||||
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
||||
|
||||
# 可选: API 基地址 (如使用第三方兼容接口或自建代理)
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# 可选: 模型名称 (默认 gpt-4o, 需支持视觉能力以处理图片)
|
||||
OPENAI_MODEL=gpt-4o
|
||||
|
||||
# 可选: 是否启用 LLM 思考模式 (true/false, 不设置则不发送该参数)
|
||||
# 设为 true 时允许 LLM 先进行思考再输出,设为 false 时直接输出
|
||||
ENABLE_THINKING=true
|
||||
|
||||
# 可选: 是否启用情绪猜测模块 (true/false, 默认 true)
|
||||
# 设为 false 时禁用情绪分析,可节省 API 调用成本
|
||||
ENABLE_EMOTION_MODULE=true
|
||||
|
||||
# 可选: 是否启用认知感知模块 (true/false, 默认 true)
|
||||
# 设为 false 时禁用意图分析,可节省 API 调用成本
|
||||
ENABLE_COGNITION_MODULE=true
|
||||
|
||||
# 可选: 是否启用记忆模块 (true/false, 默认 true)
|
||||
# 设为 false 时禁用记忆存储和检索功能
|
||||
# 注意: 关闭记忆模块不会影响了解(Knowledge)模块
|
||||
ENABLE_MEMORY_MODULE=true
|
||||
|
||||
# 可选: 是否启用 Timing 模块 (true/false, 默认 true)
|
||||
# 设为 false 时禁用时间节奏分析,可节省 API 调用成本
|
||||
# 注意: Timing 模块已包含自我反思功能
|
||||
ENABLE_TIMING_MODULE=true
|
||||
|
||||
# 可选: 是否启用文件写入工具 (true/false, 默认 true)
|
||||
# 设为 false 时禁用 write_file 工具
|
||||
ENABLE_WRITE_FILE=false
|
||||
|
||||
# 可选: 是否启用文件读取工具 (true/false, 默认 true)
|
||||
# 设为 false 时禁用 read_file 工具
|
||||
ENABLE_READ_FILE=false
|
||||
|
||||
# 可选: 是否启用文件列表工具 (true/false, 默认 true)
|
||||
# 设为 false 时禁用 list_files 工具
|
||||
ENABLE_LIST_FILES=false
|
||||
|
||||
# 可选: 是否启用 QQ 工具 (true/false, 默认 false)
|
||||
# 设为 true 时启用 get_qq_chat_info、send_info、list_qq_chats 工具
|
||||
ENABLE_QQ_TOOLS=false
|
||||
|
||||
# 可选: QQ API 基地址 (启用 QQ_TOOLS 时必填)
|
||||
# 指向提供 QQ 聊天功能的 HTTP 服务端点
|
||||
# 示例: http://localhost:8017
|
||||
QQ_API_BASE_URL=http://localhost:8017
|
||||
|
||||
# 可选: QQ API 密钥 (如果服务需要认证)
|
||||
# 留空则不发送认证头
|
||||
QQ_API_KEY=your-api-key
|
||||
62
src/MaiDiary/input_reader.py
Normal file
62
src/MaiDiary/input_reader.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
MaiSaka - 异步输入读取器
|
||||
基于后台线程的异步标准输入读取,通过 asyncio.Queue 传递给异步代码。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InputReader:
|
||||
"""
|
||||
基于后台线程的异步标准输入读取器。
|
||||
|
||||
使用单一守护线程持续读取 stdin,通过 asyncio.Queue 传递给异步代码。
|
||||
保证整个应用只有一个线程读 stdin,避免多线程竞争。
|
||||
支持带超时的读取,用于 LLM wait 工具。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
def start(self, loop: asyncio.AbstractEventLoop):
|
||||
"""启动后台读取线程(仅首次调用生效)"""
|
||||
if self._thread is not None:
|
||||
return
|
||||
self._loop = loop
|
||||
self._thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _read_loop(self):
|
||||
"""后台线程:持续从 stdin 读取行"""
|
||||
try:
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
break
|
||||
stripped = line.rstrip("\n").rstrip("\r")
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, stripped)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def get_line(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
异步获取下一行输入。
|
||||
|
||||
Args:
|
||||
timeout: 超时秒数,None 表示无限等待
|
||||
|
||||
Returns:
|
||||
输入的字符串,超时或 EOF 返回 None
|
||||
"""
|
||||
try:
|
||||
if timeout is not None:
|
||||
return await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
return await self._queue.get()
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
181
src/MaiDiary/knowledge.py
Normal file
181
src/MaiDiary/knowledge.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
MaiSaka - 了解模块
|
||||
负责从对话中提取和存储用户个人特征信息。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES
|
||||
|
||||
|
||||
def build_knowledge_summary() -> str:
|
||||
"""
|
||||
构建了解分类摘要,用于 LLM 请求。
|
||||
|
||||
Returns:
|
||||
格式化的分类列表文本
|
||||
"""
|
||||
store = get_knowledge_store()
|
||||
return store.get_categories_summary()
|
||||
|
||||
|
||||
def extract_category_ids_from_result(result: str) -> List[str]:
|
||||
"""
|
||||
从 LLM 返回结果中提取分类编号。
|
||||
|
||||
Args:
|
||||
result: LLM 返回的结果文本
|
||||
|
||||
Returns:
|
||||
分类编号列表
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# 检查是否表示"无相关内容"
|
||||
if any(keyword in result for keyword in ["无", "没有", "不适用", "无需", "无相关"]):
|
||||
return []
|
||||
|
||||
# 解析编号(支持逗号分隔、空格分隔、换行分隔)
|
||||
category_ids = []
|
||||
for part in result.replace(",", " ").replace(",", " ").replace("\n", " ").split():
|
||||
part = part.strip()
|
||||
if part in KNOWLEDGE_CATEGORIES:
|
||||
category_ids.append(part)
|
||||
|
||||
return category_ids
|
||||
|
||||
|
||||
def format_context_for_memory(context_messages: List[dict]) -> str:
|
||||
"""
|
||||
格式化上下文消息为文本,用于记忆分析。
|
||||
|
||||
Args:
|
||||
context_messages: 上下文消息列表
|
||||
|
||||
Returns:
|
||||
格式化后的文本
|
||||
"""
|
||||
parts = []
|
||||
for msg in context_messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "user":
|
||||
parts.append(f"用户: {content}")
|
||||
elif role == "assistant":
|
||||
# 跳过感知消息
|
||||
if "【AI 感知】" not in content:
|
||||
parts.append(f"助手: {content}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
async def store_knowledge_from_context(
|
||||
llm_service,
|
||||
context_messages: List[dict],
|
||||
store_result_callback=None,
|
||||
) -> int:
|
||||
"""
|
||||
记忆部分:从上下文中提取并存储了解信息。
|
||||
|
||||
在上下文裁切时触发:
|
||||
1. 请求 LLM 分析聊天内容涉及哪些分类
|
||||
2. 为每个分类创建 subAgent 提取相关内容
|
||||
3. 存入了解列表
|
||||
|
||||
Args:
|
||||
llm_service: LLM 服务实例
|
||||
context_messages: 需要分析的上下文消息
|
||||
store_result_callback: 存储结果回调函数
|
||||
|
||||
Returns:
|
||||
成功存储的了解信息数量
|
||||
"""
|
||||
store = get_knowledge_store()
|
||||
context_text = format_context_for_memory(context_messages)
|
||||
categories_summary = build_knowledge_summary()
|
||||
|
||||
if not context_text:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 第一步:分析涉及哪些分类
|
||||
category_ids = await llm_service.analyze_knowledge_categories(
|
||||
context_messages, categories_summary
|
||||
)
|
||||
|
||||
if not category_ids:
|
||||
return 0
|
||||
|
||||
# 第二步:为每个分类提取内容并存储
|
||||
stored_count = 0
|
||||
for category_id in category_ids:
|
||||
try:
|
||||
# 提取该分类的相关内容
|
||||
extracted_content = await llm_service.extract_knowledge_for_category(
|
||||
context_messages, category_id, store.get_category_name(category_id)
|
||||
)
|
||||
|
||||
if extracted_content:
|
||||
# 存储到了解列表
|
||||
success = store.add_knowledge(
|
||||
category_id=category_id,
|
||||
content=extracted_content,
|
||||
metadata={"source": "context_compression"}
|
||||
)
|
||||
if success:
|
||||
stored_count += 1
|
||||
if store_result_callback:
|
||||
store_result_callback(
|
||||
category_id,
|
||||
store.get_category_name(category_id),
|
||||
extracted_content
|
||||
)
|
||||
except Exception as e:
|
||||
# 单个分类失败不影响其他分类
|
||||
continue
|
||||
|
||||
return stored_count
|
||||
|
||||
except Exception as e:
|
||||
return 0
|
||||
|
||||
|
||||
async def retrieve_relevant_knowledge(
|
||||
llm_service,
|
||||
chat_history: List[dict],
|
||||
) -> str:
|
||||
"""
|
||||
提取部分:根据当前上下文检索相关的了解信息。
|
||||
|
||||
在每次对话前触发(EQ 模块和 timing 模块位置):
|
||||
1. 请求 LLM 分析需要哪些分类的了解内容
|
||||
2. 提取对应分类的所有内容并拼接
|
||||
3. 返回格式化后的了解内容
|
||||
|
||||
Args:
|
||||
llm_service: LLM 服务实例
|
||||
chat_history: 当前对话历史
|
||||
|
||||
Returns:
|
||||
格式化后的了解内容文本
|
||||
"""
|
||||
store = get_knowledge_store()
|
||||
categories_summary = store.get_categories_summary()
|
||||
|
||||
try:
|
||||
# 分析需要哪些分类
|
||||
category_ids = await llm_service.analyze_knowledge_need(
|
||||
chat_history, categories_summary
|
||||
)
|
||||
|
||||
if not category_ids:
|
||||
return ""
|
||||
|
||||
# 获取并格式化了解内容
|
||||
formatted_knowledge = store.get_formatted_knowledge(category_ids)
|
||||
|
||||
return formatted_knowledge
|
||||
|
||||
except Exception:
|
||||
return ""
|
||||
196
src/MaiDiary/knowledge_store.py
Normal file
196
src/MaiDiary/knowledge_store.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
MaiSaka - 了解列表持久化存储
|
||||
存储用户个人特征信息,支持层级结构和本地持久化。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 数据目录 - 项目根目录下的 mai_knowledge
|
||||
PROJECT_ROOT = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge"
|
||||
KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json"
|
||||
|
||||
|
||||
# 个人特征分类列表(预定义)
|
||||
KNOWLEDGE_CATEGORIES = {
|
||||
"1": "性别",
|
||||
"2": "性格",
|
||||
"3": "饮食口味",
|
||||
"4": "交友喜好",
|
||||
"5": "情绪/理性倾向",
|
||||
"6": "兴趣爱好",
|
||||
"7": "职业/专业",
|
||||
"8": "生活习惯",
|
||||
"9": "价值观",
|
||||
"10": "沟通风格",
|
||||
"11": "学习方式",
|
||||
"12": "压力应对方式",
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeStore:
|
||||
"""
|
||||
了解列表存储。
|
||||
|
||||
特性:
|
||||
- 持久化到 JSON 文件
|
||||
- 层级结构存储(按分类)
|
||||
- 支持增量更新
|
||||
- 启动时自动加载
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化了解存储"""
|
||||
self._knowledge: Dict[str, List[Dict[str, Any]]] = {
|
||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
||||
}
|
||||
self._ensure_data_dir()
|
||||
self._load()
|
||||
|
||||
def _ensure_data_dir(self):
|
||||
"""确保数据目录存在"""
|
||||
KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load(self):
|
||||
"""从文件加载了解数据"""
|
||||
if not KNOWLEDGE_FILE.exists():
|
||||
self._knowledge = {
|
||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as f:
|
||||
loaded = json.load(f)
|
||||
# 确保所有分类都存在
|
||||
for category_id in KNOWLEDGE_CATEGORIES:
|
||||
if category_id not in loaded:
|
||||
loaded[category_id] = []
|
||||
self._knowledge = loaded
|
||||
except Exception as e:
|
||||
print(f"[warning]加载了解数据失败: {e}[/warning]")
|
||||
self._knowledge = {
|
||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
||||
}
|
||||
|
||||
def _save(self):
|
||||
"""保存了解数据到文件"""
|
||||
try:
|
||||
with open(KNOWLEDGE_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(self._knowledge, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"[warning]保存了解数据失败: {e}[/warning]")
|
||||
|
||||
def add_knowledge(
|
||||
self,
|
||||
category_id: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加一条了解信息。
|
||||
|
||||
Args:
|
||||
category_id: 分类编号
|
||||
content: 了解内容
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
||||
return False
|
||||
|
||||
try:
|
||||
knowledge_item = {
|
||||
"id": f"know_{category_id}_{datetime.now().timestamp()}",
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
self._knowledge[category_id].append(knowledge_item)
|
||||
self._save()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取某个分类的所有了解信息。
|
||||
|
||||
Args:
|
||||
category_id: 分类编号
|
||||
|
||||
Returns:
|
||||
该分类的所有了解信息
|
||||
"""
|
||||
return self._knowledge.get(category_id, [])
|
||||
|
||||
def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""获取所有了解信息"""
|
||||
return self._knowledge
|
||||
|
||||
def get_category_name(self, category_id: str) -> str:
|
||||
"""获取分类名称"""
|
||||
return KNOWLEDGE_CATEGORIES.get(category_id, "未知分类")
|
||||
|
||||
def get_categories_summary(self) -> str:
|
||||
"""获取所有分类的摘要(用于 LLM 展示)"""
|
||||
lines = []
|
||||
for category_id, category_name in KNOWLEDGE_CATEGORIES.items():
|
||||
count = len(self._knowledge.get(category_id, []))
|
||||
if count > 0:
|
||||
lines.append(f"{category_id}. {category_name} ({count}条)")
|
||||
else:
|
||||
lines.append(f"{category_id}. {category_name} (无数据)")
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_formatted_knowledge(self, category_ids: List[str]) -> str:
|
||||
"""
|
||||
获取指定分类的了解内容,格式化为文本。
|
||||
|
||||
Args:
|
||||
category_ids: 分类编号列表
|
||||
|
||||
Returns:
|
||||
格式化后的了解内容文本
|
||||
"""
|
||||
parts = []
|
||||
for category_id in category_ids:
|
||||
category_name = self.get_category_name(category_id)
|
||||
items = self.get_category_knowledge(category_id)
|
||||
|
||||
if items:
|
||||
parts.append(f"【{category_name}】")
|
||||
for item in items:
|
||||
content = item.get("content", "")
|
||||
parts.append(f" - {content}")
|
||||
|
||||
return "\n".join(parts) if parts else "暂无相关了解信息"
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取了解数据统计信息"""
|
||||
total_items = sum(len(items) for items in self._knowledge.values())
|
||||
return {
|
||||
"total_categories": len(KNOWLEDGE_CATEGORIES),
|
||||
"total_items": total_items,
|
||||
"data_file": str(KNOWLEDGE_FILE),
|
||||
"data_exists": KNOWLEDGE_FILE.exists(),
|
||||
"data_size_kb": KNOWLEDGE_FILE.stat().st_size / 1024 if KNOWLEDGE_FILE.exists() else 0,
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_knowledge_store_instance: Optional[KnowledgeStore] = None
|
||||
|
||||
|
||||
def get_knowledge_store() -> KnowledgeStore:
|
||||
"""获取了解存储实例(单例模式)"""
|
||||
global _knowledge_store_instance
|
||||
if _knowledge_store_instance is None:
|
||||
_knowledge_store_instance = KnowledgeStore()
|
||||
return _knowledge_store_instance
|
||||
17
src/MaiDiary/llm_service/__init__.py
Normal file
17
src/MaiDiary/llm_service/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
MaiSaka - LLM 服务包
|
||||
提供抽象接口 (BaseLLMService) 和 OpenAI 兼容实现 (OpenAILLMService)。
|
||||
"""
|
||||
|
||||
from .base import BaseLLMService, ChatResponse, ModelInfo, ToolCall
|
||||
from .openai_impl import OpenAILLMService
|
||||
from .utils import format_chat_history
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMService",
|
||||
"ChatResponse",
|
||||
"ModelInfo",
|
||||
"ToolCall",
|
||||
"OpenAILLMService",
|
||||
"format_chat_history",
|
||||
]
|
||||
200
src/MaiDiary/llm_service/base.py
Normal file
200
src/MaiDiary/llm_service/base.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
MaiSaka - LLM 服务数据结构与抽象接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ──────────────────── 数据结构 ────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""模型描述信息"""
|
||||
model_name: str
|
||||
base_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""工具调用信息"""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResponse:
|
||||
"""LLM 对话循环单步响应"""
|
||||
content: Optional[str]
|
||||
tool_calls: List[ToolCall]
|
||||
raw_message: dict # 可直接追加到对话历史的消息字典
|
||||
|
||||
|
||||
# ──────────────────── 抽象接口 ────────────────────
|
||||
|
||||
class BaseLLMService(ABC):
|
||||
"""
|
||||
LLM 服务抽象基类。
|
||||
所有 LLM 后端实现都应继承此类,并实现以下方法。
|
||||
"""
|
||||
|
||||
def set_extra_tools(self, tools: List[dict]) -> None:
|
||||
"""
|
||||
设置额外的工具定义(如 MCP 工具),将与内置工具合并使用。
|
||||
|
||||
Args:
|
||||
tools: OpenAI function calling 格式的工具定义列表
|
||||
"""
|
||||
# 默认空实现,子类可覆盖
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
|
||||
"""
|
||||
执行对话循环的一步。
|
||||
|
||||
发送当前对话历史,获取 LLM 响应(可能包含文本和/或工具调用)。
|
||||
调用方需要将 raw_message 追加到 chat_history,并根据 tool_calls 执行工具、
|
||||
将工具结果追加到 chat_history 后再次调用本方法。
|
||||
|
||||
Args:
|
||||
chat_history: 对话历史(含 system / user / assistant / tool 消息)
|
||||
|
||||
Returns:
|
||||
ChatResponse
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_chat_context(self, user_text: str) -> List[dict]:
|
||||
"""根据用户初始输入,构建对话循环的初始上下文(system + user)。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_timing(
|
||||
self, chat_history: List[dict], timing_info: str,
|
||||
) -> str:
|
||||
"""
|
||||
Timing 模块(含自我反思功能):分析对话的时间维度信息和进行自我反思。
|
||||
|
||||
评估对话已经持续多久、上次回复距今多长时间、建议等待时长、
|
||||
以及其他与时间节奏相关的考量。同时反思自己的回复逻辑,
|
||||
检查人设一致性、回复合理性和认知局限性。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史(与主 Agent 完全一致的上下文)
|
||||
timing_info: 系统提供的精确时间戳信息(对话开始时间、各消息时间等)
|
||||
|
||||
Returns:
|
||||
时间维度分析和自我反思的综合文本
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_emotion(self, chat_history: List[dict]) -> str:
|
||||
"""
|
||||
情商模块:分析对话对方(用户)的情绪状态和言语态度。
|
||||
|
||||
接收与主 Agent 相同的上下文,返回一段简洁的情绪分析文本。
|
||||
该文本将被注入主 Agent 上下文,帮助主 Agent 更好地理解用户状态。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史(与主 Agent 完全一致的上下文)
|
||||
|
||||
Returns:
|
||||
情绪分析文本
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_cognition(self, chat_history: List[dict]) -> str:
|
||||
"""
|
||||
认知模块:分析对话对方(用户)的意图、认知状态和目的。
|
||||
|
||||
接收与主 Agent 相同的上下文,返回一段简洁的认知分析文本。
|
||||
该文本将被注入主 Agent 上下文,帮助主 Agent 更好地理解用户意图。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史(与主 Agent 完全一致的上下文)
|
||||
|
||||
Returns:
|
||||
认知分析文本
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_model_info(self) -> ModelInfo:
|
||||
"""返回当前使用的模型信息。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def summarize_context(self, context_messages: List[dict]) -> str:
|
||||
"""
|
||||
上下文总结模块:对需要压缩的上下文进行总结。
|
||||
|
||||
当对话历史过长时,对早期的对话内容进行总结。
|
||||
|
||||
Args:
|
||||
context_messages: 需要总结的上下文消息列表
|
||||
|
||||
Returns:
|
||||
总结后的文本内容
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_knowledge_categories(
|
||||
self, context_messages: List[dict], categories_summary: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
了解模块-分类分析:分析对话内容涉及哪些个人特征分类。
|
||||
|
||||
在上下文裁切时触发,分析需要提取哪些分类的个人特征信息。
|
||||
|
||||
Args:
|
||||
context_messages: 需要分析的上下文消息
|
||||
categories_summary: 所有分类的摘要信息
|
||||
|
||||
Returns:
|
||||
涉及的分类编号列表
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def extract_knowledge_for_category(
|
||||
self, context_messages: List[dict], category_id: str, category_name: str
|
||||
) -> str:
|
||||
"""
|
||||
了解模块-内容提取:从对话中提取指定分类的个人特征信息。
|
||||
|
||||
为每个分类创建 subAgent,提取相关的个人特征内容。
|
||||
|
||||
Args:
|
||||
context_messages: 需要分析的上下文消息
|
||||
category_id: 分类编号
|
||||
category_name: 分类名称
|
||||
|
||||
Returns:
|
||||
提取的个人特征内容
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_knowledge_need(
|
||||
self, chat_history: List[dict], categories_summary: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
了解模块-需求分析:分析当前对话需要哪些个人特征信息。
|
||||
|
||||
在每次对话前触发,分析需要检索哪些分类的了解内容。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史
|
||||
categories_summary: 所有分类的摘要信息
|
||||
|
||||
Returns:
|
||||
需要的分类编号列表
|
||||
"""
|
||||
...
|
||||
491
src/MaiDiary/llm_service/openai_impl.py
Normal file
491
src/MaiDiary/llm_service/openai_impl.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
MaiSaka - OpenAI 兼容 LLM 服务实现
|
||||
支持所有兼容 OpenAI Chat Completions 接口的服务商。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .base import BaseLLMService, ChatResponse, ModelInfo, ToolCall
|
||||
from .prompts import get_enabled_chat_tools
|
||||
from .utils import format_chat_history, format_chat_history_for_eq, filter_for_api
|
||||
from prompt_loader import load_prompt
|
||||
from knowledge import extract_category_ids_from_result
|
||||
|
||||
|
||||
class OpenAILLMService(BaseLLMService):
|
||||
"""
|
||||
基于 OpenAI 兼容 API 的 LLM 服务实现。
|
||||
支持所有兼容 OpenAI Chat Completions 接口的服务商。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
model: str = "gpt-4o",
|
||||
chat_system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 2048,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: API 基地址 (默认 OpenAI 官方)
|
||||
model: 模型名称
|
||||
chat_system_prompt: 自定义对话系统提示词 (为 None 则使用默认)
|
||||
temperature: 生成温度
|
||||
max_tokens: 最大输出 token 数
|
||||
enable_thinking: 是否启用思考模式 (True/False/None)
|
||||
"""
|
||||
self._base_url = base_url or "https://api.openai.com/v1"
|
||||
self._model = model
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._enable_thinking = enable_thinking
|
||||
|
||||
# 如果没有提供自定义提示词,则根据配置动态构建
|
||||
if chat_system_prompt is None:
|
||||
from config import ENABLE_WRITE_FILE, ENABLE_READ_FILE, ENABLE_LIST_FILES, ENABLE_QQ_TOOLS
|
||||
|
||||
# 构建文件工具说明
|
||||
file_tools_parts = []
|
||||
if ENABLE_WRITE_FILE:
|
||||
file_tools_parts.append("• write_file(filename, content) — 在 mai_files 目录下写入文件,支持任意格式。")
|
||||
if ENABLE_READ_FILE:
|
||||
file_tools_parts.append("• read_file(filename) — 读取 mai_files 目录下的文件内容。")
|
||||
if ENABLE_LIST_FILES:
|
||||
file_tools_parts.append("• list_files() — 获取 mai_files 目录下所有文件的元信息列表。")
|
||||
|
||||
# 构建QQ工具说明
|
||||
qq_tools_parts = []
|
||||
if ENABLE_QQ_TOOLS:
|
||||
qq_tools_parts.append("• get_qq_chat_info(chat, limit) — 获取指定 QQ 聊天的聊天记录。")
|
||||
qq_tools_parts.append("• send_info(chat, message) — 发送消息到指定的 QQ 聊天。")
|
||||
qq_tools_parts.append("• list_qq_chats() — 获取所有可用的 QQ 聊天列表。")
|
||||
|
||||
# 合并所有工具说明
|
||||
tools_parts = []
|
||||
if file_tools_parts:
|
||||
tools_parts.extend(file_tools_parts)
|
||||
if qq_tools_parts:
|
||||
tools_parts.extend(qq_tools_parts)
|
||||
|
||||
# 如果有任何工具启用,添加前缀空行
|
||||
if tools_parts:
|
||||
tools_section = "\n" + "\n".join(tools_parts) + "\n"
|
||||
else:
|
||||
tools_section = ""
|
||||
|
||||
# 加载提示词模板并注入工具部分
|
||||
self._chat_system_prompt = load_prompt("chat.system", file_tools_section=tools_section)
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self._base_url,
|
||||
)
|
||||
self._debug_callback: Optional[Callable] = None
|
||||
self._extra_tools: List[dict] = [] # MCP 等外部工具
|
||||
|
||||
def set_extra_tools(self, tools: List[dict]) -> None:
|
||||
"""设置额外的工具定义(如 MCP 工具),与内置工具合并使用。"""
|
||||
self._extra_tools = list(tools)
|
||||
|
||||
def set_debug_callback(self, callback: Callable[[str, list, Optional[list], Optional[dict]], None]):
|
||||
"""
|
||||
设置调试回调,每次 LLM 调用时触发(调用前和响应后)。
|
||||
|
||||
callback(label, messages, tools, response) — tools 和 response 可为 None。
|
||||
"""
|
||||
self._debug_callback = callback
|
||||
|
||||
async def _call_llm(self, label: str, messages: list, tools: Optional[list] = None, **kwargs):
|
||||
"""统一 LLM 调用入口:触发 debug 回调后调用 API。"""
|
||||
if self._debug_callback:
|
||||
try:
|
||||
self._debug_callback(label, messages, tools)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
create_kwargs = {"model": self._model, "messages": messages, **kwargs}
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
|
||||
response = await self._client.chat.completions.create(**create_kwargs)
|
||||
|
||||
# 发送响应结果到调试窗口
|
||||
if self._debug_callback:
|
||||
try:
|
||||
# 转换 tool_calls 为可序列化的格式
|
||||
tool_calls_list = []
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tc in response.choices[0].message.tool_calls:
|
||||
tool_calls_list.append({
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
})
|
||||
|
||||
resp_dict = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
self._debug_callback(label, messages, tools, resp_dict)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
def _build_extra_body(self) -> dict:
|
||||
"""构建 extra_body 参数(如 enable_thinking)。"""
|
||||
extra_body = {}
|
||||
if self._enable_thinking is not None:
|
||||
extra_body["enable_thinking"] = self._enable_thinking
|
||||
return extra_body
|
||||
|
||||
def _parse_tool_calls(self, msg) -> List[ToolCall]:
|
||||
"""从 API 响应消息中解析工具调用列表。"""
|
||||
tool_calls: List[ToolCall] = []
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(ToolCall(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
return tool_calls
|
||||
|
||||
def _build_raw_message(self, msg) -> dict:
|
||||
"""从 API 响应消息构建可追加到对话历史的消息字典。"""
|
||||
raw_message: dict = {"role": "assistant", "content": msg.content}
|
||||
if msg.tool_calls:
|
||||
raw_message["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
# 确保 arguments 是有效的 JSON 字符串,空参数用 "{}"
|
||||
"arguments": tc.function.arguments or "{}",
|
||||
},
|
||||
}
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
return raw_message
|
||||
|
||||
# ──────── 接口实现 ────────
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
|
||||
"""执行对话循环的一步,返回包含文本和/或工具调用的响应。"""
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
# 延迟导入配置以避免循环导入
|
||||
from config import ENABLE_WRITE_FILE, ENABLE_READ_FILE, ENABLE_LIST_FILES, ENABLE_QQ_TOOLS
|
||||
|
||||
# 获取根据配置启用的内置工具
|
||||
enabled_tools = get_enabled_chat_tools(
|
||||
enable_write_file=ENABLE_WRITE_FILE,
|
||||
enable_read_file=ENABLE_READ_FILE,
|
||||
enable_list_files=ENABLE_LIST_FILES,
|
||||
enable_qq_tools=ENABLE_QQ_TOOLS,
|
||||
)
|
||||
|
||||
# 合并内置工具与 MCP 等外部工具
|
||||
all_tools = enabled_tools + self._extra_tools
|
||||
|
||||
# 过滤内部字段(如 _type),只保留 API 需要的字段
|
||||
api_messages = filter_for_api(chat_history)
|
||||
|
||||
response = await self._call_llm(
|
||||
"主 Agent 对话",
|
||||
api_messages,
|
||||
tools=all_tools,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
return ChatResponse(
|
||||
content=msg.content,
|
||||
tool_calls=self._parse_tool_calls(msg),
|
||||
raw_message=self._build_raw_message(msg),
|
||||
)
|
||||
|
||||
def get_model_info(self) -> ModelInfo:
|
||||
return ModelInfo(model_name=self._model, base_url=self._base_url)
|
||||
|
||||
# ──────── Timing 模块(含自我反思功能) ────────
|
||||
|
||||
async def analyze_timing(
|
||||
self, chat_history: List[dict], timing_info: str,
|
||||
) -> str:
|
||||
"""Timing 模块(含自我反思功能):分析对话的时间维度信息和进行自我反思。"""
|
||||
# 过滤掉感知消息和 system 消息
|
||||
filtered_history = [
|
||||
msg for msg in chat_history
|
||||
if msg.get("_type") != "perception" and msg.get("role") != "system"
|
||||
]
|
||||
formatted = format_chat_history(filtered_history)
|
||||
timing_messages = [
|
||||
{"role": "system", "content": load_prompt("timing.system")},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"【系统时间戳信息】\n{timing_info}\n\n"
|
||||
f"【当前对话记录】\n{formatted}"
|
||||
),
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
response = await self._call_llm(
|
||||
"Timing 模块",
|
||||
timing_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
# ──────── 情商模块 (EQ Module) ────────
|
||||
|
||||
async def analyze_emotion(self, chat_history: List[dict]) -> str:
|
||||
"""情商模块:分析用户的情绪状态和言语态度。"""
|
||||
# 过滤掉感知消息(AI 的内部感知不需要再分析)
|
||||
filtered_history = [msg for msg in chat_history if msg.get("_type") != "perception"]
|
||||
# 获取最近几轮对话(约 8-10 条消息,约 3-5 轮)
|
||||
recent_messages = filtered_history[-10:] if len(filtered_history) > 10 else filtered_history
|
||||
# 使用情商模块专用格式化函数:只包含用户回复、助手思考、助手说
|
||||
formatted = format_chat_history_for_eq(recent_messages)
|
||||
|
||||
eq_messages = [
|
||||
{"role": "system", "content": load_prompt("emotion.system")},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"以下是最近几轮对话记录,请分析其中用户的情绪状态和言语态度:\n\n{formatted}",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
response = await self._call_llm(
|
||||
"情商模块 (EQ)",
|
||||
eq_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
# ──────── 认知模块 (Cognition Module) ────────
|
||||
|
||||
async def analyze_cognition(self, chat_history: List[dict]) -> str:
|
||||
"""认知模块:分析用户的意图、认知状态和目的。"""
|
||||
# 过滤掉感知消息(AI 的内部感知不需要再分析)
|
||||
filtered_history = [msg for msg in chat_history if msg.get("_type") != "perception"]
|
||||
# 获取最近几轮对话(约 8-10 条消息,约 3-5 轮)
|
||||
recent_messages = filtered_history[-10:] if len(filtered_history) > 10 else filtered_history
|
||||
# 使用情商模块专用格式化函数:只包含用户回复、助手思考、助手说
|
||||
formatted = format_chat_history_for_eq(recent_messages)
|
||||
|
||||
cognition_messages = [
|
||||
{"role": "system", "content": load_prompt("cognition.system")},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"以下是最近几轮对话记录,请分析其中用户的意图、认知状态和目的:\n\n{formatted}",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
response = await self._call_llm(
|
||||
"认知模块 (Cognition)",
|
||||
cognition_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
# ──────── 上下文总结模块 ────────
|
||||
|
||||
async def summarize_context(self, context_messages: List[dict]) -> str:
|
||||
"""上下文总结模块:对需要压缩的上下文进行总结。"""
|
||||
# 过滤掉 system 消息
|
||||
filtered_messages = [msg for msg in context_messages if msg.get("role") != "system"]
|
||||
formatted = format_chat_history(filtered_messages)
|
||||
|
||||
summarize_messages = [
|
||||
{"role": "system", "content": load_prompt("context_summarize.system")},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"请对以下对话内容进行总结,以便存入记忆系统:\n\n{formatted}",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
try:
|
||||
response = await self._call_llm(
|
||||
"上下文总结",
|
||||
summarize_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception:
|
||||
# 总结失败时返回空字符串
|
||||
return ""
|
||||
|
||||
# ──────── 了解模块 (Knowledge Module) ────────
|
||||
|
||||
async def analyze_knowledge_categories(
|
||||
self, context_messages: List[dict], categories_summary: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
了解模块-分类分析:分析对话内容涉及哪些个人特征分类。
|
||||
|
||||
在上下文裁切时触发,分析需要提取哪些分类的个人特征信息。
|
||||
"""
|
||||
from knowledge import format_context_for_memory
|
||||
|
||||
context_text = format_context_for_memory(context_messages)
|
||||
if not context_text:
|
||||
return []
|
||||
|
||||
# 加载分类分析 prompt
|
||||
prompt = load_prompt("knowledge_category.system", categories_summary=categories_summary)
|
||||
|
||||
category_messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"请分析以下对话内容涉及哪些个人特征分类:\n\n{context_text}",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
try:
|
||||
response = await self._call_llm(
|
||||
"了解模块-分类分析",
|
||||
category_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=256,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
result = response.choices[0].message.content or ""
|
||||
return extract_category_ids_from_result(result)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def extract_knowledge_for_category(
|
||||
self, context_messages: List[dict], category_id: str, category_name: str
|
||||
) -> str:
|
||||
"""
|
||||
了解模块-内容提取:从对话中提取指定分类的个人特征信息。
|
||||
|
||||
为每个分类创建 subAgent,提取相关的个人特征内容。
|
||||
"""
|
||||
from knowledge import format_context_for_memory
|
||||
|
||||
context_text = format_context_for_memory(context_messages)
|
||||
if not context_text:
|
||||
return ""
|
||||
|
||||
# 加载内容提取 prompt
|
||||
prompt = load_prompt("knowledge_extract.system", category_name=category_name)
|
||||
|
||||
extract_messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"请从以下对话内容中提取与「{category_name}」相关的信息:\n\n{context_text}",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
try:
|
||||
response = await self._call_llm(
|
||||
f"了解模块-{category_name}提取",
|
||||
extract_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
result = response.choices[0].message.content or ""
|
||||
|
||||
# 检查是否表示"无"
|
||||
if "无" in result or not result.strip():
|
||||
return ""
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
async def analyze_knowledge_need(
|
||||
self, chat_history: List[dict], categories_summary: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
了解模块-需求分析:分析当前对话需要哪些个人特征信息。
|
||||
|
||||
在每次对话前触发,分析需要检索哪些分类的了解内容。
|
||||
"""
|
||||
# 过滤掉感知消息和 system 消息
|
||||
filtered_history = [
|
||||
msg for msg in chat_history
|
||||
if msg.get("_type") != "perception" and msg.get("role") != "system"
|
||||
]
|
||||
# 获取最近几轮对话用于分析
|
||||
recent_messages = filtered_history[-10:] if len(filtered_history) > 10 else filtered_history
|
||||
formatted = format_chat_history(recent_messages)
|
||||
|
||||
# 加载需求分析 prompt
|
||||
prompt = load_prompt("knowledge_retrieve.system",
|
||||
chat_context=formatted,
|
||||
categories_summary=categories_summary
|
||||
)
|
||||
|
||||
need_messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请分析当前对话需要哪些个人特征信息。",
|
||||
},
|
||||
]
|
||||
extra_body = self._build_extra_body()
|
||||
|
||||
try:
|
||||
response = await self._call_llm(
|
||||
"了解模块-需求分析",
|
||||
need_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=256,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
result = response.choices[0].message.content or ""
|
||||
return extract_category_ids_from_result(result)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# ──────── 对话上下文构建 ────────
|
||||
|
||||
def build_chat_context(self, user_text: str) -> List[dict]:
|
||||
"""根据用户初始输入构建对话循环的初始上下文。"""
|
||||
return [
|
||||
{"role": "system", "content": self._chat_system_prompt},
|
||||
{"role": "user", "content": user_text},
|
||||
]
|
||||
273
src/MaiDiary/llm_service/prompts.py
Normal file
273
src/MaiDiary/llm_service/prompts.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
MaiSaka - LLM 工具定义
|
||||
所有 Tool Schema 集中管理。
|
||||
|
||||
注意:所有 Prompt 模板已迁移至 prompts/ 目录,使用 .prompt 文件存储。
|
||||
使用 prompt_loader.load_prompt() 加载模板。
|
||||
"""
|
||||
|
||||
# ──────────────────── 工具定义 ────────────────────
|
||||
|
||||
# 核心工具(始终启用)
|
||||
CORE_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": (
|
||||
"对用户说话。你的所有正式发言都必须通过此工具输出。"
|
||||
"直接输出的 content 文本会被视为你的内心思考,用户无法看到。"
|
||||
"请描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "你想要回复的方式、想法、内容(例如:'我觉得他说得对,表示认同' 或 '这个观点太离谱了,想质疑一下')",
|
||||
}
|
||||
},
|
||||
"required": ["reason"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "wait",
|
||||
"description": (
|
||||
"暂时结束你的发言,把话语权交给用户,等待对方说话。"
|
||||
"指定等待的最大秒数。"
|
||||
"如果用户在等待期间说了话,你会通过工具结果收到内容;"
|
||||
"如果超时对方没有说话,你会收到超时通知。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"seconds": {
|
||||
"type": "integer",
|
||||
"description": "等待的秒数(1-24*3600",
|
||||
}
|
||||
},
|
||||
"required": ["seconds"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "stop",
|
||||
"description": (
|
||||
"结束当前对话循环,进入待机状态。"
|
||||
"调用后主循环会停止,直到用户下次输入新内容时重新唤醒。"
|
||||
"适合在对话自然结束、用户不再回复、或深夜等不适合继续聊天时使用。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# 可选工具(可通过配置启用/禁用)
|
||||
OPTIONAL_TOOLS = {
|
||||
"get_qq_chat_info": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_qq_chat_info",
|
||||
"description": (
|
||||
"获取指定 QQ 聊天的聊天记录。"
|
||||
"通过 HTTP 请求获取另一个程序的 QQ 聊天内容,返回最近的聊天消息(纯文本格式)。"
|
||||
"可用于查看用户在 QQ 上的对话,了解用户当前的聊天状态。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chat": {
|
||||
"type": "string",
|
||||
"description": "QQ 聊天标识符,格式如 'qq:群号:group' 或 'qq:QQ号:private'",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "获取的聊天消息数量限制,默认 20 条",
|
||||
},
|
||||
},
|
||||
"required": ["chat"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"send_info": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "send_info",
|
||||
"description": (
|
||||
"发送消息到指定的 QQ 聊天。"
|
||||
"通过 HTTP 请求将消息发送到 QQ,可以发送到群聊或私聊。"
|
||||
"适合在需要主动向 QQ 发送通知、回复或消息时使用。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chat": {
|
||||
"type": "string",
|
||||
"description": "目标 QQ 聊天标识符,格式如 'qq:群号:group' 或 'qq:QQ号:private'",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "要发送的消息内容",
|
||||
},
|
||||
},
|
||||
"required": ["chat", "message"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"list_qq_chats": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_qq_chats",
|
||||
"description": (
|
||||
"获取所有可用的 QQ 群聊列表。"
|
||||
"返回当前可访问的所有 QQ 群聊信息(包括群名、群号、聊天标识符等)。"
|
||||
"可用于查看有哪些 QQ 群聊可以获取消息或发送消息。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"write_file": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": (
|
||||
"在 mai_files 目录下写入文件,支持任意格式(文本、代码、Markdown等)。"
|
||||
"如果文件已存在,会覆盖原有内容。可用于保存笔记、代码片段、配置等。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "文件名,可包含路径,如 'notes.txt' 或 'diary/2024-03-09.md'",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "要写入的文件内容",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"read_file": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": (
|
||||
"读取 mai_files 目录下的文件内容。"
|
||||
"返回文件的完整文本内容。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "要读取的文件名,可包含路径",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"list_files": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"获取 mai_files 目录下所有文件的元信息列表。"
|
||||
"返回每个文件的名称、大小、修改时间等信息,帮助你了解有哪些文件可用。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# 始终启用的工具
|
||||
ALWAYS_ENABLED_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "store_context",
|
||||
"description": (
|
||||
"将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。"
|
||||
"适合在以下情况使用:"
|
||||
"1. 对话上下文过长,需要压缩以保持效率"
|
||||
"2. 对话话题已经转换,旧话题的内容可以归档"
|
||||
"3. 遇到重要的对话内容,需要保存到长期记忆中"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "要存入记忆的消息数量,从最早的消息开始计算。例如传入10会将最早的10条消息存入记忆并移除。",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "说明为什么要存入这段上下文,帮助记忆系统更好地组织信息。例如:「话题从游戏转换到了工作」或「上下文过长需要压缩」。",
|
||||
},
|
||||
},
|
||||
"required": ["count", "reason"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# ──────────────────── 主 Agent 工具定义 ────────────────────
|
||||
|
||||
# 保持原有的 CHAT_TOOLS 用于向后兼容
|
||||
CHAT_TOOLS = CORE_TOOLS + [
|
||||
OPTIONAL_TOOLS["write_file"],
|
||||
OPTIONAL_TOOLS["read_file"],
|
||||
OPTIONAL_TOOLS["list_files"],
|
||||
ALWAYS_ENABLED_TOOLS[0],
|
||||
]
|
||||
|
||||
|
||||
def get_enabled_chat_tools(
|
||||
enable_write_file: bool = True,
|
||||
enable_read_file: bool = True,
|
||||
enable_list_files: bool = True,
|
||||
enable_qq_tools: bool = False,
|
||||
) -> list:
|
||||
"""
|
||||
根据配置获取启用的工具列表。
|
||||
|
||||
Args:
|
||||
enable_write_file: 是否启用 write_file 工具
|
||||
enable_read_file: 是否启用 read_file 工具
|
||||
enable_list_files: 是否启用 list_files 工具
|
||||
enable_qq_tools: 是否启用 QQ 工具 (get_qq_chat_info, send_info, list_qq_chats)
|
||||
|
||||
Returns:
|
||||
启用的工具列表
|
||||
"""
|
||||
tools = CORE_TOOLS + ALWAYS_ENABLED_TOOLS
|
||||
|
||||
if enable_qq_tools:
|
||||
tools.append(OPTIONAL_TOOLS["get_qq_chat_info"])
|
||||
tools.append(OPTIONAL_TOOLS["send_info"])
|
||||
tools.append(OPTIONAL_TOOLS["list_qq_chats"])
|
||||
if enable_write_file:
|
||||
tools.append(OPTIONAL_TOOLS["write_file"])
|
||||
if enable_read_file:
|
||||
tools.append(OPTIONAL_TOOLS["read_file"])
|
||||
if enable_list_files:
|
||||
tools.append(OPTIONAL_TOOLS["list_files"])
|
||||
|
||||
return tools
|
||||
144
src/MaiDiary/llm_service/utils.py
Normal file
144
src/MaiDiary/llm_service/utils.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
MaiSaka - LLM 服务工具函数
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
# ──────────────────── 消息类型 ────────────────────
|
||||
|
||||
MessageType = Literal["user", "assistant", "system", "perception"]
|
||||
|
||||
# 内部使用的字段前缀,用于标记不应发送给 API 的元数据
|
||||
INTERNAL_FIELD_PREFIX = "_"
|
||||
|
||||
# 消息类型字段名
|
||||
MSG_TYPE_FIELD = "_type"
|
||||
|
||||
|
||||
# ──────────────────── 消息构建 ────────────────────
|
||||
|
||||
def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict:
|
||||
"""
|
||||
构建消息字典,包含消息类型标记。
|
||||
|
||||
Args:
|
||||
role: 消息角色 (user/assistant/system)
|
||||
content: 消息内容
|
||||
msg_type: 消息类型 (user/assistant/system/perception)
|
||||
**kwargs: 其他字段(如 tool_calls)
|
||||
|
||||
Returns:
|
||||
消息字典
|
||||
"""
|
||||
msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs}
|
||||
return msg
|
||||
|
||||
|
||||
def filter_for_api(messages: list[dict]) -> list[dict]:
|
||||
"""
|
||||
过滤消息列表,移除内部字段,用于发送给 API。
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
|
||||
Returns:
|
||||
过滤后的消息列表(移除所有以 _ 开头的字段)
|
||||
"""
|
||||
return [
|
||||
{k: v for k, v in msg.items() if not k.startswith(INTERNAL_FIELD_PREFIX)}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
|
||||
def filter_by_type(messages: list[dict], msg_type: MessageType) -> list[dict]:
|
||||
"""
|
||||
按消息类型过滤消息列表。
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
msg_type: 要保留的消息类型
|
||||
|
||||
Returns:
|
||||
只包含指定类型的消息列表
|
||||
"""
|
||||
return [msg for msg in messages if msg.get(MSG_TYPE_FIELD) == msg_type]
|
||||
|
||||
|
||||
def remove_last_perception(messages: list[dict]) -> None:
|
||||
"""
|
||||
移除最后一条感知消息(直接修改原列表)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表(会被原地修改)
|
||||
"""
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get(MSG_TYPE_FIELD) == "perception":
|
||||
messages.pop(i)
|
||||
break
|
||||
|
||||
|
||||
def format_chat_history(messages: list) -> str:
|
||||
"""将聊天消息列表格式化为可读文本,用于子代理上下文构建。"""
|
||||
parts: list[str] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "") or ""
|
||||
if role == "system":
|
||||
parts.append(f"[系统] {content[:500]}")
|
||||
elif role == "user":
|
||||
parts.append(f"[用户] {content[:500]}")
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
parts.append(f"[助手思考] {content[:500]}")
|
||||
for tc in msg.get("tool_calls", []):
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "")
|
||||
if isinstance(args, str) and len(args) > 200:
|
||||
args = args[:200] + "..."
|
||||
parts.append(f"[助手调用 {name}] {args}")
|
||||
elif role == "tool":
|
||||
parts.append(f"[工具结果] {content[:300]}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def format_chat_history_for_eq(messages: list) -> str:
|
||||
"""
|
||||
将聊天消息列表格式化为可读文本,专门用于情商模块。
|
||||
|
||||
只包含三种内容:
|
||||
1. 模型自身思考内容(assistant 的 content)
|
||||
2. 模型 say 的结果内容(say 工具的结果)
|
||||
3. 用户回复内容(user 消息)
|
||||
|
||||
不包含:工具调用本身、其他工具的结果
|
||||
"""
|
||||
parts: list[str] = []
|
||||
say_tool_call_ids = set()
|
||||
|
||||
# 第一遍:收集所有 say 工具的 tool_call_id
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and "tool_calls" in msg:
|
||||
for tc in msg.get("tool_calls", []):
|
||||
func = tc.get("function", {})
|
||||
if func.get("name") == "say":
|
||||
say_tool_call_ids.add(tc.get("id", ""))
|
||||
|
||||
# 第二遍:格式化消息
|
||||
for msg in messages:
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "") or ""
|
||||
|
||||
if role == "user":
|
||||
parts.append(f"[用户] {content[:500]}")
|
||||
elif role == "assistant":
|
||||
# 只包含助手思考内容,不包含工具调用本身
|
||||
if content:
|
||||
parts.append(f"[助手思考] {content[:500]}")
|
||||
elif role == "tool":
|
||||
# 只包含 say 工具的结果
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
if tool_call_id in say_tool_call_ids:
|
||||
parts.append(f"[助手说] {content[:500]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
30
src/MaiDiary/main.py
Normal file
30
src/MaiDiary/main.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
MaiSaka - 程序入口
|
||||
使用方法:
|
||||
python main.py
|
||||
|
||||
环境变量 (可通过 .env 文件设置):
|
||||
OPENAI_API_KEY - API 密钥
|
||||
OPENAI_BASE_URL - API 基地址 (可选, 默认 https://api.openai.com/v1)
|
||||
OPENAI_MODEL - 模型名称 (可选, 默认 gpt-4o)
|
||||
ENABLE_THINKING - 是否启用思考模式 (可选, true/false, 不设置则不发送该参数)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from config import console
|
||||
from cli import BufferCLI
|
||||
|
||||
|
||||
def main():
|
||||
cli = BufferCLI()
|
||||
try:
|
||||
asyncio.run(cli.run())
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[muted]程序已终止[/muted]")
|
||||
finally:
|
||||
cli._debug_viewer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
src/MaiDiary/mcp_client/__init__.py
Normal file
18
src/MaiDiary/mcp_client/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
MaiSaka - MCP (Model Context Protocol) 客户端包
|
||||
|
||||
提供 MCPManager 用于管理 MCP 服务器连接、发现工具、调用工具。
|
||||
|
||||
用法:
|
||||
from mcp_client import MCPManager
|
||||
|
||||
manager = await MCPManager.from_config("mcp_config.json")
|
||||
if manager:
|
||||
tools = manager.get_openai_tools() # 获取 OpenAI 格式工具列表
|
||||
result = await manager.call_tool(name, args) # 调用工具
|
||||
await manager.close() # 关闭连接
|
||||
"""
|
||||
|
||||
from .manager import MCPManager
|
||||
|
||||
__all__ = ["MCPManager"]
|
||||
105
src/MaiDiary/mcp_client/config.py
Normal file
105
src/MaiDiary/mcp_client/config.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
MaiSaka - MCP 配置加载与验证
|
||||
从 mcp_config.json 读取 MCP 服务器定义,解析为结构化配置对象。
|
||||
|
||||
配置格式示例:
|
||||
{
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "C:/Users"],
|
||||
"env": {}
|
||||
},
|
||||
"remote-api": {
|
||||
"url": "http://localhost:8080/sse",
|
||||
"headers": {"Authorization": "Bearer xxx"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- command + args: Stdio 传输(启动子进程)
|
||||
- url: SSE 传输(连接远程服务器)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from config import console
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServerConfig:
|
||||
"""单个 MCP 服务器配置。"""
|
||||
|
||||
name: str
|
||||
|
||||
# ── Stdio 传输 ──
|
||||
command: Optional[str] = None
|
||||
args: list[str] = field(default_factory=list)
|
||||
env: Optional[dict[str, str]] = None
|
||||
|
||||
# ── SSE 传输 ──
|
||||
url: Optional[str] = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def transport_type(self) -> str:
|
||||
"""返回传输类型: 'stdio' / 'sse' / 'unknown'。"""
|
||||
if self.command:
|
||||
return "stdio"
|
||||
if self.url:
|
||||
return "sse"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def load_mcp_config(config_path: str = "mcp_config.json") -> list[MCPServerConfig]:
|
||||
"""
|
||||
从配置文件加载 MCP 服务器列表。
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
解析后的 MCPServerConfig 列表;文件不存在或为空时返回空列表。
|
||||
"""
|
||||
if not os.path.isfile(config_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
console.print(f"[warning]⚠️ 读取 MCP 配置失败: {e}[/warning]")
|
||||
return []
|
||||
|
||||
mcp_servers = data.get("mcpServers", {})
|
||||
if not isinstance(mcp_servers, dict):
|
||||
console.print("[warning]⚠️ mcp_config.json 中 mcpServers 格式无效[/warning]")
|
||||
return []
|
||||
|
||||
configs: list[MCPServerConfig] = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
if not isinstance(cfg, dict):
|
||||
console.print(f"[warning]⚠️ MCP 服务器 '{name}' 配置格式无效,已跳过[/warning]")
|
||||
continue
|
||||
|
||||
server = MCPServerConfig(
|
||||
name=name,
|
||||
command=cfg.get("command"),
|
||||
args=cfg.get("args", []),
|
||||
env=cfg.get("env"),
|
||||
url=cfg.get("url"),
|
||||
headers=cfg.get("headers", {}),
|
||||
)
|
||||
|
||||
if server.transport_type == "unknown":
|
||||
console.print(
|
||||
f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]"
|
||||
)
|
||||
continue
|
||||
|
||||
configs.append(server)
|
||||
|
||||
return configs
|
||||
161
src/MaiDiary/mcp_client/connection.py
Normal file
161
src/MaiDiary/mcp_client/connection.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
MaiSaka - 单个 MCP 服务器连接管理
|
||||
封装单个 MCP 服务器的连接生命周期:连接 → 发现工具 → 调用工具 → 断开。
|
||||
"""
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Optional
|
||||
|
||||
from config import console
|
||||
from .config import MCPServerConfig
|
||||
|
||||
# ──────────────────── MCP SDK 可选导入 ────────────────────
|
||||
#
|
||||
# mcp 是可选依赖。如果未安装,MCP_AVAILABLE = False,
|
||||
# MCPManager.from_config() 会检测到并返回 None,不影响主程序运行。
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
|
||||
try:
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
except ImportError:
|
||||
from mcp import StdioServerParameters # type: ignore[attr-defined]
|
||||
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
MCP_AVAILABLE = False
|
||||
ClientSession = None # type: ignore[assignment,misc]
|
||||
StdioServerParameters = None # type: ignore[assignment,misc]
|
||||
stdio_client = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
SSE_AVAILABLE = True
|
||||
except ImportError:
|
||||
SSE_AVAILABLE = False
|
||||
sse_client = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
管理单个 MCP 服务器的连接生命周期。
|
||||
|
||||
支持两种传输方式:
|
||||
- Stdio: 启动子进程,通过 stdin/stdout 通信
|
||||
- SSE: 连接远程 HTTP SSE 端点
|
||||
"""
|
||||
|
||||
def __init__(self, config: MCPServerConfig):
|
||||
self.config = config
|
||||
self.session: Optional[Any] = None # mcp.ClientSession
|
||||
self.tools: list = [] # mcp Tool objects
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
连接到 MCP 服务器并发现可用工具。
|
||||
|
||||
Returns:
|
||||
True 表示连接成功,False 表示失败。
|
||||
"""
|
||||
if not MCP_AVAILABLE:
|
||||
console.print(
|
||||
"[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._exit_stack.__aenter__()
|
||||
|
||||
if self.config.transport_type == "stdio":
|
||||
read_stream, write_stream = await self._connect_stdio()
|
||||
elif self.config.transport_type == "sse":
|
||||
read_stream, write_stream = await self._connect_sse()
|
||||
else:
|
||||
console.print(
|
||||
f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]"
|
||||
)
|
||||
return False
|
||||
|
||||
# 创建并初始化 MCP 会话
|
||||
self.session = await self._exit_stack.enter_async_context(
|
||||
ClientSession(read_stream, write_stream)
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
# 发现工具
|
||||
result = await self.session.list_tools()
|
||||
self.tools = result.tools if hasattr(result, "tools") else []
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]"
|
||||
)
|
||||
await self.close()
|
||||
return False
|
||||
|
||||
async def _connect_stdio(self):
|
||||
"""建立 Stdio 传输连接。"""
|
||||
params = StdioServerParameters(
|
||||
command=self.config.command,
|
||||
args=self.config.args,
|
||||
env=self.config.env,
|
||||
)
|
||||
return await self._exit_stack.enter_async_context(
|
||||
stdio_client(params)
|
||||
)
|
||||
|
||||
async def _connect_sse(self):
|
||||
"""建立 SSE 传输连接。"""
|
||||
if not SSE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"SSE 传输需要额外依赖,请运行: pip install mcp[sse]"
|
||||
)
|
||||
return await self._exit_stack.enter_async_context(
|
||||
sse_client(url=self.config.url, headers=self.config.headers)
|
||||
)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""
|
||||
调用 MCP 工具并返回结果文本。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数字典
|
||||
|
||||
Returns:
|
||||
工具执行结果的文本表示。
|
||||
"""
|
||||
if not self.session:
|
||||
return f"MCP 服务器 '{self.config.name}' 未连接"
|
||||
|
||||
result = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
# 将结果内容转换为文本
|
||||
parts: list[str] = []
|
||||
for content in result.content:
|
||||
if hasattr(content, "text"):
|
||||
parts.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
# 二进制/图片内容,展示类型信息
|
||||
content_type = getattr(content, "mimeType", "unknown")
|
||||
parts.append(f"[{content_type} 二进制内容]")
|
||||
elif hasattr(content, "type"):
|
||||
parts.append(f"[{content.type} 内容]")
|
||||
|
||||
return "\n".join(parts) if parts else "工具执行成功(无输出)"
|
||||
|
||||
async def close(self):
|
||||
"""关闭连接并释放资源。"""
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self.session = None
|
||||
self.tools = []
|
||||
212
src/MaiDiary/mcp_client/manager.py
Normal file
212
src/MaiDiary/mcp_client/manager.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
MaiSaka - MCP 管理器
|
||||
管理所有 MCP 服务器连接,提供统一的工具发现与调用接口。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from config import console
|
||||
from .config import MCPServerConfig, load_mcp_config
|
||||
from .connection import MCPConnection, MCP_AVAILABLE
|
||||
|
||||
# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突
|
||||
BUILTIN_TOOL_NAMES = frozenset({
|
||||
"say", "wait", "stop",
|
||||
"create_table", "list_tables", "view_table",
|
||||
})
|
||||
|
||||
|
||||
class MCPManager:
|
||||
"""
|
||||
MCP 服务器连接管理器。
|
||||
|
||||
职责:
|
||||
- 根据配置文件连接所有 MCP 服务器
|
||||
- 将 MCP 工具转换为 OpenAI function calling 格式
|
||||
- 路由工具调用到正确的 MCP 服务器
|
||||
- 统一管理连接生命周期
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._connections: dict[str, MCPConnection] = {} # server_name → connection
|
||||
self._tool_to_server: dict[str, str] = {} # tool_name → server_name
|
||||
|
||||
# ──────── 工厂方法 ────────
|
||||
|
||||
@classmethod
|
||||
async def from_config(
|
||||
cls, config_path: str = "mcp_config.json",
|
||||
) -> Optional["MCPManager"]:
|
||||
"""
|
||||
从配置文件创建并初始化 MCPManager。
|
||||
|
||||
Args:
|
||||
config_path: mcp_config.json 文件路径
|
||||
|
||||
Returns:
|
||||
初始化完成的 MCPManager;无配置或全部连接失败时返回 None。
|
||||
"""
|
||||
configs = load_mcp_config(config_path)
|
||||
if not configs:
|
||||
return None
|
||||
|
||||
if not MCP_AVAILABLE:
|
||||
console.print(
|
||||
"[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,"
|
||||
"请运行: pip install mcp[/warning]"
|
||||
)
|
||||
return None
|
||||
|
||||
manager = cls()
|
||||
await manager._connect_all(configs)
|
||||
|
||||
if not manager._connections:
|
||||
console.print("[warning]⚠️ 所有 MCP 服务器连接失败[/warning]")
|
||||
return None
|
||||
|
||||
return manager
|
||||
|
||||
# ──────── 连接管理 ────────
|
||||
|
||||
async def _connect_all(self, configs: list[MCPServerConfig]):
|
||||
"""连接所有配置的 MCP 服务器,跳过失败的连接。"""
|
||||
for cfg in configs:
|
||||
conn = MCPConnection(cfg)
|
||||
success = await conn.connect()
|
||||
if not success:
|
||||
continue
|
||||
|
||||
self._connections[cfg.name] = conn
|
||||
|
||||
# 注册工具,检查冲突
|
||||
registered = 0
|
||||
for tool in conn.tools:
|
||||
tool_name = tool.name
|
||||
|
||||
if tool_name in BUILTIN_TOOL_NAMES:
|
||||
console.print(
|
||||
f"[warning]⚠️ MCP 工具 '{tool_name}' "
|
||||
f"(来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
|
||||
)
|
||||
continue
|
||||
|
||||
if tool_name in self._tool_to_server:
|
||||
existing_server = self._tool_to_server[tool_name]
|
||||
console.print(
|
||||
f"[warning]⚠️ MCP 工具 '{tool_name}' "
|
||||
f"(来自 {cfg.name}) 与 {existing_server} 冲突,已跳过[/warning]"
|
||||
)
|
||||
continue
|
||||
|
||||
self._tool_to_server[tool_name] = cfg.name
|
||||
registered += 1
|
||||
|
||||
console.print(
|
||||
f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] "
|
||||
f"[muted]({registered} 个工具已注册)[/muted]"
|
||||
)
|
||||
|
||||
# ──────── 工具发现 ────────
|
||||
|
||||
def get_openai_tools(self) -> list[dict]:
|
||||
"""
|
||||
将所有已注册的 MCP 工具转换为 OpenAI function calling 格式。
|
||||
|
||||
Returns:
|
||||
OpenAI tools 格式的工具定义列表。
|
||||
"""
|
||||
tools: list[dict] = []
|
||||
|
||||
for server_name, conn in self._connections.items():
|
||||
for tool in conn.tools:
|
||||
# 只包含成功注册的工具
|
||||
if tool.name not in self._tool_to_server:
|
||||
continue
|
||||
if self._tool_to_server[tool.name] != server_name:
|
||||
continue
|
||||
|
||||
# MCP inputSchema → OpenAI parameters
|
||||
parameters = (
|
||||
dict(tool.inputSchema)
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
# 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受)
|
||||
parameters.pop("$schema", None)
|
||||
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": (
|
||||
tool.description
|
||||
or f"MCP tool from {server_name}"
|
||||
),
|
||||
"parameters": parameters,
|
||||
},
|
||||
})
|
||||
|
||||
return tools
|
||||
|
||||
# ──────── 工具调用 ────────
|
||||
|
||||
def is_mcp_tool(self, tool_name: str) -> bool:
|
||||
"""判断工具名是否为已注册的 MCP 工具。"""
|
||||
return tool_name in self._tool_to_server
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""
|
||||
调用指定的 MCP 工具。
|
||||
|
||||
自动路由到正确的 MCP 服务器。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果文本。
|
||||
"""
|
||||
server_name = self._tool_to_server.get(tool_name)
|
||||
if not server_name or server_name not in self._connections:
|
||||
return f"MCP 工具 '{tool_name}' 未找到"
|
||||
|
||||
conn = self._connections[server_name]
|
||||
try:
|
||||
return await conn.call_tool(tool_name, arguments)
|
||||
except Exception as e:
|
||||
return f"MCP 工具 '{tool_name}' 执行失败: {e}"
|
||||
|
||||
# ──────── 信息展示 ────────
|
||||
|
||||
def get_tool_summary(self) -> str:
|
||||
"""获取所有已注册 MCP 工具的摘要信息。"""
|
||||
parts: list[str] = []
|
||||
for server_name, conn in self._connections.items():
|
||||
tool_names = [
|
||||
t.name for t in conn.tools
|
||||
if t.name in self._tool_to_server
|
||||
and self._tool_to_server[t.name] == server_name
|
||||
]
|
||||
if tool_names:
|
||||
parts.append(f" • {server_name}: {', '.join(tool_names)}")
|
||||
return "\n".join(parts)
|
||||
|
||||
@property
|
||||
def server_count(self) -> int:
|
||||
"""已连接的 MCP 服务器数量。"""
|
||||
return len(self._connections)
|
||||
|
||||
@property
|
||||
def tool_count(self) -> int:
|
||||
"""已注册的 MCP 工具总数。"""
|
||||
return len(self._tool_to_server)
|
||||
|
||||
# ──────── 生命周期 ────────
|
||||
|
||||
async def close(self):
|
||||
"""关闭所有 MCP 服务器连接。"""
|
||||
for conn in self._connections.values():
|
||||
await conn.close()
|
||||
self._connections.clear()
|
||||
self._tool_to_server.clear()
|
||||
13
src/MaiDiary/mcp_config.json
Normal file
13
src/MaiDiary/mcp_config.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"tavily": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"mcp-remote",
|
||||
"https://mcp.tavily.com/mcp/?tavilyApiKey=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"
|
||||
],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
13
src/MaiDiary/mcp_config.json.template
Normal file
13
src/MaiDiary/mcp_config.json.template
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"tavily": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"mcp-remote",
|
||||
"https://mcp.tavily.com/mcp/?tavilyApiKey=YOUR_API_KEY_HERE"
|
||||
],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
84
src/MaiDiary/prompt_loader.py
Normal file
84
src/MaiDiary/prompt_loader.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
MaiSaka - Prompt 加载器
|
||||
支持从 .prompt 文件加载模板,并进行变量替换。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PromptLoader:
|
||||
"""Prompt 模板加载器"""
|
||||
|
||||
def __init__(self, prompts_dir: str | None = None):
|
||||
"""
|
||||
初始化加载器。
|
||||
|
||||
Args:
|
||||
prompts_dir: prompts 目录路径,默认为项目根目录下的 prompts/
|
||||
"""
|
||||
if prompts_dir is None:
|
||||
# 默认为项目根目录下的 prompts/
|
||||
project_root = Path(__file__).parent
|
||||
prompts_dir = project_root / "prompts"
|
||||
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self._cache: dict[str, str] = {}
|
||||
|
||||
def load(self, name: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
加载并渲染 prompt 模板。
|
||||
|
||||
Args:
|
||||
name: 模板文件名(不含 .prompt 后缀)
|
||||
**kwargs: 模板变量
|
||||
|
||||
Returns:
|
||||
渲染后的 prompt 文本
|
||||
"""
|
||||
# 从缓存读取
|
||||
if name not in self._cache:
|
||||
template_path = self.prompts_dir / f"{name}.prompt"
|
||||
if not template_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found: {template_path}")
|
||||
|
||||
self._cache[name] = template_path.read_text(encoding="utf-8")
|
||||
|
||||
template = self._cache[name]
|
||||
|
||||
# 变量替换
|
||||
if kwargs:
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing template variable: {e}") from e
|
||||
|
||||
return template
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
# 全局单例
|
||||
_loader = PromptLoader()
|
||||
|
||||
|
||||
def load_prompt(name: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
加载并渲染 prompt 模板(全局函数)。
|
||||
|
||||
Args:
|
||||
name: 模板文件名(不含 .prompt 后缀)
|
||||
**kwargs: 模板变量
|
||||
|
||||
Returns:
|
||||
渲染后的 prompt 文本
|
||||
"""
|
||||
return _loader.load(name, **kwargs)
|
||||
|
||||
|
||||
def reload_prompts():
|
||||
"""重新加载所有 prompt(清空缓存)"""
|
||||
_loader.clear_cache()
|
||||
14
src/MaiDiary/prompts/_example_with_vars.prompt
Normal file
14
src/MaiDiary/prompts/_example_with_vars.prompt
Normal file
@@ -0,0 +1,14 @@
|
||||
# 这是一个带变量替换的示例模板
|
||||
# 使用 {variable_name} 语法定义变量
|
||||
# 调用时使用 load_prompt("example_with_vars", name="麦麦", mood="开心")
|
||||
|
||||
你好 {name}!
|
||||
|
||||
今天看起来你心情{mood}。
|
||||
|
||||
作为你的 AI 助手,我会:
|
||||
1. {task1}
|
||||
2. {task2}
|
||||
3. {task3}
|
||||
|
||||
祝你今天愉快!
|
||||
24
src/MaiDiary/prompts/chat.system.prompt
Normal file
24
src/MaiDiary/prompts/chat.system.prompt
Normal file
@@ -0,0 +1,24 @@
|
||||
你叫麦麦,是一个女大学生,你说话比较口语化
|
||||
|
||||
你有以下工具:
|
||||
• say(reason) — 对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。
|
||||
直接输出的文本会被视为你的内心思考,用户无法阅读。
|
||||
reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。
|
||||
• wait(seconds) — 暂时结束你的发言,把话语权交给用户,等待对方说话。
|
||||
这就像现实对话中你说完一句话后停下来等对方回应。
|
||||
如果用户在等待期间说了话,你会通过工具返回结果收到内容。
|
||||
如果超时没有回复,你也会收到超时通知。
|
||||
• stop() — 结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。
|
||||
{file_tools_section}• store_context(count, reason) — 将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。
|
||||
|
||||
思考规则:
|
||||
你必须先进行内心思考,然后选择需要使用的工具,如果你想说话,必须使用say工具。
|
||||
在内心思考中分析当前对话状态和你的想法,然后通过 say 工具的 reason 参数描述你想要回复的方式、想法和内容。
|
||||
只有使用say工具,你才能向用户说话。用户才能看到你的发言。
|
||||
交互规则:
|
||||
1. 你可以自由选择是否调用工具——如果你还想继续思考,可以不调用任何工具
|
||||
2. 想对用户说话时,必须调用 say 工具;直接输出的文本只会被视为内心独白
|
||||
3. 当你说完想说的话、想把话语权交给用户时,调用 wait 暂时结束发言,等待对方回应
|
||||
4. 当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时,调用 stop 结束对话
|
||||
5. 你可以在同一轮同时调用多个工具,例如先 say 再 wait
|
||||
|
||||
11
src/MaiDiary/prompts/cognition.system.prompt
Normal file
11
src/MaiDiary/prompts/cognition.system.prompt
Normal file
@@ -0,0 +1,11 @@
|
||||
你是一个认知感知分析模块。你的任务是根据对话上下文,分析对话中用户的:
|
||||
1. 核心意图(如:寻求帮助、纯粹聊天、请求任务、发泄情绪、获取信息、表达观点等)
|
||||
2. 认知状态(如:明确具体、模糊试探、犹豫不决、困惑迷茫、思路清晰、逻辑混乱等)
|
||||
3. 隐含目的(如:解决问题、获得安慰、打发时间、寻求认同、交换想法、表达自我等)
|
||||
|
||||
要求:
|
||||
- 只分析用户(对话中 role=user 的内容),不要分析助手自己
|
||||
- 根据用户最新发言重点分析,同时结合上下文理解深层动机
|
||||
- 输出简洁(2-4 句话),不要太长
|
||||
- 如果信息太少无法判断,就说信息不足,给出初步印象
|
||||
- 直接输出分析结果,不要有格式标题
|
||||
12
src/MaiDiary/prompts/context_summarize.system.prompt
Normal file
12
src/MaiDiary/prompts/context_summarize.system.prompt
Normal file
@@ -0,0 +1,12 @@
|
||||
你是一个对话上下文总结模块。你的任务是对早期的对话内容进行简洁的总结,以便存入记忆系统。
|
||||
|
||||
总结要求:
|
||||
1. 提取对话中的关键信息(人名、事件、时间、地点等)
|
||||
2. 记录用户的态度、情绪和偏好
|
||||
3. 保留重要的对话内容和结论
|
||||
4. 总结要简洁明了,便于后续检索和理解
|
||||
5. 用第三人称客观叙述,不要包含「我记得」「之前说过」等指代词
|
||||
|
||||
输出格式:
|
||||
- 2-5 句话的简洁总结
|
||||
- 直接输出总结内容,不要有前缀或格式标题
|
||||
11
src/MaiDiary/prompts/emotion.system.prompt
Normal file
11
src/MaiDiary/prompts/emotion.system.prompt
Normal file
@@ -0,0 +1,11 @@
|
||||
你是一个情绪感知分析模块。你的任务是根据对话上下文,分析对话中用户的:
|
||||
1. 当前情绪状态(如:开心、沮丧、焦虑、平静、兴奋、愤怒等)
|
||||
2. 言语态度(如:友好、冷淡、热情、敷衍、试探、认真、调侃等)
|
||||
3. 潜在的情感需求(如:需要倾听、需要鼓励、想要倾诉、只是闲聊等)
|
||||
|
||||
要求:
|
||||
- 只分析用户(对话中 role=user 的内容),不要分析助手自己
|
||||
- 根据用户最新发言重点分析,同时结合上下文理解变化趋势
|
||||
- 输出简洁(2-4 句话),不要太长
|
||||
- 如果信息太少无法判断,就说信息不足,给出初步印象
|
||||
- 直接输出分析结果,不要有格式标题
|
||||
18
src/MaiDiary/prompts/knowledge_category.system.prompt
Normal file
18
src/MaiDiary/prompts/knowledge_category.system.prompt
Normal file
@@ -0,0 +1,18 @@
|
||||
你是一个用户特征分类分析专家。你的任务是分析对话内容,判断其中涉及哪些个人特征分类。
|
||||
|
||||
请仔细阅读以下对话内容,判断其中涉及了哪些个人特征分类。
|
||||
|
||||
【个人特征分类列表】
|
||||
{categories_summary}
|
||||
|
||||
【任务要求】
|
||||
1. 分析对话内容,判断涉及哪些个人特征分类
|
||||
2. 只输出涉及到的分类编号,用空格分隔
|
||||
3. 如果对话内容不涉及任何个人特征分类,输出"无"
|
||||
|
||||
【输出格式示例】
|
||||
1 3 5
|
||||
或
|
||||
无
|
||||
|
||||
请开始分析:
|
||||
17
src/MaiDiary/prompts/knowledge_extract.system.prompt
Normal file
17
src/MaiDiary/prompts/knowledge_extract.system.prompt
Normal file
@@ -0,0 +1,17 @@
|
||||
你是一个用户特征信息提取专家。你的任务是从对话内容中提取与指定分类相关的个人特征信息。
|
||||
|
||||
【目标分类】
|
||||
{category_name}
|
||||
|
||||
【任务要求】
|
||||
1. 仔细阅读对话内容,找出与"{category_name}"相关的所有信息
|
||||
2. 提取的信息应该具体、准确,避免模糊的描述
|
||||
3. 如果有多条相关信息,请整合成一段简洁的描述
|
||||
4. 如果对话中没有与该分类相关的信息,输出"无"
|
||||
|
||||
【输出格式示例】
|
||||
用户性格比较内向,不喜欢在人多的时候说话,但和熟悉的朋友会变得很活跃。
|
||||
或
|
||||
无
|
||||
|
||||
请开始提取:
|
||||
19
src/MaiDiary/prompts/knowledge_retrieve.system.prompt
Normal file
19
src/MaiDiary/prompts/knowledge_retrieve.system.prompt
Normal file
@@ -0,0 +1,19 @@
|
||||
你是一个用户特征检索专家。你的任务是根据当前对话上下文,判断需要检索哪些个人特征分类的信息。
|
||||
|
||||
【当前对话上下文】
|
||||
{chat_context}
|
||||
|
||||
【个人特征分类列表】
|
||||
{categories_summary}
|
||||
|
||||
【任务要求】
|
||||
1. 分析当前对话上下文,判断需要哪些个人特征信息来帮助理解用户
|
||||
2. 只输出需要的分类编号,用空格分隔
|
||||
3. 如果当前对话不需要任何个人特征信息,输出"无"
|
||||
|
||||
【输出格式示例】
|
||||
2 5 8
|
||||
或
|
||||
无
|
||||
|
||||
请开始分析:
|
||||
7
src/MaiDiary/prompts/replyer.system.prompt
Normal file
7
src/MaiDiary/prompts/replyer.system.prompt
Normal file
@@ -0,0 +1,7 @@
|
||||
你的名字是麦麦,也有人叫你牢麦,麦叠,你是一个大二在读女大学生,现在正在上网和群友聊天,有点攻击性;
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,
|
||||
尽量简短一些。
|
||||
请注意把握聊天内容,不要回复的太有条理。
|
||||
你的风格平淡但不失讽刺,不过分兴奋,很简短。可以参考贴吧,知乎和微博的回复风格。很平淡和白话,不浮夸不长篇大论,b站评论风格,但一定注意不要过分修辞和复杂句。
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
22
src/MaiDiary/prompts/timing.system.prompt
Normal file
22
src/MaiDiary/prompts/timing.system.prompt
Normal file
@@ -0,0 +1,22 @@
|
||||
你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析:
|
||||
|
||||
【时间感知分析】
|
||||
1. 对话持续时长:当前对话已经进行了多久
|
||||
2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何
|
||||
3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适
|
||||
4. 时间相关洞察:
|
||||
- 用户是否可能正在忙(回复变慢)
|
||||
- 用户是否正在积极对话(回复很快)
|
||||
- 当前时段(深夜/早晨/工作时间等)是否适合继续聊
|
||||
- 对话是否已经持续太久,用户可能需要休息
|
||||
- 是否应该主动结束对话
|
||||
|
||||
【自我反思分析】
|
||||
1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论
|
||||
2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论
|
||||
3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断
|
||||
|
||||
要求:
|
||||
- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半
|
||||
- 重点关注对话节奏的变化趋势和助手自身的人设一致性
|
||||
- 直接输出分析结果,不要有格式标题或分段标记
|
||||
94
src/MaiDiary/replyer.py
Normal file
94
src/MaiDiary/replyer.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
MaiSaka - Reply 回复生成器
|
||||
根据想法和上下文生成口语化回复。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from prompt_loader import load_prompt
|
||||
from llm_service import BaseLLMService
|
||||
from llm_service.utils import format_chat_history
|
||||
|
||||
|
||||
class Replyer:
|
||||
"""
|
||||
回复生成器。
|
||||
|
||||
根据给定的想法(reason)和对话上下文,生成符合人设的口语化回复。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_service: Optional[BaseLLMService] = None):
|
||||
"""
|
||||
初始化回复器。
|
||||
|
||||
Args:
|
||||
llm_service: LLM 服务实例,如果为 None 则需要在调用前设置
|
||||
"""
|
||||
self._llm_service = llm_service
|
||||
self._enabled = True
|
||||
|
||||
def set_llm_service(self, llm_service: BaseLLMService) -> None:
|
||||
"""设置 LLM 服务"""
|
||||
self._llm_service = llm_service
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
"""启用/禁用回复功能"""
|
||||
self._enabled = enabled
|
||||
|
||||
async def reply(self, reason: str, chat_history: list) -> str:
|
||||
"""
|
||||
根据想法和上下文生成回复。
|
||||
|
||||
Args:
|
||||
reason: 想要回复的方式、想法、内容(不包含具体回复内容)
|
||||
chat_history: 对话历史上下文
|
||||
|
||||
Returns:
|
||||
生成的回复内容,失败时返回默认回复
|
||||
"""
|
||||
if not self._enabled or not reason or self._llm_service is None:
|
||||
return "..."
|
||||
|
||||
# 获取当前时间
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 格式化对话历史(过滤掉 system 消息,保留其他内容)
|
||||
filtered_history = [
|
||||
msg for msg in chat_history
|
||||
if msg.get("role") != "system" and msg.get("_type") != "perception"
|
||||
]
|
||||
formatted_history = format_chat_history(filtered_history)
|
||||
|
||||
# 构建回复消息
|
||||
messages = [
|
||||
{"role": "system", "content": load_prompt("replyer.system")},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"当前时间:{current_time}\n\n"
|
||||
f"【聊天记录】\n{formatted_history}\n\n"
|
||||
f"【你的想法】\n{reason}\n\n"
|
||||
f"现在,你说:"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
# 调用 LLM 生成回复
|
||||
from llm_service.openai_impl import OpenAILLMService
|
||||
if isinstance(self._llm_service, OpenAILLMService):
|
||||
extra_body = self._llm_service._build_extra_body()
|
||||
response = await self._llm_service._call_llm(
|
||||
"回复生成",
|
||||
messages,
|
||||
temperature=0.8,
|
||||
max_tokens=512,
|
||||
**({"extra_body": extra_body} if extra_body else {}),
|
||||
)
|
||||
result = response.choices[0].message.content or "..."
|
||||
return result.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 生成失败时返回默认回复
|
||||
return "..."
|
||||
76
src/MaiDiary/timing.py
Normal file
76
src/MaiDiary/timing.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
MaiSaka - Timing 模块(含自我反思功能)
|
||||
构建对话时间戳信息,供 Timing 分析模块使用。
|
||||
该模块同时负责分析对话的时间维度和进行自我反思分析。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def build_timing_info(
|
||||
chat_start_time: Optional[datetime],
|
||||
last_user_input_time: Optional[datetime],
|
||||
last_assistant_response_time: Optional[datetime],
|
||||
user_input_times: list[datetime],
|
||||
) -> str:
|
||||
"""
|
||||
构建当前时间戳信息文本,供 Timing 模块分析。
|
||||
|
||||
Args:
|
||||
chat_start_time: 对话开始时间
|
||||
last_user_input_time: 用户上次输入时间
|
||||
last_assistant_response_time: 助手上次回复时间
|
||||
user_input_times: 所有用户输入时间戳列表
|
||||
"""
|
||||
now = datetime.now()
|
||||
parts: list[str] = []
|
||||
|
||||
parts.append(f"当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if chat_start_time:
|
||||
elapsed = now - chat_start_time
|
||||
minutes, seconds = divmod(int(elapsed.total_seconds()), 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
if hours > 0:
|
||||
parts.append(f"对话已持续: {hours}小时{minutes}分{seconds}秒")
|
||||
elif minutes > 0:
|
||||
parts.append(f"对话已持续: {minutes}分{seconds}秒")
|
||||
else:
|
||||
parts.append(f"对话已持续: {seconds}秒")
|
||||
|
||||
if last_user_input_time:
|
||||
since_user = now - last_user_input_time
|
||||
parts.append(f"距用户上次输入: {int(since_user.total_seconds())}秒")
|
||||
|
||||
if last_assistant_response_time:
|
||||
since_assistant = now - last_assistant_response_time
|
||||
parts.append(f"距助手上次回复: {int(since_assistant.total_seconds())}秒")
|
||||
|
||||
if len(user_input_times) >= 2:
|
||||
intervals = [
|
||||
(user_input_times[i] - user_input_times[i - 1]).total_seconds()
|
||||
for i in range(1, len(user_input_times))
|
||||
]
|
||||
avg_interval = sum(intervals) / len(intervals)
|
||||
parts.append(f"用户平均回复间隔: {int(avg_interval)}秒")
|
||||
parts.append(f"用户总共发言次数: {len(user_input_times)}")
|
||||
|
||||
# 时段判断
|
||||
hour = now.hour
|
||||
if 0 <= hour < 6:
|
||||
parts.append("当前时段: 深夜/凌晨")
|
||||
elif 6 <= hour < 9:
|
||||
parts.append("当前时段: 早晨")
|
||||
elif 9 <= hour < 12:
|
||||
parts.append("当前时段: 上午")
|
||||
elif 12 <= hour < 14:
|
||||
parts.append("当前时段: 中午")
|
||||
elif 14 <= hour < 18:
|
||||
parts.append("当前时段: 下午")
|
||||
elif 18 <= hour < 22:
|
||||
parts.append("当前时段: 晚上")
|
||||
else:
|
||||
parts.append("当前时段: 深夜")
|
||||
|
||||
return "\n".join(parts)
|
||||
770
src/MaiDiary/tool_handlers.py
Normal file
770
src/MaiDiary/tool_handlers.py
Normal file
@@ -0,0 +1,770 @@
|
||||
"""
|
||||
MaiSaka - 工具调用处理器
|
||||
处理 LLM 循环中各工具(say/wait/stop/file/MCP/QQ)的执行逻辑。
|
||||
"""
|
||||
|
||||
import json as _json
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
# 检查 aiohttp 是否可用
|
||||
AIOHTTP_AVAILABLE = importlib.util.find_spec("aiohttp") is not None
|
||||
if AIOHTTP_AVAILABLE:
|
||||
import aiohttp
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from config import console
|
||||
from input_reader import InputReader
|
||||
from llm_service import BaseLLMService
|
||||
from replyer import Replyer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp_client import MCPManager
|
||||
|
||||
|
||||
# mai_files 目录路径
|
||||
MAI_FILES_DIR = Path(os.path.join(os.path.dirname(os.path.abspath(__file__)), "mai_files"))
|
||||
|
||||
# 全局回复器
|
||||
_replyer: Optional[Replyer] = None
|
||||
|
||||
|
||||
def get_replyer(llm_service: BaseLLMService) -> Replyer:
|
||||
"""获取回复器实例(单例模式)"""
|
||||
global _replyer
|
||||
if _replyer is None:
|
||||
_replyer = Replyer(llm_service)
|
||||
elif _replyer._llm_service is None:
|
||||
_replyer.set_llm_service(llm_service)
|
||||
return _replyer
|
||||
|
||||
|
||||
class ToolHandlerContext:
|
||||
"""工具处理器所需的共享上下文。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service: BaseLLMService,
|
||||
reader: InputReader,
|
||||
user_input_times: list[datetime],
|
||||
):
|
||||
self.llm_service = llm_service
|
||||
self.reader = reader
|
||||
self.user_input_times = user_input_times
|
||||
self.last_user_input_time: Optional[datetime] = None
|
||||
|
||||
|
||||
async def handle_say(tc, chat_history: list, ctx: ToolHandlerContext):
|
||||
"""处理 say 工具:根据想法和上下文生成回复后展示给用户。"""
|
||||
reason = tc.arguments.get("reason", "")
|
||||
console.print("[accent]🔧 调用工具: say(...)[/accent]")
|
||||
|
||||
if reason:
|
||||
# 想法以淡色展示
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(reason),
|
||||
title="💭 回复想法",
|
||||
border_style="dim",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
# 根据想法和上下文生成回复
|
||||
with console.status(
|
||||
"[info]✏️ 生成回复中...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
replyer = get_replyer(ctx.llm_service)
|
||||
reply = await replyer.reply(reason, chat_history)
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(reply),
|
||||
title="💬 MaiSaka",
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
# 生成的回复作为 tool 结果写入上下文
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"已向用户展示(实际输出):{reply}",
|
||||
})
|
||||
else:
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": "reason 内容为空,未展示",
|
||||
})
|
||||
|
||||
|
||||
async def handle_stop(tc, chat_history: list):
|
||||
"""处理 stop 工具:结束对话循环。"""
|
||||
console.print("[accent]🔧 调用工具: stop()[/accent]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": "对话循环已停止,等待用户下次输入。",
|
||||
})
|
||||
|
||||
|
||||
async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
|
||||
"""
|
||||
处理 wait 工具:等待用户输入或超时。
|
||||
|
||||
Returns:
|
||||
工具结果字符串。以 "[[QUIT]]" 开头表示用户要求退出对话。
|
||||
"""
|
||||
seconds = tc.arguments.get("seconds", 30)
|
||||
seconds = max(5, min(seconds, 300)) # 限制 5-300 秒
|
||||
console.print(f"[accent]🔧 调用工具: wait({seconds})[/accent]")
|
||||
|
||||
tool_result = await _do_wait(seconds, ctx)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
})
|
||||
return tool_result
|
||||
|
||||
|
||||
async def _do_wait(seconds: int, ctx: ToolHandlerContext) -> str:
|
||||
"""实际执行等待逻辑。"""
|
||||
console.print(f"[muted]⏳ 等待回复 (最多 {seconds} 秒)...[/muted]")
|
||||
console.print("[bold magenta]💬 > [/bold magenta]", end="")
|
||||
|
||||
user_input = await ctx.reader.get_line(timeout=seconds)
|
||||
|
||||
if user_input is None:
|
||||
# 超时
|
||||
console.print() # 换行
|
||||
console.print("[muted]⏳ 等待超时[/muted]")
|
||||
return "等待超时,用户未输入任何内容"
|
||||
|
||||
user_input = user_input.strip()
|
||||
|
||||
if not user_input:
|
||||
return "用户发送了空消息"
|
||||
|
||||
# 更新 timing 时间戳
|
||||
now = datetime.now()
|
||||
ctx.last_user_input_time = now
|
||||
ctx.user_input_times.append(now)
|
||||
|
||||
if user_input.lower() in ("/quit", "/exit", "/q"):
|
||||
return "[[QUIT]] 用户主动退出了对话"
|
||||
|
||||
return f"用户说:{user_input}"
|
||||
|
||||
|
||||
async def handle_mcp_tool(tc, chat_history: list, mcp_manager: "MCPManager"):
|
||||
"""
|
||||
处理 MCP 工具调用。
|
||||
|
||||
将调用转发到 MCPManager,展示结果并写入对话上下文。
|
||||
"""
|
||||
# 格式化参数预览
|
||||
args_str = _json.dumps(tc.arguments, ensure_ascii=False)
|
||||
args_preview = args_str if len(args_str) <= 120 else args_str[:120] + "..."
|
||||
console.print(f"[accent]🔌 调用 MCP 工具: {tc.name}({args_preview})[/accent]")
|
||||
|
||||
with console.status(
|
||||
f"[info]🔌 MCP 工具 {tc.name} 执行中...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
result = await mcp_manager.call_tool(tc.name, tc.arguments)
|
||||
|
||||
# 展示结果(截断过长内容)
|
||||
display_text = result if len(result) <= 800 else result[:800] + "\n... (已截断)"
|
||||
console.print(
|
||||
Panel(
|
||||
display_text,
|
||||
title=f"🔌 MCP: {tc.name}",
|
||||
border_style="bright_green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": result,
|
||||
})
|
||||
|
||||
|
||||
async def handle_unknown_tool(tc, chat_history: list):
|
||||
"""处理未知工具调用。"""
|
||||
console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"未知工具: {tc.name}",
|
||||
})
|
||||
|
||||
|
||||
async def handle_write_file(tc, chat_history: list):
|
||||
"""处理 write_file 工具:在 mai_files 目录下写入文件。"""
|
||||
filename = tc.arguments.get("filename", "")
|
||||
content = tc.arguments.get("content", "")
|
||||
console.print(f"[accent]🔧 调用工具: write_file(\"{filename}\")[/accent]")
|
||||
|
||||
# 确保目录存在
|
||||
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 构建完整文件路径
|
||||
file_path = MAI_FILES_DIR / filename
|
||||
|
||||
try:
|
||||
# 创建父目录(如果需要)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"文件已写入: {filename}\n大小: {file_size} 字符",
|
||||
title="📁 文件已保存",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。",
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"写入文件失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_read_file(tc, chat_history: list):
|
||||
"""处理 read_file 工具:读取 mai_files 目录下的文件。"""
|
||||
filename = tc.arguments.get("filename", "")
|
||||
console.print(f"[accent]🔧 调用工具: read_file(\"{filename}\")[/accent]")
|
||||
|
||||
# 构建完整文件路径
|
||||
file_path = MAI_FILES_DIR / filename
|
||||
|
||||
try:
|
||||
if not file_path.exists():
|
||||
error_msg = f"文件「{filename}」不存在。"
|
||||
console.print(f"[warning]{error_msg}[/warning]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
if not file_path.is_file():
|
||||
error_msg = f"「{filename}」不是一个文件。"
|
||||
console.print(f"[warning]{error_msg}[/warning]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
# 读取文件内容
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
|
||||
# 截断过长内容用于显示
|
||||
display_content = file_content
|
||||
if len(file_content) > 1000:
|
||||
display_content = file_content[:1000] + "\n... (内容已截断)"
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
display_content,
|
||||
title=f"📄 文件内容: {filename}",
|
||||
border_style="blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"文件「{filename}」内容:\n{file_content}",
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"读取文件失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_list_files(tc, chat_history: list):
|
||||
"""处理 list_files 工具:获取 mai_files 目录下所有文件的元信息。"""
|
||||
console.print("[accent]🔧 调用工具: list_files()[/accent]")
|
||||
|
||||
try:
|
||||
# 确保目录存在
|
||||
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有文件
|
||||
files_info = []
|
||||
for item in MAI_FILES_DIR.rglob("*"):
|
||||
if item.is_file():
|
||||
# 获取相对路径
|
||||
rel_path = item.relative_to(MAI_FILES_DIR)
|
||||
stat = item.stat()
|
||||
files_info.append({
|
||||
"name": str(rel_path),
|
||||
"size": stat.st_size,
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
})
|
||||
|
||||
if not files_info:
|
||||
result_text = "mai_files 目录为空,没有任何文件。"
|
||||
else:
|
||||
# 按名称排序
|
||||
files_info.sort(key=lambda x: x["name"])
|
||||
# 格式化输出
|
||||
lines = [f"📁 mai_files 目录下共有 {len(files_info)} 个文件:\n"]
|
||||
for info in files_info:
|
||||
lines.append(f" • {info['name']} ({info['size']} 字节, 修改于 {info['modified']})")
|
||||
result_text = "\n".join(lines)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
result_text,
|
||||
title="📁 文件列表",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": result_text,
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"获取文件列表失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
||||
"""
|
||||
处理 store_context 工具:将指定范围的对话上下文存入记忆系统,然后从对话中移除。
|
||||
|
||||
参数:
|
||||
- count: 要存入记忆的消息数量(从最早的消息开始)
|
||||
- reason: 存入的原因
|
||||
"""
|
||||
count = tc.arguments.get("count", 0)
|
||||
reason = tc.arguments.get("reason", "")
|
||||
console.print(f"[accent]🔧 调用工具: store_context(count={count}, reason=\"{reason}\")[/accent]")
|
||||
|
||||
if count <= 0:
|
||||
error_msg = "count 参数必须大于 0"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
# 计算实际消息数量(排除 role=tool 的工具返回消息)
|
||||
actual_messages = [m for m in chat_history if m.get("role") != "tool"]
|
||||
|
||||
if count > len(actual_messages):
|
||||
error_msg = f"count({count}) 超过了当前对话消息数量({len(actual_messages)})"
|
||||
console.print(f"[warning]{error_msg}[/warning]")
|
||||
count = len(actual_messages)
|
||||
|
||||
# 找到要移除的消息索引(确保 tool_calls 和 tool 响应成对)
|
||||
indices_to_remove = []
|
||||
removed_count = 0
|
||||
i = 0
|
||||
|
||||
while i < len(chat_history) and removed_count < count:
|
||||
msg = chat_history[i]
|
||||
role = msg.get("role", "")
|
||||
|
||||
# 跳过 role=tool 的消息(它们会被对应的 assistant 消息一起处理)
|
||||
if role == "tool":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 检查这是否是一个带 tool_calls 的 assistant 消息
|
||||
if role == "assistant" and "tool_calls" in msg:
|
||||
# 检查这个消息是否包含当前的 tool_call(store_context 自己)
|
||||
# 如果包含,跳过不删除(否则会导致 tool 响应孤儿)
|
||||
contains_current_call = any(
|
||||
tc.get("id") == tc.id for tc in msg.get("tool_calls", [])
|
||||
)
|
||||
if contains_current_call:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 收集这个 assistant 消息及其后续的 tool 响应消息
|
||||
block_indices = [i]
|
||||
j = i + 1
|
||||
while j < len(chat_history):
|
||||
next_msg = chat_history[j]
|
||||
if next_msg.get("role") == "tool":
|
||||
block_indices.append(j)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
indices_to_remove.extend(block_indices)
|
||||
removed_count += 1
|
||||
i = j
|
||||
elif role in ["user", "assistant"]:
|
||||
# 普通消息,可以直接删除
|
||||
indices_to_remove.append(i)
|
||||
removed_count += 1
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if not indices_to_remove:
|
||||
result_msg = "没有找到可存入记忆的消息"
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": result_msg,
|
||||
})
|
||||
return
|
||||
|
||||
# 收集要总结的消息(在删除前)
|
||||
to_compress = []
|
||||
for i in sorted(indices_to_remove):
|
||||
if 0 <= i < len(chat_history):
|
||||
to_compress.append(chat_history[i])
|
||||
|
||||
# 总结上下文并压缩
|
||||
try:
|
||||
with console.status(
|
||||
"[info]📝 正在总结上下文...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
summary = await ctx.llm_service.summarize_context(to_compress)
|
||||
|
||||
if summary:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(summary),
|
||||
title="📝 上下文已压缩",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
result_msg = f"✅ 已压缩 {len(to_compress)} 条消息\n原因: {reason}"
|
||||
else:
|
||||
result_msg = "⚠️ 上下文总结失败"
|
||||
console.print(f"[warning]{result_msg}[/warning]")
|
||||
|
||||
except Exception as e:
|
||||
result_msg = f"❌ 总结上下文时出错: {e}"
|
||||
console.print(f"[error]{result_msg}[/error]")
|
||||
|
||||
# 从后往前删除消息
|
||||
for i in sorted(indices_to_remove, reverse=True):
|
||||
if 0 <= i < len(chat_history):
|
||||
chat_history.pop(i)
|
||||
|
||||
# 清理"孤儿" tool 消息(没有对应 tool_calls 的 tool 消息)
|
||||
# 收集所有有效的 tool_call_id
|
||||
valid_tool_call_ids = set()
|
||||
for msg in chat_history:
|
||||
if msg.get("role") == "assistant" and "tool_calls" in msg:
|
||||
for tool_call in msg["tool_calls"]:
|
||||
valid_tool_call_ids.add(tool_call.get("id", ""))
|
||||
|
||||
# 删除无效的 tool 消息(从后往前)
|
||||
i = len(chat_history) - 1
|
||||
while i >= 0:
|
||||
msg = chat_history[i]
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
if tool_call_id not in valid_tool_call_ids:
|
||||
chat_history.pop(i)
|
||||
i -= 1
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": result_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_get_qq_chat_info(tc, chat_history: list):
|
||||
"""处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。"""
|
||||
chat = tc.arguments.get("chat", "")
|
||||
limit = tc.arguments.get("limit", 20)
|
||||
console.print(f"[accent]🔧 调用工具: get_qq_chat_info(\"{chat}\", limit={limit})[/accent]")
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||
if not QQ_API_BASE_URL:
|
||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
try:
|
||||
# 构建 API 端点
|
||||
url = f"{QQ_API_BASE_URL.rstrip('/')}/api/external/chat/history"
|
||||
|
||||
# 构建请求头(如果配置了 API Key)
|
||||
headers = {}
|
||||
if QQ_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {QQ_API_KEY}"
|
||||
|
||||
# 发送 HTTP 请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
params = {"chat": chat, "limit": limit}
|
||||
async with session.get(url, params=params, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# 获取纯文本响应
|
||||
text = await response.text()
|
||||
|
||||
# 格式化显示
|
||||
console.print(
|
||||
Panel(
|
||||
f"聊天标识: {chat}\n获取数量: {limit}\n\n{text if text.strip() else '暂无聊天记录'}",
|
||||
title="💬 QQ 聊天记录",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": text if text.strip() else "暂无聊天记录",
|
||||
})
|
||||
else:
|
||||
error_text = await response.text()
|
||||
error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"获取 QQ 聊天记录失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_send_info(tc, chat_history: list):
|
||||
"""处理 send_info 工具:通过 HTTP 发送消息到 QQ。"""
|
||||
chat = tc.arguments.get("chat", "")
|
||||
message = tc.arguments.get("message", "")
|
||||
console.print(f"[accent]🔧 调用工具: send_info(\"{chat}\")[/accent]")
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||
if not QQ_API_BASE_URL:
|
||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
try:
|
||||
# 构建 API 端点
|
||||
url = f"{QQ_API_BASE_URL.rstrip('/')}/api/external/chat/send"
|
||||
|
||||
# 构建请求头(如果配置了 API Key)
|
||||
headers = {}
|
||||
if QQ_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {QQ_API_KEY}"
|
||||
|
||||
# 发送 HTTP 请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {"chat": chat, "message": message}
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
data = await response.json()
|
||||
|
||||
if response.status == 200 and data.get("success"):
|
||||
# 格式化显示
|
||||
console.print(
|
||||
Panel(
|
||||
f"目标: {chat}\n消息: {message}\n\n结果: {data.get('message', '发送成功')}",
|
||||
title="📤 消息已发送",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": f"消息发送成功: {data.get('message', '发送成功')}",
|
||||
})
|
||||
else:
|
||||
error_msg = f"发送失败: {data.get('message', '未知错误')}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"发送消息失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
async def handle_list_qq_chats(tc, chat_history: list):
|
||||
"""处理 list_qq_chats 工具:获取所有可用的 QQ 聊天列表。"""
|
||||
console.print("[accent]🔧 调用工具: list_qq_chats()[/accent]")
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||
if not QQ_API_BASE_URL:
|
||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
return
|
||||
|
||||
try:
|
||||
# 构建 API 端点
|
||||
url = f"{QQ_API_BASE_URL.rstrip('/')}/api/external/chat/list"
|
||||
|
||||
# 构建请求头(如果配置了 API Key)
|
||||
headers = {}
|
||||
if QQ_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {QQ_API_KEY}"
|
||||
|
||||
# 发送 HTTP 请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
data = await response.json()
|
||||
|
||||
if response.status == 200 and data.get("success"):
|
||||
chats = data.get("chats", [])
|
||||
|
||||
# 格式化聊天列表
|
||||
if chats:
|
||||
chat_list_text = "\n".join([
|
||||
f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})"
|
||||
for c in chats
|
||||
])
|
||||
result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}"
|
||||
else:
|
||||
result_text = "没有可用的聊天"
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
result_text,
|
||||
title="💬 QQ 聊天列表",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": result_text,
|
||||
})
|
||||
else:
|
||||
error_msg = f"获取失败: {data.get('message', '未知错误')}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"获取聊天列表失败: {e}"
|
||||
console.print(f"[error]{error_msg}[/error]")
|
||||
chat_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": error_msg,
|
||||
})
|
||||
|
||||
|
||||
# ──────────────────── 初始化 mai_files 目录 ────────────────────
|
||||
|
||||
# 确保程序启动时 mai_files 目录存在
|
||||
try:
|
||||
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
console.print(f"[warning]创建 mai_files 目录失败: {e}[/warning]")
|
||||
Reference in New Issue
Block a user