Ruff Format
This commit is contained in:
@@ -14,9 +14,7 @@ class BetterFrequencyPlugin(MaiBotPlugin):
|
|||||||
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
||||||
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
|
pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$",
|
||||||
)
|
)
|
||||||
async def handle_set_talk_frequency(
|
async def handle_set_talk_frequency(self, stream_id: str = "", matched_groups: dict | None = None, **kwargs):
|
||||||
self, stream_id: str = "", matched_groups: dict | None = None, **kwargs
|
|
||||||
):
|
|
||||||
"""设置当前聊天的 talk_frequency"""
|
"""设置当前聊天的 talk_frequency"""
|
||||||
if not matched_groups or "value" not in matched_groups:
|
if not matched_groups or "value" not in matched_groups:
|
||||||
return False, "命令格式错误", False
|
return False, "命令格式错误", False
|
||||||
|
|||||||
@@ -116,7 +116,9 @@ class HelloWorldPlugin(MaiBotPlugin):
|
|||||||
print(f"接收到消息: {raw}")
|
print(f"接收到消息: {raw}")
|
||||||
return True, True, "消息已打印", None, None
|
return True, True, "消息已打印", None, None
|
||||||
|
|
||||||
@EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE)
|
@EventHandler(
|
||||||
|
"forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE
|
||||||
|
)
|
||||||
async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
|
async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs):
|
||||||
"""收集消息并定期转发"""
|
"""收集消息并定期转发"""
|
||||||
if not message:
|
if not message:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "ma
|
|||||||
|
|
||||||
# ─── 协议层测试 ───────────────────────────────────────────
|
# ─── 协议层测试 ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestProtocol:
|
class TestProtocol:
|
||||||
"""协议层测试"""
|
"""协议层测试"""
|
||||||
|
|
||||||
@@ -111,6 +112,7 @@ class TestProtocol:
|
|||||||
|
|
||||||
# ─── 传输层测试 ───────────────────────────────────────────
|
# ─── 传输层测试 ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestTransport:
|
class TestTransport:
|
||||||
"""传输层测试"""
|
"""传输层测试"""
|
||||||
|
|
||||||
@@ -198,6 +200,7 @@ class TestTransport:
|
|||||||
|
|
||||||
# ─── Host 层测试 ──────────────────────────────────────────
|
# ─── Host 层测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestHost:
|
class TestHost:
|
||||||
"""Host 端基础设施测试"""
|
"""Host 端基础设施测试"""
|
||||||
|
|
||||||
@@ -244,6 +247,7 @@ class TestHost:
|
|||||||
|
|
||||||
# ─── SDK 测试 ─────────────────────────────────────────────
|
# ─── SDK 测试 ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestSDK:
|
class TestSDK:
|
||||||
"""SDK 框架测试"""
|
"""SDK 框架测试"""
|
||||||
|
|
||||||
@@ -312,6 +316,7 @@ class TestSDK:
|
|||||||
|
|
||||||
# ─── 端到端集成测试 ────────────────────────────────────────
|
# ─── 端到端集成测试 ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestE2E:
|
class TestE2E:
|
||||||
"""端到端集成测试(Host + Runner 通信)"""
|
"""端到端集成测试(Host + Runner 通信)"""
|
||||||
|
|
||||||
@@ -392,6 +397,7 @@ class TestE2E:
|
|||||||
|
|
||||||
# ─── Manifest 校验测试 ─────────────────────────────────────
|
# ─── Manifest 校验测试 ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestManifestValidator:
|
class TestManifestValidator:
|
||||||
"""Manifest 校验器测试"""
|
"""Manifest 校验器测试"""
|
||||||
|
|
||||||
@@ -489,6 +495,7 @@ class TestVersionComparator:
|
|||||||
|
|
||||||
# ─── 依赖解析测试 ──────────────────────────────────────────
|
# ─── 依赖解析测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestDependencyResolution:
|
class TestDependencyResolution:
|
||||||
"""插件依赖解析测试"""
|
"""插件依赖解析测试"""
|
||||||
|
|
||||||
@@ -498,8 +505,16 @@ class TestDependencyResolution:
|
|||||||
loader = PluginLoader()
|
loader = PluginLoader()
|
||||||
candidates = {
|
candidates = {
|
||||||
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
|
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
|
||||||
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"),
|
"auth": (
|
||||||
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"),
|
"dir_auth",
|
||||||
|
{"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]},
|
||||||
|
"plugin.py",
|
||||||
|
),
|
||||||
|
"api": (
|
||||||
|
"dir_api",
|
||||||
|
{"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]},
|
||||||
|
"plugin.py",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
order, failed = loader._resolve_dependencies(candidates)
|
order, failed = loader._resolve_dependencies(candidates)
|
||||||
@@ -512,7 +527,17 @@ class TestDependencyResolution:
|
|||||||
|
|
||||||
loader = PluginLoader()
|
loader = PluginLoader()
|
||||||
candidates = {
|
candidates = {
|
||||||
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"),
|
"plugin_a": (
|
||||||
|
"dir_a",
|
||||||
|
{
|
||||||
|
"name": "plugin_a",
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "d",
|
||||||
|
"author": "a",
|
||||||
|
"dependencies": ["nonexistent"],
|
||||||
|
},
|
||||||
|
"plugin.py",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
order, failed = loader._resolve_dependencies(candidates)
|
order, failed = loader._resolve_dependencies(candidates)
|
||||||
@@ -524,8 +549,16 @@ class TestDependencyResolution:
|
|||||||
|
|
||||||
loader = PluginLoader()
|
loader = PluginLoader()
|
||||||
candidates = {
|
candidates = {
|
||||||
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"),
|
"a": (
|
||||||
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"),
|
"dir_a",
|
||||||
|
{"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]},
|
||||||
|
"p.py",
|
||||||
|
),
|
||||||
|
"b": (
|
||||||
|
"dir_b",
|
||||||
|
{"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]},
|
||||||
|
"p.py",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
order, failed = loader._resolve_dependencies(candidates)
|
order, failed = loader._resolve_dependencies(candidates)
|
||||||
@@ -534,6 +567,7 @@ class TestDependencyResolution:
|
|||||||
|
|
||||||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestComponentRegistry:
|
class TestComponentRegistry:
|
||||||
"""Host-side 组件注册表测试"""
|
"""Host-side 组件注册表测试"""
|
||||||
|
|
||||||
@@ -541,17 +575,32 @@ class TestComponentRegistry:
|
|||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("greet", "action", "plugin_a", {
|
reg.register_component(
|
||||||
"description": "打招呼",
|
"greet",
|
||||||
"activation_type": "keyword",
|
"action",
|
||||||
"activation_keywords": ["hi"],
|
"plugin_a",
|
||||||
})
|
{
|
||||||
reg.register_component("help", "command", "plugin_a", {
|
"description": "打招呼",
|
||||||
"command_pattern": r"^/help",
|
"activation_type": "keyword",
|
||||||
})
|
"activation_keywords": ["hi"],
|
||||||
reg.register_component("search", "tool", "plugin_b", {
|
},
|
||||||
"description": "搜索",
|
)
|
||||||
})
|
reg.register_component(
|
||||||
|
"help",
|
||||||
|
"command",
|
||||||
|
"plugin_a",
|
||||||
|
{
|
||||||
|
"command_pattern": r"^/help",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
reg.register_component(
|
||||||
|
"search",
|
||||||
|
"tool",
|
||||||
|
"plugin_b",
|
||||||
|
{
|
||||||
|
"description": "搜索",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
stats = reg.get_stats()
|
stats = reg.get_stats()
|
||||||
assert stats["total"] == 3
|
assert stats["total"] == 3
|
||||||
@@ -573,12 +622,22 @@ class TestComponentRegistry:
|
|||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("help", "command", "p1", {
|
reg.register_component(
|
||||||
"command_pattern": r"^/help",
|
"help",
|
||||||
})
|
"command",
|
||||||
reg.register_component("echo", "command", "p1", {
|
"p1",
|
||||||
"command_pattern": r"^/echo\s",
|
{
|
||||||
})
|
"command_pattern": r"^/help",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
reg.register_component(
|
||||||
|
"echo",
|
||||||
|
"command",
|
||||||
|
"p1",
|
||||||
|
{
|
||||||
|
"command_pattern": r"^/echo\s",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
match = reg.find_command_by_text("/help me")
|
match = reg.find_command_by_text("/help me")
|
||||||
assert match is not None
|
assert match is not None
|
||||||
@@ -622,12 +681,24 @@ class TestComponentRegistry:
|
|||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("h_low", "event_handler", "p1", {
|
reg.register_component(
|
||||||
"event_type": "on_message", "weight": 10,
|
"h_low",
|
||||||
})
|
"event_handler",
|
||||||
reg.register_component("h_high", "event_handler", "p2", {
|
"p1",
|
||||||
"event_type": "on_message", "weight": 100,
|
{
|
||||||
})
|
"event_type": "on_message",
|
||||||
|
"weight": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
reg.register_component(
|
||||||
|
"h_high",
|
||||||
|
"event_handler",
|
||||||
|
"p2",
|
||||||
|
{
|
||||||
|
"event_type": "on_message",
|
||||||
|
"weight": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
handlers = reg.get_event_handlers("on_message")
|
handlers = reg.get_event_handlers("on_message")
|
||||||
assert handlers[0].name == "h_high"
|
assert handlers[0].name == "h_high"
|
||||||
@@ -637,10 +708,15 @@ class TestComponentRegistry:
|
|||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("search", "tool", "p1", {
|
reg.register_component(
|
||||||
"description": "搜索工具",
|
"search",
|
||||||
"parameters_raw": {"query": {"type": "string"}},
|
"tool",
|
||||||
})
|
"p1",
|
||||||
|
{
|
||||||
|
"description": "搜索工具",
|
||||||
|
"parameters_raw": {"query": {"type": "string"}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tools = reg.get_tools_for_llm()
|
tools = reg.get_tools_for_llm()
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
@@ -650,6 +726,7 @@ class TestComponentRegistry:
|
|||||||
|
|
||||||
# ─── EventDispatcher 测试 ─────────────────────────────────
|
# ─── EventDispatcher 测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestEventDispatcher:
|
class TestEventDispatcher:
|
||||||
"""Host-side 事件分发器测试"""
|
"""Host-side 事件分发器测试"""
|
||||||
|
|
||||||
@@ -659,11 +736,16 @@ class TestEventDispatcher:
|
|||||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("h1", "event_handler", "p1", {
|
reg.register_component(
|
||||||
"event_type": "on_start",
|
"h1",
|
||||||
"weight": 0,
|
"event_handler",
|
||||||
"intercept_message": False,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"event_type": "on_start",
|
||||||
|
"weight": 0,
|
||||||
|
"intercept_message": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
dispatcher = EventDispatcher(reg)
|
dispatcher = EventDispatcher(reg)
|
||||||
call_log = []
|
call_log = []
|
||||||
@@ -672,9 +754,7 @@ class TestEventDispatcher:
|
|||||||
call_log.append((plugin_id, comp_name))
|
call_log.append((plugin_id, comp_name))
|
||||||
return {"success": True, "continue_processing": True}
|
return {"success": True, "continue_processing": True}
|
||||||
|
|
||||||
should_continue, modified = await dispatcher.dispatch_event(
|
should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke)
|
||||||
"on_start", mock_invoke
|
|
||||||
)
|
|
||||||
assert should_continue
|
assert should_continue
|
||||||
# 非阻塞分发是异步的,等一下让 task 完成
|
# 非阻塞分发是异步的,等一下让 task 完成
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
@@ -687,11 +767,16 @@ class TestEventDispatcher:
|
|||||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("filter", "event_handler", "p1", {
|
reg.register_component(
|
||||||
"event_type": "on_message_pre_process",
|
"filter",
|
||||||
"weight": 100,
|
"event_handler",
|
||||||
"intercept_message": True,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"event_type": "on_message_pre_process",
|
||||||
|
"weight": 100,
|
||||||
|
"intercept_message": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
dispatcher = EventDispatcher(reg)
|
dispatcher = EventDispatcher(reg)
|
||||||
|
|
||||||
@@ -755,6 +840,7 @@ class TestEventBus:
|
|||||||
|
|
||||||
# ─── MaiMessages 测试 ─────────────────────────────────────
|
# ─── MaiMessages 测试 ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestMaiMessages:
|
class TestMaiMessages:
|
||||||
"""统一消息模型测试"""
|
"""统一消息模型测试"""
|
||||||
|
|
||||||
@@ -799,6 +885,7 @@ class TestMaiMessages:
|
|||||||
|
|
||||||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowExecutor:
|
class TestWorkflowExecutor:
|
||||||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||||||
|
|
||||||
@@ -829,11 +916,16 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("upper", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "pre_process",
|
"upper",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"stage": "pre_process",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
async def mock_invoke(plugin_id, comp_name, args):
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
@@ -859,11 +951,16 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("blocker", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "pre_process",
|
"blocker",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"stage": "pre_process",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
async def mock_invoke(plugin_id, comp_name, args):
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
@@ -884,17 +981,27 @@ class TestWorkflowExecutor:
|
|||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
# high-priority hook 返回 skip_stage
|
# high-priority hook 返回 skip_stage
|
||||||
reg.register_component("skipper", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"skipper",
|
||||||
"priority": 100,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"stage": "ingress",
|
||||||
|
"priority": 100,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
# low-priority hook 不应被执行
|
# low-priority hook 不应被执行
|
||||||
reg.register_component("checker", "workflow_step", "p2", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"checker",
|
||||||
"priority": 1,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p2",
|
||||||
})
|
{
|
||||||
|
"stage": "ingress",
|
||||||
|
"priority": 1,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
call_log = []
|
call_log = []
|
||||||
@@ -905,9 +1012,7 @@ class TestWorkflowExecutor:
|
|||||||
return {"hook_result": "skip_stage"}
|
return {"hook_result": "skip_stage"}
|
||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
result, _, _ = await executor.execute(
|
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||||
mock_invoke, message={"plain_text": "test"}
|
|
||||||
)
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
# 只有 skipper 被调用,checker 被跳过
|
# 只有 skipper 被调用,checker 被跳过
|
||||||
assert call_log == ["skipper"]
|
assert call_log == ["skipper"]
|
||||||
@@ -919,12 +1024,17 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("only_dm", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"only_dm",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
"filter": {"chat_type": "direct"},
|
{
|
||||||
})
|
"stage": "ingress",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
"filter": {"chat_type": "direct"},
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
call_log = []
|
call_log = []
|
||||||
@@ -934,15 +1044,11 @@ class TestWorkflowExecutor:
|
|||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
# 不匹配 filter —— hook 不应被调用
|
# 不匹配 filter —— hook 不应被调用
|
||||||
await executor.execute(
|
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
|
||||||
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
|
||||||
)
|
|
||||||
assert not call_log
|
assert not call_log
|
||||||
|
|
||||||
# 匹配 filter —— hook 应被调用
|
# 匹配 filter —— hook 应被调用
|
||||||
await executor.execute(
|
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
|
||||||
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
|
|
||||||
)
|
|
||||||
assert call_log == ["only_dm"]
|
assert call_log == ["only_dm"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -952,17 +1058,27 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("failer", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"failer",
|
||||||
"priority": 100,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
"error_policy": "skip",
|
{
|
||||||
})
|
"stage": "ingress",
|
||||||
reg.register_component("ok_step", "workflow_step", "p2", {
|
"priority": 100,
|
||||||
"stage": "ingress",
|
"blocking": True,
|
||||||
"priority": 1,
|
"error_policy": "skip",
|
||||||
"blocking": True,
|
},
|
||||||
})
|
)
|
||||||
|
reg.register_component(
|
||||||
|
"ok_step",
|
||||||
|
"workflow_step",
|
||||||
|
"p2",
|
||||||
|
{
|
||||||
|
"stage": "ingress",
|
||||||
|
"priority": 1,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
call_log = []
|
call_log = []
|
||||||
@@ -973,9 +1089,7 @@ class TestWorkflowExecutor:
|
|||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
result, _, ctx = await executor.execute(
|
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||||
mock_invoke, message={"plain_text": "test"}
|
|
||||||
)
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert "failer" in call_log
|
assert "failer" in call_log
|
||||||
assert "ok_step" in call_log
|
assert "ok_step" in call_log
|
||||||
@@ -988,20 +1102,23 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("failer", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"failer",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
# error_policy defaults to "abort"
|
{
|
||||||
})
|
"stage": "ingress",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
# error_policy defaults to "abort"
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
async def mock_invoke(plugin_id, comp_name, args):
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
raise RuntimeError("fatal")
|
raise RuntimeError("fatal")
|
||||||
|
|
||||||
result, _, ctx = await executor.execute(
|
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||||
mock_invoke, message={"plain_text": "test"}
|
|
||||||
)
|
|
||||||
assert result.status == "failed"
|
assert result.status == "failed"
|
||||||
assert result.stopped_at == "ingress"
|
assert result.stopped_at == "ingress"
|
||||||
|
|
||||||
@@ -1013,11 +1130,16 @@ class TestWorkflowExecutor:
|
|||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", {
|
reg.register_component(
|
||||||
"stage": "post_process",
|
f"nb_{i}",
|
||||||
"priority": 0,
|
"workflow_step",
|
||||||
"blocking": False,
|
f"p{i}",
|
||||||
})
|
{
|
||||||
|
"stage": "post_process",
|
||||||
|
"priority": 0,
|
||||||
|
"blocking": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
call_log = []
|
call_log = []
|
||||||
@@ -1026,9 +1148,7 @@ class TestWorkflowExecutor:
|
|||||||
call_log.append(comp_name)
|
call_log.append(comp_name)
|
||||||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||||||
|
|
||||||
result, final_msg, _ = await executor.execute(
|
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||||
mock_invoke, message={"plain_text": "original"}
|
|
||||||
)
|
|
||||||
# non-blocking 的 modified_message 被忽略
|
# non-blocking 的 modified_message 被忽略
|
||||||
assert final_msg["plain_text"] == "original"
|
assert final_msg["plain_text"] == "original"
|
||||||
# 给异步 task 时间完成
|
# 给异步 task 时间完成
|
||||||
@@ -1042,9 +1162,14 @@ class TestWorkflowExecutor:
|
|||||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
reg.register_component("help", "command", "p1", {
|
reg.register_component(
|
||||||
"command_pattern": r"^/help",
|
"help",
|
||||||
})
|
"command",
|
||||||
|
"p1",
|
||||||
|
{
|
||||||
|
"command_pattern": r"^/help",
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
async def mock_invoke(plugin_id, comp_name, args):
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
@@ -1052,9 +1177,7 @@ class TestWorkflowExecutor:
|
|||||||
return {"output": "帮助信息"}
|
return {"output": "帮助信息"}
|
||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
result, _, ctx = await executor.execute(
|
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
|
||||||
mock_invoke, message={"plain_text": "/help topic"}
|
|
||||||
)
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert ctx.matched_command == "p1.help"
|
assert ctx.matched_command == "p1.help"
|
||||||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||||||
@@ -1069,17 +1192,27 @@ class TestWorkflowExecutor:
|
|||||||
|
|
||||||
reg = ComponentRegistry()
|
reg = ComponentRegistry()
|
||||||
# ingress 阶段写入数据
|
# ingress 阶段写入数据
|
||||||
reg.register_component("writer", "workflow_step", "p1", {
|
reg.register_component(
|
||||||
"stage": "ingress",
|
"writer",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p1",
|
||||||
})
|
{
|
||||||
|
"stage": "ingress",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
# pre_process 阶段读取数据
|
# pre_process 阶段读取数据
|
||||||
reg.register_component("reader", "workflow_step", "p2", {
|
reg.register_component(
|
||||||
"stage": "pre_process",
|
"reader",
|
||||||
"priority": 10,
|
"workflow_step",
|
||||||
"blocking": True,
|
"p2",
|
||||||
})
|
{
|
||||||
|
"stage": "pre_process",
|
||||||
|
"priority": 10,
|
||||||
|
"blocking": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
executor = WorkflowExecutor(reg)
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
async def mock_invoke(plugin_id, comp_name, args):
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
@@ -1096,9 +1229,7 @@ class TestWorkflowExecutor:
|
|||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
return {"hook_result": "continue"}
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
result, _, ctx = await executor.execute(
|
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
|
||||||
mock_invoke, message={"plain_text": "hi"}
|
|
||||||
)
|
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||||
|
|
||||||
@@ -1324,10 +1455,15 @@ class TestIntegration:
|
|||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.stopped = True
|
self.stopped = True
|
||||||
|
|
||||||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]))
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]))
|
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
|
||||||
|
)
|
||||||
|
|
||||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||||
|
|
||||||
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
|
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
|
||||||
|
|
||||||
manager = integration_module.PluginRuntimeManager()
|
manager = integration_module.PluginRuntimeManager()
|
||||||
|
|||||||
@@ -9,16 +9,12 @@ from datetime import datetime
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
|
from src.common.data_models.message_component_data_model import MessageSequence
|
||||||
from src.chat.message_receive.message import (
|
from src.chat.message_receive.message import (
|
||||||
SessionMessage,
|
SessionMessage,
|
||||||
TextComponent,
|
TextComponent,
|
||||||
ImageComponent,
|
ImageComponent,
|
||||||
EmojiComponent,
|
|
||||||
VoiceComponent,
|
|
||||||
AtComponent,
|
AtComponent,
|
||||||
ReplyComponent,
|
|
||||||
ForwardNodeComponent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -203,17 +199,23 @@ def load_message_via_file(monkeypatch):
|
|||||||
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
||||||
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
||||||
|
|
||||||
|
|
||||||
def dummy_is_bot_self(user_id: str, platform) -> bool:
|
def dummy_is_bot_self(user_id: str, platform) -> bool:
|
||||||
return user_id == "bot_self"
|
return user_id == "bot_self"
|
||||||
|
|
||||||
|
|
||||||
def load_utils_via_file(monkeypatch):
|
def load_utils_via_file(monkeypatch):
|
||||||
setup_mocks(monkeypatch)
|
setup_mocks(monkeypatch)
|
||||||
|
|
||||||
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
|
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
|
||||||
math_utils_mod = ModuleType("src.common.utils.math_utils")
|
math_utils_mod = ModuleType("src.common.utils.math_utils")
|
||||||
math_utils_mod.number_to_short_id = dummy_number_to_short_id
|
math_utils_mod.number_to_short_id = dummy_number_to_short_id
|
||||||
math_utils_mod.TimestampMode = type("TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"})
|
math_utils_mod.TimestampMode = type(
|
||||||
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: "2024-01-01 12:00:00" # 返回固定的时间字符串
|
"TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
|
||||||
|
)
|
||||||
|
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
|
||||||
|
"2024-01-01 12:00:00"
|
||||||
|
) # 返回固定的时间字符串
|
||||||
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
|
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
|
||||||
|
|
||||||
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
|
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
|
||||||
@@ -349,6 +351,7 @@ async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(
|
|||||||
assert "u_comb" in mapping
|
assert "u_comb" in mapping
|
||||||
assert mapping["u_comb"][0] == "XXXXXX"
|
assert mapping["u_comb"][0] == "XXXXXX"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_build_readable_message_with_at(monkeypatch):
|
async def test_build_readable_message_with_at(monkeypatch):
|
||||||
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
|
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
|
||||||
@@ -363,7 +366,9 @@ async def test_build_readable_message_with_at(monkeypatch):
|
|||||||
msg.session_id = "s_at"
|
msg.session_id = "s_at"
|
||||||
msg.raw_message = MessageSequence([at_comp])
|
msg.raw_message = MessageSequence([at_comp])
|
||||||
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
|
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
|
||||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot")
|
text, mapping, _ = await MessageUtils.build_readable_message(
|
||||||
|
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
|
||||||
|
)
|
||||||
# 验证主消息和@组件中的用户信息都被处理
|
# 验证主消息和@组件中的用户信息都被处理
|
||||||
assert "XXXXXX说:" in text # 主消息用户被匿名化
|
assert "XXXXXX说:" in text # 主消息用户被匿名化
|
||||||
assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化
|
assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化
|
||||||
|
|||||||
@@ -22,11 +22,7 @@ def should_skip(path: Path) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def iter_python_files(root: Path) -> list[Path]:
|
def iter_python_files(root: Path) -> list[Path]:
|
||||||
return sorted(
|
return sorted(path for path in root.rglob("*.py") if path.is_file() and not should_skip(path.relative_to(root)))
|
||||||
path
|
|
||||||
for path in root.rglob("*.py")
|
|
||||||
if path.is_file() and not should_skip(path.relative_to(root))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CandidateExtractor(ast.NodeVisitor):
|
class CandidateExtractor(ast.NodeVisitor):
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|||||||
from src.common.logger import initialize_logging, get_logger
|
from src.common.logger import initialize_logging, get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import db
|
||||||
from src.common.database.database_model import LLMUsage
|
from src.common.database.database_model import LLMUsage
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from maim_message import UserInfo, GroupInfo
|
from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
logger = get_logger("test_memory_retrieval")
|
logger = get_logger("test_memory_retrieval")
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from src.common.logger import get_logger
|
|||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.bw_learner.learner_utils_old import weighted_sample
|
from src.bw_learner.learner_utils_old import weighted_sample
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.common_utils import TempMethodsExpression
|
from src.chat.utils.common_utils import TempMethodsExpression
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|||||||
@@ -1,22 +1,14 @@
|
|||||||
import re
|
|
||||||
import difflib
|
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import (
|
|
||||||
build_readable_messages,
|
|
||||||
)
|
|
||||||
from src.chat.utils.utils import parse_platform_accounts
|
|
||||||
from json_repair import repair_json
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("learner_utils")
|
logger = get_logger("learner_utils")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
根据表达的count计算权重,范围限定在1~5之间。
|
根据表达的count计算权重,范围限定在1~5之间。
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from .action_planner import ActionPlanner
|
|||||||
from .observation_info import ObservationInfo
|
from .observation_info import ObservationInfo
|
||||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||||
from .reply_generator import ReplyGenerator
|
from .reply_generator import ReplyGenerator
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||||
from .waiter import Waiter
|
from .waiter import Waiter
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from src.config.config import global_config
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.common.data_models.message_data_model import ReplyContentType
|
from src.common.data_models.message_data_model import ReplyContentType
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||||
@@ -22,7 +22,12 @@ from src.person_info.person_info import Person
|
|||||||
from src.core.types import ActionInfo, EventType
|
from src.core.types import ActionInfo, EventType
|
||||||
from src.core.event_bus import event_bus
|
from src.core.event_bus import event_bus
|
||||||
from src.chat.event_helpers import build_event_message
|
from src.chat.event_helpers import build_event_message
|
||||||
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api
|
from src.services import (
|
||||||
|
generator_service as generator_api,
|
||||||
|
send_service as send_api,
|
||||||
|
message_service as message_api,
|
||||||
|
database_service as database_api,
|
||||||
|
)
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
@@ -294,10 +299,10 @@ class BrainChatting:
|
|||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
prompt_key="brain_planner",
|
prompt_key="brain_planner",
|
||||||
)
|
)
|
||||||
_event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id)
|
_event_msg = build_event_message(
|
||||||
continue_flag, modified_message = await event_bus.emit(
|
EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id
|
||||||
EventType.ON_PLAN, _event_msg
|
|
||||||
)
|
)
|
||||||
|
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg)
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
return False
|
return False
|
||||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Optional, List, TYPE_CHECKING, Tuple, Dict
|
from typing import Optional, List, TYPE_CHECKING
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -14,7 +14,7 @@ from src.chat.message_receive.chat_manager import chat_manager
|
|||||||
from src.bw_learner.expression_learner import ExpressionLearner
|
from src.bw_learner.expression_learner import ExpressionLearner
|
||||||
from src.bw_learner.jargon_miner import JargonMiner
|
from src.bw_learner.jargon_miner import JargonMiner
|
||||||
|
|
||||||
from .heartFC_utils import CycleDetail, CycleActionInfo, CyclePlanInfo
|
from .heartFC_utils import CycleDetail
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.message import SessionMessage
|
from src.chat.message_receive.message import SessionMessage
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Dict
|
from typing import Dict
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@@ -9,6 +9,7 @@ from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
|||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
|
|
||||||
|
|
||||||
# TODO: 恢复PFC,现在暂时禁用
|
# TODO: 恢复PFC,现在暂时禁用
|
||||||
class HeartflowManager:
|
class HeartflowManager:
|
||||||
"""主心流协调器,负责初始化并协调聊天,控制聊天属性"""
|
"""主心流协调器,负责初始化并协调聊天,控制聊天属性"""
|
||||||
@@ -17,7 +18,7 @@ class HeartflowManager:
|
|||||||
# self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
# self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
||||||
self.heartflow_chat_list: Dict[str, HeartFChatting] = {}
|
self.heartflow_chat_list: Dict[str, HeartFChatting] = {}
|
||||||
|
|
||||||
async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]:
|
async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]:
|
||||||
"""获取或创建一个新的HeartFChatting实例"""
|
"""获取或创建一个新的HeartFChatting实例"""
|
||||||
try:
|
try:
|
||||||
if chat := self.heartflow_chat_list.get(session_id):
|
if chat := self.heartflow_chat_list.get(session_id):
|
||||||
|
|||||||
@@ -9,11 +9,10 @@ from src.common.logger import get_logger
|
|||||||
from src.common.utils.utils_message import MessageUtils
|
from src.common.utils.utils_message import MessageUtils
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
|
|
||||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
from src.core.announcement_manager import global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
from src.core.component_registry import component_registry
|
from src.core.component_registry import component_registry
|
||||||
from src.core.types import EventType
|
|
||||||
|
|
||||||
from .message import SessionMessage
|
from .message import SessionMessage
|
||||||
from .chat_manager import chat_manager
|
from .chat_manager import chat_manager
|
||||||
@@ -391,6 +390,7 @@ class ChatBot:
|
|||||||
else:
|
else:
|
||||||
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
|
|
||||||
await preprocess()
|
await preprocess()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import asyncio
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Messages
|
from src.common.database.database_model import Messages
|
||||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo
|
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||||
from src.common.data_models.message_component_data_model import (
|
from src.common.data_models.message_component_data_model import (
|
||||||
TextComponent,
|
TextComponent,
|
||||||
ImageComponent,
|
ImageComponent,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ _webui_chat_broadcaster = None
|
|||||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||||
|
|
||||||
|
|
||||||
# TODO: 重构完成后完成webui相关
|
# TODO: 重构完成后完成webui相关
|
||||||
def get_webui_chat_broadcaster():
|
def get_webui_chat_broadcaster():
|
||||||
"""获取 WebUI 聊天室广播器"""
|
"""获取 WebUI 聊天室广播器"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from src.chat.message_receive.chat_manager import BotChatSession
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.core.component_registry import component_registry, ActionExecutor
|
from src.core.component_registry import component_registry, ActionExecutor
|
||||||
from src.core.types import ActionInfo, ComponentType
|
from src.core.types import ActionInfo
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, TYPE_CHECKING, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|||||||
@@ -119,9 +119,7 @@ class PrivateReplyer:
|
|||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
||||||
continue_flag, modified_message = await event_bus.emit(
|
continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg)
|
||||||
EventType.POST_LLM, _event_msg
|
|
||||||
)
|
|
||||||
if not continue_flag:
|
if not continue_flag:
|
||||||
raise UserWarning("插件于请求前中断了内容生成")
|
raise UserWarning("插件于请求前中断了内容生成")
|
||||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||||
@@ -140,10 +138,10 @@ class PrivateReplyer:
|
|||||||
llm_response.reasoning = reasoning_content
|
llm_response.reasoning = reasoning_content
|
||||||
llm_response.model = model_name
|
llm_response.model = model_name
|
||||||
llm_response.tool_calls = tool_call
|
llm_response.tool_calls = tool_call
|
||||||
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
|
_event_msg = build_event_message(
|
||||||
continue_flag, modified_message = await event_bus.emit(
|
EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id
|
||||||
EventType.AFTER_LLM, _event_msg
|
|
||||||
)
|
)
|
||||||
|
continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg)
|
||||||
if not from_plugin and not continue_flag:
|
if not from_plugin and not continue_flag:
|
||||||
raise UserWarning("插件于请求后取消了内容生成")
|
raise UserWarning("插件于请求后取消了内容生成")
|
||||||
if modified_message:
|
if modified_message:
|
||||||
|
|||||||
@@ -817,6 +817,8 @@ def assign_message_ids(messages: List[SessionMessage]) -> List[Tuple[str, Sessio
|
|||||||
result.append((message_id, message))
|
result.append((message_id, message))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# break
|
# break
|
||||||
# result.append((message_id, message))
|
# result.append((message_id, message))
|
||||||
|
|
||||||
|
|||||||
@@ -25,4 +25,4 @@
|
|||||||
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||||
# loop_start_time: Optional[float] = None
|
# loop_start_time: Optional[float] = None
|
||||||
# action_reasoning: Optional[str] = None
|
# action_reasoning: Optional[str] = None
|
||||||
# TODO: 重构
|
# TODO: 重构
|
||||||
|
|||||||
@@ -20,4 +20,4 @@
|
|||||||
# timing: Optional[Dict[str, Any]] = None
|
# timing: Optional[Dict[str, Any]] = None
|
||||||
# processed_output: Optional[List[str]] = None
|
# processed_output: Optional[List[str]] = None
|
||||||
# timing_logs: Optional[List[str]] = None
|
# timing_logs: Optional[List[str]] = None
|
||||||
# TODO: 重构
|
# TODO: 重构
|
||||||
|
|||||||
@@ -13,8 +13,10 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("base_message_component_model")
|
logger = get_logger("base_message_component_model")
|
||||||
|
|
||||||
|
|
||||||
class UnknownUser(str): ...
|
class UnknownUser(str): ...
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageComponentModel(ABC):
|
class BaseMessageComponentModel(ABC):
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class UniversalMessageSender:
|
|||||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as legacy_e:
|
except Exception:
|
||||||
# # Legacy API 抛出异常,尝试 Fallback
|
# # Legacy API 抛出异常,尝试 Fallback
|
||||||
# return await self._send_with_fallback(
|
# return await self._send_with_fallback(
|
||||||
# message, message_preview, platform, show_log, legacy_exception=legacy_e
|
# message, message_preview, platform, show_log, legacy_exception=legacy_e
|
||||||
|
|||||||
@@ -76,4 +76,4 @@ def assert_port_available(
|
|||||||
port=port,
|
port=port,
|
||||||
config_hint=config_hint,
|
config_hint=config_hint,
|
||||||
)
|
)
|
||||||
raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port))
|
raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port))
|
||||||
|
|||||||
@@ -133,8 +133,8 @@ class ConfigBase(BaseModel, AttrDocBase):
|
|||||||
|
|
||||||
# UI 分组元数据:子类可覆盖以声明所属 Tab 分组
|
# UI 分组元数据:子类可覆盖以声明所属 Tab 分组
|
||||||
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
|
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
|
||||||
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
|
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
|
||||||
__ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名)
|
__ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
||||||
|
|||||||
@@ -182,7 +182,9 @@ class FileWatcher:
|
|||||||
self._stats.callbacks_skipped_cooldown += 1
|
self._stats.callbacks_skipped_cooldown += 1
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s)
|
await asyncio.wait_for(
|
||||||
|
self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s
|
||||||
|
)
|
||||||
state.consecutive_failures = 0
|
state.consecutive_failures = 0
|
||||||
self._stats.callbacks_succeeded += 1
|
self._stats.callbacks_succeeded += 1
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -10,11 +10,10 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union
|
from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.core.types import (
|
from src.core.types import (
|
||||||
ActionActivationType,
|
|
||||||
ActionInfo,
|
ActionInfo,
|
||||||
CommandInfo,
|
CommandInfo,
|
||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
@@ -130,9 +129,7 @@ class ComponentRegistry:
|
|||||||
logger.debug(f"注册 Command: {name}")
|
logger.debug(f"注册 Command: {name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def find_command_by_text(
|
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
||||||
self, text: str
|
|
||||||
) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
|
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -117,9 +117,7 @@ class EventBus:
|
|||||||
self._fire_and_forget(entry, event_type, current_message)
|
self._fire_and_forget(entry, event_type, current_message)
|
||||||
|
|
||||||
# 桥接到 IPC 插件运行时
|
# 桥接到 IPC 插件运行时
|
||||||
continue_flag, current_message = await self._bridge_to_ipc_runtime(
|
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
|
||||||
event_type, continue_flag, current_message
|
|
||||||
)
|
|
||||||
|
|
||||||
return continue_flag, current_message
|
return continue_flag, current_message
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ MaiSaka - 内置工具定义
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType
|
from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType
|
||||||
|
|
||||||
|
|
||||||
# 内置工具定义
|
# 内置工具定义
|
||||||
def create_builtin_tools() -> List[ToolOption]:
|
def create_builtin_tools() -> List[ToolOption]:
|
||||||
"""创建内置工具列表"""
|
"""创建内置工具列表"""
|
||||||
@@ -17,57 +18,66 @@ def create_builtin_tools() -> List[ToolOption]:
|
|||||||
# say 工具
|
# say 工具
|
||||||
say_builder = ToolOptionBuilder()
|
say_builder = ToolOptionBuilder()
|
||||||
say_builder.set_name("say")
|
say_builder.set_name("say")
|
||||||
say_builder.set_description("对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考,用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。")
|
say_builder.set_description(
|
||||||
|
"对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考,用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。"
|
||||||
|
)
|
||||||
say_builder.add_param(
|
say_builder.add_param(
|
||||||
name="reason",
|
name="reason",
|
||||||
param_type=ToolParamType.STRING,
|
param_type=ToolParamType.STRING,
|
||||||
description="描述你想要回复的方式、想法和内容。例如:'同意对方的看法,并分享自己的经历' 或 '礼貌地拒绝,表示现在不方便聊天'",
|
description="描述你想要回复的方式、想法和内容。例如:'同意对方的看法,并分享自己的经历' 或 '礼貌地拒绝,表示现在不方便聊天'",
|
||||||
required=True,
|
required=True,
|
||||||
enum_values=None
|
enum_values=None,
|
||||||
)
|
)
|
||||||
tools.append(say_builder.build())
|
tools.append(say_builder.build())
|
||||||
|
|
||||||
# wait 工具
|
# wait 工具
|
||||||
wait_builder = ToolOptionBuilder()
|
wait_builder = ToolOptionBuilder()
|
||||||
wait_builder.set_name("wait")
|
wait_builder.set_name("wait")
|
||||||
wait_builder.set_description("暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。")
|
wait_builder.set_description(
|
||||||
|
"暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。"
|
||||||
|
)
|
||||||
wait_builder.add_param(
|
wait_builder.add_param(
|
||||||
name="seconds",
|
name="seconds",
|
||||||
param_type=ToolParamType.INTEGER,
|
param_type=ToolParamType.INTEGER,
|
||||||
description="等待的秒数。建议 3-10 秒。超过这个时间用户没有回复会显示超时提示。",
|
description="等待的秒数。建议 3-10 秒。超过这个时间用户没有回复会显示超时提示。",
|
||||||
required=True,
|
required=True,
|
||||||
enum_values=None
|
enum_values=None,
|
||||||
)
|
)
|
||||||
tools.append(wait_builder.build())
|
tools.append(wait_builder.build())
|
||||||
|
|
||||||
# stop 工具
|
# stop 工具
|
||||||
stop_builder = ToolOptionBuilder()
|
stop_builder = ToolOptionBuilder()
|
||||||
stop_builder.set_name("stop")
|
stop_builder.set_name("stop")
|
||||||
stop_builder.set_description("结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。")
|
stop_builder.set_description(
|
||||||
|
"结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。"
|
||||||
|
)
|
||||||
tools.append(stop_builder.build())
|
tools.append(stop_builder.build())
|
||||||
|
|
||||||
# store_context 工具
|
# store_context 工具
|
||||||
store_context_builder = ToolOptionBuilder()
|
store_context_builder = ToolOptionBuilder()
|
||||||
store_context_builder.set_name("store_context")
|
store_context_builder.set_name("store_context")
|
||||||
store_context_builder.set_description("将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。")
|
store_context_builder.set_description(
|
||||||
|
"将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。"
|
||||||
|
)
|
||||||
store_context_builder.add_param(
|
store_context_builder.add_param(
|
||||||
name="count",
|
name="count",
|
||||||
param_type=ToolParamType.INTEGER,
|
param_type=ToolParamType.INTEGER,
|
||||||
description="要保存的消息条数(从最早的对话开始计数)。建议 5-20 条。",
|
description="要保存的消息条数(从最早的对话开始计数)。建议 5-20 条。",
|
||||||
required=True,
|
required=True,
|
||||||
enum_values=None
|
enum_values=None,
|
||||||
)
|
)
|
||||||
store_context_builder.add_param(
|
store_context_builder.add_param(
|
||||||
name="reason",
|
name="reason",
|
||||||
param_type=ToolParamType.STRING,
|
param_type=ToolParamType.STRING,
|
||||||
description="保存原因,用于后续检索。例如:'讨论了用户的工作情况' 或 '用户分享了对电影的看法'",
|
description="保存原因,用于后续检索。例如:'讨论了用户的工作情况' 或 '用户分享了对电影的看法'",
|
||||||
required=True,
|
required=True,
|
||||||
enum_values=None
|
enum_values=None,
|
||||||
)
|
)
|
||||||
tools.append(store_context_builder.build())
|
tools.append(store_context_builder.build())
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
# 为了兼容性,创建一个函数来将工具转换为 dict 格式(用于调试显示)
|
# 为了兼容性,创建一个函数来将工具转换为 dict 格式(用于调试显示)
|
||||||
def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
|
def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
|
||||||
"""将内置工具转换为 dict 格式(用于调试)"""
|
"""将内置工具转换为 dict 格式(用于调试)"""
|
||||||
@@ -77,31 +87,23 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
|
|||||||
"description": "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。",
|
"description": "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"reason": {"type": "string", "description": "回复的想法和内容"}},
|
||||||
"reason": {"type": "string", "description": "回复的想法和内容"}
|
"required": ["reason"],
|
||||||
},
|
},
|
||||||
"required": ["reason"]
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "wait",
|
"name": "wait",
|
||||||
"description": "暂时结束发言,等待用户回应",
|
"description": "暂时结束发言,等待用户回应",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"seconds": {"type": "number", "description": "等待秒数"}},
|
||||||
"seconds": {"type": "number", "description": "等待秒数"}
|
"required": ["seconds"],
|
||||||
},
|
},
|
||||||
"required": ["seconds"]
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stop",
|
"name": "stop",
|
||||||
"description": "结束对话循环",
|
"description": "结束对话循环",
|
||||||
"parameters": {
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "store_context",
|
"name": "store_context",
|
||||||
@@ -110,17 +112,19 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]:
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"count": {"type": "number", "description": "保存的消息条数"},
|
"count": {"type": "number", "description": "保存的消息条数"},
|
||||||
"reason": {"type": "string", "description": "保存原因"}
|
"reason": {"type": "string", "description": "保存原因"},
|
||||||
},
|
},
|
||||||
"required": ["count", "reason"]
|
"required": ["count", "reason"],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# 导出工具创建函数和列表
|
# 导出工具创建函数和列表
|
||||||
def get_builtin_tools() -> List[ToolOption]:
|
def get_builtin_tools() -> List[ToolOption]:
|
||||||
"""获取内置工具列表"""
|
"""获取内置工具列表"""
|
||||||
return create_builtin_tools()
|
return create_builtin_tools()
|
||||||
|
|
||||||
|
|
||||||
# 为了向后兼容,也导出 dict 格式
|
# 为了向后兼容,也导出 dict 格式
|
||||||
BUILTIN_TOOLS_DICTS = builtin_tools_as_dicts()
|
BUILTIN_TOOLS_DICTS = builtin_tools_as_dicts()
|
||||||
|
|||||||
@@ -13,10 +13,17 @@ from rich.markdown import Markdown
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich import box
|
from rich import box
|
||||||
|
|
||||||
from config import console, ENABLE_EMOTION_MODULE, ENABLE_COGNITION_MODULE, ENABLE_TIMING_MODULE, ENABLE_KNOWLEDGE_MODULE, ENABLE_MCP
|
from config import (
|
||||||
|
console,
|
||||||
|
ENABLE_EMOTION_MODULE,
|
||||||
|
ENABLE_COGNITION_MODULE,
|
||||||
|
ENABLE_TIMING_MODULE,
|
||||||
|
ENABLE_KNOWLEDGE_MODULE,
|
||||||
|
ENABLE_MCP,
|
||||||
|
)
|
||||||
from input_reader import InputReader
|
from input_reader import InputReader
|
||||||
from timing import build_timing_info
|
from timing import build_timing_info
|
||||||
from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge, build_knowledge_summary
|
from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge
|
||||||
from knowledge_store import get_knowledge_store
|
from knowledge_store import get_knowledge_store
|
||||||
from llm_service import MaiSakaLLMService, build_message, remove_last_perception
|
from llm_service import MaiSakaLLMService, build_message, remove_last_perception
|
||||||
from mcp_client import MCPManager
|
from mcp_client import MCPManager
|
||||||
@@ -64,11 +71,7 @@ class BufferCLI:
|
|||||||
def _init_llm(self):
|
def _init_llm(self):
|
||||||
"""初始化 LLM 服务 - 使用主项目配置系统"""
|
"""初始化 LLM 服务 - 使用主项目配置系统"""
|
||||||
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
||||||
enable_thinking: Optional[bool] = (
|
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
|
||||||
True if thinking_env == "true"
|
|
||||||
else False if thinking_env == "false"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# MaiSakaLLMService 现在使用主项目的配置系统
|
# MaiSakaLLMService 现在使用主项目的配置系统
|
||||||
# 参数仅为兼容性保留,实际从 config_manager 读取配置
|
# 参数仅为兼容性保留,实际从 config_manager 读取配置
|
||||||
@@ -210,7 +213,7 @@ class BufferCLI:
|
|||||||
to_compress,
|
to_compress,
|
||||||
store_result_callback=lambda cat_id, cat_name, content: console.print(
|
store_result_callback=lambda cat_id, cat_name, content: console.print(
|
||||||
f"[muted] [OK] 存储了解信息: {cat_name}[/muted]"
|
f"[muted] [OK] 存储了解信息: {cat_name}[/muted]"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
if knowledge_count > 0:
|
if knowledge_count > 0:
|
||||||
console.print(f"[success][OK] 了解模块: 存储{knowledge_count}条特征信息[/success]")
|
console.print(f"[success][OK] 了解模块: 存储{knowledge_count}条特征信息[/success]")
|
||||||
@@ -272,10 +275,12 @@ class BufferCLI:
|
|||||||
self._chat_history = self.llm_service.build_chat_context(user_text)
|
self._chat_history = self.llm_service.build_chat_context(user_text)
|
||||||
else:
|
else:
|
||||||
# 后续对话:追加用户消息到已有上下文
|
# 后续对话:追加用户消息到已有上下文
|
||||||
self._chat_history.append({
|
self._chat_history.append(
|
||||||
"role": "user",
|
{
|
||||||
"content": user_text,
|
"role": "user",
|
||||||
})
|
"content": user_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self._run_llm_loop(self._chat_history)
|
await self._run_llm_loop(self._chat_history)
|
||||||
|
|
||||||
@@ -436,16 +441,17 @@ class BufferCLI:
|
|||||||
|
|
||||||
if perception_parts:
|
if perception_parts:
|
||||||
# 添加感知消息(AI 的感知能力结果)
|
# 添加感知消息(AI 的感知能力结果)
|
||||||
chat_history.append(build_message(
|
chat_history.append(
|
||||||
role="assistant",
|
build_message(
|
||||||
content="\n\n".join(perception_parts),
|
role="assistant",
|
||||||
msg_type="perception",
|
content="\n\n".join(perception_parts),
|
||||||
))
|
msg_type="perception",
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 上次没有调用工具,跳过模块分析
|
# 上次没有调用工具,跳过模块分析
|
||||||
console.print("[muted]ℹ️ 上次未调用工具,跳过模块分析[/muted]")
|
console.print("[muted]ℹ️ 上次未调用工具,跳过模块分析[/muted]")
|
||||||
|
|
||||||
|
|
||||||
# ── 调用 LLM ──
|
# ── 调用 LLM ──
|
||||||
with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"):
|
with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"):
|
||||||
try:
|
try:
|
||||||
@@ -540,7 +546,8 @@ class BufferCLI:
|
|||||||
async def _init_mcp(self):
|
async def _init_mcp(self):
|
||||||
"""初始化 MCP 服务器连接,发现并注册外部工具。"""
|
"""初始化 MCP 服务器连接,发现并注册外部工具。"""
|
||||||
config_path = os.path.join(
|
config_path = os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)), "mcp_config.json",
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"mcp_config.json",
|
||||||
)
|
)
|
||||||
self._mcp_manager = await MCPManager.from_config(config_path)
|
self._mcp_manager = await MCPManager.from_config(config_path)
|
||||||
|
|
||||||
|
|||||||
@@ -15,16 +15,20 @@ if str(_root) not in sys.path:
|
|||||||
|
|
||||||
# ──────────────────── 从主配置读取 ────────────────────
|
# ──────────────────── 从主配置读取 ────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _get_maisaka_config():
|
def _get_maisaka_config():
|
||||||
"""获取 MaiSaka 配置"""
|
"""获取 MaiSaka 配置"""
|
||||||
try:
|
try:
|
||||||
from src.config.config import config_manager
|
from src.config.config import config_manager
|
||||||
|
|
||||||
return config_manager.config.maisaka
|
return config_manager.config.maisaka
|
||||||
except Exception:
|
except Exception:
|
||||||
# 如果配置加载失败,返回默认值
|
# 如果配置加载失败,返回默认值
|
||||||
from src.config.official_configs import MaiSakaConfig
|
from src.config.official_configs import MaiSakaConfig
|
||||||
|
|
||||||
return MaiSakaConfig()
|
return MaiSakaConfig()
|
||||||
|
|
||||||
|
|
||||||
_maisaka_config = _get_maisaka_config()
|
_maisaka_config = _get_maisaka_config()
|
||||||
|
|
||||||
# ──────────────────── 模块开关配置 ────────────────────
|
# ──────────────────── 模块开关配置 ────────────────────
|
||||||
|
|||||||
@@ -48,9 +48,7 @@ class DebugViewer:
|
|||||||
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
conn.connect(("127.0.0.1", self._port))
|
conn.connect(("127.0.0.1", self._port))
|
||||||
self._conn = conn
|
self._conn = conn
|
||||||
console.print(
|
console.print(f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]")
|
||||||
f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ from rich import box
|
|||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
ROLE_STYLES = {
|
ROLE_STYLES = {
|
||||||
"system": ("📋", "bold blue"),
|
"system": ("📋", "bold blue"),
|
||||||
"user": ("👤", "bold green"),
|
"user": ("👤", "bold green"),
|
||||||
"assistant": ("🤖", "bold magenta"),
|
"assistant": ("🤖", "bold magenta"),
|
||||||
"tool": ("🔧", "bold yellow"),
|
"tool": ("🔧", "bold yellow"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -54,8 +54,10 @@ def format_message(idx: int, msg: dict) -> str:
|
|||||||
|
|
||||||
# 正文
|
# 正文
|
||||||
if content:
|
if content:
|
||||||
display = content if len(content) <= 3000 else (
|
display = (
|
||||||
content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]"
|
content
|
||||||
|
if len(content) <= 3000
|
||||||
|
else (content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]")
|
||||||
)
|
)
|
||||||
parts.append(display)
|
parts.append(display)
|
||||||
|
|
||||||
@@ -88,8 +90,7 @@ def main():
|
|||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n"
|
f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n[dim]监听端口: {port} 等待主进程连接...[/dim]",
|
||||||
f"[dim]监听端口: {port} 等待主进程连接...[/dim]",
|
|
||||||
box=box.DOUBLE_EDGE,
|
box=box.DOUBLE_EDGE,
|
||||||
border_style="cyan",
|
border_style="cyan",
|
||||||
)
|
)
|
||||||
@@ -131,8 +132,7 @@ def main():
|
|||||||
# ── 标题栏 ──
|
# ── 标题栏 ──
|
||||||
console.print(f"\n{'═' * 90}")
|
console.print(f"\n{'═' * 90}")
|
||||||
console.print(
|
console.print(
|
||||||
f"[bold yellow]#{call_count} {label}[/bold yellow] "
|
f"[bold yellow]#{call_count} {label}[/bold yellow] [dim]({len(messages)} messages)[/dim]"
|
||||||
f"[dim]({len(messages)} messages)[/dim]"
|
|
||||||
)
|
)
|
||||||
console.print(f"{'═' * 90}")
|
console.print(f"{'═' * 90}")
|
||||||
|
|
||||||
@@ -144,12 +144,8 @@ def main():
|
|||||||
|
|
||||||
# ── tools 信息 ──
|
# ── tools 信息 ──
|
||||||
if tools:
|
if tools:
|
||||||
tool_names = [
|
tool_names = [t.get("function", {}).get("name", "?") for t in tools]
|
||||||
t.get("function", {}).get("name", "?") for t in tools
|
console.print(f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]")
|
||||||
]
|
|
||||||
console.print(
|
|
||||||
f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"\n[red]数据处理错误: {e}[/red]")
|
console.print(f"\n[red]数据处理错误: {e}[/red]")
|
||||||
console.print(f"[dim]Payload: {payload}[/dim]")
|
console.print(f"[dim]Payload: {payload}[/dim]")
|
||||||
@@ -161,8 +157,12 @@ def main():
|
|||||||
console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]")
|
console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]")
|
||||||
resp_content = response.get("content", "")
|
resp_content = response.get("content", "")
|
||||||
if resp_content:
|
if resp_content:
|
||||||
display = resp_content if len(str(resp_content)) <= 3000 else (
|
display = (
|
||||||
str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]"
|
resp_content
|
||||||
|
if len(str(resp_content)) <= 3000
|
||||||
|
else (
|
||||||
|
str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
console.print(Panel(display, border_style="cyan", padding=(0, 1)))
|
console.print(Panel(display, border_style="cyan", padding=(0, 1)))
|
||||||
resp_tool_calls = response.get("tool_calls", [])
|
resp_tool_calls = response.get("tool_calls", [])
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ MaiSaka - 了解模块
|
|||||||
负责从对话中提取和存储用户个人特征信息。
|
负责从对话中提取和存储用户个人特征信息。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES
|
from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES
|
||||||
|
|
||||||
|
|
||||||
@@ -100,9 +100,7 @@ async def store_knowledge_from_context(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 第一步:分析涉及哪些分类
|
# 第一步:分析涉及哪些分类
|
||||||
category_ids = await llm_service.analyze_knowledge_categories(
|
category_ids = await llm_service.analyze_knowledge_categories(context_messages, categories_summary)
|
||||||
context_messages, categories_summary
|
|
||||||
)
|
|
||||||
|
|
||||||
if not category_ids:
|
if not category_ids:
|
||||||
return 0
|
return 0
|
||||||
@@ -119,25 +117,19 @@ async def store_knowledge_from_context(
|
|||||||
if extracted_content:
|
if extracted_content:
|
||||||
# 存储到了解列表
|
# 存储到了解列表
|
||||||
success = store.add_knowledge(
|
success = store.add_knowledge(
|
||||||
category_id=category_id,
|
category_id=category_id, content=extracted_content, metadata={"source": "context_compression"}
|
||||||
content=extracted_content,
|
|
||||||
metadata={"source": "context_compression"}
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
stored_count += 1
|
stored_count += 1
|
||||||
if store_result_callback:
|
if store_result_callback:
|
||||||
store_result_callback(
|
store_result_callback(category_id, store.get_category_name(category_id), extracted_content)
|
||||||
category_id,
|
except Exception:
|
||||||
store.get_category_name(category_id),
|
|
||||||
extracted_content
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# 单个分类失败不影响其他分类
|
# 单个分类失败不影响其他分类
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return stored_count
|
return stored_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@@ -165,9 +157,7 @@ async def retrieve_relevant_knowledge(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 分析需要哪些分类
|
# 分析需要哪些分类
|
||||||
category_ids = await llm_service.analyze_knowledge_need(
|
category_ids = await llm_service.analyze_knowledge_need(chat_history, categories_summary)
|
||||||
chat_history, categories_summary
|
|
||||||
)
|
|
||||||
|
|
||||||
if not category_ids:
|
if not category_ids:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -45,9 +45,7 @@ class KnowledgeStore:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化了解存储"""
|
"""初始化了解存储"""
|
||||||
self._knowledge: Dict[str, List[Dict[str, Any]]] = {
|
self._knowledge: Dict[str, List[Dict[str, Any]]] = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
|
||||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
|
||||||
}
|
|
||||||
self._ensure_data_dir()
|
self._ensure_data_dir()
|
||||||
self._load()
|
self._load()
|
||||||
|
|
||||||
@@ -58,9 +56,7 @@ class KnowledgeStore:
|
|||||||
def _load(self):
|
def _load(self):
|
||||||
"""从文件加载了解数据"""
|
"""从文件加载了解数据"""
|
||||||
if not KNOWLEDGE_FILE.exists():
|
if not KNOWLEDGE_FILE.exists():
|
||||||
self._knowledge = {
|
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
|
||||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -73,9 +69,7 @@ class KnowledgeStore:
|
|||||||
self._knowledge = loaded
|
self._knowledge = loaded
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[warning]加载了解数据失败: {e}[/warning]")
|
print(f"[warning]加载了解数据失败: {e}[/warning]")
|
||||||
self._knowledge = {
|
self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES}
|
||||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
|
||||||
}
|
|
||||||
|
|
||||||
def _save(self):
|
def _save(self):
|
||||||
"""保存了解数据到文件"""
|
"""保存了解数据到文件"""
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ MaiSaka LLM 服务 - 使用主项目 LLM 系统
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Literal
|
from typing import List, Optional, Literal
|
||||||
|
|
||||||
@@ -34,6 +33,7 @@ MSG_TYPE_FIELD = "_type"
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolCall:
|
class ToolCall:
|
||||||
"""工具调用信息"""
|
"""工具调用信息"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: dict
|
arguments: dict
|
||||||
@@ -42,6 +42,7 @@ class ToolCall:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatResponse:
|
class ChatResponse:
|
||||||
"""LLM 对话循环单步响应"""
|
"""LLM 对话循环单步响应"""
|
||||||
|
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
tool_calls: List[ToolCall]
|
tool_calls: List[ToolCall]
|
||||||
raw_message: dict # 可直接追加到对话历史的消息字典
|
raw_message: dict # 可直接追加到对话历史的消息字典
|
||||||
@@ -49,6 +50,7 @@ class ChatResponse:
|
|||||||
|
|
||||||
# ──────────────────── 工具函数 ────────────────────
|
# ──────────────────── 工具函数 ────────────────────
|
||||||
|
|
||||||
|
|
||||||
def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict:
|
def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict:
|
||||||
"""构建消息字典,包含消息类型标记。"""
|
"""构建消息字典,包含消息类型标记。"""
|
||||||
msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs}
|
msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs}
|
||||||
@@ -93,23 +95,18 @@ class MaiSakaLLMService:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# 如果配置加载失败,使用默认配置
|
# 如果配置加载失败,使用默认配置
|
||||||
from src.config.model_configs import ModelTaskConfig
|
from src.config.model_configs import ModelTaskConfig
|
||||||
|
|
||||||
self._model_configs = ModelTaskConfig()
|
self._model_configs = ModelTaskConfig()
|
||||||
logger.warning("无法加载主项目模型配置,使用默认配置")
|
logger.warning("无法加载主项目模型配置,使用默认配置")
|
||||||
|
|
||||||
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer)
|
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer)
|
||||||
self._llm_tool_use = LLMRequest(
|
self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use")
|
||||||
model_set=self._model_configs.tool_use,
|
|
||||||
request_type="maisaka_tool_use"
|
|
||||||
)
|
|
||||||
# 主对话也使用 tool_use 模型(因为需要工具调用支持)
|
# 主对话也使用 tool_use 模型(因为需要工具调用支持)
|
||||||
self._llm_chat = self._llm_tool_use
|
self._llm_chat = self._llm_tool_use
|
||||||
# 分析模块也使用 tool_use 模型
|
# 分析模块也使用 tool_use 模型
|
||||||
self._llm_utils = self._llm_tool_use
|
self._llm_utils = self._llm_tool_use
|
||||||
# 回复生成使用 replyer 模型
|
# 回复生成使用 replyer 模型
|
||||||
self._llm_replyer = LLMRequest(
|
self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer")
|
||||||
model_set=self._model_configs.replyer,
|
|
||||||
request_type="maisaka_replyer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 尝试修复数据库 schema(忽略错误)
|
# 尝试修复数据库 schema(忽略错误)
|
||||||
self._try_fix_database_schema()
|
self._try_fix_database_schema()
|
||||||
@@ -133,6 +130,7 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
chat_prompt.add_context("file_tools_section", tools_section if tools_section else "")
|
chat_prompt.add_context("file_tools_section", tools_section if tools_section else "")
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
@@ -147,7 +145,9 @@ class MaiSakaLLMService:
|
|||||||
self._chat_system_prompt = chat_system_prompt
|
self._chat_system_prompt = chat_system_prompt
|
||||||
|
|
||||||
# 获取模型名称用于显示
|
# 获取模型名称用于显示
|
||||||
self._model_name = self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置"
|
self._model_name = (
|
||||||
|
self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置"
|
||||||
|
)
|
||||||
|
|
||||||
# 加载子模块提示词
|
# 加载子模块提示词
|
||||||
self._emotion_prompt: Optional[str] = None
|
self._emotion_prompt: Optional[str] = None
|
||||||
@@ -157,21 +157,22 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
self._emotion_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
self._emotion_prompt = loop.run_until_complete(
|
||||||
prompt_manager.get_prompt("maidairy_emotion")
|
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_emotion"))
|
||||||
))
|
)
|
||||||
self._cognition_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
self._cognition_prompt = loop.run_until_complete(
|
||||||
prompt_manager.get_prompt("maidairy_cognition")
|
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_cognition"))
|
||||||
))
|
)
|
||||||
self._timing_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
self._timing_prompt = loop.run_until_complete(
|
||||||
prompt_manager.get_prompt("maidairy_timing")
|
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_timing"))
|
||||||
))
|
)
|
||||||
self._context_summarize_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
self._context_summarize_prompt = loop.run_until_complete(
|
||||||
prompt_manager.get_prompt("maidairy_context_summarize")
|
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_context_summarize"))
|
||||||
))
|
)
|
||||||
logger.info("成功加载 MaiSaka 子模块提示词")
|
logger.info("成功加载 MaiSaka 子模块提示词")
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
@@ -191,9 +192,7 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
if "model_api_provider_name" not in columns:
|
if "model_api_provider_name" not in columns:
|
||||||
# 添加缺失的列
|
# 添加缺失的列
|
||||||
session.execute(text(
|
session.execute(text("ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"))
|
||||||
"ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"
|
|
||||||
))
|
|
||||||
session.commit()
|
session.commit()
|
||||||
logger.info("数据库 schema 已修复:添加 model_api_provider_name 列")
|
logger.info("数据库 schema 已修复:添加 model_api_provider_name 列")
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -205,7 +204,7 @@ class MaiSakaLLMService:
|
|||||||
self._extra_tools = list(tools)
|
self._extra_tools = list(tools)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_option_to_dict(tool: 'ToolOption') -> dict:
|
def _tool_option_to_dict(tool: "ToolOption") -> dict:
|
||||||
"""将 ToolOption 对象转换为主项目期望的 dict 格式
|
"""将 ToolOption 对象转换为主项目期望的 dict 格式
|
||||||
|
|
||||||
主项目的 _build_tool_options() 期望的格式:
|
主项目的 _build_tool_options() 期望的格式:
|
||||||
@@ -218,18 +217,8 @@ class MaiSakaLLMService:
|
|||||||
params = []
|
params = []
|
||||||
if tool.params:
|
if tool.params:
|
||||||
for param in tool.params:
|
for param in tool.params:
|
||||||
params.append((
|
params.append((param.name, param.param_type, param.description, param.required, param.enum_values))
|
||||||
param.name,
|
return {"name": tool.name, "description": tool.description, "parameters": params}
|
||||||
param.param_type,
|
|
||||||
param.description,
|
|
||||||
param.required,
|
|
||||||
param.enum_values
|
|
||||||
))
|
|
||||||
return {
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": params
|
|
||||||
}
|
|
||||||
|
|
||||||
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
|
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
|
||||||
"""执行对话循环的一步 - 使用 tool_use 模型"""
|
"""执行对话循环的一步 - 使用 tool_use 模型"""
|
||||||
@@ -271,11 +260,13 @@ class MaiSakaLLMService:
|
|||||||
for tc in msg["tool_calls"]:
|
for tc in msg["tool_calls"]:
|
||||||
tc_func = tc.get("function", {})
|
tc_func = tc.get("function", {})
|
||||||
# 主项目的 ToolCall: call_id, func_name, args
|
# 主项目的 ToolCall: call_id, func_name, args
|
||||||
tool_calls_list.append(ToolCallOption(
|
tool_calls_list.append(
|
||||||
call_id=tc.get("id", ""),
|
ToolCallOption(
|
||||||
func_name=tc_func.get("name", ""),
|
call_id=tc.get("id", ""),
|
||||||
args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {}
|
func_name=tc_func.get("name", ""),
|
||||||
))
|
args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {},
|
||||||
|
)
|
||||||
|
)
|
||||||
builder.set_tool_calls(tool_calls_list)
|
builder.set_tool_calls(tool_calls_list)
|
||||||
elif role == "tool" and "tool_call_id" in msg:
|
elif role == "tool" and "tool_call_id" in msg:
|
||||||
builder.add_tool_call(msg["tool_call_id"])
|
builder.add_tool_call(msg["tool_call_id"])
|
||||||
@@ -290,15 +281,17 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
# 调用 LLM(使用带消息的接口)
|
# 调用 LLM(使用带消息的接口)
|
||||||
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict)
|
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict)
|
||||||
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (self._extra_tools if self._extra_tools else [])
|
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (
|
||||||
|
self._extra_tools if self._extra_tools else []
|
||||||
|
)
|
||||||
|
|
||||||
# 打印消息列表
|
# 打印消息列表
|
||||||
built_messages = message_factory(None)
|
built_messages = message_factory(None)
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - chat_loop_step:")
|
print("MaiSaka LLM Request - chat_loop_step:")
|
||||||
for msg in built_messages:
|
for msg in built_messages:
|
||||||
print(f" {msg}")
|
print(f" {msg}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
|
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
|
||||||
message_factory=message_factory,
|
message_factory=message_factory,
|
||||||
@@ -312,15 +305,17 @@ class MaiSakaLLMService:
|
|||||||
if tool_calls:
|
if tool_calls:
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
# 主项目的 ToolCall 有 call_id, func_name, args
|
# 主项目的 ToolCall 有 call_id, func_name, args
|
||||||
call_id = tc.call_id if hasattr(tc, 'call_id') else ""
|
call_id = tc.call_id if hasattr(tc, "call_id") else ""
|
||||||
func_name = tc.func_name if hasattr(tc, 'func_name') else ""
|
func_name = tc.func_name if hasattr(tc, "func_name") else ""
|
||||||
args = tc.args if hasattr(tc, 'args') else {}
|
args = tc.args if hasattr(tc, "args") else {}
|
||||||
|
|
||||||
converted_tool_calls.append(ToolCall(
|
converted_tool_calls.append(
|
||||||
id=call_id,
|
ToolCall(
|
||||||
name=func_name,
|
id=call_id,
|
||||||
arguments=args,
|
name=func_name,
|
||||||
))
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 构建原始消息格式(MaiSaka 风格)
|
# 构建原始消息格式(MaiSaka 风格)
|
||||||
raw_message = {"role": "assistant", "content": response}
|
raw_message = {"role": "assistant", "content": response}
|
||||||
@@ -394,10 +389,10 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - analyze_emotion:")
|
print("MaiSaka LLM Request - analyze_emotion:")
|
||||||
print(f" {prompt}")
|
print(f" {prompt}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self._llm_utils.generate_response_async(
|
response, _ = await self._llm_utils.generate_response_async(
|
||||||
@@ -428,10 +423,10 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - analyze_cognition:")
|
print("MaiSaka LLM Request - analyze_cognition:")
|
||||||
print(f" {prompt}")
|
print(f" {prompt}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self._llm_utils.generate_response_async(
|
response, _ = await self._llm_utils.generate_response_async(
|
||||||
@@ -463,10 +458,10 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - analyze_timing:")
|
print("MaiSaka LLM Request - analyze_timing:")
|
||||||
print(f" {prompt}")
|
print(f" {prompt}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self._llm_utils.generate_response_async(
|
response, _ = await self._llm_utils.generate_response_async(
|
||||||
@@ -498,10 +493,10 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - summarize_context:")
|
print("MaiSaka LLM Request - summarize_context:")
|
||||||
print(f" {prompt}")
|
print(f" {prompt}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self._llm_utils.generate_response_async(
|
response, _ = await self._llm_utils.generate_response_async(
|
||||||
@@ -529,8 +524,7 @@ class MaiSakaLLMService:
|
|||||||
|
|
||||||
# 格式化对话历史
|
# 格式化对话历史
|
||||||
filtered_history = [
|
filtered_history = [
|
||||||
msg for msg in chat_history
|
msg for msg in chat_history if msg.get("role") != "system" and msg.get("_type") != "perception"
|
||||||
if msg.get("role") != "system" and msg.get("_type") != "perception"
|
|
||||||
]
|
]
|
||||||
formatted_history = format_chat_history(filtered_history)
|
formatted_history = format_chat_history(filtered_history)
|
||||||
|
|
||||||
@@ -542,18 +536,15 @@ class MaiSakaLLMService:
|
|||||||
system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。"
|
system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。"
|
||||||
|
|
||||||
user_prompt = (
|
user_prompt = (
|
||||||
f"当前时间:{current_time}\n\n"
|
f"当前时间:{current_time}\n\n【聊天记录】\n{formatted_history}\n\n【你的想法】\n{reason}\n\n现在,你说:"
|
||||||
f"【聊天记录】\n{formatted_history}\n\n"
|
|
||||||
f"【你的想法】\n{reason}\n\n"
|
|
||||||
f"现在,你说:"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = f"System: {system_prompt}\n\nUser: {user_prompt}"
|
messages = f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("MaiSaka LLM Request - generate_reply:")
|
print("MaiSaka LLM Request - generate_reply:")
|
||||||
print(f" {messages}")
|
print(f" {messages}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self._llm_replyer.generate_response_async(
|
response, _ = await self._llm_replyer.generate_response_async(
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ def load_mcp_config(config_path: str = "mcp_config.json") -> list[MCPServerConfi
|
|||||||
)
|
)
|
||||||
|
|
||||||
if server.transport_type == "unknown":
|
if server.transport_type == "unknown":
|
||||||
console.print(
|
console.print(f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]")
|
||||||
f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
configs.append(server)
|
configs.append(server)
|
||||||
|
|||||||
@@ -63,9 +63,7 @@ class MCPConnection:
|
|||||||
True 表示连接成功,False 表示失败。
|
True 表示连接成功,False 表示失败。
|
||||||
"""
|
"""
|
||||||
if not MCP_AVAILABLE:
|
if not MCP_AVAILABLE:
|
||||||
console.print(
|
console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]")
|
||||||
"[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -76,15 +74,11 @@ class MCPConnection:
|
|||||||
elif self.config.transport_type == "sse":
|
elif self.config.transport_type == "sse":
|
||||||
read_stream, write_stream = await self._connect_sse()
|
read_stream, write_stream = await self._connect_sse()
|
||||||
else:
|
else:
|
||||||
console.print(
|
console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]")
|
||||||
f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 创建并初始化 MCP 会话
|
# 创建并初始化 MCP 会话
|
||||||
self.session = await self._exit_stack.enter_async_context(
|
self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||||
ClientSession(read_stream, write_stream)
|
|
||||||
)
|
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
|
|
||||||
# 发现工具
|
# 发现工具
|
||||||
@@ -94,9 +88,7 @@ class MCPConnection:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(
|
console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]")
|
||||||
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]"
|
|
||||||
)
|
|
||||||
await self.close()
|
await self.close()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -107,19 +99,13 @@ class MCPConnection:
|
|||||||
args=self.config.args,
|
args=self.config.args,
|
||||||
env=self.config.env,
|
env=self.config.env,
|
||||||
)
|
)
|
||||||
return await self._exit_stack.enter_async_context(
|
return await self._exit_stack.enter_async_context(stdio_client(params))
|
||||||
stdio_client(params)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _connect_sse(self):
|
async def _connect_sse(self):
|
||||||
"""建立 SSE 传输连接。"""
|
"""建立 SSE 传输连接。"""
|
||||||
if not SSE_AVAILABLE:
|
if not SSE_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]")
|
||||||
"SSE 传输需要额外依赖,请运行: pip install mcp[sse]"
|
return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers))
|
||||||
)
|
|
||||||
return await self._exit_stack.enter_async_context(
|
|
||||||
sse_client(url=self.config.url, headers=self.config.headers)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
async def call_tool(self, tool_name: str, arguments: dict) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -10,10 +10,16 @@ from .config import MCPServerConfig, load_mcp_config
|
|||||||
from .connection import MCPConnection, MCP_AVAILABLE
|
from .connection import MCPConnection, MCP_AVAILABLE
|
||||||
|
|
||||||
# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突
|
# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突
|
||||||
BUILTIN_TOOL_NAMES = frozenset({
|
BUILTIN_TOOL_NAMES = frozenset(
|
||||||
"say", "wait", "stop",
|
{
|
||||||
"create_table", "list_tables", "view_table",
|
"say",
|
||||||
})
|
"wait",
|
||||||
|
"stop",
|
||||||
|
"create_table",
|
||||||
|
"list_tables",
|
||||||
|
"view_table",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MCPManager:
|
class MCPManager:
|
||||||
@@ -35,7 +41,8 @@ class MCPManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_config(
|
async def from_config(
|
||||||
cls, config_path: str = "mcp_config.json",
|
cls,
|
||||||
|
config_path: str = "mcp_config.json",
|
||||||
) -> Optional["MCPManager"]:
|
) -> Optional["MCPManager"]:
|
||||||
"""
|
"""
|
||||||
从配置文件创建并初始化 MCPManager。
|
从配置文件创建并初始化 MCPManager。
|
||||||
@@ -51,10 +58,7 @@ class MCPManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if not MCP_AVAILABLE:
|
if not MCP_AVAILABLE:
|
||||||
console.print(
|
console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,请运行: pip install mcp[/warning]")
|
||||||
"[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,"
|
|
||||||
"请运行: pip install mcp[/warning]"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
manager = cls()
|
manager = cls()
|
||||||
@@ -85,8 +89,7 @@ class MCPManager:
|
|||||||
|
|
||||||
if tool_name in BUILTIN_TOOL_NAMES:
|
if tool_name in BUILTIN_TOOL_NAMES:
|
||||||
console.print(
|
console.print(
|
||||||
f"[warning]⚠️ MCP 工具 '{tool_name}' "
|
f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
|
||||||
f"(来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -102,8 +105,7 @@ class MCPManager:
|
|||||||
registered += 1
|
registered += 1
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] "
|
f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]"
|
||||||
f"[muted]({registered} 个工具已注册)[/muted]"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ──────── 工具发现 ────────
|
# ──────── 工具发现 ────────
|
||||||
@@ -134,17 +136,16 @@ class MCPManager:
|
|||||||
# 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受)
|
# 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受)
|
||||||
parameters.pop("$schema", None)
|
parameters.pop("$schema", None)
|
||||||
|
|
||||||
tools.append({
|
tools.append(
|
||||||
"type": "function",
|
{
|
||||||
"function": {
|
"type": "function",
|
||||||
"name": tool.name,
|
"function": {
|
||||||
"description": (
|
"name": tool.name,
|
||||||
tool.description
|
"description": (tool.description or f"MCP tool from {server_name}"),
|
||||||
or f"MCP tool from {server_name}"
|
"parameters": parameters,
|
||||||
),
|
},
|
||||||
"parameters": parameters,
|
}
|
||||||
},
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@@ -184,9 +185,9 @@ class MCPManager:
|
|||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
for server_name, conn in self._connections.items():
|
for server_name, conn in self._connections.items():
|
||||||
tool_names = [
|
tool_names = [
|
||||||
t.name for t in conn.tools
|
t.name
|
||||||
if t.name in self._tool_to_server
|
for t in conn.tools
|
||||||
and self._tool_to_server[t.name] == server_name
|
if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name
|
||||||
]
|
]
|
||||||
if tool_names:
|
if tool_names:
|
||||||
parts.append(f" • {server_name}: {', '.join(tool_names)}")
|
parts.append(f" • {server_name}: {', '.join(tool_names)}")
|
||||||
|
|||||||
@@ -49,8 +49,7 @@ def build_timing_info(
|
|||||||
|
|
||||||
if len(user_input_times) >= 2:
|
if len(user_input_times) >= 2:
|
||||||
intervals = [
|
intervals = [
|
||||||
(user_input_times[i] - user_input_times[i - 1]).total_seconds()
|
(user_input_times[i] - user_input_times[i - 1]).total_seconds() for i in range(1, len(user_input_times))
|
||||||
for i in range(1, len(user_input_times))
|
|
||||||
]
|
]
|
||||||
avg_interval = sum(intervals) / len(intervals)
|
avg_interval = sum(intervals) / len(intervals)
|
||||||
parts.append(f"用户平均回复间隔: {int(avg_interval)}秒")
|
parts.append(f"用户平均回复间隔: {int(avg_interval)}秒")
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ MaiSaka - 工具调用处理器
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json as _json
|
import json as _json
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
@@ -92,27 +91,33 @@ async def handle_say(tc, chat_history: list, ctx: ToolHandlerContext):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# 生成的回复作为 tool 结果写入上下文
|
# 生成的回复作为 tool 结果写入上下文
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": f"已向用户展示(实际输出):{reply}",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": f"已向用户展示(实际输出):{reply}",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": "reason 内容为空,未展示",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": "reason 内容为空,未展示",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_stop(tc, chat_history: list):
|
async def handle_stop(tc, chat_history: list):
|
||||||
"""处理 stop 工具:结束对话循环。"""
|
"""处理 stop 工具:结束对话循环。"""
|
||||||
console.print("[accent]🔧 调用工具: stop()[/accent]")
|
console.print("[accent]🔧 调用工具: stop()[/accent]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": "对话循环已停止,等待用户下次输入。",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": "对话循环已停止,等待用户下次输入。",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
|
async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
|
||||||
@@ -128,11 +133,13 @@ async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str:
|
|||||||
|
|
||||||
tool_result = await _do_wait(seconds, ctx)
|
tool_result = await _do_wait(seconds, ctx)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": tool_result,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": tool_result,
|
||||||
|
}
|
||||||
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
|
|
||||||
@@ -193,28 +200,32 @@ async def handle_mcp_tool(tc, chat_history: list, mcp_manager: "MCPManager"):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": result,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_unknown_tool(tc, chat_history: list):
|
async def handle_unknown_tool(tc, chat_history: list):
|
||||||
"""处理未知工具调用。"""
|
"""处理未知工具调用。"""
|
||||||
console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]")
|
console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": f"未知工具: {tc.name}",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": f"未知工具: {tc.name}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_write_file(tc, chat_history: list):
|
async def handle_write_file(tc, chat_history: list):
|
||||||
"""处理 write_file 工具:在 mai_files 目录下写入文件。"""
|
"""处理 write_file 工具:在 mai_files 目录下写入文件。"""
|
||||||
filename = tc.arguments.get("filename", "")
|
filename = tc.arguments.get("filename", "")
|
||||||
content = tc.arguments.get("content", "")
|
content = tc.arguments.get("content", "")
|
||||||
console.print(f"[accent]🔧 调用工具: write_file(\"{filename}\")[/accent]")
|
console.print(f'[accent]🔧 调用工具: write_file("{filename}")[/accent]')
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
|
MAI_FILES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -242,25 +253,29 @@ async def handle_write_file(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。",
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"写入文件失败: {e}"
|
error_msg = f"写入文件失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_read_file(tc, chat_history: list):
|
async def handle_read_file(tc, chat_history: list):
|
||||||
"""处理 read_file 工具:读取 mai_files 目录下的文件。"""
|
"""处理 read_file 工具:读取 mai_files 目录下的文件。"""
|
||||||
filename = tc.arguments.get("filename", "")
|
filename = tc.arguments.get("filename", "")
|
||||||
console.print(f"[accent]🔧 调用工具: read_file(\"{filename}\")[/accent]")
|
console.print(f'[accent]🔧 调用工具: read_file("{filename}")[/accent]')
|
||||||
|
|
||||||
# 构建完整文件路径
|
# 构建完整文件路径
|
||||||
file_path = MAI_FILES_DIR / filename
|
file_path = MAI_FILES_DIR / filename
|
||||||
@@ -269,21 +284,25 @@ async def handle_read_file(tc, chat_history: list):
|
|||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
error_msg = f"文件「{filename}」不存在。"
|
error_msg = f"文件「{filename}」不存在。"
|
||||||
console.print(f"[warning]{error_msg}[/warning]")
|
console.print(f"[warning]{error_msg}[/warning]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not file_path.is_file():
|
if not file_path.is_file():
|
||||||
error_msg = f"「{filename}」不是一个文件。"
|
error_msg = f"「{filename}」不是一个文件。"
|
||||||
console.print(f"[warning]{error_msg}[/warning]")
|
console.print(f"[warning]{error_msg}[/warning]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 读取文件内容
|
# 读取文件内容
|
||||||
@@ -304,19 +323,23 @@ async def handle_read_file(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": f"文件「{filename}」内容:\n{file_content}",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": f"文件「{filename}」内容:\n{file_content}",
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"读取文件失败: {e}"
|
error_msg = f"读取文件失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_list_files(tc, chat_history: list):
|
async def handle_list_files(tc, chat_history: list):
|
||||||
@@ -334,11 +357,13 @@ async def handle_list_files(tc, chat_history: list):
|
|||||||
# 获取相对路径
|
# 获取相对路径
|
||||||
rel_path = item.relative_to(MAI_FILES_DIR)
|
rel_path = item.relative_to(MAI_FILES_DIR)
|
||||||
stat = item.stat()
|
stat = item.stat()
|
||||||
files_info.append({
|
files_info.append(
|
||||||
"name": str(rel_path),
|
{
|
||||||
"size": stat.st_size,
|
"name": str(rel_path),
|
||||||
"modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
|
"size": stat.st_size,
|
||||||
})
|
"modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not files_info:
|
if not files_info:
|
||||||
result_text = "mai_files 目录为空,没有任何文件。"
|
result_text = "mai_files 目录为空,没有任何文件。"
|
||||||
@@ -360,19 +385,23 @@ async def handle_list_files(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": result_text,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": result_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"获取文件列表失败: {e}"
|
error_msg = f"获取文件列表失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
||||||
@@ -385,16 +414,18 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
|||||||
"""
|
"""
|
||||||
count = tc.arguments.get("count", 0)
|
count = tc.arguments.get("count", 0)
|
||||||
reason = tc.arguments.get("reason", "")
|
reason = tc.arguments.get("reason", "")
|
||||||
console.print(f"[accent]🔧 调用工具: store_context(count={count}, reason=\"{reason}\")[/accent]")
|
console.print(f'[accent]🔧 调用工具: store_context(count={count}, reason="{reason}")[/accent]')
|
||||||
|
|
||||||
if count <= 0:
|
if count <= 0:
|
||||||
error_msg = "count 参数必须大于 0"
|
error_msg = "count 参数必须大于 0"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 计算实际消息数量(排除 role=tool 的工具返回消息)
|
# 计算实际消息数量(排除 role=tool 的工具返回消息)
|
||||||
@@ -423,9 +454,7 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
|||||||
if role == "assistant" and "tool_calls" in msg:
|
if role == "assistant" and "tool_calls" in msg:
|
||||||
# 检查这个消息是否包含当前的 tool_call(store_context 自己)
|
# 检查这个消息是否包含当前的 tool_call(store_context 自己)
|
||||||
# 如果包含,跳过不删除(否则会导致 tool 响应孤儿)
|
# 如果包含,跳过不删除(否则会导致 tool 响应孤儿)
|
||||||
contains_current_call = any(
|
contains_current_call = any(tc.get("id") == tc.id for tc in msg.get("tool_calls", []))
|
||||||
tc.get("id") == tc.id for tc in msg.get("tool_calls", [])
|
|
||||||
)
|
|
||||||
if contains_current_call:
|
if contains_current_call:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@@ -453,11 +482,13 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
|||||||
|
|
||||||
if not indices_to_remove:
|
if not indices_to_remove:
|
||||||
result_msg = "没有找到可存入记忆的消息"
|
result_msg = "没有找到可存入记忆的消息"
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": result_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": result_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 收集要总结的消息(在删除前)
|
# 收集要总结的消息(在删除前)
|
||||||
@@ -516,38 +547,45 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext):
|
|||||||
chat_history.pop(i)
|
chat_history.pop(i)
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": result_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": result_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_get_qq_chat_info(tc, chat_history: list):
|
async def handle_get_qq_chat_info(tc, chat_history: list):
|
||||||
"""处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。"""
|
"""处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。"""
|
||||||
chat = tc.arguments.get("chat", "")
|
chat = tc.arguments.get("chat", "")
|
||||||
limit = tc.arguments.get("limit", 20)
|
limit = tc.arguments.get("limit", 20)
|
||||||
console.print(f"[accent]🔧 调用工具: get_qq_chat_info(\"{chat}\", limit={limit})[/accent]")
|
console.print(f'[accent]🔧 调用工具: get_qq_chat_info("{chat}", limit={limit})[/accent]')
|
||||||
|
|
||||||
if not AIOHTTP_AVAILABLE:
|
if not AIOHTTP_AVAILABLE:
|
||||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||||
|
|
||||||
if not QQ_API_BASE_URL:
|
if not QQ_API_BASE_URL:
|
||||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -577,55 +615,66 @@ async def handle_get_qq_chat_info(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": text if text.strip() else "暂无聊天记录",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": text if text.strip() else "暂无聊天记录",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}"
|
error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"获取 QQ 聊天记录失败: {e}"
|
error_msg = f"获取 QQ 聊天记录失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_send_info(tc, chat_history: list):
|
async def handle_send_info(tc, chat_history: list):
|
||||||
"""处理 send_info 工具:通过 HTTP 发送消息到 QQ。"""
|
"""处理 send_info 工具:通过 HTTP 发送消息到 QQ。"""
|
||||||
chat = tc.arguments.get("chat", "")
|
chat = tc.arguments.get("chat", "")
|
||||||
message = tc.arguments.get("message", "")
|
message = tc.arguments.get("message", "")
|
||||||
console.print(f"[accent]🔧 调用工具: send_info(\"{chat}\")[/accent]")
|
console.print(f'[accent]🔧 调用工具: send_info("{chat}")[/accent]')
|
||||||
|
|
||||||
if not AIOHTTP_AVAILABLE:
|
if not AIOHTTP_AVAILABLE:
|
||||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||||
|
|
||||||
if not QQ_API_BASE_URL:
|
if not QQ_API_BASE_URL:
|
||||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -654,27 +703,33 @@ async def handle_send_info(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": f"消息发送成功: {data.get('message', '发送成功')}",
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": f"消息发送成功: {data.get('message', '发送成功')}",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error_msg = f"发送失败: {data.get('message', '未知错误')}"
|
error_msg = f"发送失败: {data.get('message', '未知错误')}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"发送消息失败: {e}"
|
error_msg = f"发送消息失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_list_qq_chats(tc, chat_history: list):
|
async def handle_list_qq_chats(tc, chat_history: list):
|
||||||
@@ -684,22 +739,27 @@ async def handle_list_qq_chats(tc, chat_history: list):
|
|||||||
if not AIOHTTP_AVAILABLE:
|
if not AIOHTTP_AVAILABLE:
|
||||||
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
from config import QQ_API_BASE_URL, QQ_API_KEY
|
from config import QQ_API_BASE_URL, QQ_API_KEY
|
||||||
|
|
||||||
if not QQ_API_BASE_URL:
|
if not QQ_API_BASE_URL:
|
||||||
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -721,10 +781,12 @@ async def handle_list_qq_chats(tc, chat_history: list):
|
|||||||
|
|
||||||
# 格式化聊天列表
|
# 格式化聊天列表
|
||||||
if chats:
|
if chats:
|
||||||
chat_list_text = "\n".join([
|
chat_list_text = "\n".join(
|
||||||
f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})"
|
[
|
||||||
for c in chats
|
f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})"
|
||||||
])
|
for c in chats
|
||||||
|
]
|
||||||
|
)
|
||||||
result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}"
|
result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}"
|
||||||
else:
|
else:
|
||||||
result_text = "没有可用的聊天"
|
result_text = "没有可用的聊天"
|
||||||
@@ -738,27 +800,33 @@ async def handle_list_qq_chats(tc, chat_history: list):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": result_text,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": result_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error_msg = f"获取失败: {data.get('message', '未知错误')}"
|
error_msg = f"获取失败: {data.get('message', '未知错误')}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"获取聊天列表失败: {e}"
|
error_msg = f"获取聊天列表失败: {e}"
|
||||||
console.print(f"[error]{error_msg}[/error]")
|
console.print(f"[error]{error_msg}[/error]")
|
||||||
chat_history.append({
|
chat_history.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tc.id,
|
"role": "tool",
|
||||||
"content": error_msg,
|
"tool_call_id": tc.id,
|
||||||
})
|
"content": error_msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────── 初始化 mai_files 目录 ────────────────────
|
# ──────────────────── 初始化 mai_files 目录 ────────────────────
|
||||||
|
|||||||
@@ -847,11 +847,7 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
|
|||||||
if not records:
|
if not records:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer]
|
||||||
f"问题:{record.question}\n答案:{record.answer}"
|
|
||||||
for record in records
|
|
||||||
if record.answer
|
|
||||||
]
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
||||||
|
|||||||
@@ -13,4 +13,3 @@ ENV_SESSION_TOKEN = "MAIBOT_SESSION_TOKEN"
|
|||||||
|
|
||||||
ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
|
ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
|
||||||
"""Runner 需要加载的插件目录列表(os.pathsep 分隔)"""
|
"""Runner 需要加载的插件目录列表(os.pathsep 分隔)"""
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ class CapabilityService:
|
|||||||
# 1. 权限校验
|
# 1. 权限校验
|
||||||
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
error_code = (
|
||||||
|
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||||
|
)
|
||||||
return envelope.make_error_response(
|
return envelope.make_error_response(
|
||||||
error_code.value,
|
error_code.value,
|
||||||
reason,
|
reason,
|
||||||
|
|||||||
@@ -22,8 +22,13 @@ class RegisteredComponent:
|
|||||||
"""已注册的组件条目"""
|
"""已注册的组件条目"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"name", "full_name", "component_type", "plugin_id",
|
"name",
|
||||||
"metadata", "enabled", "_compiled_pattern",
|
"full_name",
|
||||||
|
"component_type",
|
||||||
|
"plugin_id",
|
||||||
|
"metadata",
|
||||||
|
"enabled",
|
||||||
|
"_compiled_pattern",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -165,18 +170,14 @@ class ComponentRegistry:
|
|||||||
"""按全名查询。"""
|
"""按全名查询。"""
|
||||||
return self._components.get(full_name)
|
return self._components.get(full_name)
|
||||||
|
|
||||||
def get_components_by_type(
|
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||||
self, component_type: str, *, enabled_only: bool = True
|
|
||||||
) -> List[RegisteredComponent]:
|
|
||||||
"""按类型查询。"""
|
"""按类型查询。"""
|
||||||
type_dict = self._by_type.get(component_type, {})
|
type_dict = self._by_type.get(component_type, {})
|
||||||
if enabled_only:
|
if enabled_only:
|
||||||
return [c for c in type_dict.values() if c.enabled]
|
return [c for c in type_dict.values() if c.enabled]
|
||||||
return list(type_dict.values())
|
return list(type_dict.values())
|
||||||
|
|
||||||
def get_components_by_plugin(
|
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||||
self, plugin_id: str, *, enabled_only: bool = True
|
|
||||||
) -> List[RegisteredComponent]:
|
|
||||||
"""按插件查询。"""
|
"""按插件查询。"""
|
||||||
comps = self._by_plugin.get(plugin_id, [])
|
comps = self._by_plugin.get(plugin_id, [])
|
||||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||||
@@ -200,9 +201,7 @@ class ComponentRegistry:
|
|||||||
return comp, {}
|
return comp, {}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_event_handlers(
|
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||||
self, event_type: str, *, enabled_only: bool = True
|
|
||||||
) -> List[RegisteredComponent]:
|
|
||||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||||
handlers = []
|
handlers = []
|
||||||
for comp in self._by_type.get("event_handler", {}).values():
|
for comp in self._by_type.get("event_handler", {}).values():
|
||||||
@@ -213,9 +212,7 @@ class ComponentRegistry:
|
|||||||
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_workflow_steps(
|
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||||
self, stage: str, *, enabled_only: bool = True
|
|
||||||
) -> List[RegisteredComponent]:
|
|
||||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||||
steps = []
|
steps = []
|
||||||
for comp in self._by_type.get("workflow_step", {}).values():
|
for comp in self._by_type.get("workflow_step", {}).values():
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
|||||||
|
|
||||||
class EventResult:
|
class EventResult:
|
||||||
"""单个 EventHandler 的执行结果"""
|
"""单个 EventHandler 的执行结果"""
|
||||||
|
|
||||||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -107,9 +108,7 @@ class EventDispatcher:
|
|||||||
modified_message = result.modified_message
|
modified_message = result.modified_message
|
||||||
else:
|
else:
|
||||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||||||
self._invoke_handler(invoke_fn, handler, args, event_type)
|
|
||||||
)
|
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CapabilityToken:
|
class CapabilityToken:
|
||||||
"""能力令牌"""
|
"""能力令牌"""
|
||||||
|
|
||||||
plugin_id: str
|
plugin_id: str
|
||||||
generation: int
|
generation: int
|
||||||
capabilities: Set[str] = field(default_factory=set)
|
capabilities: Set[str] = field(default_factory=set)
|
||||||
|
|||||||
@@ -231,9 +231,7 @@ class RPCServer:
|
|||||||
stale_count = 0
|
stale_count = 0
|
||||||
for _req_id, future in list(self._pending_requests.items()):
|
for _req_id, future in list(self._pending_requests.items()):
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.set_exception(
|
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
|
||||||
RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")
|
|
||||||
)
|
|
||||||
stale_count += 1
|
stale_count += 1
|
||||||
self._pending_requests.clear()
|
self._pending_requests.clear()
|
||||||
if stale_count:
|
if stale_count:
|
||||||
@@ -399,9 +397,7 @@ class RPCServer:
|
|||||||
result = await handler(envelope)
|
result = await handler(envelope)
|
||||||
# 检查 handler 返回的信封是否包含错误信息
|
# 检查 handler 返回的信封是否包含错误信息
|
||||||
if result is not None and isinstance(result, Envelope) and result.error:
|
if result is not None and isinstance(result, Envelope) and result.error:
|
||||||
logger.warning(
|
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
|
||||||
f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ logger = get_logger("plugin_runtime.host.supervisor")
|
|||||||
|
|
||||||
# ─── 日志桥 ──────────────────────────────────────────────────────
|
# ─── 日志桥 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class RunnerLogBridge:
|
class RunnerLogBridge:
|
||||||
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
|
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
|
||||||
|
|
||||||
@@ -80,9 +81,7 @@ class RunnerLogBridge:
|
|||||||
|
|
||||||
stdlib_logging.getLogger(entry.logger_name).handle(record)
|
stdlib_logging.getLogger(entry.logger_name).handle(record)
|
||||||
|
|
||||||
return envelope.make_response(
|
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
|
||||||
payload={"accepted": True, "count": len(batch.entries)}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginSupervisor:
|
class PluginSupervisor:
|
||||||
@@ -101,8 +100,12 @@ class PluginSupervisor:
|
|||||||
):
|
):
|
||||||
_cfg = global_config.plugin_runtime
|
_cfg = global_config.plugin_runtime
|
||||||
self._plugin_dirs = plugin_dirs or []
|
self._plugin_dirs = plugin_dirs or []
|
||||||
self._health_interval = health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
|
self._health_interval = (
|
||||||
self._runner_spawn_timeout = runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
|
health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
|
||||||
|
)
|
||||||
|
self._runner_spawn_timeout = (
|
||||||
|
runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
|
||||||
|
)
|
||||||
|
|
||||||
# 基础设施
|
# 基础设施
|
||||||
self._transport = create_transport_server(socket_path=socket_path)
|
self._transport = create_transport_server(socket_path=socket_path)
|
||||||
@@ -114,6 +117,7 @@ class PluginSupervisor:
|
|||||||
|
|
||||||
# 编解码
|
# 编解码
|
||||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||||
|
|
||||||
codec = MsgPackCodec()
|
codec = MsgPackCodec()
|
||||||
|
|
||||||
self._rpc_server = RPCServer(
|
self._rpc_server = RPCServer(
|
||||||
@@ -124,7 +128,9 @@ class PluginSupervisor:
|
|||||||
# Runner 子进程
|
# Runner 子进程
|
||||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||||
self._runner_generation: int = 0
|
self._runner_generation: int = 0
|
||||||
self._max_restart_attempts: int = max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
|
self._max_restart_attempts: int = (
|
||||||
|
max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
|
||||||
|
)
|
||||||
self._restart_count: int = 0
|
self._restart_count: int = 0
|
||||||
|
|
||||||
# 已注册的插件组件信息
|
# 已注册的插件组件信息
|
||||||
@@ -173,6 +179,7 @@ class PluginSupervisor:
|
|||||||
extra_args: Optional[Dict[str, Any]] = None,
|
extra_args: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||||
|
|
||||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
resp = await self.invoke_plugin(
|
resp = await self.invoke_plugin(
|
||||||
method="plugin.emit_event",
|
method="plugin.emit_event",
|
||||||
@@ -196,6 +203,7 @@ class PluginSupervisor:
|
|||||||
context: Optional[WorkflowContext] = None,
|
context: Optional[WorkflowContext] = None,
|
||||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||||
"""执行 Workflow Pipeline 的快捷方法。"""
|
"""执行 Workflow Pipeline 的快捷方法。"""
|
||||||
|
|
||||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
resp = await self.invoke_plugin(
|
resp = await self.invoke_plugin(
|
||||||
method="plugin.invoke_workflow_step",
|
method="plugin.invoke_workflow_step",
|
||||||
@@ -415,7 +423,9 @@ class PluginSupervisor:
|
|||||||
env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs)
|
env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs)
|
||||||
|
|
||||||
self._runner_process = await asyncio.create_subprocess_exec(
|
self._runner_process = await asyncio.create_subprocess_exec(
|
||||||
sys.executable, "-m", runner_module,
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
runner_module,
|
||||||
env=env,
|
env=env,
|
||||||
# stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler)
|
# stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler)
|
||||||
stdout=None,
|
stdout=None,
|
||||||
@@ -557,9 +567,7 @@ class PluginSupervisor:
|
|||||||
)
|
)
|
||||||
self._stderr_drain_task = task
|
self._stderr_drain_task = task
|
||||||
task.add_done_callback(
|
task.add_done_callback(
|
||||||
lambda done_task: None
|
lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
|
||||||
if self._stderr_drain_task is not done_task
|
|
||||||
else self._clear_stderr_drain_task()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _clear_stderr_drain_task(self) -> None:
|
def _clear_stderr_drain_task(self) -> None:
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ HOOK_CONTINUE = "continue"
|
|||||||
HOOK_SKIP_STAGE = "skip_stage"
|
HOOK_SKIP_STAGE = "skip_stage"
|
||||||
HOOK_ABORT = "abort"
|
HOOK_ABORT = "abort"
|
||||||
|
|
||||||
|
|
||||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||||||
# 从配置文件读取,允许用户调整
|
# 从配置文件读取,允许用户调整
|
||||||
def _get_blocking_timeout() -> float:
|
def _get_blocking_timeout() -> float:
|
||||||
@@ -52,6 +53,7 @@ def _get_blocking_timeout() -> float:
|
|||||||
|
|
||||||
class ModificationRecord:
|
class ModificationRecord:
|
||||||
"""消息修改记录"""
|
"""消息修改记录"""
|
||||||
|
|
||||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||||
|
|
||||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||||
@@ -141,9 +143,7 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
# PLAN 阶段: 先做 Command 路由
|
# PLAN 阶段: 先做 Command 路由
|
||||||
if stage == "plan" and current_message:
|
if stage == "plan" and current_message:
|
||||||
cmd_result = await self._route_command(
|
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
||||||
command_invoke_fn or invoke_fn, current_message, ctx
|
|
||||||
)
|
|
||||||
if cmd_result is not None:
|
if cmd_result is not None:
|
||||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||||||
@@ -195,10 +195,10 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
# 更新消息(仅 blocking hook 有权修改)
|
# 更新消息(仅 blocking hook 有权修改)
|
||||||
if modified:
|
if modified:
|
||||||
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
|
changed_fields = (
|
||||||
ctx.modification_log.append(
|
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||||
ModificationRecord(stage, step.full_name, changed_fields)
|
|
||||||
)
|
)
|
||||||
|
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
||||||
current_message = modified
|
current_message = modified
|
||||||
|
|
||||||
if hook_result == HOOK_ABORT:
|
if hook_result == HOOK_ABORT:
|
||||||
@@ -222,9 +222,7 @@ class WorkflowExecutor:
|
|||||||
if nonblocking_steps and not skip_stage:
|
if nonblocking_steps and not skip_stage:
|
||||||
nb_tasks = [
|
nb_tasks = [
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self._invoke_step_fire_and_forget(
|
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||||||
invoke_fn, step, stage, ctx, current_message
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for step in nonblocking_steps
|
for step in nonblocking_steps
|
||||||
]
|
]
|
||||||
@@ -314,12 +312,16 @@ class WorkflowExecutor:
|
|||||||
step_start = time.perf_counter()
|
step_start = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coro = invoke_fn(step.plugin_id, step.name, {
|
coro = invoke_fn(
|
||||||
"stage": stage,
|
step.plugin_id,
|
||||||
"trace_id": ctx.trace_id,
|
step.name,
|
||||||
"message": message,
|
{
|
||||||
"stage_outputs": ctx.stage_outputs,
|
"stage": stage,
|
||||||
})
|
"trace_id": ctx.trace_id,
|
||||||
|
"message": message,
|
||||||
|
"stage_outputs": ctx.stage_outputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||||
|
|
||||||
@@ -355,12 +357,16 @@ class WorkflowExecutor:
|
|||||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coro = invoke_fn(step.plugin_id, step.name, {
|
coro = invoke_fn(
|
||||||
"stage": stage,
|
step.plugin_id,
|
||||||
"trace_id": ctx.trace_id,
|
step.name,
|
||||||
"message": message,
|
{
|
||||||
"stage_outputs": ctx.stage_outputs,
|
"stage": stage,
|
||||||
})
|
"trace_id": ctx.trace_id,
|
||||||
|
"message": message,
|
||||||
|
"stage_outputs": ctx.stage_outputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
||||||
@@ -393,12 +399,16 @@ class WorkflowExecutor:
|
|||||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await invoke_fn(matched.plugin_id, matched.name, {
|
return await invoke_fn(
|
||||||
"text": plain_text,
|
matched.plugin_id,
|
||||||
"message": message,
|
matched.name,
|
||||||
"trace_id": ctx.trace_id,
|
{
|
||||||
"matched_groups": matched_groups,
|
"text": plain_text,
|
||||||
})
|
"message": message,
|
||||||
|
"trace_id": ctx.trace_id,
|
||||||
|
"matched_groups": matched_groups,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||||
|
|||||||
@@ -113,9 +113,7 @@ class PluginRuntimeManager:
|
|||||||
await self._thirdparty_supervisor.start()
|
await self._thirdparty_supervisor.start()
|
||||||
started_supervisors.append(self._thirdparty_supervisor)
|
started_supervisors.append(self._thirdparty_supervisor)
|
||||||
self._started = True
|
self._started = True
|
||||||
logger.info(
|
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}")
|
||||||
f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||||
@@ -303,7 +301,9 @@ class PluginRuntimeManager:
|
|||||||
cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins)
|
cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins)
|
||||||
cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info)
|
cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info)
|
||||||
cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins)
|
cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins)
|
||||||
cap_service.register_capability("component.list_registered_plugins", self._cap_component_list_registered_plugins)
|
cap_service.register_capability(
|
||||||
|
"component.list_registered_plugins", self._cap_component_list_registered_plugins
|
||||||
|
)
|
||||||
cap_service.register_capability("component.enable", self._cap_component_enable)
|
cap_service.register_capability("component.enable", self._cap_component_enable)
|
||||||
cap_service.register_capability("component.disable", self._cap_component_disable)
|
cap_service.register_capability("component.disable", self._cap_component_disable)
|
||||||
cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin)
|
cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin)
|
||||||
@@ -1232,9 +1232,7 @@ class PluginRuntimeManager:
|
|||||||
count: int = args.get("count", 1)
|
count: int = args.get("count", 1)
|
||||||
try:
|
try:
|
||||||
results = await emoji_api.get_random(count=count)
|
results = await emoji_api.get_random(count=count)
|
||||||
emojis = [
|
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
|
||||||
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results
|
|
||||||
]
|
|
||||||
return {"success": True, "emojis": emojis}
|
return {"success": True, "emojis": emojis}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
|
||||||
@@ -1269,9 +1267,9 @@ class PluginRuntimeManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
results = await emoji_api.get_all()
|
results = await emoji_api.get_all()
|
||||||
emojis = [
|
emojis = (
|
||||||
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results
|
[{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
|
||||||
] if results else []
|
)
|
||||||
return {"success": True, "emojis": emojis}
|
return {"success": True, "emojis": emojis}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -12,20 +12,16 @@ class Codec(ABC):
|
|||||||
"""消息编解码器基类"""
|
"""消息编解码器基类"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
def encode_envelope(self, envelope: Envelope) -> bytes: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def decode_envelope(self, data: bytes) -> Envelope:
|
def decode_envelope(self, data: bytes) -> Envelope: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode(self, obj: Dict[str, Any]) -> bytes:
|
def encode(self, obj: Dict[str, Any]) -> bytes: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def decode(self, data: bytes) -> Dict[str, Any]:
|
def decode(self, data: bytes) -> Dict[str, Any]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MsgPackCodec(Codec):
|
class MsgPackCodec(Codec):
|
||||||
|
|||||||
@@ -24,8 +24,10 @@ MAX_SDK_VERSION = "1.99.99"
|
|||||||
|
|
||||||
# ─── 消息类型 ──────────────────────────────────────────────────────
|
# ─── 消息类型 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class MessageType(str, Enum):
|
||||||
"""RPC 消息类型"""
|
"""RPC 消息类型"""
|
||||||
|
|
||||||
REQUEST = "request"
|
REQUEST = "request"
|
||||||
RESPONSE = "response"
|
RESPONSE = "response"
|
||||||
EVENT = "event"
|
EVENT = "event"
|
||||||
@@ -33,6 +35,7 @@ class MessageType(str, Enum):
|
|||||||
|
|
||||||
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
|
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class RequestIdGenerator:
|
class RequestIdGenerator:
|
||||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||||
|
|
||||||
@@ -47,6 +50,7 @@ class RequestIdGenerator:
|
|||||||
|
|
||||||
# ─── Envelope 模型 ─────────────────────────────────────────────────
|
# ─── Envelope 模型 ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class Envelope(BaseModel):
|
class Envelope(BaseModel):
|
||||||
"""RPC 统一信封
|
"""RPC 统一信封
|
||||||
|
|
||||||
@@ -75,7 +79,9 @@ class Envelope(BaseModel):
|
|||||||
def is_event(self) -> bool:
|
def is_event(self) -> bool:
|
||||||
return self.message_type == MessageType.EVENT
|
return self.message_type == MessageType.EVENT
|
||||||
|
|
||||||
def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope":
|
def make_response(
|
||||||
|
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
|
||||||
|
) -> "Envelope":
|
||||||
"""基于当前请求创建对应的响应信封"""
|
"""基于当前请求创建对应的响应信封"""
|
||||||
return Envelope(
|
return Envelope(
|
||||||
protocol_version=self.protocol_version,
|
protocol_version=self.protocol_version,
|
||||||
@@ -101,8 +107,10 @@ class Envelope(BaseModel):
|
|||||||
|
|
||||||
# ─── 握手消息 ──────────────────────────────────────────────────────
|
# ─── 握手消息 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class HelloPayload(BaseModel):
|
class HelloPayload(BaseModel):
|
||||||
"""runner.hello 握手请求 payload"""
|
"""runner.hello 握手请求 payload"""
|
||||||
|
|
||||||
runner_id: str = Field(description="Runner 进程唯一标识")
|
runner_id: str = Field(description="Runner 进程唯一标识")
|
||||||
sdk_version: str = Field(description="SDK 版本号")
|
sdk_version: str = Field(description="SDK 版本号")
|
||||||
session_token: str = Field(description="一次性会话令牌")
|
session_token: str = Field(description="一次性会话令牌")
|
||||||
@@ -110,6 +118,7 @@ class HelloPayload(BaseModel):
|
|||||||
|
|
||||||
class HelloResponsePayload(BaseModel):
|
class HelloResponsePayload(BaseModel):
|
||||||
"""runner.hello 握手响应 payload"""
|
"""runner.hello 握手响应 payload"""
|
||||||
|
|
||||||
accepted: bool = Field(description="是否接受连接")
|
accepted: bool = Field(description="是否接受连接")
|
||||||
host_version: str = Field(default="", description="Host 版本号")
|
host_version: str = Field(default="", description="Host 版本号")
|
||||||
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
|
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
|
||||||
@@ -118,8 +127,10 @@ class HelloResponsePayload(BaseModel):
|
|||||||
|
|
||||||
# ─── 组件注册消息 ──────────────────────────────────────────────────
|
# ─── 组件注册消息 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class ComponentDeclaration(BaseModel):
|
class ComponentDeclaration(BaseModel):
|
||||||
"""单个组件声明"""
|
"""单个组件声明"""
|
||||||
|
|
||||||
name: str = Field(description="组件名称")
|
name: str = Field(description="组件名称")
|
||||||
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
||||||
plugin_id: str = Field(description="所属插件 ID")
|
plugin_id: str = Field(description="所属插件 ID")
|
||||||
@@ -128,6 +139,7 @@ class ComponentDeclaration(BaseModel):
|
|||||||
|
|
||||||
class RegisterComponentsPayload(BaseModel):
|
class RegisterComponentsPayload(BaseModel):
|
||||||
"""plugin.register_components 请求 payload"""
|
"""plugin.register_components 请求 payload"""
|
||||||
|
|
||||||
plugin_id: str = Field(description="插件 ID")
|
plugin_id: str = Field(description="插件 ID")
|
||||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||||
@@ -136,36 +148,44 @@ class RegisterComponentsPayload(BaseModel):
|
|||||||
|
|
||||||
# ─── 调用消息 ──────────────────────────────────────────────────────
|
# ─── 调用消息 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class InvokePayload(BaseModel):
|
class InvokePayload(BaseModel):
|
||||||
"""plugin.invoke_* 请求 payload"""
|
"""plugin.invoke_* 请求 payload"""
|
||||||
|
|
||||||
component_name: str = Field(description="要调用的组件名称")
|
component_name: str = Field(description="要调用的组件名称")
|
||||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||||
|
|
||||||
|
|
||||||
class InvokeResultPayload(BaseModel):
|
class InvokeResultPayload(BaseModel):
|
||||||
"""plugin.invoke_* 响应 payload"""
|
"""plugin.invoke_* 响应 payload"""
|
||||||
|
|
||||||
success: bool = Field(description="是否成功")
|
success: bool = Field(description="是否成功")
|
||||||
result: Any = Field(default=None, description="返回值")
|
result: Any = Field(default=None, description="返回值")
|
||||||
|
|
||||||
|
|
||||||
# ─── 能力调用消息 ──────────────────────────────────────────────────
|
# ─── 能力调用消息 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class CapabilityRequestPayload(BaseModel):
|
class CapabilityRequestPayload(BaseModel):
|
||||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||||
|
|
||||||
capability: str = Field(description="能力名称,如 send.text, db.query")
|
capability: str = Field(description="能力名称,如 send.text, db.query")
|
||||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||||
|
|
||||||
|
|
||||||
class CapabilityResponsePayload(BaseModel):
|
class CapabilityResponsePayload(BaseModel):
|
||||||
"""cap.* 响应 payload"""
|
"""cap.* 响应 payload"""
|
||||||
|
|
||||||
success: bool = Field(description="是否成功")
|
success: bool = Field(description="是否成功")
|
||||||
result: Any = Field(default=None, description="返回值")
|
result: Any = Field(default=None, description="返回值")
|
||||||
|
|
||||||
|
|
||||||
# ─── 健康检查 ──────────────────────────────────────────────────────
|
# ─── 健康检查 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class HealthPayload(BaseModel):
|
class HealthPayload(BaseModel):
|
||||||
"""plugin.health 响应 payload"""
|
"""plugin.health 响应 payload"""
|
||||||
|
|
||||||
healthy: bool = Field(description="是否健康")
|
healthy: bool = Field(description="是否健康")
|
||||||
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
|
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||||
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
||||||
@@ -173,11 +193,13 @@ class HealthPayload(BaseModel):
|
|||||||
|
|
||||||
# ─── 配置更新 ──────────────────────────────────────────────────────
|
# ─── 配置更新 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
# TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated
|
# TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated
|
||||||
# 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。
|
# 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。
|
||||||
# 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。
|
# 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。
|
||||||
class ConfigUpdatedPayload(BaseModel):
|
class ConfigUpdatedPayload(BaseModel):
|
||||||
"""plugin.config_updated 事件 payload"""
|
"""plugin.config_updated 事件 payload"""
|
||||||
|
|
||||||
plugin_id: str = Field(description="插件 ID")
|
plugin_id: str = Field(description="插件 ID")
|
||||||
config_version: str = Field(description="新配置版本")
|
config_version: str = Field(description="新配置版本")
|
||||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||||
@@ -185,14 +207,17 @@ class ConfigUpdatedPayload(BaseModel):
|
|||||||
|
|
||||||
# ─── 关停 ──────────────────────────────────────────────────────────
|
# ─── 关停 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class ShutdownPayload(BaseModel):
|
class ShutdownPayload(BaseModel):
|
||||||
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
||||||
|
|
||||||
reason: str = Field(default="normal", description="关停原因")
|
reason: str = Field(default="normal", description="关停原因")
|
||||||
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
|
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
|
||||||
|
|
||||||
|
|
||||||
# ─── 日志传输 ──────────────────────────────────────────────────────
|
# ─── 日志传输 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class LogEntry(BaseModel):
|
class LogEntry(BaseModel):
|
||||||
"""单条日志记录(Runner → Host 传输格式)"""
|
"""单条日志记录(Runner → Host 传输格式)"""
|
||||||
|
|
||||||
@@ -200,10 +225,7 @@ class LogEntry(BaseModel):
|
|||||||
description="日志时间戳,Unix epoch 毫秒",
|
description="日志时间戳,Unix epoch 毫秒",
|
||||||
)
|
)
|
||||||
level: int = Field(
|
level: int = Field(
|
||||||
description=(
|
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
|
||||||
"stdlib logging 整数级别:"
|
|
||||||
" 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
logger_name: str = Field(
|
logger_name: str = Field(
|
||||||
description="Logger 名称,如 plugin.my_plugin.submodule",
|
description="Logger 名称,如 plugin.my_plugin.submodule",
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ Host 端将其重放到主进程的 Logger(以 plugin.<name> 为名)中,
|
|||||||
- 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送
|
- 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送
|
||||||
- IPC 发送失败时静默忽略;stderr fallback 由 supervisor 的 drain task 覆盖
|
- IPC 发送失败时静默忽略;stderr fallback 由 supervisor 的 drain task 覆盖
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
@@ -203,6 +204,7 @@ class RunnerIPCLogHandler(logging.Handler):
|
|||||||
# IPC 连接断开时回退到 stderr,避免日志静默丢失
|
# IPC 连接断开时回退到 stderr,避免日志静默丢失
|
||||||
if not self._rpc_client.is_connected:
|
if not self._rpc_client.is_connected:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
print(
|
print(
|
||||||
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
|
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
|
||||||
@@ -218,6 +220,7 @@ class RunnerIPCLogHandler(logging.Handler):
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
print(
|
print(
|
||||||
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
|
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
|
||||||
|
|||||||
@@ -105,9 +105,7 @@ class ManifestValidator:
|
|||||||
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
|
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
|
||||||
mv = manifest.get("manifest_version")
|
mv = manifest.get("manifest_version")
|
||||||
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
|
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
|
||||||
self.errors.append(
|
self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
|
||||||
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_author(self, manifest: Dict[str, Any]) -> None:
|
def _check_author(self, manifest: Dict[str, Any]) -> None:
|
||||||
author = manifest.get("author")
|
author = manifest.get("author")
|
||||||
|
|||||||
@@ -240,8 +240,7 @@ class PluginLoader:
|
|||||||
instance = self._try_load_legacy_plugin(module, plugin_id)
|
instance = self._try_load_legacy_plugin(module, plugin_id)
|
||||||
if instance is not None:
|
if instance is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"插件 {plugin_id} v{manifest.get('version', '?')} "
|
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
||||||
f"通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
|
||||||
)
|
)
|
||||||
return PluginMeta(
|
return PluginMeta(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@@ -261,6 +260,7 @@ class PluginLoader:
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
from maibot_sdk.compat._import_hook import install_hook
|
from maibot_sdk.compat._import_hook import install_hook
|
||||||
|
|
||||||
install_hook()
|
install_hook()
|
||||||
self._compat_hook_installed = True
|
self._compat_hook_installed = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -281,11 +281,7 @@ class PluginLoader:
|
|||||||
|
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
obj = getattr(module, attr_name, None)
|
obj = getattr(module, attr_name, None)
|
||||||
if (
|
if isinstance(obj, type) and issubclass(obj, LegacyBasePlugin) and obj is not LegacyBasePlugin:
|
||||||
isinstance(obj, type)
|
|
||||||
and issubclass(obj, LegacyBasePlugin)
|
|
||||||
and obj is not LegacyBasePlugin
|
|
||||||
):
|
|
||||||
legacy_cls = obj
|
legacy_cls = obj
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -294,6 +290,7 @@ class PluginLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter
|
from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter
|
||||||
|
|
||||||
legacy_instance = legacy_cls()
|
legacy_instance = legacy_cls()
|
||||||
return LegacyPluginAdapter(legacy_instance)
|
return LegacyPluginAdapter(legacy_instance)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ def _get_sdk_version() -> str:
|
|||||||
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
|
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
|
||||||
try:
|
try:
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
return version("maibot-plugin-sdk")
|
return version("maibot-plugin-sdk")
|
||||||
except Exception:
|
except Exception:
|
||||||
return "1.0.0"
|
return "1.0.0"
|
||||||
|
|||||||
@@ -133,7 +133,9 @@ class PluginRunner:
|
|||||||
self._suspend_console_handlers()
|
self._suspend_console_handlers()
|
||||||
stdlib_logging.root.addHandler(handler)
|
stdlib_logging.root.addHandler(handler)
|
||||||
self._log_handler = handler
|
self._log_handler = handler
|
||||||
logger.debug("RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b")
|
logger.debug(
|
||||||
|
"RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b"
|
||||||
|
)
|
||||||
|
|
||||||
async def _uninstall_log_handler(self) -> None:
|
async def _uninstall_log_handler(self) -> None:
|
||||||
"""关停前从 logging.root 移除 Handler 并刷空缓冲。
|
"""关停前从 logging.root 移除 Handler 并刷空缓冲。
|
||||||
@@ -291,7 +293,11 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
result = (
|
||||||
|
await handler_method(**invoke.args)
|
||||||
|
if inspect.iscoroutinefunction(handler_method)
|
||||||
|
else handler_method(**invoke.args)
|
||||||
|
)
|
||||||
resp_payload = InvokeResultPayload(success=True, result=result)
|
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||||
return envelope.make_response(payload=resp_payload.model_dump())
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -332,7 +338,11 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
raw = (
|
||||||
|
await handler_method(**invoke.args)
|
||||||
|
if inspect.iscoroutinefunction(handler_method)
|
||||||
|
else handler_method(**invoke.args)
|
||||||
|
)
|
||||||
|
|
||||||
# 规范化返回值:将 EventHandler 返回展平到 payload 顶层
|
# 规范化返回值:将 EventHandler 返回展平到 payload 顶层
|
||||||
if raw is None:
|
if raw is None:
|
||||||
@@ -341,7 +351,9 @@ class PluginRunner:
|
|||||||
result = {
|
result = {
|
||||||
"success": True,
|
"success": True,
|
||||||
# 兼容 guide.md 中文档的 {"blocked": True} 写法
|
# 兼容 guide.md 中文档的 {"blocked": True} 写法
|
||||||
"continue_processing": not raw.get("blocked", False) if "blocked" in raw else raw.get("continue_processing", True),
|
"continue_processing": not raw.get("blocked", False)
|
||||||
|
if "blocked" in raw
|
||||||
|
else raw.get("continue_processing", True),
|
||||||
"modified_message": raw.get("modified_message"),
|
"modified_message": raw.get("modified_message"),
|
||||||
"custom_result": raw.get("custom_result"),
|
"custom_result": raw.get("custom_result"),
|
||||||
}
|
}
|
||||||
@@ -383,7 +395,11 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
raw = (
|
||||||
|
await handler_method(**invoke.args)
|
||||||
|
if inspect.iscoroutinefunction(handler_method)
|
||||||
|
else handler_method(**invoke.args)
|
||||||
|
)
|
||||||
|
|
||||||
# 规范化返回值
|
# 规范化返回值
|
||||||
if isinstance(raw, str):
|
if isinstance(raw, str):
|
||||||
@@ -455,6 +471,7 @@ class PluginRunner:
|
|||||||
|
|
||||||
# ─── sys.path 隔离 ────────────────────────────────────────
|
# ─── sys.path 隔离 ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||||
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
|
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
|
||||||
|
|
||||||
@@ -504,9 +521,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
return self if self._should_block(fullname) else None
|
return self if self._should_block(fullname) else None
|
||||||
|
|
||||||
def load_module(self, fullname):
|
def load_module(self, fullname):
|
||||||
raise ImportError(
|
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
|
||||||
f"Runner 子进程不允许导入主程序模块: {fullname}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _should_block(self, fullname: str) -> bool:
|
def _should_block(self, fullname: str) -> bool:
|
||||||
# 放行非 src.* 的导入、以及 "src" 本身
|
# 放行非 src.* 的导入、以及 "src" 本身
|
||||||
@@ -514,8 +529,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
return False
|
return False
|
||||||
# 放行白名单前缀
|
# 放行白名单前缀
|
||||||
return not any(
|
return not any(
|
||||||
fullname == prefix or fullname.startswith(f"{prefix}.")
|
fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
|
||||||
for prefix in self._ALLOWED_SRC_PREFIXES
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.meta_path.insert(0, _PluginImportBlocker())
|
sys.meta_path.insert(0, _PluginImportBlocker())
|
||||||
@@ -523,6 +537,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
|
|
||||||
# ─── 进程入口 ──────────────────────────────────────────────
|
# ─── 进程入口 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _async_main() -> None:
|
async def _async_main() -> None:
|
||||||
"""异步主入口"""
|
"""异步主入口"""
|
||||||
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
|
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
|
|||||||
|
|
||||||
class ConnectionClosed(Exception):
|
class ConnectionClosed(Exception):
|
||||||
"""连接已关闭"""
|
"""连接已关闭"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,12 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
|
|||||||
"""
|
"""
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
from .uds import UDSTransportServer
|
from .uds import UDSTransportServer
|
||||||
|
|
||||||
return UDSTransportServer(socket_path=socket_path)
|
return UDSTransportServer(socket_path=socket_path)
|
||||||
else:
|
else:
|
||||||
# Windows 回退到 TCP(后续可改为 Named Pipe)
|
# Windows 回退到 TCP(后续可改为 Named Pipe)
|
||||||
from .tcp import TCPTransportServer
|
from .tcp import TCPTransportServer
|
||||||
|
|
||||||
return TCPTransportServer()
|
return TCPTransportServer()
|
||||||
|
|
||||||
|
|
||||||
@@ -39,9 +41,11 @@ def create_transport_client(address: str) -> TransportClient:
|
|||||||
"""
|
"""
|
||||||
if "/" in address or address.endswith(".sock"):
|
if "/" in address or address.endswith(".sock"):
|
||||||
from .uds import UDSTransportClient
|
from .uds import UDSTransportClient
|
||||||
|
|
||||||
return UDSTransportClient(socket_path=address)
|
return UDSTransportClient(socket_path=address)
|
||||||
elif ":" in address:
|
elif ":" in address:
|
||||||
from .tcp import TCPTransportClient
|
from .tcp import TCPTransportClient
|
||||||
|
|
||||||
host, port_str = address.rsplit(":", 1)
|
host, port_str = address.rsplit(":", 1)
|
||||||
return TCPTransportClient(host=host, port=int(port_str))
|
return TCPTransportClient(host=host, port=int(port_str))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
|
|||||||
|
|
||||||
class TCPConnection(Connection):
|
class TCPConnection(Connection):
|
||||||
"""基于 TCP 的连接"""
|
"""基于 TCP 的连接"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
|
|||||||
|
|
||||||
class UDSConnection(Connection):
|
class UDSConnection(Connection):
|
||||||
"""基于 UDS 的连接"""
|
"""基于 UDS 的连接"""
|
||||||
|
|
||||||
pass # 直接复用 Connection 基类的分帧读写
|
pass # 直接复用 Connection 基类的分帧读写
|
||||||
|
|
||||||
|
|
||||||
@@ -30,16 +31,17 @@ class UDSTransportServer(TransportServer):
|
|||||||
if socket_path is None:
|
if socket_path is None:
|
||||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||||
import uuid
|
import uuid
|
||||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
|
||||||
|
socket_path = os.path.join(
|
||||||
|
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||||
|
)
|
||||||
|
|
||||||
# 如果路径超出 UDS 限制,回退到更短的路径
|
# 如果路径超出 UDS 限制,回退到更短的路径
|
||||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
||||||
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
||||||
|
|
||||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
||||||
raise OSError(
|
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
|
||||||
f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._socket_path = socket_path
|
self._socket_path = socket_path
|
||||||
self._server: Optional[asyncio.AbstractServer] = None
|
self._server: Optional[asyncio.AbstractServer] = None
|
||||||
|
|||||||
@@ -14,8 +14,16 @@ class KnowledgePlugin(MaiBotPlugin):
|
|||||||
"lpmm_search_knowledge",
|
"lpmm_search_knowledge",
|
||||||
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
|
description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具",
|
||||||
parameters=[
|
parameters=[
|
||||||
ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True),
|
ToolParameterInfo(
|
||||||
ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数,默认5", required=False, default=5),
|
name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True
|
||||||
|
),
|
||||||
|
ToolParameterInfo(
|
||||||
|
name="limit",
|
||||||
|
param_type=ToolParamType.INTEGER,
|
||||||
|
description="希望返回的相关知识条数,默认5",
|
||||||
|
required=False,
|
||||||
|
default=5,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):
|
async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs):
|
||||||
|
|||||||
@@ -231,9 +231,7 @@ class PluginManagementPlugin(MaiBotPlugin):
|
|||||||
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
|
text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered)
|
||||||
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
|
await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id)
|
||||||
|
|
||||||
async def _handle_component_toggle(
|
async def _handle_component_toggle(self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str):
|
||||||
self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str
|
|
||||||
):
|
|
||||||
if action not in ("enable", "disable"):
|
if action not in ("enable", "disable"):
|
||||||
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
await self.ctx.send.text("插件管理命令不合法", stream_id)
|
||||||
return
|
return
|
||||||
@@ -245,13 +243,9 @@ class PluginManagementPlugin(MaiBotPlugin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if action == "enable":
|
if action == "enable":
|
||||||
result = await self.ctx.component.enable_component(
|
result = await self.ctx.component.enable_component(comp_name, comp_type, scope=scope, stream_id=stream_id)
|
||||||
comp_name, comp_type, scope=scope, stream_id=stream_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
result = await self.ctx.component.disable_component(
|
result = await self.ctx.component.disable_component(comp_name, comp_type, scope=scope, stream_id=stream_id)
|
||||||
comp_name, comp_type, scope=scope, stream_id=stream_id
|
|
||||||
)
|
|
||||||
|
|
||||||
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
ok = result.get("success", False) if isinstance(result, dict) else bool(result)
|
||||||
scope_label = "全局" if scope == "global" else "本地"
|
scope_label = "全局" if scope == "global" else "本地"
|
||||||
|
|||||||
@@ -142,13 +142,21 @@ class ChatManager:
|
|||||||
|
|
||||||
if chat_stream.is_group_session:
|
if chat_stream.is_group_session:
|
||||||
info["group_id"] = chat_stream.group_id
|
info["group_id"] = chat_stream.group_id
|
||||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
|
if (
|
||||||
|
chat_stream.context
|
||||||
|
and chat_stream.context.message
|
||||||
|
and chat_stream.context.message.message_info.group_info
|
||||||
|
):
|
||||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
||||||
else:
|
else:
|
||||||
info["group_name"] = "未知群聊"
|
info["group_name"] = "未知群聊"
|
||||||
else:
|
else:
|
||||||
info["user_id"] = chat_stream.user_id
|
info["user_id"] = chat_stream.user_id
|
||||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
|
if (
|
||||||
|
chat_stream.context
|
||||||
|
and chat_stream.context.message
|
||||||
|
and chat_stream.context.message.message_info.user_info
|
||||||
|
):
|
||||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
||||||
else:
|
else:
|
||||||
info["user_name"] = "未知用户"
|
info["user_name"] = "未知用户"
|
||||||
|
|||||||
@@ -44,7 +44,9 @@ def get_replyer(
|
|||||||
if not chat_id and not chat_stream:
|
if not chat_id and not chat_stream:
|
||||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(
|
||||||
|
f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}"
|
||||||
|
)
|
||||||
return replyer_manager.get_replyer(
|
return replyer_manager.get_replyer(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
|
|||||||
@@ -70,10 +70,10 @@ async def _send_to_target(
|
|||||||
if reply_message:
|
if reply_message:
|
||||||
anchor_message = db_message_to_mai_message(reply_message)
|
anchor_message = db_message_to_mai_message(reply_message)
|
||||||
if anchor_message:
|
if anchor_message:
|
||||||
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
|
logger.debug(
|
||||||
reply_to_platform_id = (
|
f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
|
||||||
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
|
|
||||||
)
|
)
|
||||||
|
reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||||
|
|
||||||
sender_info = None
|
sender_info = None
|
||||||
if target_stream.context and target_stream.context.message:
|
if target_stream.context and target_stream.context.message:
|
||||||
|
|||||||
@@ -174,4 +174,4 @@ async def broadcast_log(log_data: dict):
|
|||||||
# 清理断开的连接
|
# 清理断开的连接
|
||||||
if disconnected:
|
if disconnected:
|
||||||
active_connections.difference_update(disconnected)
|
active_connections.difference_update(disconnected)
|
||||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||||
|
|||||||
@@ -551,9 +551,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||||||
try:
|
try:
|
||||||
# 1. 表情包之王 - 使用次数最多的表情包
|
# 1. 表情包之王 - 使用次数最多的表情包
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = (
|
statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
|
||||||
select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5)
|
|
||||||
)
|
|
||||||
top_emojis = session.exec(statement).all()
|
top_emojis = session.exec(statement).all()
|
||||||
if top_emojis:
|
if top_emojis:
|
||||||
data.top_emoji = {
|
data.top_emoji = {
|
||||||
|
|||||||
@@ -314,17 +314,13 @@ async def get_jargon_stats():
|
|||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
total = session.exec(select(fn.count()).select_from(Jargon)).one()
|
total = session.exec(select(fn.count()).select_from(Jargon)).one()
|
||||||
|
|
||||||
confirmed_jargon = session.exec(
|
confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one()
|
||||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))
|
|
||||||
).one()
|
|
||||||
confirmed_not_jargon = session.exec(
|
confirmed_not_jargon = session.exec(
|
||||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
|
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False))
|
||||||
).one()
|
).one()
|
||||||
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
|
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
|
||||||
|
|
||||||
complete_count = session.exec(
|
complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one()
|
||||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))
|
|
||||||
).one()
|
|
||||||
|
|
||||||
chat_count = session.exec(
|
chat_count = session.exec(
|
||||||
select(fn.count()).select_from(
|
select(fn.count()).select_from(
|
||||||
|
|||||||
Reference in New Issue
Block a user