Ruff Format

This commit is contained in:
DrSmoothl
2026-03-13 11:45:26 +08:00
parent 2a510312bc
commit a576313b22
70 changed files with 956 additions and 731 deletions

View File

@@ -14,9 +14,7 @@ class BetterFrequencyPlugin(MaiBotPlugin):
description="设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>", description="设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>",
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$", pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
) )
async def handle_set_talk_frequency( async def handle_set_talk_frequency(self, stream_id: str = "", matched_groups: dict | None = None, **kwargs):
self, stream_id: str = "", matched_groups: dict | None = None, **kwargs
):
"""设置当前聊天的 talk_frequency""" """设置当前聊天的 talk_frequency"""
if not matched_groups or "value" not in matched_groups: if not matched_groups or "value" not in matched_groups:
return False, "命令格式错误", False return False, "命令格式错误", False

View File

@@ -116,7 +116,9 @@ class HelloWorldPlugin(MaiBotPlugin):
print(f"接收到消息: {raw}") print(f"接收到消息: {raw}")
return True, True, "消息已打印", None, None return True, True, "消息已打印", None, None
@EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE) @EventHandler(
"forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE
)
async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs): async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
"""收集消息并定期转发""" """收集消息并定期转发"""
if not message: if not message:

View File

@@ -19,6 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "ma
# ─── 协议层测试 ─────────────────────────────────────────── # ─── 协议层测试 ───────────────────────────────────────────
class TestProtocol: class TestProtocol:
"""协议层测试""" """协议层测试"""
@@ -111,6 +112,7 @@ class TestProtocol:
# ─── 传输层测试 ─────────────────────────────────────────── # ─── 传输层测试 ───────────────────────────────────────────
class TestTransport: class TestTransport:
"""传输层测试""" """传输层测试"""
@@ -198,6 +200,7 @@ class TestTransport:
# ─── Host 层测试 ────────────────────────────────────────── # ─── Host 层测试 ──────────────────────────────────────────
class TestHost: class TestHost:
"""Host 端基础设施测试""" """Host 端基础设施测试"""
@@ -244,6 +247,7 @@ class TestHost:
# ─── SDK 测试 ───────────────────────────────────────────── # ─── SDK 测试 ─────────────────────────────────────────────
class TestSDK: class TestSDK:
"""SDK 框架测试""" """SDK 框架测试"""
@@ -312,6 +316,7 @@ class TestSDK:
# ─── 端到端集成测试 ──────────────────────────────────────── # ─── 端到端集成测试 ────────────────────────────────────────
class TestE2E: class TestE2E:
"""端到端集成测试Host + Runner 通信)""" """端到端集成测试Host + Runner 通信)"""
@@ -392,6 +397,7 @@ class TestE2E:
# ─── Manifest 校验测试 ───────────────────────────────────── # ─── Manifest 校验测试 ─────────────────────────────────────
class TestManifestValidator: class TestManifestValidator:
"""Manifest 校验器测试""" """Manifest 校验器测试"""
@@ -489,6 +495,7 @@ class TestVersionComparator:
# ─── 依赖解析测试 ────────────────────────────────────────── # ─── 依赖解析测试 ──────────────────────────────────────────
class TestDependencyResolution: class TestDependencyResolution:
"""插件依赖解析测试""" """插件依赖解析测试"""
@@ -498,8 +505,16 @@ class TestDependencyResolution:
loader = PluginLoader() loader = PluginLoader()
candidates = { candidates = {
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"), "auth": (
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"), "dir_auth",
{"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]},
"plugin.py",
),
"api": (
"dir_api",
{"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]},
"plugin.py",
),
} }
order, failed = loader._resolve_dependencies(candidates) order, failed = loader._resolve_dependencies(candidates)
@@ -512,7 +527,17 @@ class TestDependencyResolution:
loader = PluginLoader() loader = PluginLoader()
candidates = { candidates = {
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"), "plugin_a": (
"dir_a",
{
"name": "plugin_a",
"version": "1.0",
"description": "d",
"author": "a",
"dependencies": ["nonexistent"],
},
"plugin.py",
),
} }
order, failed = loader._resolve_dependencies(candidates) order, failed = loader._resolve_dependencies(candidates)
@@ -524,8 +549,16 @@ class TestDependencyResolution:
loader = PluginLoader() loader = PluginLoader()
candidates = { candidates = {
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"), "a": (
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"), "dir_a",
{"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]},
"p.py",
),
"b": (
"dir_b",
{"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]},
"p.py",
),
} }
order, failed = loader._resolve_dependencies(candidates) order, failed = loader._resolve_dependencies(candidates)
@@ -534,6 +567,7 @@ class TestDependencyResolution:
# ─── Host-side ComponentRegistry 测试 ────────────────────── # ─── Host-side ComponentRegistry 测试 ──────────────────────
class TestComponentRegistry: class TestComponentRegistry:
"""Host-side 组件注册表测试""" """Host-side 组件注册表测试"""
@@ -541,17 +575,32 @@ class TestComponentRegistry:
from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("greet", "action", "plugin_a", { reg.register_component(
"description": "打招呼", "greet",
"activation_type": "keyword", "action",
"activation_keywords": ["hi"], "plugin_a",
}) {
reg.register_component("help", "command", "plugin_a", { "description": "打招呼",
"command_pattern": r"^/help", "activation_type": "keyword",
}) "activation_keywords": ["hi"],
reg.register_component("search", "tool", "plugin_b", { },
"description": "搜索", )
}) reg.register_component(
"help",
"command",
"plugin_a",
{
"command_pattern": r"^/help",
},
)
reg.register_component(
"search",
"tool",
"plugin_b",
{
"description": "搜索",
},
)
stats = reg.get_stats() stats = reg.get_stats()
assert stats["total"] == 3 assert stats["total"] == 3
@@ -573,12 +622,22 @@ class TestComponentRegistry:
from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("help", "command", "p1", { reg.register_component(
"command_pattern": r"^/help", "help",
}) "command",
reg.register_component("echo", "command", "p1", { "p1",
"command_pattern": r"^/echo\s", {
}) "command_pattern": r"^/help",
},
)
reg.register_component(
"echo",
"command",
"p1",
{
"command_pattern": r"^/echo\s",
},
)
match = reg.find_command_by_text("/help me") match = reg.find_command_by_text("/help me")
assert match is not None assert match is not None
@@ -622,12 +681,24 @@ class TestComponentRegistry:
from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("h_low", "event_handler", "p1", { reg.register_component(
"event_type": "on_message", "weight": 10, "h_low",
}) "event_handler",
reg.register_component("h_high", "event_handler", "p2", { "p1",
"event_type": "on_message", "weight": 100, {
}) "event_type": "on_message",
"weight": 10,
},
)
reg.register_component(
"h_high",
"event_handler",
"p2",
{
"event_type": "on_message",
"weight": 100,
},
)
handlers = reg.get_event_handlers("on_message") handlers = reg.get_event_handlers("on_message")
assert handlers[0].name == "h_high" assert handlers[0].name == "h_high"
@@ -637,10 +708,15 @@ class TestComponentRegistry:
from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("search", "tool", "p1", { reg.register_component(
"description": "搜索工具", "search",
"parameters_raw": {"query": {"type": "string"}}, "tool",
}) "p1",
{
"description": "搜索工具",
"parameters_raw": {"query": {"type": "string"}},
},
)
tools = reg.get_tools_for_llm() tools = reg.get_tools_for_llm()
assert len(tools) == 1 assert len(tools) == 1
@@ -650,6 +726,7 @@ class TestComponentRegistry:
# ─── EventDispatcher 测试 ───────────────────────────────── # ─── EventDispatcher 测试 ─────────────────────────────────
class TestEventDispatcher: class TestEventDispatcher:
"""Host-side 事件分发器测试""" """Host-side 事件分发器测试"""
@@ -659,11 +736,16 @@ class TestEventDispatcher:
from src.plugin_runtime.host.event_dispatcher import EventDispatcher from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("h1", "event_handler", "p1", { reg.register_component(
"event_type": "on_start", "h1",
"weight": 0, "event_handler",
"intercept_message": False, "p1",
}) {
"event_type": "on_start",
"weight": 0,
"intercept_message": False,
},
)
dispatcher = EventDispatcher(reg) dispatcher = EventDispatcher(reg)
call_log = [] call_log = []
@@ -672,9 +754,7 @@ class TestEventDispatcher:
call_log.append((plugin_id, comp_name)) call_log.append((plugin_id, comp_name))
return {"success": True, "continue_processing": True} return {"success": True, "continue_processing": True}
should_continue, modified = await dispatcher.dispatch_event( should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke)
"on_start", mock_invoke
)
assert should_continue assert should_continue
# 非阻塞分发是异步的,等一下让 task 完成 # 非阻塞分发是异步的,等一下让 task 完成
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@@ -687,11 +767,16 @@ class TestEventDispatcher:
from src.plugin_runtime.host.event_dispatcher import EventDispatcher from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("filter", "event_handler", "p1", { reg.register_component(
"event_type": "on_message_pre_process", "filter",
"weight": 100, "event_handler",
"intercept_message": True, "p1",
}) {
"event_type": "on_message_pre_process",
"weight": 100,
"intercept_message": True,
},
)
dispatcher = EventDispatcher(reg) dispatcher = EventDispatcher(reg)
@@ -755,6 +840,7 @@ class TestEventBus:
# ─── MaiMessages 测试 ───────────────────────────────────── # ─── MaiMessages 测试 ─────────────────────────────────────
class TestMaiMessages: class TestMaiMessages:
"""统一消息模型测试""" """统一消息模型测试"""
@@ -799,6 +885,7 @@ class TestMaiMessages:
# ─── WorkflowExecutor 测试 ──────────────────────────────── # ─── WorkflowExecutor 测试 ────────────────────────────────
class TestWorkflowExecutor: class TestWorkflowExecutor:
"""Host-side Workflow 执行器测试(新 pipeline 模型)""" """Host-side Workflow 执行器测试(新 pipeline 模型)"""
@@ -829,11 +916,16 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("upper", "workflow_step", "p1", { reg.register_component(
"stage": "pre_process", "upper",
"priority": 10, "workflow_step",
"blocking": True, "p1",
}) {
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): async def mock_invoke(plugin_id, comp_name, args):
@@ -859,11 +951,16 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("blocker", "workflow_step", "p1", { reg.register_component(
"stage": "pre_process", "blocker",
"priority": 10, "workflow_step",
"blocking": True, "p1",
}) {
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): async def mock_invoke(plugin_id, comp_name, args):
@@ -884,17 +981,27 @@ class TestWorkflowExecutor:
reg = ComponentRegistry() reg = ComponentRegistry()
# high-priority hook 返回 skip_stage # high-priority hook 返回 skip_stage
reg.register_component("skipper", "workflow_step", "p1", { reg.register_component(
"stage": "ingress", "skipper",
"priority": 100, "workflow_step",
"blocking": True, "p1",
}) {
"stage": "ingress",
"priority": 100,
"blocking": True,
},
)
# low-priority hook 不应被执行 # low-priority hook 不应被执行
reg.register_component("checker", "workflow_step", "p2", { reg.register_component(
"stage": "ingress", "checker",
"priority": 1, "workflow_step",
"blocking": True, "p2",
}) {
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
call_log = [] call_log = []
@@ -905,9 +1012,7 @@ class TestWorkflowExecutor:
return {"hook_result": "skip_stage"} return {"hook_result": "skip_stage"}
return {"hook_result": "continue"} return {"hook_result": "continue"}
result, _, _ = await executor.execute( result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "completed" assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过 # 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"] assert call_log == ["skipper"]
@@ -919,12 +1024,17 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("only_dm", "workflow_step", "p1", { reg.register_component(
"stage": "ingress", "only_dm",
"priority": 10, "workflow_step",
"blocking": True, "p1",
"filter": {"chat_type": "direct"}, {
}) "stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
call_log = [] call_log = []
@@ -934,15 +1044,11 @@ class TestWorkflowExecutor:
return {"hook_result": "continue"} return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用 # 不匹配 filter —— hook 不应被调用
await executor.execute( await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
)
assert not call_log assert not call_log
# 匹配 filter —— hook 应被调用 # 匹配 filter —— hook 应被调用
await executor.execute( await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
)
assert call_log == ["only_dm"] assert call_log == ["only_dm"]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -952,17 +1058,27 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("failer", "workflow_step", "p1", { reg.register_component(
"stage": "ingress", "failer",
"priority": 100, "workflow_step",
"blocking": True, "p1",
"error_policy": "skip", {
}) "stage": "ingress",
reg.register_component("ok_step", "workflow_step", "p2", { "priority": 100,
"stage": "ingress", "blocking": True,
"priority": 1, "error_policy": "skip",
"blocking": True, },
}) )
reg.register_component(
"ok_step",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
call_log = [] call_log = []
@@ -973,9 +1089,7 @@ class TestWorkflowExecutor:
raise RuntimeError("boom") raise RuntimeError("boom")
return {"hook_result": "continue"} return {"hook_result": "continue"}
result, _, ctx = await executor.execute( result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "completed" assert result.status == "completed"
assert "failer" in call_log assert "failer" in call_log
assert "ok_step" in call_log assert "ok_step" in call_log
@@ -988,20 +1102,23 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("failer", "workflow_step", "p1", { reg.register_component(
"stage": "ingress", "failer",
"priority": 10, "workflow_step",
"blocking": True, "p1",
# error_policy defaults to "abort" {
}) "stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal") raise RuntimeError("fatal")
result, _, ctx = await executor.execute( result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "failed" assert result.status == "failed"
assert result.stopped_at == "ingress" assert result.stopped_at == "ingress"
@@ -1013,11 +1130,16 @@ class TestWorkflowExecutor:
reg = ComponentRegistry() reg = ComponentRegistry()
for i in range(3): for i in range(3):
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", { reg.register_component(
"stage": "post_process", f"nb_{i}",
"priority": 0, "workflow_step",
"blocking": False, f"p{i}",
}) {
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
call_log = [] call_log = []
@@ -1026,9 +1148,7 @@ class TestWorkflowExecutor:
call_log.append(comp_name) call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute( result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
mock_invoke, message={"plain_text": "original"}
)
# non-blocking 的 modified_message 被忽略 # non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original" assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成 # 给异步 task 时间完成
@@ -1042,9 +1162,14 @@ class TestWorkflowExecutor:
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() reg = ComponentRegistry()
reg.register_component("help", "command", "p1", { reg.register_component(
"command_pattern": r"^/help", "help",
}) "command",
"p1",
{
"command_pattern": r"^/help",
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): async def mock_invoke(plugin_id, comp_name, args):
@@ -1052,9 +1177,7 @@ class TestWorkflowExecutor:
return {"output": "帮助信息"} return {"output": "帮助信息"}
return {"hook_result": "continue"} return {"hook_result": "continue"}
result, _, ctx = await executor.execute( result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
mock_invoke, message={"plain_text": "/help topic"}
)
assert result.status == "completed" assert result.status == "completed"
assert ctx.matched_command == "p1.help" assert ctx.matched_command == "p1.help"
cmd_result = ctx.get_stage_output("plan", "command_result") cmd_result = ctx.get_stage_output("plan", "command_result")
@@ -1069,17 +1192,27 @@ class TestWorkflowExecutor:
reg = ComponentRegistry() reg = ComponentRegistry()
# ingress 阶段写入数据 # ingress 阶段写入数据
reg.register_component("writer", "workflow_step", "p1", { reg.register_component(
"stage": "ingress", "writer",
"priority": 10, "workflow_step",
"blocking": True, "p1",
}) {
"stage": "ingress",
"priority": 10,
"blocking": True,
},
)
# pre_process 阶段读取数据 # pre_process 阶段读取数据
reg.register_component("reader", "workflow_step", "p2", { reg.register_component(
"stage": "pre_process", "reader",
"priority": 10, "workflow_step",
"blocking": True, "p2",
}) {
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg) executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): async def mock_invoke(plugin_id, comp_name, args):
@@ -1096,9 +1229,7 @@ class TestWorkflowExecutor:
return {"hook_result": "continue"} return {"hook_result": "continue"}
return {"hook_result": "continue"} return {"hook_result": "continue"}
result, _, ctx = await executor.execute( result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
mock_invoke, message={"plain_text": "hi"}
)
assert result.status == "completed" assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
@@ -1324,10 +1455,15 @@ class TestIntegration:
async def stop(self): async def stop(self):
self.stopped = True self.stopped = True
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])) monkeypatch.setattr(
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])) integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
)
monkeypatch.setattr(
integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
)
import src.plugin_runtime.host.supervisor as supervisor_module import src.plugin_runtime.host.supervisor as supervisor_module
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
manager = integration_module.PluginRuntimeManager() manager = integration_module.PluginRuntimeManager()

View File

@@ -9,16 +9,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent from src.common.data_models.message_component_data_model import MessageSequence
from src.chat.message_receive.message import ( from src.chat.message_receive.message import (
SessionMessage, SessionMessage,
TextComponent, TextComponent,
ImageComponent, ImageComponent,
EmojiComponent,
VoiceComponent,
AtComponent, AtComponent,
ReplyComponent,
ForwardNodeComponent,
) )
@@ -203,17 +199,23 @@ def load_message_via_file(monkeypatch):
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str: def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
return "X" * length # 返回固定的字符串长度由参数决定模拟生成短ID的行为 return "X" * length # 返回固定的字符串长度由参数决定模拟生成短ID的行为
def dummy_is_bot_self(user_id: str, platform) -> bool: def dummy_is_bot_self(user_id: str, platform) -> bool:
return user_id == "bot_self" return user_id == "bot_self"
def load_utils_via_file(monkeypatch): def load_utils_via_file(monkeypatch):
setup_mocks(monkeypatch) setup_mocks(monkeypatch)
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用 # Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
math_utils_mod = ModuleType("src.common.utils.math_utils") math_utils_mod = ModuleType("src.common.utils.math_utils")
math_utils_mod.number_to_short_id = dummy_number_to_short_id math_utils_mod.number_to_short_id = dummy_number_to_short_id
math_utils_mod.TimestampMode = type("TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}) math_utils_mod.TimestampMode = type(
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: "2024-01-01 12:00:00" # 返回固定的时间字符串 "TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
)
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
"2024-01-01 12:00:00"
) # 返回固定的时间字符串
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod) monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析 # 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
@@ -349,6 +351,7 @@ async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(
assert "u_comb" in mapping assert "u_comb" in mapping
assert mapping["u_comb"][0] == "XXXXXX" assert mapping["u_comb"][0] == "XXXXXX"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_build_readable_message_with_at(monkeypatch): async def test_build_readable_message_with_at(monkeypatch):
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换""" """包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
@@ -363,7 +366,9 @@ async def test_build_readable_message_with_at(monkeypatch):
msg.session_id = "s_at" msg.session_id = "s_at"
msg.raw_message = MessageSequence([at_comp]) msg.raw_message = MessageSequence([at_comp])
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser")) msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot") text, mapping, _ = await MessageUtils.build_readable_message(
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
)
# 验证主消息和@组件中的用户信息都被处理 # 验证主消息和@组件中的用户信息都被处理
assert "XXXXXX说" in text # 主消息用户被匿名化 assert "XXXXXX说" in text # 主消息用户被匿名化
assert "XXXXXX说@XXXXXX" in text # @组件用户被匿名化 assert "XXXXXX说@XXXXXX" in text # @组件用户被匿名化

View File

@@ -22,11 +22,7 @@ def should_skip(path: Path) -> bool:
def iter_python_files(root: Path) -> list[Path]: def iter_python_files(root: Path) -> list[Path]:
return sorted( return sorted(path for path in root.rglob("*.py") if path.is_file() and not should_skip(path.relative_to(root)))
path
for path in root.rglob("*.py")
if path.is_file() and not should_skip(path.relative_to(root))
)
class CandidateExtractor(ast.NodeVisitor): class CandidateExtractor(ast.NodeVisitor):

View File

@@ -23,7 +23,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import LLMUsage from src.common.database.database_model import LLMUsage
from src.chat.message_receive.chat_manager import BotChatSession
from maim_message import UserInfo, GroupInfo from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval") logger = get_logger("test_memory_retrieval")

View File

@@ -10,7 +10,6 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils_old import weighted_sample from src.bw_learner.learner_utils_old import weighted_sample
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector") logger = get_logger("expression_selector")

View File

@@ -1,22 +1,14 @@
import re
import difflib
import random import random
import json import json
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.chat.utils.utils import parse_platform_accounts
from json_repair import repair_json
logger = get_logger("learner_utils") logger = get_logger("learner_utils")
def _compute_weights(population: List[Dict]) -> List[float]: def _compute_weights(population: List[Dict]) -> List[float]:
""" """
根据表达的count计算权重范围限定在1~5之间。 根据表达的count计算权重范围限定在1~5之间。

View File

@@ -16,7 +16,7 @@ from .action_planner import ActionPlanner
from .observation_info import ObservationInfo from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
from .reply_generator import ReplyGenerator from .reply_generator import ReplyGenerator
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from maim_message import UserInfo from maim_message import UserInfo
from .pfc_KnowledgeFetcher import KnowledgeFetcher from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter from .waiter import Waiter

View File

@@ -9,7 +9,7 @@ from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.brain_chat.brain_planner import BrainPlanner from src.chat.brain_chat.brain_planner import BrainPlanner
@@ -22,7 +22,12 @@ from src.person_info.person_info import Person
from src.core.types import ActionInfo, EventType from src.core.types import ActionInfo, EventType
from src.core.event_bus import event_bus from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message from src.chat.event_helpers import build_event_message
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api from src.services import (
generator_service as generator_api,
send_service as send_api,
message_service as message_api,
database_service as database_api,
)
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
@@ -294,10 +299,10 @@ class BrainChatting:
message_id_list=message_id_list, message_id_list=message_id_list,
prompt_key="brain_planner", prompt_key="brain_planner",
) )
_event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id) _event_msg = build_event_message(
continue_flag, modified_message = await event_bus.emit( EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id
EventType.ON_PLAN, _event_msg
) )
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg)
if not continue_flag: if not continue_flag:
return False return False
if modified_message and modified_message._modify_flags.modify_llm_prompt: if modified_message and modified_message._modify_flags.modify_llm_prompt:

View File

@@ -1,5 +1,5 @@
from rich.traceback import install from rich.traceback import install
from typing import Optional, List, TYPE_CHECKING, Tuple, Dict from typing import Optional, List, TYPE_CHECKING
import asyncio import asyncio
import time import time
@@ -14,7 +14,7 @@ from src.chat.message_receive.chat_manager import chat_manager
from src.bw_learner.expression_learner import ExpressionLearner from src.bw_learner.expression_learner import ExpressionLearner
from src.bw_learner.jargon_miner import JargonMiner from src.bw_learner.jargon_miner import JargonMiner
from .heartFC_utils import CycleDetail, CycleActionInfo, CyclePlanInfo from .heartFC_utils import CycleDetail
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.message import SessionMessage

View File

@@ -1,4 +1,4 @@
from typing import Optional, Dict from typing import Dict
import traceback import traceback
@@ -9,6 +9,7 @@ from src.chat.heart_flow.heartFC_chat import HeartFChatting
logger = get_logger("heartflow") logger = get_logger("heartflow")
# TODO: 恢复PFC现在暂时禁用 # TODO: 恢复PFC现在暂时禁用
class HeartflowManager: class HeartflowManager:
"""主心流协调器,负责初始化并协调聊天,控制聊天属性""" """主心流协调器,负责初始化并协调聊天,控制聊天属性"""
@@ -17,7 +18,7 @@ class HeartflowManager:
# self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {} # self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
self.heartflow_chat_list: Dict[str, HeartFChatting] = {} self.heartflow_chat_list: Dict[str, HeartFChatting] = {}
async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]: async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]:
"""获取或创建一个新的HeartFChatting实例""" """获取或创建一个新的HeartFChatting实例"""
try: try:
if chat := self.heartflow_chat_list.get(session_id): if chat := self.heartflow_chat_list.get(session_id):

View File

@@ -9,11 +9,10 @@ from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils from src.common.utils.utils_session import SessionUtils
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager # from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.core.announcement_manager import global_announcement_manager from src.core.announcement_manager import global_announcement_manager
from src.core.component_registry import component_registry from src.core.component_registry import component_registry
from src.core.types import EventType
from .message import SessionMessage from .message import SessionMessage
from .chat_manager import chat_manager from .chat_manager import chat_manager
@@ -391,6 +390,7 @@ class ChatBot:
else: else:
logger.debug("[群聊]检测到群聊消息路由到HeartFlow系统") logger.debug("[群聊]检测到群聊消息路由到HeartFlow系统")
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)
await preprocess() await preprocess()
except Exception as e: except Exception as e:

View File

@@ -8,7 +8,7 @@ import asyncio
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import Messages from src.common.database.database_model import Messages
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_component_data_model import ( from src.common.data_models.message_component_data_model import (
TextComponent, TextComponent,
ImageComponent, ImageComponent,

View File

@@ -22,6 +22,7 @@ _webui_chat_broadcaster = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致) # 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关 # TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster(): def get_webui_chat_broadcaster():
"""获取 WebUI 聊天室广播器""" """获取 WebUI 聊天室广播器"""

View File

@@ -4,7 +4,7 @@ from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.core.component_registry import component_registry, ActionExecutor from src.core.component_registry import component_registry, ActionExecutor
from src.core.types import ActionInfo, ComponentType from src.core.types import ActionInfo
logger = get_logger("action_manager") logger = get_logger("action_manager")

View File

@@ -1,6 +1,6 @@
import random import random
import time import time
from typing import List, Dict, TYPE_CHECKING, Tuple from typing import List, Dict, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config

View File

@@ -119,9 +119,7 @@ class PrivateReplyer:
if not from_plugin: if not from_plugin:
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit( continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg)
EventType.POST_LLM, _event_msg
)
if not continue_flag: if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成") raise UserWarning("插件于请求前中断了内容生成")
if modified_message and modified_message._modify_flags.modify_llm_prompt: if modified_message and modified_message._modify_flags.modify_llm_prompt:
@@ -140,10 +138,10 @@ class PrivateReplyer:
llm_response.reasoning = reasoning_content llm_response.reasoning = reasoning_content
llm_response.model = model_name llm_response.model = model_name
llm_response.tool_calls = tool_call llm_response.tool_calls = tool_call
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id) _event_msg = build_event_message(
continue_flag, modified_message = await event_bus.emit( EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id
EventType.AFTER_LLM, _event_msg
) )
continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg)
if not from_plugin and not continue_flag: if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成") raise UserWarning("插件于请求后取消了内容生成")
if modified_message: if modified_message:

View File

@@ -817,6 +817,8 @@ def assign_message_ids(messages: List[SessionMessage]) -> List[Tuple[str, Sessio
result.append((message_id, message)) result.append((message_id, message))
return result return result
# break # break
# result.append((message_id, message)) # result.append((message_id, message))

View File

@@ -25,4 +25,4 @@
# available_actions: Optional[Dict[str, "ActionInfo"]] = None # available_actions: Optional[Dict[str, "ActionInfo"]] = None
# loop_start_time: Optional[float] = None # loop_start_time: Optional[float] = None
# action_reasoning: Optional[str] = None # action_reasoning: Optional[str] = None
# TODO: 重构 # TODO: 重构

View File

@@ -20,4 +20,4 @@
# timing: Optional[Dict[str, Any]] = None # timing: Optional[Dict[str, Any]] = None
# processed_output: Optional[List[str]] = None # processed_output: Optional[List[str]] = None
# timing_logs: Optional[List[str]] = None # timing_logs: Optional[List[str]] = None
# TODO: 重构 # TODO: 重构

View File

@@ -13,8 +13,10 @@ from src.common.logger import get_logger
logger = get_logger("base_message_component_model") logger = get_logger("base_message_component_model")
class UnknownUser(str): ... class UnknownUser(str): ...
class BaseMessageComponentModel(ABC): class BaseMessageComponentModel(ABC):
@property @property
@abstractmethod @abstractmethod

View File

@@ -107,7 +107,7 @@ class UniversalMessageSender:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'") logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
return True return True
except Exception as legacy_e: except Exception:
# # Legacy API 抛出异常,尝试 Fallback # # Legacy API 抛出异常,尝试 Fallback
# return await self._send_with_fallback( # return await self._send_with_fallback(
# message, message_preview, platform, show_log, legacy_exception=legacy_e # message, message_preview, platform, show_log, legacy_exception=legacy_e

View File

@@ -76,4 +76,4 @@ def assert_port_available(
port=port, port=port,
config_hint=config_hint, config_hint=config_hint,
) )
raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port)) raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port))

View File

@@ -133,8 +133,8 @@ class ConfigBase(BaseModel, AttrDocBase):
# UI 分组元数据:子类可覆盖以声明所属 Tab 分组 # UI 分组元数据:子类可覆盖以声明所属 Tab 分组
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab __ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc __ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
__ui_icon__: ClassVar[str] = "" # Tab 图标名称Lucide 图标名) __ui_icon__: ClassVar[str] = "" # Tab 图标名称Lucide 图标名)
@classmethod @classmethod
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]): def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):

View File

@@ -182,7 +182,9 @@ class FileWatcher:
self._stats.callbacks_skipped_cooldown += 1 self._stats.callbacks_skipped_cooldown += 1
continue continue
try: try:
await asyncio.wait_for(self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s) await asyncio.wait_for(
self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s
)
state.consecutive_failures = 0 state.consecutive_failures = 0
self._stats.callbacks_succeeded += 1 self._stats.callbacks_succeeded += 1
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@@ -10,11 +10,10 @@
""" """
import re import re
from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.core.types import ( from src.core.types import (
ActionActivationType,
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
ComponentInfo, ComponentInfo,
@@ -130,9 +129,7 @@ class ComponentRegistry:
logger.debug(f"注册 Command: {name}") logger.debug(f"注册 Command: {name}")
return True return True
def find_command_by_text( def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
self, text: str
) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
Returns: Returns:

View File

@@ -117,9 +117,7 @@ class EventBus:
self._fire_and_forget(entry, event_type, current_message) self._fire_and_forget(entry, event_type, current_message)
# 桥接到 IPC 插件运行时 # 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime( continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
event_type, continue_flag, current_message
)
return continue_flag, current_message return continue_flag, current_message

View File

@@ -7,6 +7,7 @@ MaiSaka - 内置工具定义
from typing import List, Dict, Any from typing import List, Dict, Any
from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType
# 内置工具定义 # 内置工具定义
def create_builtin_tools() -> List[ToolOption]: def create_builtin_tools() -> List[ToolOption]:
"""创建内置工具列表""" """创建内置工具列表"""
@@ -17,57 +18,66 @@ def create_builtin_tools() -> List[ToolOption]:
# say 工具 # say 工具
say_builder = ToolOptionBuilder() say_builder = ToolOptionBuilder()
say_builder.set_name("say") say_builder.set_name("say")
say_builder.set_description("对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。") say_builder.set_description(
"对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。"
)
say_builder.add_param( say_builder.add_param(
name="reason", name="reason",
param_type=ToolParamType.STRING, param_type=ToolParamType.STRING,
description="描述你想要回复的方式、想法和内容。例如:'同意对方的看法,并分享自己的经历''礼貌地拒绝,表示现在不方便聊天'", description="描述你想要回复的方式、想法和内容。例如:'同意对方的看法,并分享自己的经历''礼貌地拒绝,表示现在不方便聊天'",
required=True, required=True,
enum_values=None enum_values=None,
) )
tools.append(say_builder.build()) tools.append(say_builder.build())
# wait 工具 # wait 工具
wait_builder = ToolOptionBuilder() wait_builder = ToolOptionBuilder()
wait_builder.set_name("wait") wait_builder.set_name("wait")
wait_builder.set_description("暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。") wait_builder.set_description(
"暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。"
)
wait_builder.add_param( wait_builder.add_param(
name="seconds", name="seconds",
param_type=ToolParamType.INTEGER, param_type=ToolParamType.INTEGER,
description="等待的秒数。建议 3-10 秒。超过这个时间用户没有回复会显示超时提示。", description="等待的秒数。建议 3-10 秒。超过这个时间用户没有回复会显示超时提示。",
required=True, required=True,
enum_values=None enum_values=None,
) )
tools.append(wait_builder.build()) tools.append(wait_builder.build())
# stop 工具 # stop 工具
stop_builder = ToolOptionBuilder() stop_builder = ToolOptionBuilder()
stop_builder.set_name("stop") stop_builder.set_name("stop")
stop_builder.set_description("结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。") stop_builder.set_description(
"结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。"
)
tools.append(stop_builder.build()) tools.append(stop_builder.build())
# store_context 工具 # store_context 工具
store_context_builder = ToolOptionBuilder() store_context_builder = ToolOptionBuilder()
store_context_builder.set_name("store_context") store_context_builder.set_name("store_context")
store_context_builder.set_description("将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。") store_context_builder.set_description(
"将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。"
)
store_context_builder.add_param( store_context_builder.add_param(
name="count", name="count",
param_type=ToolParamType.INTEGER, param_type=ToolParamType.INTEGER,
description="要保存的消息条数(从最早的对话开始计数)。建议 5-20 条。", description="要保存的消息条数(从最早的对话开始计数)。建议 5-20 条。",
required=True, required=True,
enum_values=None enum_values=None,
) )
store_context_builder.add_param( store_context_builder.add_param(
name="reason", name="reason",
param_type=ToolParamType.STRING, param_type=ToolParamType.STRING,
description="保存原因,用于后续检索。例如:'讨论了用户的工作情况''用户分享了对电影的看法'", description="保存原因,用于后续检索。例如:'讨论了用户的工作情况''用户分享了对电影的看法'",
required=True, required=True,
enum_values=None enum_values=None,
) )
tools.append(store_context_builder.build()) tools.append(store_context_builder.build())
return tools return tools
# 为了兼容性,创建一个函数来将工具转换为 dict 格式(用于调试显示) # 为了兼容性,创建一个函数来将工具转换为 dict 格式(用于调试显示)
def builtin_tools_as_dicts() -> List[Dict[str, Any]]: def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
"""将内置工具转换为 dict 格式(用于调试)""" """将内置工具转换为 dict 格式(用于调试)"""
@@ -77,31 +87,23 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
"description": "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。", "description": "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"reason": {"type": "string", "description": "回复的想法和内容"}},
"reason": {"type": "string", "description": "回复的想法和内容"} "required": ["reason"],
}, },
"required": ["reason"]
}
}, },
{ {
"name": "wait", "name": "wait",
"description": "暂时结束发言,等待用户回应", "description": "暂时结束发言,等待用户回应",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"seconds": {"type": "number", "description": "等待秒数"}},
"seconds": {"type": "number", "description": "等待秒数"} "required": ["seconds"],
}, },
"required": ["seconds"]
}
}, },
{ {
"name": "stop", "name": "stop",
"description": "结束对话循环", "description": "结束对话循环",
"parameters": { "parameters": {"type": "object", "properties": {}, "required": []},
"type": "object",
"properties": {},
"required": []
}
}, },
{ {
"name": "store_context", "name": "store_context",
@@ -110,17 +112,19 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
"type": "object", "type": "object",
"properties": { "properties": {
"count": {"type": "number", "description": "保存的消息条数"}, "count": {"type": "number", "description": "保存的消息条数"},
"reason": {"type": "string", "description": "保存原因"} "reason": {"type": "string", "description": "保存原因"},
}, },
"required": ["count", "reason"] "required": ["count", "reason"],
} },
} },
] ]
# 导出工具创建函数和列表 # 导出工具创建函数和列表
def get_builtin_tools() -> List[ToolOption]: def get_builtin_tools() -> List[ToolOption]:
"""获取内置工具列表""" """获取内置工具列表"""
return create_builtin_tools() return create_builtin_tools()
# 为了向后兼容,也导出 dict 格式 # 为了向后兼容,也导出 dict 格式
BUILTIN_TOOLS_DICTS = builtin_tools_as_dicts() BUILTIN_TOOLS_DICTS = builtin_tools_as_dicts()

View File

@@ -13,10 +13,17 @@ from rich.markdown import Markdown
from rich.text import Text from rich.text import Text
from rich import box from rich import box
from config import console, ENABLE_EMOTION_MODULE, ENABLE_COGNITION_MODULE, ENABLE_TIMING_MODULE, ENABLE_KNOWLEDGE_MODULE, ENABLE_MCP from config import (
console,
ENABLE_EMOTION_MODULE,
ENABLE_COGNITION_MODULE,
ENABLE_TIMING_MODULE,
ENABLE_KNOWLEDGE_MODULE,
ENABLE_MCP,
)
from input_reader import InputReader from input_reader import InputReader
from timing import build_timing_info from timing import build_timing_info
from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge, build_knowledge_summary from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge
from knowledge_store import get_knowledge_store from knowledge_store import get_knowledge_store
from llm_service import MaiSakaLLMService, build_message, remove_last_perception from llm_service import MaiSakaLLMService, build_message, remove_last_perception
from mcp_client import MCPManager from mcp_client import MCPManager
@@ -64,11 +71,7 @@ class BufferCLI:
def _init_llm(self): def _init_llm(self):
"""初始化 LLM 服务 - 使用主项目配置系统""" """初始化 LLM 服务 - 使用主项目配置系统"""
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower() thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
enable_thinking: Optional[bool] = ( enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
True if thinking_env == "true"
else False if thinking_env == "false"
else None
)
# MaiSakaLLMService 现在使用主项目的配置系统 # MaiSakaLLMService 现在使用主项目的配置系统
# 参数仅为兼容性保留,实际从 config_manager 读取配置 # 参数仅为兼容性保留,实际从 config_manager 读取配置
@@ -210,7 +213,7 @@ class BufferCLI:
to_compress, to_compress,
store_result_callback=lambda cat_id, cat_name, content: console.print( store_result_callback=lambda cat_id, cat_name, content: console.print(
f"[muted] [OK] 存储了解信息: {cat_name}[/muted]" f"[muted] [OK] 存储了解信息: {cat_name}[/muted]"
) ),
) )
if knowledge_count > 0: if knowledge_count > 0:
console.print(f"[success][OK] 了解模块: 存储{knowledge_count}条特征信息[/success]") console.print(f"[success][OK] 了解模块: 存储{knowledge_count}条特征信息[/success]")
@@ -272,10 +275,12 @@ class BufferCLI:
self._chat_history = self.llm_service.build_chat_context(user_text) self._chat_history = self.llm_service.build_chat_context(user_text)
else: else:
# 后续对话:追加用户消息到已有上下文 # 后续对话:追加用户消息到已有上下文
self._chat_history.append({ self._chat_history.append(
"role": "user", {
"content": user_text, "role": "user",
}) "content": user_text,
}
)
await self._run_llm_loop(self._chat_history) await self._run_llm_loop(self._chat_history)
@@ -436,16 +441,17 @@ class BufferCLI:
if perception_parts: if perception_parts:
# 添加感知消息AI 的感知能力结果) # 添加感知消息AI 的感知能力结果)
chat_history.append(build_message( chat_history.append(
role="assistant", build_message(
content="\n\n".join(perception_parts), role="assistant",
msg_type="perception", content="\n\n".join(perception_parts),
)) msg_type="perception",
)
)
else: else:
# 上次没有调用工具,跳过模块分析 # 上次没有调用工具,跳过模块分析
console.print("[muted] 上次未调用工具,跳过模块分析[/muted]") console.print("[muted] 上次未调用工具,跳过模块分析[/muted]")
# ── 调用 LLM ── # ── 调用 LLM ──
with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"): with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"):
try: try:
@@ -540,7 +546,8 @@ class BufferCLI:
async def _init_mcp(self): async def _init_mcp(self):
"""初始化 MCP 服务器连接,发现并注册外部工具。""" """初始化 MCP 服务器连接,发现并注册外部工具。"""
config_path = os.path.join( config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "mcp_config.json", os.path.dirname(os.path.abspath(__file__)),
"mcp_config.json",
) )
self._mcp_manager = await MCPManager.from_config(config_path) self._mcp_manager = await MCPManager.from_config(config_path)

View File

@@ -15,16 +15,20 @@ if str(_root) not in sys.path:
# ──────────────────── 从主配置读取 ──────────────────── # ──────────────────── 从主配置读取 ────────────────────
def _get_maisaka_config(): def _get_maisaka_config():
"""获取 MaiSaka 配置""" """获取 MaiSaka 配置"""
try: try:
from src.config.config import config_manager from src.config.config import config_manager
return config_manager.config.maisaka return config_manager.config.maisaka
except Exception: except Exception:
# 如果配置加载失败,返回默认值 # 如果配置加载失败,返回默认值
from src.config.official_configs import MaiSakaConfig from src.config.official_configs import MaiSakaConfig
return MaiSakaConfig() return MaiSakaConfig()
_maisaka_config = _get_maisaka_config() _maisaka_config = _get_maisaka_config()
# ──────────────────── 模块开关配置 ──────────────────── # ──────────────────── 模块开关配置 ────────────────────

View File

@@ -48,9 +48,7 @@ class DebugViewer:
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
conn.connect(("127.0.0.1", self._port)) conn.connect(("127.0.0.1", self._port))
self._conn = conn self._conn = conn
console.print( console.print(f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]")
f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]"
)
return return
except ConnectionRefusedError: except ConnectionRefusedError:
conn.close() conn.close()

View File

@@ -16,10 +16,10 @@ from rich import box
console = Console() console = Console()
ROLE_STYLES = { ROLE_STYLES = {
"system": ("📋", "bold blue"), "system": ("📋", "bold blue"),
"user": ("👤", "bold green"), "user": ("👤", "bold green"),
"assistant": ("🤖", "bold magenta"), "assistant": ("🤖", "bold magenta"),
"tool": ("🔧", "bold yellow"), "tool": ("🔧", "bold yellow"),
} }
@@ -54,8 +54,10 @@ def format_message(idx: int, msg: dict) -> str:
# 正文 # 正文
if content: if content:
display = content if len(content) <= 3000 else ( display = (
content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]" content
if len(content) <= 3000
else (content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]")
) )
parts.append(display) parts.append(display)
@@ -88,8 +90,7 @@ def main():
console.print( console.print(
Panel( Panel(
f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n" f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n[dim]监听端口: {port} 等待主进程连接...[/dim]",
f"[dim]监听端口: {port} 等待主进程连接...[/dim]",
box=box.DOUBLE_EDGE, box=box.DOUBLE_EDGE,
border_style="cyan", border_style="cyan",
) )
@@ -131,8 +132,7 @@ def main():
# ── 标题栏 ── # ── 标题栏 ──
console.print(f"\n{'' * 90}") console.print(f"\n{'' * 90}")
console.print( console.print(
f"[bold yellow]#{call_count} {label}[/bold yellow] " f"[bold yellow]#{call_count} {label}[/bold yellow] [dim]({len(messages)} messages)[/dim]"
f"[dim]({len(messages)} messages)[/dim]"
) )
console.print(f"{'' * 90}") console.print(f"{'' * 90}")
@@ -144,12 +144,8 @@ def main():
# ── tools 信息 ── # ── tools 信息 ──
if tools: if tools:
tool_names = [ tool_names = [t.get("function", {}).get("name", "?") for t in tools]
t.get("function", {}).get("name", "?") for t in tools console.print(f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]")
]
console.print(
f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]"
)
except Exception as e: except Exception as e:
console.print(f"\n[red]数据处理错误: {e}[/red]") console.print(f"\n[red]数据处理错误: {e}[/red]")
console.print(f"[dim]Payload: {payload}[/dim]") console.print(f"[dim]Payload: {payload}[/dim]")
@@ -161,8 +157,12 @@ def main():
console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]") console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]")
resp_content = response.get("content", "") resp_content = response.get("content", "")
if resp_content: if resp_content:
display = resp_content if len(str(resp_content)) <= 3000 else ( display = (
str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]" resp_content
if len(str(resp_content)) <= 3000
else (
str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]"
)
) )
console.print(Panel(display, border_style="cyan", padding=(0, 1))) console.print(Panel(display, border_style="cyan", padding=(0, 1)))
resp_tool_calls = response.get("tool_calls", []) resp_tool_calls = response.get("tool_calls", [])

View File

@@ -3,7 +3,7 @@ MaiSaka - 了解模块
负责从对话中提取和存储用户个人特征信息。 负责从对话中提取和存储用户个人特征信息。
""" """
from typing import List, Optional from typing import List
from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES
@@ -100,9 +100,7 @@ async def store_knowledge_from_context(
try: try:
# 第一步:分析涉及哪些分类 # 第一步:分析涉及哪些分类
category_ids = await llm_service.analyze_knowledge_categories( category_ids = await llm_service.analyze_knowledge_categories(context_messages, categories_summary)
context_messages, categories_summary
)
if not category_ids: if not category_ids:
return 0 return 0
@@ -119,25 +117,19 @@ async def store_knowledge_from_context(
if extracted_content: if extracted_content:
# 存储到了解列表 # 存储到了解列表
success = store.add_knowledge( success = store.add_knowledge(
category_id=category_id, category_id=category_id, content=extracted_content, metadata={"source": "context_compression"}
content=extracted_content,
metadata={"source": "context_compression"}
) )
if success: if success:
stored_count += 1 stored_count += 1
if store_result_callback: if store_result_callback:
store_result_callback( store_result_callback(category_id, store.get_category_name(category_id), extracted_content)
category_id, except Exception:
store.get_category_name(category_id),
extracted_content
)
except Exception as e:
# 单个分类失败不影响其他分类 # 单个分类失败不影响其他分类
continue continue
return stored_count return stored_count
except Exception as e: except Exception:
return 0 return 0
@@ -165,9 +157,7 @@ async def retrieve_relevant_knowledge(
try: try:
# 分析需要哪些分类 # 分析需要哪些分类
category_ids = await llm_service.analyze_knowledge_need( category_ids = await llm_service.analyze_knowledge_need(chat_history, categories_summary)
chat_history, categories_summary
)
if not category_ids: if not category_ids:
return "" return ""

View File

@@ -45,9 +45,7 @@ class KnowledgeStore:
def __init__(self): def __init__(self):
"""初始化了解存储""" """初始化了解存储"""
self._knowledge: Dict[str, List[Dict[str, Any]]] = { self._knowledge: Dict[str, List[Dict[str, Any]]] = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
self._ensure_data_dir() self._ensure_data_dir()
self._load() self._load()
@@ -58,9 +56,7 @@ class KnowledgeStore:
def _load(self): def _load(self):
"""从文件加载了解数据""" """从文件加载了解数据"""
if not KNOWLEDGE_FILE.exists(): if not KNOWLEDGE_FILE.exists():
self._knowledge = { self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
return return
try: try:
@@ -73,9 +69,7 @@ class KnowledgeStore:
self._knowledge = loaded self._knowledge = loaded
except Exception as e: except Exception as e:
print(f"[warning]加载了解数据失败: {e}[/warning]") print(f"[warning]加载了解数据失败: {e}[/warning]")
self._knowledge = { self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
}
def _save(self): def _save(self):
"""保存了解数据到文件""" """保存了解数据到文件"""

View File

@@ -4,7 +4,6 @@ MaiSaka LLM 服务 - 使用主项目 LLM 系统
""" """
import json import json
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Literal from typing import List, Optional, Literal
@@ -34,6 +33,7 @@ MSG_TYPE_FIELD = "_type"
@dataclass @dataclass
class ToolCall: class ToolCall:
"""工具调用信息""" """工具调用信息"""
id: str id: str
name: str name: str
arguments: dict arguments: dict
@@ -42,6 +42,7 @@ class ToolCall:
@dataclass @dataclass
class ChatResponse: class ChatResponse:
"""LLM 对话循环单步响应""" """LLM 对话循环单步响应"""
content: Optional[str] content: Optional[str]
tool_calls: List[ToolCall] tool_calls: List[ToolCall]
raw_message: dict # 可直接追加到对话历史的消息字典 raw_message: dict # 可直接追加到对话历史的消息字典
@@ -49,6 +50,7 @@ class ChatResponse:
# ──────────────────── 工具函数 ──────────────────── # ──────────────────── 工具函数 ────────────────────
def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict: def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict:
"""构建消息字典,包含消息类型标记。""" """构建消息字典,包含消息类型标记。"""
msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs} msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs}
@@ -93,23 +95,18 @@ class MaiSakaLLMService:
except Exception: except Exception:
# 如果配置加载失败,使用默认配置 # 如果配置加载失败,使用默认配置
from src.config.model_configs import ModelTaskConfig from src.config.model_configs import ModelTaskConfig
self._model_configs = ModelTaskConfig() self._model_configs = ModelTaskConfig()
logger.warning("无法加载主项目模型配置,使用默认配置") logger.warning("无法加载主项目模型配置,使用默认配置")
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer # 初始化 LLMRequest 实例(只使用 tool_use 和 replyer
self._llm_tool_use = LLMRequest( self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use")
model_set=self._model_configs.tool_use,
request_type="maisaka_tool_use"
)
# 主对话也使用 tool_use 模型(因为需要工具调用支持) # 主对话也使用 tool_use 模型(因为需要工具调用支持)
self._llm_chat = self._llm_tool_use self._llm_chat = self._llm_tool_use
# 分析模块也使用 tool_use 模型 # 分析模块也使用 tool_use 模型
self._llm_utils = self._llm_tool_use self._llm_utils = self._llm_tool_use
# 回复生成使用 replyer 模型 # 回复生成使用 replyer 模型
self._llm_replyer = LLMRequest( self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer")
model_set=self._model_configs.replyer,
request_type="maisaka_replyer"
)
# 尝试修复数据库 schema忽略错误 # 尝试修复数据库 schema忽略错误
self._try_fix_database_schema() self._try_fix_database_schema()
@@ -133,6 +130,7 @@ class MaiSakaLLMService:
chat_prompt.add_context("file_tools_section", tools_section if tools_section else "") chat_prompt.add_context("file_tools_section", tools_section if tools_section else "")
import asyncio import asyncio
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
@@ -147,7 +145,9 @@ class MaiSakaLLMService:
self._chat_system_prompt = chat_system_prompt self._chat_system_prompt = chat_system_prompt
# 获取模型名称用于显示 # 获取模型名称用于显示
self._model_name = self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置" self._model_name = (
self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置"
)
# 加载子模块提示词 # 加载子模块提示词
self._emotion_prompt: Optional[str] = None self._emotion_prompt: Optional[str] = None
@@ -157,21 +157,22 @@ class MaiSakaLLMService:
try: try:
import asyncio import asyncio
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
self._emotion_prompt = loop.run_until_complete(prompt_manager.render_prompt( self._emotion_prompt = loop.run_until_complete(
prompt_manager.get_prompt("maidairy_emotion") prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_emotion"))
)) )
self._cognition_prompt = loop.run_until_complete(prompt_manager.render_prompt( self._cognition_prompt = loop.run_until_complete(
prompt_manager.get_prompt("maidairy_cognition") prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_cognition"))
)) )
self._timing_prompt = loop.run_until_complete(prompt_manager.render_prompt( self._timing_prompt = loop.run_until_complete(
prompt_manager.get_prompt("maidairy_timing") prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_timing"))
)) )
self._context_summarize_prompt = loop.run_until_complete(prompt_manager.render_prompt( self._context_summarize_prompt = loop.run_until_complete(
prompt_manager.get_prompt("maidairy_context_summarize") prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_context_summarize"))
)) )
logger.info("成功加载 MaiSaka 子模块提示词") logger.info("成功加载 MaiSaka 子模块提示词")
finally: finally:
loop.close() loop.close()
@@ -191,9 +192,7 @@ class MaiSakaLLMService:
if "model_api_provider_name" not in columns: if "model_api_provider_name" not in columns:
# 添加缺失的列 # 添加缺失的列
session.execute(text( session.execute(text("ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"))
"ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"
))
session.commit() session.commit()
logger.info("数据库 schema 已修复:添加 model_api_provider_name 列") logger.info("数据库 schema 已修复:添加 model_api_provider_name 列")
except Exception: except Exception:
@@ -205,7 +204,7 @@ class MaiSakaLLMService:
self._extra_tools = list(tools) self._extra_tools = list(tools)
@staticmethod @staticmethod
def _tool_option_to_dict(tool: 'ToolOption') -> dict: def _tool_option_to_dict(tool: "ToolOption") -> dict:
"""将 ToolOption 对象转换为主项目期望的 dict 格式 """将 ToolOption 对象转换为主项目期望的 dict 格式
主项目的 _build_tool_options() 期望的格式: 主项目的 _build_tool_options() 期望的格式:
@@ -218,18 +217,8 @@ class MaiSakaLLMService:
params = [] params = []
if tool.params: if tool.params:
for param in tool.params: for param in tool.params:
params.append(( params.append((param.name, param.param_type, param.description, param.required, param.enum_values))
param.name, return {"name": tool.name, "description": tool.description, "parameters": params}
param.param_type,
param.description,
param.required,
param.enum_values
))
return {
"name": tool.name,
"description": tool.description,
"parameters": params
}
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse: async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
"""执行对话循环的一步 - 使用 tool_use 模型""" """执行对话循环的一步 - 使用 tool_use 模型"""
@@ -271,11 +260,13 @@ class MaiSakaLLMService:
for tc in msg["tool_calls"]: for tc in msg["tool_calls"]:
tc_func = tc.get("function", {}) tc_func = tc.get("function", {})
# 主项目的 ToolCall: call_id, func_name, args # 主项目的 ToolCall: call_id, func_name, args
tool_calls_list.append(ToolCallOption( tool_calls_list.append(
call_id=tc.get("id", ""), ToolCallOption(
func_name=tc_func.get("name", ""), call_id=tc.get("id", ""),
args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {} func_name=tc_func.get("name", ""),
)) args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {},
)
)
builder.set_tool_calls(tool_calls_list) builder.set_tool_calls(tool_calls_list)
elif role == "tool" and "tool_call_id" in msg: elif role == "tool" and "tool_call_id" in msg:
builder.add_tool_call(msg["tool_call_id"]) builder.add_tool_call(msg["tool_call_id"])
@@ -290,15 +281,17 @@ class MaiSakaLLMService:
# 调用 LLM使用带消息的接口 # 调用 LLM使用带消息的接口
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict # 合并内置工具和额外工具(将 ToolOption 对象转换为 dict
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (self._extra_tools if self._extra_tools else []) all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (
self._extra_tools if self._extra_tools else []
)
# 打印消息列表 # 打印消息列表
built_messages = message_factory(None) built_messages = message_factory(None)
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - chat_loop_step:") print("MaiSaka LLM Request - chat_loop_step:")
for msg in built_messages: for msg in built_messages:
print(f" {msg}") print(f" {msg}")
print("="*60 + "\n") print("=" * 60 + "\n")
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async( response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
message_factory=message_factory, message_factory=message_factory,
@@ -312,15 +305,17 @@ class MaiSakaLLMService:
if tool_calls: if tool_calls:
for tc in tool_calls: for tc in tool_calls:
# 主项目的 ToolCall 有 call_id, func_name, args # 主项目的 ToolCall 有 call_id, func_name, args
call_id = tc.call_id if hasattr(tc, 'call_id') else "" call_id = tc.call_id if hasattr(tc, "call_id") else ""
func_name = tc.func_name if hasattr(tc, 'func_name') else "" func_name = tc.func_name if hasattr(tc, "func_name") else ""
args = tc.args if hasattr(tc, 'args') else {} args = tc.args if hasattr(tc, "args") else {}
converted_tool_calls.append(ToolCall( converted_tool_calls.append(
id=call_id, ToolCall(
name=func_name, id=call_id,
arguments=args, name=func_name,
)) arguments=args,
)
)
# 构建原始消息格式MaiSaka 风格) # 构建原始消息格式MaiSaka 风格)
raw_message = {"role": "assistant", "content": response} raw_message = {"role": "assistant", "content": response}
@@ -394,10 +389,10 @@ class MaiSakaLLMService:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - analyze_emotion:") print("MaiSaka LLM Request - analyze_emotion:")
print(f" {prompt}") print(f" {prompt}")
print("="*60 + "\n") print("=" * 60 + "\n")
try: try:
response, _ = await self._llm_utils.generate_response_async( response, _ = await self._llm_utils.generate_response_async(
@@ -428,10 +423,10 @@ class MaiSakaLLMService:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - analyze_cognition:") print("MaiSaka LLM Request - analyze_cognition:")
print(f" {prompt}") print(f" {prompt}")
print("="*60 + "\n") print("=" * 60 + "\n")
try: try:
response, _ = await self._llm_utils.generate_response_async( response, _ = await self._llm_utils.generate_response_async(
@@ -463,10 +458,10 @@ class MaiSakaLLMService:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - analyze_timing:") print("MaiSaka LLM Request - analyze_timing:")
print(f" {prompt}") print(f" {prompt}")
print("="*60 + "\n") print("=" * 60 + "\n")
try: try:
response, _ = await self._llm_utils.generate_response_async( response, _ = await self._llm_utils.generate_response_async(
@@ -498,10 +493,10 @@ class MaiSakaLLMService:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - summarize_context:") print("MaiSaka LLM Request - summarize_context:")
print(f" {prompt}") print(f" {prompt}")
print("="*60 + "\n") print("=" * 60 + "\n")
try: try:
response, _ = await self._llm_utils.generate_response_async( response, _ = await self._llm_utils.generate_response_async(
@@ -529,8 +524,7 @@ class MaiSakaLLMService:
# 格式化对话历史 # 格式化对话历史
filtered_history = [ filtered_history = [
msg for msg in chat_history msg for msg in chat_history if msg.get("role") != "system" and msg.get("_type") != "perception"
if msg.get("role") != "system" and msg.get("_type") != "perception"
] ]
formatted_history = format_chat_history(filtered_history) formatted_history = format_chat_history(filtered_history)
@@ -542,18 +536,15 @@ class MaiSakaLLMService:
system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。" system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。"
user_prompt = ( user_prompt = (
f"当前时间:{current_time}\n\n" f"当前时间:{current_time}\n\n【聊天记录】\n{formatted_history}\n\n【你的想法】\n{reason}\n\n现在,你说:"
f"【聊天记录】\n{formatted_history}\n\n"
f"【你的想法】\n{reason}\n\n"
f"现在,你说:"
) )
messages = f"System: {system_prompt}\n\nUser: {user_prompt}" messages = f"System: {system_prompt}\n\nUser: {user_prompt}"
print("\n" + "="*60) print("\n" + "=" * 60)
print("MaiSaka LLM Request - generate_reply:") print("MaiSaka LLM Request - generate_reply:")
print(f" {messages}") print(f" {messages}")
print("="*60 + "\n") print("=" * 60 + "\n")
try: try:
response, _ = await self._llm_replyer.generate_response_async( response, _ = await self._llm_replyer.generate_response_async(

View File

@@ -95,9 +95,7 @@ def load_mcp_config(config_path: str = "mcp_config.json") -> list[MCPServerConfi
) )
if server.transport_type == "unknown": if server.transport_type == "unknown":
console.print( console.print(f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url已跳过[/warning]")
f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url已跳过[/warning]"
)
continue continue
configs.append(server) configs.append(server)

View File

@@ -63,9 +63,7 @@ class MCPConnection:
True 表示连接成功False 表示失败。 True 表示连接成功False 表示失败。
""" """
if not MCP_AVAILABLE: if not MCP_AVAILABLE:
console.print( console.print("[warning]⚠️ 未安装 mcp SDK请运行: pip install mcp[/warning]")
"[warning]⚠️ 未安装 mcp SDK请运行: pip install mcp[/warning]"
)
return False return False
try: try:
@@ -76,15 +74,11 @@ class MCPConnection:
elif self.config.transport_type == "sse": elif self.config.transport_type == "sse":
read_stream, write_stream = await self._connect_sse() read_stream, write_stream = await self._connect_sse()
else: else:
console.print( console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]")
f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]"
)
return False return False
# 创建并初始化 MCP 会话 # 创建并初始化 MCP 会话
self.session = await self._exit_stack.enter_async_context( self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
ClientSession(read_stream, write_stream)
)
await self.session.initialize() await self.session.initialize()
# 发现工具 # 发现工具
@@ -94,9 +88,7 @@ class MCPConnection:
return True return True
except Exception as e: except Exception as e:
console.print( console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]")
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]"
)
await self.close() await self.close()
return False return False
@@ -107,19 +99,13 @@ class MCPConnection:
args=self.config.args, args=self.config.args,
env=self.config.env, env=self.config.env,
) )
return await self._exit_stack.enter_async_context( return await self._exit_stack.enter_async_context(stdio_client(params))
stdio_client(params)
)
async def _connect_sse(self): async def _connect_sse(self):
"""建立 SSE 传输连接。""" """建立 SSE 传输连接。"""
if not SSE_AVAILABLE: if not SSE_AVAILABLE:
raise ImportError( raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]")
"SSE 传输需要额外依赖,请运行: pip install mcp[sse]" return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers))
)
return await self._exit_stack.enter_async_context(
sse_client(url=self.config.url, headers=self.config.headers)
)
async def call_tool(self, tool_name: str, arguments: dict) -> str: async def call_tool(self, tool_name: str, arguments: dict) -> str:
""" """

View File

@@ -10,10 +10,16 @@ from .config import MCPServerConfig, load_mcp_config
from .connection import MCPConnection, MCP_AVAILABLE from .connection import MCPConnection, MCP_AVAILABLE
# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突 # 内置工具名称集合 —— MCP 工具不允许与这些名称冲突
BUILTIN_TOOL_NAMES = frozenset({ BUILTIN_TOOL_NAMES = frozenset(
"say", "wait", "stop", {
"create_table", "list_tables", "view_table", "say",
}) "wait",
"stop",
"create_table",
"list_tables",
"view_table",
}
)
class MCPManager: class MCPManager:
@@ -35,7 +41,8 @@ class MCPManager:
@classmethod @classmethod
async def from_config( async def from_config(
cls, config_path: str = "mcp_config.json", cls,
config_path: str = "mcp_config.json",
) -> Optional["MCPManager"]: ) -> Optional["MCPManager"]:
""" """
从配置文件创建并初始化 MCPManager。 从配置文件创建并初始化 MCPManager。
@@ -51,10 +58,7 @@ class MCPManager:
return None return None
if not MCP_AVAILABLE: if not MCP_AVAILABLE:
console.print( console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK请运行: pip install mcp[/warning]")
"[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK"
"请运行: pip install mcp[/warning]"
)
return None return None
manager = cls() manager = cls()
@@ -85,8 +89,7 @@ class MCPManager:
if tool_name in BUILTIN_TOOL_NAMES: if tool_name in BUILTIN_TOOL_NAMES:
console.print( console.print(
f"[warning]⚠️ MCP 工具 '{tool_name}' " f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
f"(来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
) )
continue continue
@@ -102,8 +105,7 @@ class MCPManager:
registered += 1 registered += 1
console.print( console.print(
f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] " f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]"
f"[muted]({registered} 个工具已注册)[/muted]"
) )
# ──────── 工具发现 ──────── # ──────── 工具发现 ────────
@@ -134,17 +136,16 @@ class MCPManager:
# 移除 $schema 字段(部分 MCP 服务器会带上OpenAI 不接受) # 移除 $schema 字段(部分 MCP 服务器会带上OpenAI 不接受)
parameters.pop("$schema", None) parameters.pop("$schema", None)
tools.append({ tools.append(
"type": "function", {
"function": { "type": "function",
"name": tool.name, "function": {
"description": ( "name": tool.name,
tool.description "description": (tool.description or f"MCP tool from {server_name}"),
or f"MCP tool from {server_name}" "parameters": parameters,
), },
"parameters": parameters, }
}, )
})
return tools return tools
@@ -184,9 +185,9 @@ class MCPManager:
parts: list[str] = [] parts: list[str] = []
for server_name, conn in self._connections.items(): for server_name, conn in self._connections.items():
tool_names = [ tool_names = [
t.name for t in conn.tools t.name
if t.name in self._tool_to_server for t in conn.tools
and self._tool_to_server[t.name] == server_name if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name
] ]
if tool_names: if tool_names:
parts.append(f"{server_name}: {', '.join(tool_names)}") parts.append(f"{server_name}: {', '.join(tool_names)}")

View File

@@ -49,8 +49,7 @@ def build_timing_info(
if len(user_input_times) >= 2: if len(user_input_times) >= 2:
intervals = [ intervals = [
(user_input_times[i] - user_input_times[i - 1]).total_seconds() (user_input_times[i] - user_input_times[i - 1]).total_seconds() for i in range(1, len(user_input_times))
for i in range(1, len(user_input_times))
] ]
avg_interval = sum(intervals) / len(intervals) avg_interval = sum(intervals) / len(intervals)
parts.append(f"用户平均回复间隔: {int(avg_interval)}") parts.append(f"用户平均回复间隔: {int(avg_interval)}")

View File

@@ -4,7 +4,6 @@ MaiSaka - 工具调用处理器
""" """
import json as _json import json as _json
import asyncio
import os import os
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
@@ -92,27 +91,33 @@ async def handle_say(tc, chat_history: list, ctx: ToolHandlerContext):
) )
) )
# 生成的回复作为 tool 结果写入上下文 # 生成的回复作为 tool 结果写入上下文
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": f"已向用户展示(实际输出):{reply}", "tool_call_id": tc.id,
}) "content": f"已向用户展示(实际输出):{reply}",
}
)
else: else:
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": "reason 内容为空,未展示", "tool_call_id": tc.id,
}) "content": "reason 内容为空,未展示",
}
)
async def handle_stop(tc, chat_history: list): async def handle_stop(tc, chat_history: list):
"""处理 stop 工具:结束对话循环。""" """处理 stop 工具:结束对话循环。"""
console.print("[accent]🔧 调用工具: stop()[/accent]") console.print("[accent]🔧 调用工具: stop()[/accent]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": "对话循环已停止,等待用户下次输入。", "tool_call_id": tc.id,
}) "content": "对话循环已停止,等待用户下次输入。",
}
)
async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str: async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
@@ -128,11 +133,13 @@ async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
tool_result = await _do_wait(seconds, ctx) tool_result = await _do_wait(seconds, ctx)
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": tool_result, "tool_call_id": tc.id,
}) "content": tool_result,
}
)
return tool_result return tool_result
@@ -193,28 +200,32 @@ async def handle_mcp_tool(tc, chat_history: list, mcp_manager: "MCPManager"):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": result, "tool_call_id": tc.id,
}) "content": result,
}
)
async def handle_unknown_tool(tc, chat_history: list): async def handle_unknown_tool(tc, chat_history: list):
"""处理未知工具调用。""" """处理未知工具调用。"""
console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]") console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": f"未知工具: {tc.name}", "tool_call_id": tc.id,
}) "content": f"未知工具: {tc.name}",
}
)
async def handle_write_file(tc, chat_history: list): async def handle_write_file(tc, chat_history: list):
"""处理 write_file 工具:在 mai_files 目录下写入文件。""" """处理 write_file 工具:在 mai_files 目录下写入文件。"""
filename = tc.arguments.get("filename", "") filename = tc.arguments.get("filename", "")
content = tc.arguments.get("content", "") content = tc.arguments.get("content", "")
console.print(f"[accent]🔧 调用工具: write_file(\"{filename}\")[/accent]") console.print(f'[accent]🔧 调用工具: write_file("{filename}")[/accent]')
# 确保目录存在 # 确保目录存在
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True) MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
@@ -242,25 +253,29 @@ async def handle_write_file(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。", "tool_call_id": tc.id,
}) "content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。",
}
)
except Exception as e: except Exception as e:
error_msg = f"写入文件失败: {e}" error_msg = f"写入文件失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
async def handle_read_file(tc, chat_history: list): async def handle_read_file(tc, chat_history: list):
"""处理 read_file 工具:读取 mai_files 目录下的文件。""" """处理 read_file 工具:读取 mai_files 目录下的文件。"""
filename = tc.arguments.get("filename", "") filename = tc.arguments.get("filename", "")
console.print(f"[accent]🔧 调用工具: read_file(\"{filename}\")[/accent]") console.print(f'[accent]🔧 调用工具: read_file("{filename}")[/accent]')
# 构建完整文件路径 # 构建完整文件路径
file_path = MAI_FILES_DIR / filename file_path = MAI_FILES_DIR / filename
@@ -269,21 +284,25 @@ async def handle_read_file(tc, chat_history: list):
if not file_path.exists(): if not file_path.exists():
error_msg = f"文件「{filename}」不存在。" error_msg = f"文件「{filename}」不存在。"
console.print(f"[warning]{error_msg}[/warning]") console.print(f"[warning]{error_msg}[/warning]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
if not file_path.is_file(): if not file_path.is_file():
error_msg = f"{filename}」不是一个文件。" error_msg = f"{filename}」不是一个文件。"
console.print(f"[warning]{error_msg}[/warning]") console.print(f"[warning]{error_msg}[/warning]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
# 读取文件内容 # 读取文件内容
@@ -304,19 +323,23 @@ async def handle_read_file(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": f"文件「{filename}」内容:\n{file_content}", "tool_call_id": tc.id,
}) "content": f"文件「{filename}」内容:\n{file_content}",
}
)
except Exception as e: except Exception as e:
error_msg = f"读取文件失败: {e}" error_msg = f"读取文件失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
async def handle_list_files(tc, chat_history: list): async def handle_list_files(tc, chat_history: list):
@@ -334,11 +357,13 @@ async def handle_list_files(tc, chat_history: list):
# 获取相对路径 # 获取相对路径
rel_path = item.relative_to(MAI_FILES_DIR) rel_path = item.relative_to(MAI_FILES_DIR)
stat = item.stat() stat = item.stat()
files_info.append({ files_info.append(
"name": str(rel_path), {
"size": stat.st_size, "name": str(rel_path),
"modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"), "size": stat.st_size,
}) "modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
}
)
if not files_info: if not files_info:
result_text = "mai_files 目录为空,没有任何文件。" result_text = "mai_files 目录为空,没有任何文件。"
@@ -360,19 +385,23 @@ async def handle_list_files(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": result_text, "tool_call_id": tc.id,
}) "content": result_text,
}
)
except Exception as e: except Exception as e:
error_msg = f"获取文件列表失败: {e}" error_msg = f"获取文件列表失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
@@ -385,16 +414,18 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
""" """
count = tc.arguments.get("count", 0) count = tc.arguments.get("count", 0)
reason = tc.arguments.get("reason", "") reason = tc.arguments.get("reason", "")
console.print(f"[accent]🔧 调用工具: store_context(count={count}, reason=\"{reason}\")[/accent]") console.print(f'[accent]🔧 调用工具: store_context(count={count}, reason="{reason}")[/accent]')
if count <= 0: if count <= 0:
error_msg = "count 参数必须大于 0" error_msg = "count 参数必须大于 0"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
# 计算实际消息数量(排除 role=tool 的工具返回消息) # 计算实际消息数量(排除 role=tool 的工具返回消息)
@@ -423,9 +454,7 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
if role == "assistant" and "tool_calls" in msg: if role == "assistant" and "tool_calls" in msg:
# 检查这个消息是否包含当前的 tool_callstore_context 自己) # 检查这个消息是否包含当前的 tool_callstore_context 自己)
# 如果包含,跳过不删除(否则会导致 tool 响应孤儿) # 如果包含,跳过不删除(否则会导致 tool 响应孤儿)
contains_current_call = any( contains_current_call = any(tc.get("id") == tc.id for tc in msg.get("tool_calls", []))
tc.get("id") == tc.id for tc in msg.get("tool_calls", [])
)
if contains_current_call: if contains_current_call:
i += 1 i += 1
continue continue
@@ -453,11 +482,13 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
if not indices_to_remove: if not indices_to_remove:
result_msg = "没有找到可存入记忆的消息" result_msg = "没有找到可存入记忆的消息"
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": result_msg, "tool_call_id": tc.id,
}) "content": result_msg,
}
)
return return
# 收集要总结的消息(在删除前) # 收集要总结的消息(在删除前)
@@ -516,38 +547,45 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
chat_history.pop(i) chat_history.pop(i)
i -= 1 i -= 1
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": result_msg, "tool_call_id": tc.id,
}) "content": result_msg,
}
)
async def handle_get_qq_chat_info(tc, chat_history: list): async def handle_get_qq_chat_info(tc, chat_history: list):
"""处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。""" """处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。"""
chat = tc.arguments.get("chat", "") chat = tc.arguments.get("chat", "")
limit = tc.arguments.get("limit", 20) limit = tc.arguments.get("limit", 20)
console.print(f"[accent]🔧 调用工具: get_qq_chat_info(\"{chat}\", limit={limit})[/accent]") console.print(f'[accent]🔧 调用工具: get_qq_chat_info("{chat}", limit={limit})[/accent]')
if not AIOHTTP_AVAILABLE: if not AIOHTTP_AVAILABLE:
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
from config import QQ_API_BASE_URL, QQ_API_KEY from config import QQ_API_BASE_URL, QQ_API_KEY
if not QQ_API_BASE_URL: if not QQ_API_BASE_URL:
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
try: try:
@@ -577,55 +615,66 @@ async def handle_get_qq_chat_info(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": text if text.strip() else "暂无聊天记录", "tool_call_id": tc.id,
}) "content": text if text.strip() else "暂无聊天记录",
}
)
else: else:
error_text = await response.text() error_text = await response.text()
error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}" error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
except Exception as e: except Exception as e:
error_msg = f"获取 QQ 聊天记录失败: {e}" error_msg = f"获取 QQ 聊天记录失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
async def handle_send_info(tc, chat_history: list): async def handle_send_info(tc, chat_history: list):
"""处理 send_info 工具:通过 HTTP 发送消息到 QQ。""" """处理 send_info 工具:通过 HTTP 发送消息到 QQ。"""
chat = tc.arguments.get("chat", "") chat = tc.arguments.get("chat", "")
message = tc.arguments.get("message", "") message = tc.arguments.get("message", "")
console.print(f"[accent]🔧 调用工具: send_info(\"{chat}\")[/accent]") console.print(f'[accent]🔧 调用工具: send_info("{chat}")[/accent]')
if not AIOHTTP_AVAILABLE: if not AIOHTTP_AVAILABLE:
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
from config import QQ_API_BASE_URL, QQ_API_KEY from config import QQ_API_BASE_URL, QQ_API_KEY
if not QQ_API_BASE_URL: if not QQ_API_BASE_URL:
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
try: try:
@@ -654,27 +703,33 @@ async def handle_send_info(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": f"消息发送成功: {data.get('message', '发送成功')}", "tool_call_id": tc.id,
}) "content": f"消息发送成功: {data.get('message', '发送成功')}",
}
)
else: else:
error_msg = f"发送失败: {data.get('message', '未知错误')}" error_msg = f"发送失败: {data.get('message', '未知错误')}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
except Exception as e: except Exception as e:
error_msg = f"发送消息失败: {e}" error_msg = f"发送消息失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
async def handle_list_qq_chats(tc, chat_history: list): async def handle_list_qq_chats(tc, chat_history: list):
@@ -684,22 +739,27 @@ async def handle_list_qq_chats(tc, chat_history: list):
if not AIOHTTP_AVAILABLE: if not AIOHTTP_AVAILABLE:
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
from config import QQ_API_BASE_URL, QQ_API_KEY from config import QQ_API_BASE_URL, QQ_API_KEY
if not QQ_API_BASE_URL: if not QQ_API_BASE_URL:
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
return return
try: try:
@@ -721,10 +781,12 @@ async def handle_list_qq_chats(tc, chat_history: list):
# 格式化聊天列表 # 格式化聊天列表
if chats: if chats:
chat_list_text = "\n".join([ chat_list_text = "\n".join(
f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})" [
for c in chats f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})"
]) for c in chats
]
)
result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}" result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}"
else: else:
result_text = "没有可用的聊天" result_text = "没有可用的聊天"
@@ -738,27 +800,33 @@ async def handle_list_qq_chats(tc, chat_history: list):
) )
) )
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": result_text, "tool_call_id": tc.id,
}) "content": result_text,
}
)
else: else:
error_msg = f"获取失败: {data.get('message', '未知错误')}" error_msg = f"获取失败: {data.get('message', '未知错误')}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
except Exception as e: except Exception as e:
error_msg = f"获取聊天列表失败: {e}" error_msg = f"获取聊天列表失败: {e}"
console.print(f"[error]{error_msg}[/error]") console.print(f"[error]{error_msg}[/error]")
chat_history.append({ chat_history.append(
"role": "tool", {
"tool_call_id": tc.id, "role": "tool",
"content": error_msg, "tool_call_id": tc.id,
}) "content": error_msg,
}
)
# ──────────────────── 初始化 mai_files 目录 ──────────────────── # ──────────────────── 初始化 mai_files 目录 ────────────────────

View File

@@ -847,11 +847,7 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
if not records: if not records:
return [] return []
return [ return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer]
f"问题:{record.question}\n答案:{record.answer}"
for record in records
if record.answer
]
except Exception as e: except Exception as e:
logger.error(f"获取最近已找到答案的记录失败: {e}") logger.error(f"获取最近已找到答案的记录失败: {e}")

View File

@@ -13,4 +13,3 @@ ENV_SESSION_TOKEN = "MAIBOT_SESSION_TOKEN"
ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS" ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
"""Runner 需要加载的插件目录列表os.pathsep 分隔)""" """Runner 需要加载的插件目录列表os.pathsep 分隔)"""

View File

@@ -67,7 +67,9 @@ class CapabilityService:
# 1. 权限校验 # 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation) allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
if not allowed: if not allowed:
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED error_code = (
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
)
return envelope.make_error_response( return envelope.make_error_response(
error_code.value, error_code.value,
reason, reason,

View File

@@ -22,8 +22,13 @@ class RegisteredComponent:
"""已注册的组件条目""" """已注册的组件条目"""
__slots__ = ( __slots__ = (
"name", "full_name", "component_type", "plugin_id", "name",
"metadata", "enabled", "_compiled_pattern", "full_name",
"component_type",
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
) )
def __init__( def __init__(
@@ -165,18 +170,14 @@ class ComponentRegistry:
"""按全名查询。""" """按全名查询。"""
return self._components.get(full_name) return self._components.get(full_name)
def get_components_by_type( def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
self, component_type: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
"""按类型查询。""" """按类型查询。"""
type_dict = self._by_type.get(component_type, {}) type_dict = self._by_type.get(component_type, {})
if enabled_only: if enabled_only:
return [c for c in type_dict.values() if c.enabled] return [c for c in type_dict.values() if c.enabled]
return list(type_dict.values()) return list(type_dict.values())
def get_components_by_plugin( def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
self, plugin_id: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
"""按插件查询。""" """按插件查询。"""
comps = self._by_plugin.get(plugin_id, []) comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps) return [c for c in comps if c.enabled] if enabled_only else list(comps)
@@ -200,9 +201,7 @@ class ComponentRegistry:
return comp, {} return comp, {}
return None return None
def get_event_handlers( def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
self, event_type: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。""" """获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = [] handlers = []
for comp in self._by_type.get("event_handler", {}).values(): for comp in self._by_type.get("event_handler", {}).values():
@@ -213,9 +212,7 @@ class ComponentRegistry:
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True) handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
return handlers return handlers
def get_workflow_steps( def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
self, stage: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。""" """获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = [] steps = []
for comp in self._by_type.get("workflow_step", {}).values(): for comp in self._by_type.get("workflow_step", {}).values():

View File

@@ -22,6 +22,7 @@ InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class EventResult: class EventResult:
"""单个 EventHandler 的执行结果""" """单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result") __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__( def __init__(
@@ -107,9 +108,7 @@ class EventDispatcher:
modified_message = result.modified_message modified_message = result.modified_message
else: else:
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收 # 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task( task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
self._invoke_handler(invoke_fn, handler, args, event_type)
)
self._background_tasks.add(task) self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard) task.add_done_callback(self._background_tasks.discard)

View File

@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Set, Tuple
@dataclass @dataclass
class CapabilityToken: class CapabilityToken:
"""能力令牌""" """能力令牌"""
plugin_id: str plugin_id: str
generation: int generation: int
capabilities: Set[str] = field(default_factory=set) capabilities: Set[str] = field(default_factory=set)

View File

@@ -231,9 +231,7 @@ class RPCServer:
stale_count = 0 stale_count = 0
for _req_id, future in list(self._pending_requests.items()): for _req_id, future in list(self._pending_requests.items()):
if not future.done(): if not future.done():
future.set_exception( future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")
)
stale_count += 1 stale_count += 1
self._pending_requests.clear() self._pending_requests.clear()
if stale_count: if stale_count:
@@ -399,9 +397,7 @@ class RPCServer:
result = await handler(envelope) result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息 # 检查 handler 返回的信封是否包含错误信息
if result is not None and isinstance(result, Envelope) and result.error: if result is not None and isinstance(result, Envelope) and result.error:
logger.warning( logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}"
)
except Exception as e: except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)

View File

@@ -39,6 +39,7 @@ logger = get_logger("plugin_runtime.host.supervisor")
# ─── 日志桥 ────────────────────────────────────────────────────── # ─── 日志桥 ──────────────────────────────────────────────────────
class RunnerLogBridge: class RunnerLogBridge:
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
@@ -80,9 +81,7 @@ class RunnerLogBridge:
stdlib_logging.getLogger(entry.logger_name).handle(record) stdlib_logging.getLogger(entry.logger_name).handle(record)
return envelope.make_response( return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
payload={"accepted": True, "count": len(batch.entries)}
)
class PluginSupervisor: class PluginSupervisor:
@@ -101,8 +100,12 @@ class PluginSupervisor:
): ):
_cfg = global_config.plugin_runtime _cfg = global_config.plugin_runtime
self._plugin_dirs = plugin_dirs or [] self._plugin_dirs = plugin_dirs or []
self._health_interval = health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec self._health_interval = (
self._runner_spawn_timeout = runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
)
self._runner_spawn_timeout = (
runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
)
# 基础设施 # 基础设施
self._transport = create_transport_server(socket_path=socket_path) self._transport = create_transport_server(socket_path=socket_path)
@@ -114,6 +117,7 @@ class PluginSupervisor:
# 编解码 # 编解码
from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.codec import MsgPackCodec
codec = MsgPackCodec() codec = MsgPackCodec()
self._rpc_server = RPCServer( self._rpc_server = RPCServer(
@@ -124,7 +128,9 @@ class PluginSupervisor:
# Runner 子进程 # Runner 子进程
self._runner_process: Optional[asyncio.subprocess.Process] = None self._runner_process: Optional[asyncio.subprocess.Process] = None
self._runner_generation: int = 0 self._runner_generation: int = 0
self._max_restart_attempts: int = max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts self._max_restart_attempts: int = (
max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
)
self._restart_count: int = 0 self._restart_count: int = 0
# 已注册的插件组件信息 # 已注册的插件组件信息
@@ -173,6 +179,7 @@ class PluginSupervisor:
extra_args: Optional[Dict[str, Any]] = None, extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]: ) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler 的快捷方法。""" """分发事件到所有对应 handler 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin( resp = await self.invoke_plugin(
method="plugin.emit_event", method="plugin.emit_event",
@@ -196,6 +203,7 @@ class PluginSupervisor:
context: Optional[WorkflowContext] = None, context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 Workflow Pipeline 的快捷方法。""" """执行 Workflow Pipeline 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin( resp = await self.invoke_plugin(
method="plugin.invoke_workflow_step", method="plugin.invoke_workflow_step",
@@ -415,7 +423,9 @@ class PluginSupervisor:
env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs) env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs)
self._runner_process = await asyncio.create_subprocess_exec( self._runner_process = await asyncio.create_subprocess_exec(
sys.executable, "-m", runner_module, sys.executable,
"-m",
runner_module,
env=env, env=env,
# stdout 不捕获Runner 的日志均通过 IPC 传㛹RunnerIPCLogHandler # stdout 不捕获Runner 的日志均通过 IPC 传㛹RunnerIPCLogHandler
stdout=None, stdout=None,
@@ -557,9 +567,7 @@ class PluginSupervisor:
) )
self._stderr_drain_task = task self._stderr_drain_task = task
task.add_done_callback( task.add_done_callback(
lambda done_task: None lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
if self._stderr_drain_task is not done_task
else self._clear_stderr_drain_task()
) )
def _clear_stderr_drain_task(self) -> None: def _clear_stderr_drain_task(self) -> None:

View File

@@ -44,6 +44,7 @@ HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage" HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort" HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待 # blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
# 从配置文件读取,允许用户调整 # 从配置文件读取,允许用户调整
def _get_blocking_timeout() -> float: def _get_blocking_timeout() -> float:
@@ -52,6 +53,7 @@ def _get_blocking_timeout() -> float:
class ModificationRecord: class ModificationRecord:
"""消息修改记录""" """消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed") __slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None: def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
@@ -141,9 +143,7 @@ class WorkflowExecutor:
try: try:
# PLAN 阶段: 先做 Command 路由 # PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message: if stage == "plan" and current_message:
cmd_result = await self._route_command( cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
command_invoke_fn or invoke_fn, current_message, ctx
)
if cmd_result is not None: if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs # 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result) ctx.set_stage_output("plan", "command_result", cmd_result)
@@ -195,10 +195,10 @@ class WorkflowExecutor:
# 更新消息(仅 blocking hook 有权修改) # 更新消息(仅 blocking hook 有权修改)
if modified: if modified:
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys()) changed_fields = (
ctx.modification_log.append( _diff_keys(current_message, modified) if current_message else list(modified.keys())
ModificationRecord(stage, step.full_name, changed_fields)
) )
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
current_message = modified current_message = modified
if hook_result == HOOK_ABORT: if hook_result == HOOK_ABORT:
@@ -222,9 +222,7 @@ class WorkflowExecutor:
if nonblocking_steps and not skip_stage: if nonblocking_steps and not skip_stage:
nb_tasks = [ nb_tasks = [
asyncio.create_task( asyncio.create_task(
self._invoke_step_fire_and_forget( self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
invoke_fn, step, stage, ctx, current_message
)
) )
for step in nonblocking_steps for step in nonblocking_steps
] ]
@@ -314,12 +312,16 @@ class WorkflowExecutor:
step_start = time.perf_counter() step_start = time.perf_counter()
try: try:
coro = invoke_fn(step.plugin_id, step.name, { coro = invoke_fn(
"stage": stage, step.plugin_id,
"trace_id": ctx.trace_id, step.name,
"message": message, {
"stage_outputs": ctx.stage_outputs, "stage": stage,
}) "trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
resp = await asyncio.wait_for(coro, timeout=timeout_sec) resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start ctx.timings[step_key] = time.perf_counter() - step_start
@@ -355,12 +357,16 @@ class WorkflowExecutor:
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
try: try:
coro = invoke_fn(step.plugin_id, step.name, { coro = invoke_fn(
"stage": stage, step.plugin_id,
"trace_id": ctx.trace_id, step.name,
"message": message, {
"stage_outputs": ctx.stage_outputs, "stage": stage,
}) "trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
await asyncio.wait_for(coro, timeout=timeout_sec) await asyncio.wait_for(coro, timeout=timeout_sec)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)") logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
@@ -393,12 +399,16 @@ class WorkflowExecutor:
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try: try:
return await invoke_fn(matched.plugin_id, matched.name, { return await invoke_fn(
"text": plain_text, matched.plugin_id,
"message": message, matched.name,
"trace_id": ctx.trace_id, {
"matched_groups": matched_groups, "text": plain_text,
}) "message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
},
)
except Exception as e: except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}") ctx.errors.append(f"command:{matched.full_name}: {e}")

View File

@@ -113,9 +113,7 @@ class PluginRuntimeManager:
await self._thirdparty_supervisor.start() await self._thirdparty_supervisor.start()
started_supervisors.append(self._thirdparty_supervisor) started_supervisors.append(self._thirdparty_supervisor)
self._started = True self._started = True
logger.info( logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {thirdparty_dirs or ''}")
f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {thirdparty_dirs or ''}"
)
except Exception as e: except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True) logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
@@ -303,7 +301,9 @@ class PluginRuntimeManager:
cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins) cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins)
cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info) cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info)
cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins) cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins)
cap_service.register_capability("component.list_registered_plugins", self._cap_component_list_registered_plugins) cap_service.register_capability(
"component.list_registered_plugins", self._cap_component_list_registered_plugins
)
cap_service.register_capability("component.enable", self._cap_component_enable) cap_service.register_capability("component.enable", self._cap_component_enable)
cap_service.register_capability("component.disable", self._cap_component_disable) cap_service.register_capability("component.disable", self._cap_component_disable)
cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin) cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin)
@@ -1232,9 +1232,7 @@ class PluginRuntimeManager:
count: int = args.get("count", 1) count: int = args.get("count", 1)
try: try:
results = await emoji_api.get_random(count=count) results = await emoji_api.get_random(count=count)
emojis = [ emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results
]
return {"success": True, "emojis": emojis} return {"success": True, "emojis": emojis}
except Exception as e: except Exception as e:
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
@@ -1269,9 +1267,9 @@ class PluginRuntimeManager:
try: try:
results = await emoji_api.get_all() results = await emoji_api.get_all()
emojis = [ emojis = (
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
] if results else [] )
return {"success": True, "emojis": emojis} return {"success": True, "emojis": emojis}
except Exception as e: except Exception as e:
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)

View File

@@ -12,20 +12,16 @@ class Codec(ABC):
"""消息编解码器基类""" """消息编解码器基类"""
@abstractmethod @abstractmethod
def encode_envelope(self, envelope: Envelope) -> bytes: def encode_envelope(self, envelope: Envelope) -> bytes: ...
...
@abstractmethod @abstractmethod
def decode_envelope(self, data: bytes) -> Envelope: def decode_envelope(self, data: bytes) -> Envelope: ...
...
@abstractmethod @abstractmethod
def encode(self, obj: Dict[str, Any]) -> bytes: def encode(self, obj: Dict[str, Any]) -> bytes: ...
...
@abstractmethod @abstractmethod
def decode(self, data: bytes) -> Dict[str, Any]: def decode(self, data: bytes) -> Dict[str, Any]: ...
...
class MsgPackCodec(Codec): class MsgPackCodec(Codec):

View File

@@ -24,8 +24,10 @@ MAX_SDK_VERSION = "1.99.99"
# ─── 消息类型 ────────────────────────────────────────────────────── # ─── 消息类型 ──────────────────────────────────────────────────────
class MessageType(str, Enum): class MessageType(str, Enum):
"""RPC 消息类型""" """RPC 消息类型"""
REQUEST = "request" REQUEST = "request"
RESPONSE = "response" RESPONSE = "response"
EVENT = "event" EVENT = "event"
@@ -33,6 +35,7 @@ class MessageType(str, Enum):
# ─── 请求 ID 生成器 ─────────────────────────────────────────────── # ─── 请求 ID 生成器 ───────────────────────────────────────────────
class RequestIdGenerator: class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio""" """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
@@ -47,6 +50,7 @@ class RequestIdGenerator:
# ─── Envelope 模型 ───────────────────────────────────────────────── # ─── Envelope 模型 ─────────────────────────────────────────────────
class Envelope(BaseModel): class Envelope(BaseModel):
"""RPC 统一信封 """RPC 统一信封
@@ -75,7 +79,9 @@ class Envelope(BaseModel):
def is_event(self) -> bool: def is_event(self) -> bool:
return self.message_type == MessageType.EVENT return self.message_type == MessageType.EVENT
def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope": def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
) -> "Envelope":
"""基于当前请求创建对应的响应信封""" """基于当前请求创建对应的响应信封"""
return Envelope( return Envelope(
protocol_version=self.protocol_version, protocol_version=self.protocol_version,
@@ -101,8 +107,10 @@ class Envelope(BaseModel):
# ─── 握手消息 ────────────────────────────────────────────────────── # ─── 握手消息 ──────────────────────────────────────────────────────
class HelloPayload(BaseModel): class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload""" """runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识") runner_id: str = Field(description="Runner 进程唯一标识")
sdk_version: str = Field(description="SDK 版本号") sdk_version: str = Field(description="SDK 版本号")
session_token: str = Field(description="一次性会话令牌") session_token: str = Field(description="一次性会话令牌")
@@ -110,6 +118,7 @@ class HelloPayload(BaseModel):
class HelloResponsePayload(BaseModel): class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload""" """runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接") accepted: bool = Field(description="是否接受连接")
host_version: str = Field(default="", description="Host 版本号") host_version: str = Field(default="", description="Host 版本号")
assigned_generation: int = Field(default=0, description="分配的 generation 编号") assigned_generation: int = Field(default=0, description="分配的 generation 编号")
@@ -118,8 +127,10 @@ class HelloResponsePayload(BaseModel):
# ─── 组件注册消息 ────────────────────────────────────────────────── # ─── 组件注册消息 ──────────────────────────────────────────────────
class ComponentDeclaration(BaseModel): class ComponentDeclaration(BaseModel):
"""单个组件声明""" """单个组件声明"""
name: str = Field(description="组件名称") name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler") component_type: str = Field(description="组件类型: action/command/tool/event_handler")
plugin_id: str = Field(description="所属插件 ID") plugin_id: str = Field(description="所属插件 ID")
@@ -128,6 +139,7 @@ class ComponentDeclaration(BaseModel):
class RegisterComponentsPayload(BaseModel): class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload""" """plugin.register_components 请求 payload"""
plugin_id: str = Field(description="插件 ID") plugin_id: str = Field(description="插件 ID")
plugin_version: str = Field(default="1.0.0", description="插件版本") plugin_version: str = Field(default="1.0.0", description="插件版本")
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
@@ -136,36 +148,44 @@ class RegisterComponentsPayload(BaseModel):
# ─── 调用消息 ────────────────────────────────────────────────────── # ─── 调用消息 ──────────────────────────────────────────────────────
class InvokePayload(BaseModel): class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload""" """plugin.invoke_* 请求 payload"""
component_name: str = Field(description="要调用的组件名称") component_name: str = Field(description="要调用的组件名称")
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class InvokeResultPayload(BaseModel): class InvokeResultPayload(BaseModel):
"""plugin.invoke_* 响应 payload""" """plugin.invoke_* 响应 payload"""
success: bool = Field(description="是否成功") success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值") result: Any = Field(default=None, description="返回值")
# ─── 能力调用消息 ────────────────────────────────────────────────── # ─── 能力调用消息 ──────────────────────────────────────────────────
class CapabilityRequestPayload(BaseModel): class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)""" """cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query") capability: str = Field(description="能力名称,如 send.text, db.query")
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class CapabilityResponsePayload(BaseModel): class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload""" """cap.* 响应 payload"""
success: bool = Field(description="是否成功") success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值") result: Any = Field(default=None, description="返回值")
# ─── 健康检查 ────────────────────────────────────────────────────── # ─── 健康检查 ──────────────────────────────────────────────────────
class HealthPayload(BaseModel): class HealthPayload(BaseModel):
"""plugin.health 响应 payload""" """plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康") healthy: bool = Field(description="是否健康")
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表") loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)") uptime_ms: int = Field(default=0, description="运行时长(ms)")
@@ -173,11 +193,13 @@ class HealthPayload(BaseModel):
# ─── 配置更新 ────────────────────────────────────────────────────── # ─── 配置更新 ──────────────────────────────────────────────────────
# TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated # TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated
# 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。 # 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。
# 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。 # 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。
class ConfigUpdatedPayload(BaseModel): class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload""" """plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID") plugin_id: str = Field(description="插件 ID")
config_version: str = Field(description="新配置版本") config_version: str = Field(description="新配置版本")
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容") config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
@@ -185,14 +207,17 @@ class ConfigUpdatedPayload(BaseModel):
# ─── 关停 ────────────────────────────────────────────────────────── # ─── 关停 ──────────────────────────────────────────────────────────
class ShutdownPayload(BaseModel): class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload""" """plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因") reason: str = Field(default="normal", description="关停原因")
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)") drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
# ─── 日志传输 ────────────────────────────────────────────────────── # ─── 日志传输 ──────────────────────────────────────────────────────
class LogEntry(BaseModel): class LogEntry(BaseModel):
"""单条日志记录Runner → Host 传输格式)""" """单条日志记录Runner → Host 传输格式)"""
@@ -200,10 +225,7 @@ class LogEntry(BaseModel):
description="日志时间戳Unix epoch 毫秒", description="日志时间戳Unix epoch 毫秒",
) )
level: int = Field( level: int = Field(
description=( description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
"stdlib logging 整数级别:"
" 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"
),
) )
logger_name: str = Field( logger_name: str = Field(
description="Logger 名称,如 plugin.my_plugin.submodule", description="Logger 名称,如 plugin.my_plugin.submodule",

View File

@@ -22,6 +22,7 @@ Host 端将其重放到主进程的 Logger以 plugin.<name> 为名)中,
- 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送 - 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送
- IPC 发送失败时静默忽略stderr fallback 由 supervisor 的 drain task 覆盖 - IPC 发送失败时静默忽略stderr fallback 由 supervisor 的 drain task 覆盖
""" """
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
@@ -203,6 +204,7 @@ class RunnerIPCLogHandler(logging.Handler):
# IPC 连接断开时回退到 stderr避免日志静默丢失 # IPC 连接断开时回退到 stderr避免日志静默丢失
if not self._rpc_client.is_connected: if not self._rpc_client.is_connected:
import sys import sys
for entry in entries: for entry in entries:
print( print(
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}", f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
@@ -218,6 +220,7 @@ class RunnerIPCLogHandler(logging.Handler):
) )
except Exception: except Exception:
import sys import sys
for entry in entries: for entry in entries:
print( print(
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}", f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",

View File

@@ -105,9 +105,7 @@ class ManifestValidator:
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None: def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
mv = manifest.get("manifest_version") mv = manifest.get("manifest_version")
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
self.errors.append( self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
def _check_author(self, manifest: Dict[str, Any]) -> None: def _check_author(self, manifest: Dict[str, Any]) -> None:
author = manifest.get("author") author = manifest.get("author")

View File

@@ -240,8 +240,7 @@ class PluginLoader:
instance = self._try_load_legacy_plugin(module, plugin_id) instance = self._try_load_legacy_plugin(module, plugin_id)
if instance is not None: if instance is not None:
logger.info( logger.info(
f"插件 {plugin_id} v{manifest.get('version', '?')} " f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
f"通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
) )
return PluginMeta( return PluginMeta(
plugin_id=plugin_id, plugin_id=plugin_id,
@@ -261,6 +260,7 @@ class PluginLoader:
return return
try: try:
from maibot_sdk.compat._import_hook import install_hook from maibot_sdk.compat._import_hook import install_hook
install_hook() install_hook()
self._compat_hook_installed = True self._compat_hook_installed = True
except ImportError: except ImportError:
@@ -281,11 +281,7 @@ class PluginLoader:
for attr_name in dir(module): for attr_name in dir(module):
obj = getattr(module, attr_name, None) obj = getattr(module, attr_name, None)
if ( if isinstance(obj, type) and issubclass(obj, LegacyBasePlugin) and obj is not LegacyBasePlugin:
isinstance(obj, type)
and issubclass(obj, LegacyBasePlugin)
and obj is not LegacyBasePlugin
):
legacy_cls = obj legacy_cls = obj
break break
@@ -294,6 +290,7 @@ class PluginLoader:
try: try:
from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter
legacy_instance = legacy_cls() legacy_instance = legacy_cls()
return LegacyPluginAdapter(legacy_instance) return LegacyPluginAdapter(legacy_instance)
except Exception as e: except Exception as e:

View File

@@ -37,6 +37,7 @@ def _get_sdk_version() -> str:
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。""" """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
try: try:
from importlib.metadata import version from importlib.metadata import version
return version("maibot-plugin-sdk") return version("maibot-plugin-sdk")
except Exception: except Exception:
return "1.0.0" return "1.0.0"

View File

@@ -133,7 +133,9 @@ class PluginRunner:
self._suspend_console_handlers() self._suspend_console_handlers()
stdlib_logging.root.addHandler(handler) stdlib_logging.root.addHandler(handler)
self._log_handler = handler self._log_handler = handler
logger.debug("RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b") logger.debug(
"RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b"
)
async def _uninstall_log_handler(self) -> None: async def _uninstall_log_handler(self) -> None:
"""关停前从 logging.root 移除 Handler 并刷空缓冲。 """关停前从 logging.root 移除 Handler 并刷空缓冲。
@@ -291,7 +293,11 @@ class PluginRunner:
) )
try: try:
result = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) result = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
resp_payload = InvokeResultPayload(success=True, result=result) resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump()) return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e: except Exception as e:
@@ -332,7 +338,11 @@ class PluginRunner:
) )
try: try:
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值:将 EventHandler 返回展平到 payload 顶层 # 规范化返回值:将 EventHandler 返回展平到 payload 顶层
if raw is None: if raw is None:
@@ -341,7 +351,9 @@ class PluginRunner:
result = { result = {
"success": True, "success": True,
# 兼容 guide.md 中文档的 {"blocked": True} 写法 # 兼容 guide.md 中文档的 {"blocked": True} 写法
"continue_processing": not raw.get("blocked", False) if "blocked" in raw else raw.get("continue_processing", True), "continue_processing": not raw.get("blocked", False)
if "blocked" in raw
else raw.get("continue_processing", True),
"modified_message": raw.get("modified_message"), "modified_message": raw.get("modified_message"),
"custom_result": raw.get("custom_result"), "custom_result": raw.get("custom_result"),
} }
@@ -383,7 +395,11 @@ class PluginRunner:
) )
try: try:
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值 # 规范化返回值
if isinstance(raw, str): if isinstance(raw, str):
@@ -455,6 +471,7 @@ class PluginRunner:
# ─── sys.path 隔离 ──────────────────────────────────────── # ─── sys.path 隔离 ────────────────────────────────────────
def _isolate_sys_path(plugin_dirs: List[str]) -> None: def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。 """清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。
@@ -504,9 +521,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
return self if self._should_block(fullname) else None return self if self._should_block(fullname) else None
def load_module(self, fullname): def load_module(self, fullname):
raise ImportError( raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
f"Runner 子进程不允许导入主程序模块: {fullname}"
)
def _should_block(self, fullname: str) -> bool: def _should_block(self, fullname: str) -> bool:
# 放行非 src.* 的导入、以及 "src" 本身 # 放行非 src.* 的导入、以及 "src" 本身
@@ -514,8 +529,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
return False return False
# 放行白名单前缀 # 放行白名单前缀
return not any( return not any(
fullname == prefix or fullname.startswith(f"{prefix}.") fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
for prefix in self._ALLOWED_SRC_PREFIXES
) )
sys.meta_path.insert(0, _PluginImportBlocker()) sys.meta_path.insert(0, _PluginImportBlocker())
@@ -523,6 +537,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
# ─── 进程入口 ────────────────────────────────────────────── # ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None: async def _async_main() -> None:
"""异步主入口""" """异步主入口"""
host_address = os.environ.get(ENV_IPC_ADDRESS, "") host_address = os.environ.get(ENV_IPC_ADDRESS, "")

View File

@@ -20,6 +20,7 @@ MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
"""连接已关闭""" """连接已关闭"""
pass pass

View File

@@ -20,10 +20,12 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
""" """
if sys.platform != "win32": if sys.platform != "win32":
from .uds import UDSTransportServer from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path) return UDSTransportServer(socket_path=socket_path)
else: else:
# Windows 回退到 TCP后续可改为 Named Pipe # Windows 回退到 TCP后续可改为 Named Pipe
from .tcp import TCPTransportServer from .tcp import TCPTransportServer
return TCPTransportServer() return TCPTransportServer()
@@ -39,9 +41,11 @@ def create_transport_client(address: str) -> TransportClient:
""" """
if "/" in address or address.endswith(".sock"): if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient from .uds import UDSTransportClient
return UDSTransportClient(socket_path=address) return UDSTransportClient(socket_path=address)
elif ":" in address: elif ":" in address:
from .tcp import TCPTransportClient from .tcp import TCPTransportClient
host, port_str = address.rsplit(":", 1) host, port_str = address.rsplit(":", 1)
return TCPTransportClient(host=host, port=int(port_str)) return TCPTransportClient(host=host, port=int(port_str))
else: else:

View File

@@ -13,6 +13,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
class TCPConnection(Connection): class TCPConnection(Connection):
"""基于 TCP 的连接""" """基于 TCP 的连接"""
pass pass

View File

@@ -15,6 +15,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
class UDSConnection(Connection): class UDSConnection(Connection):
"""基于 UDS 的连接""" """基于 UDS 的连接"""
pass # 直接复用 Connection 基类的分帧读写 pass # 直接复用 Connection 基类的分帧读写
@@ -30,16 +31,17 @@ class UDSTransportServer(TransportServer):
if socket_path is None: if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid import uuid
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
socket_path = os.path.join(
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
)
# 如果路径超出 UDS 限制,回退到更短的路径 # 如果路径超出 UDS 限制,回退到更短的路径
if len(socket_path.encode()) > _UDS_PATH_MAX: if len(socket_path.encode()) > _UDS_PATH_MAX:
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
if len(socket_path.encode()) > _UDS_PATH_MAX: if len(socket_path.encode()) > _UDS_PATH_MAX:
raise OSError( raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}"
)
self._socket_path = socket_path self._socket_path = socket_path
self._server: Optional[asyncio.AbstractServer] = None self._server: Optional[asyncio.AbstractServer] = None

View File

@@ -14,8 +14,16 @@ class KnowledgePlugin(MaiBotPlugin):
"lpmm_search_knowledge", "lpmm_search_knowledge",
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具", description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
parameters=[ parameters=[
ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True), ToolParameterInfo(
ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数默认5", required=False, default=5), name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True
),
ToolParameterInfo(
name="limit",
param_type=ToolParamType.INTEGER,
description="希望返回的相关知识条数默认5",
required=False,
default=5,
),
], ],
) )
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs): async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):

View File

@@ -231,9 +231,7 @@ class PluginManagementPlugin(MaiBotPlugin):
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered) text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id) await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
async def _handle_component_toggle( async def _handle_component_toggle(self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str):
self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str
):
if action not in ("enable", "disable"): if action not in ("enable", "disable"):
await self.ctx.send.text("插件管理命令不合法", stream_id) await self.ctx.send.text("插件管理命令不合法", stream_id)
return return
@@ -245,13 +243,9 @@ class PluginManagementPlugin(MaiBotPlugin):
return return
if action == "enable": if action == "enable":
result = await self.ctx.component.enable_component( result = await self.ctx.component.enable_component(comp_name, comp_type, scope=scope, stream_id=stream_id)
comp_name, comp_type, scope=scope, stream_id=stream_id
)
else: else:
result = await self.ctx.component.disable_component( result = await self.ctx.component.disable_component(comp_name, comp_type, scope=scope, stream_id=stream_id)
comp_name, comp_type, scope=scope, stream_id=stream_id
)
ok = result.get("success", False) if isinstance(result, dict) else bool(result) ok = result.get("success", False) if isinstance(result, dict) else bool(result)
scope_label = "全局" if scope == "global" else "本地" scope_label = "全局" if scope == "global" else "本地"

View File

@@ -142,13 +142,21 @@ class ChatManager:
if chat_stream.is_group_session: if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id info["group_id"] = chat_stream.group_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info: if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.group_info
):
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else: else:
info["group_name"] = "未知群聊" info["group_name"] = "未知群聊"
else: else:
info["user_id"] = chat_stream.user_id info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info: if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.user_info
):
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else: else:
info["user_name"] = "未知用户" info["user_name"] = "未知用户"

View File

@@ -44,7 +44,9 @@ def get_replyer(
if not chat_id and not chat_stream: if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空") raise ValueError("chat_stream 和 chat_id 不可均为空")
try: try:
logger.debug(f"[GeneratorService] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}") logger.debug(
f"[GeneratorService] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}"
)
return replyer_manager.get_replyer( return replyer_manager.get_replyer(
chat_stream=chat_stream, chat_stream=chat_stream,
chat_id=chat_id, chat_id=chat_id,

View File

@@ -70,10 +70,10 @@ async def _send_to_target(
if reply_message: if reply_message:
anchor_message = db_message_to_mai_message(reply_message) anchor_message = db_message_to_mai_message(reply_message)
if anchor_message: if anchor_message:
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") logger.debug(
reply_to_platform_id = ( f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
) )
reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
sender_info = None sender_info = None
if target_stream.context and target_stream.context.message: if target_stream.context and target_stream.context.message:

View File

@@ -174,4 +174,4 @@ async def broadcast_log(log_data: dict):
# 清理断开的连接 # 清理断开的连接
if disconnected: if disconnected:
active_connections.difference_update(disconnected) active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接") logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@@ -551,9 +551,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
try: try:
# 1. 表情包之王 - 使用次数最多的表情包 # 1. 表情包之王 - 使用次数最多的表情包
with get_db_session() as session: with get_db_session() as session:
statement = ( statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
)
top_emojis = session.exec(statement).all() top_emojis = session.exec(statement).all()
if top_emojis: if top_emojis:
data.top_emoji = { data.top_emoji = {

View File

@@ -314,17 +314,13 @@ async def get_jargon_stats():
with get_db_session() as session: with get_db_session() as session:
total = session.exec(select(fn.count()).select_from(Jargon)).one() total = session.exec(select(fn.count()).select_from(Jargon)).one()
confirmed_jargon = session.exec( confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one()
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))
).one()
confirmed_not_jargon = session.exec( confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False)) select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
).one() ).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
complete_count = session.exec( complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one()
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))
).one()
chat_count = session.exec( chat_count = session.exec(
select(fn.count()).select_from( select(fn.count()).select_from(