diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index 97fae04f..b3f69384 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -14,9 +14,7 @@ class BetterFrequencyPlugin(MaiBotPlugin): description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>", pattern=r"^/chat\s+(?:talk_frequency|t)\s+(?P[+-]?\d*\.?\d+)$", ) - async def handle_set_talk_frequency( - self, stream_id: str = "", matched_groups: dict | None = None, **kwargs - ): + async def handle_set_talk_frequency(self, stream_id: str = "", matched_groups: dict | None = None, **kwargs): """设置当前聊天的 talk_frequency""" if not matched_groups or "value" not in matched_groups: return False, "命令格式错误", False diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 4e8460c0..4e21d6ae 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -116,7 +116,9 @@ class HelloWorldPlugin(MaiBotPlugin): print(f"接收到消息: {raw}") return True, True, "消息已打印", None, None - @EventHandler("forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE) + @EventHandler( + "forward_messages_handler", description="把接收到的消息转发到指定聊天ID", event_type=EventType.ON_MESSAGE + ) async def handle_forward_messages(self, message=None, stream_id: str = "", **kwargs): """收集消息并定期转发""" if not message: diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 3de238ee..7553b54c 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -19,6 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "ma # ─── 协议层测试 ─────────────────────────────────────────── + class TestProtocol: """协议层测试""" @@ -111,6 +112,7 @@ class TestProtocol: # ─── 传输层测试 ─────────────────────────────────────────── + class TestTransport: """传输层测试""" @@ -198,6 +200,7 @@ class TestTransport: # ─── Host 层测试 ────────────────────────────────────────── + class TestHost: """Host 端基础设施测试""" @@ -244,6 +247,7 @@ class TestHost: # ─── SDK 测试 ───────────────────────────────────────────── + class TestSDK: """SDK 框架测试""" @@ -312,6 +316,7 @@ class TestSDK: # ─── 端到端集成测试 ──────────────────────────────────────── + class TestE2E: """端到端集成测试(Host + Runner 通信)""" @@ -392,6 +397,7 @@ class TestE2E: # ─── Manifest 校验测试 ───────────────────────────────────── + class TestManifestValidator: """Manifest 校验器测试""" @@ -489,6 +495,7 @@ class TestVersionComparator: # ─── 依赖解析测试 ────────────────────────────────────────── + class TestDependencyResolution: """插件依赖解析测试""" @@ -498,8 +505,16 @@ class TestDependencyResolution: loader = PluginLoader() candidates = { "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), - "auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"), - "api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"), + "auth": ( + "dir_auth", + {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, + "plugin.py", + ), + "api": ( + "dir_api", + {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, + "plugin.py", + ), } order, failed = loader._resolve_dependencies(candidates) @@ -512,7 +527,17 @@ class TestDependencyResolution: loader = PluginLoader() candidates = { - "plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"), + "plugin_a": ( + "dir_a", + { + "name": "plugin_a", + "version": "1.0", + "description": "d", + "author": "a", + "dependencies": ["nonexistent"], + }, + "plugin.py", + ), } order, failed = loader._resolve_dependencies(candidates) @@ -524,8 +549,16 @@ class TestDependencyResolution: loader = PluginLoader() candidates = { - "a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"), - "b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"), + "a": ( + "dir_a", + {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, + "p.py", + ), + "b": ( + "dir_b", + {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, + "p.py", + ), } order, failed = loader._resolve_dependencies(candidates) @@ -534,6 +567,7 @@ class TestDependencyResolution: # ─── Host-side ComponentRegistry 测试 ────────────────────── + class TestComponentRegistry: """Host-side 组件注册表测试""" @@ -541,17 +575,32 @@ class TestComponentRegistry: from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() - reg.register_component("greet", "action", "plugin_a", { - "description": "打招呼", - "activation_type": "keyword", - "activation_keywords": ["hi"], - }) - reg.register_component("help", "command", "plugin_a", { - "command_pattern": r"^/help", - }) - reg.register_component("search", "tool", "plugin_b", { - "description": "搜索", - }) + reg.register_component( + "greet", + "action", + "plugin_a", + { + "description": "打招呼", + "activation_type": "keyword", + "activation_keywords": ["hi"], + }, + ) + reg.register_component( + "help", + "command", + "plugin_a", + { + "command_pattern": r"^/help", + }, + ) + reg.register_component( + "search", + "tool", + "plugin_b", + { + "description": "搜索", + }, + ) stats = reg.get_stats() assert stats["total"] == 3 @@ -573,12 +622,22 @@ class TestComponentRegistry: from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() - reg.register_component("help", "command", "p1", { - "command_pattern": r"^/help", - }) - reg.register_component("echo", "command", "p1", { - "command_pattern": r"^/echo\s", - }) + reg.register_component( + "help", + "command", + "p1", + { + "command_pattern": r"^/help", + }, + ) + reg.register_component( + "echo", + "command", + "p1", + { + "command_pattern": r"^/echo\s", + }, + ) match = reg.find_command_by_text("/help me") assert match is not None @@ -622,12 +681,24 @@ class TestComponentRegistry: from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() - reg.register_component("h_low", "event_handler", "p1", { - "event_type": "on_message", "weight": 10, - }) - reg.register_component("h_high", "event_handler", "p2", { - "event_type": "on_message", "weight": 100, - }) + reg.register_component( + "h_low", + "event_handler", + "p1", + { + "event_type": "on_message", + "weight": 10, + }, + ) + reg.register_component( + "h_high", + "event_handler", + "p2", + { + "event_type": "on_message", + "weight": 100, + }, + ) handlers = reg.get_event_handlers("on_message") assert handlers[0].name == "h_high" @@ -637,10 +708,15 @@ class TestComponentRegistry: from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() - reg.register_component("search", "tool", "p1", { - "description": "搜索工具", - "parameters_raw": {"query": {"type": "string"}}, - }) + reg.register_component( + "search", + "tool", + "p1", + { + "description": "搜索工具", + "parameters_raw": {"query": {"type": "string"}}, + }, + ) tools = reg.get_tools_for_llm() assert len(tools) == 1 @@ -650,6 +726,7 @@ class TestComponentRegistry: # ─── EventDispatcher 测试 ───────────────────────────────── + class TestEventDispatcher: """Host-side 事件分发器测试""" @@ -659,11 +736,16 @@ class TestEventDispatcher: from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() - reg.register_component("h1", "event_handler", "p1", { - "event_type": "on_start", - "weight": 0, - "intercept_message": False, - }) + reg.register_component( + "h1", + "event_handler", + "p1", + { + "event_type": "on_start", + "weight": 0, + "intercept_message": False, + }, + ) dispatcher = EventDispatcher(reg) call_log = [] @@ -672,9 +754,7 @@ class TestEventDispatcher: call_log.append((plugin_id, comp_name)) return {"success": True, "continue_processing": True} - should_continue, modified = await dispatcher.dispatch_event( - "on_start", mock_invoke - ) + should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke) assert should_continue # 非阻塞分发是异步的,等一下让 task 完成 await asyncio.sleep(0.1) @@ -687,11 +767,16 @@ class TestEventDispatcher: from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() - reg.register_component("filter", "event_handler", "p1", { - "event_type": "on_message_pre_process", - "weight": 100, - "intercept_message": True, - }) + reg.register_component( + "filter", + "event_handler", + "p1", + { + "event_type": "on_message_pre_process", + "weight": 100, + "intercept_message": True, + }, + ) dispatcher = EventDispatcher(reg) @@ -755,6 +840,7 @@ class TestEventBus: # ─── MaiMessages 测试 ───────────────────────────────────── + class TestMaiMessages: """统一消息模型测试""" @@ -799,6 +885,7 @@ class TestMaiMessages: # ─── WorkflowExecutor 测试 ──────────────────────────────── + class TestWorkflowExecutor: """Host-side Workflow 执行器测试(新 pipeline 模型)""" @@ -829,11 +916,16 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("upper", "workflow_step", "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, - }) + reg.register_component( + "upper", + "workflow_step", + "p1", + { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }, + ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): @@ -859,11 +951,16 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("blocker", "workflow_step", "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, - }) + reg.register_component( + "blocker", + "workflow_step", + "p1", + { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }, + ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): @@ -884,17 +981,27 @@ class TestWorkflowExecutor: reg = ComponentRegistry() # high-priority hook 返回 skip_stage - reg.register_component("skipper", "workflow_step", "p1", { - "stage": "ingress", - "priority": 100, - "blocking": True, - }) + reg.register_component( + "skipper", + "workflow_step", + "p1", + { + "stage": "ingress", + "priority": 100, + "blocking": True, + }, + ) # low-priority hook 不应被执行 - reg.register_component("checker", "workflow_step", "p2", { - "stage": "ingress", - "priority": 1, - "blocking": True, - }) + reg.register_component( + "checker", + "workflow_step", + "p2", + { + "stage": "ingress", + "priority": 1, + "blocking": True, + }, + ) executor = WorkflowExecutor(reg) call_log = [] @@ -905,9 +1012,7 @@ class TestWorkflowExecutor: return {"hook_result": "skip_stage"} return {"hook_result": "continue"} - result, _, _ = await executor.execute( - mock_invoke, message={"plain_text": "test"} - ) + result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "completed" # 只有 skipper 被调用,checker 被跳过 assert call_log == ["skipper"] @@ -919,12 +1024,17 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("only_dm", "workflow_step", "p1", { - "stage": "ingress", - "priority": 10, - "blocking": True, - "filter": {"chat_type": "direct"}, - }) + reg.register_component( + "only_dm", + "workflow_step", + "p1", + { + "stage": "ingress", + "priority": 10, + "blocking": True, + "filter": {"chat_type": "direct"}, + }, + ) executor = WorkflowExecutor(reg) call_log = [] @@ -934,15 +1044,11 @@ class TestWorkflowExecutor: return {"hook_result": "continue"} # 不匹配 filter —— hook 不应被调用 - await executor.execute( - mock_invoke, message={"plain_text": "hi", "chat_type": "group"} - ) + await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"}) assert not call_log # 匹配 filter —— hook 应被调用 - await executor.execute( - mock_invoke, message={"plain_text": "hi", "chat_type": "direct"} - ) + await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}) assert call_log == ["only_dm"] @pytest.mark.asyncio @@ -952,17 +1058,27 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("failer", "workflow_step", "p1", { - "stage": "ingress", - "priority": 100, - "blocking": True, - "error_policy": "skip", - }) - reg.register_component("ok_step", "workflow_step", "p2", { - "stage": "ingress", - "priority": 1, - "blocking": True, - }) + reg.register_component( + "failer", + "workflow_step", + "p1", + { + "stage": "ingress", + "priority": 100, + "blocking": True, + "error_policy": "skip", + }, + ) + reg.register_component( + "ok_step", + "workflow_step", + "p2", + { + "stage": "ingress", + "priority": 1, + "blocking": True, + }, + ) executor = WorkflowExecutor(reg) call_log = [] @@ -973,9 +1089,7 @@ class TestWorkflowExecutor: raise RuntimeError("boom") return {"hook_result": "continue"} - result, _, ctx = await executor.execute( - mock_invoke, message={"plain_text": "test"} - ) + result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "completed" assert "failer" in call_log assert "ok_step" in call_log @@ -988,20 +1102,23 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("failer", "workflow_step", "p1", { - "stage": "ingress", - "priority": 10, - "blocking": True, - # error_policy defaults to "abort" - }) + reg.register_component( + "failer", + "workflow_step", + "p1", + { + "stage": "ingress", + "priority": 10, + "blocking": True, + # error_policy defaults to "abort" + }, + ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): raise RuntimeError("fatal") - result, _, ctx = await executor.execute( - mock_invoke, message={"plain_text": "test"} - ) + result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "failed" assert result.stopped_at == "ingress" @@ -1013,11 +1130,16 @@ class TestWorkflowExecutor: reg = ComponentRegistry() for i in range(3): - reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", { - "stage": "post_process", - "priority": 0, - "blocking": False, - }) + reg.register_component( + f"nb_{i}", + "workflow_step", + f"p{i}", + { + "stage": "post_process", + "priority": 0, + "blocking": False, + }, + ) executor = WorkflowExecutor(reg) call_log = [] @@ -1026,9 +1148,7 @@ class TestWorkflowExecutor: call_log.append(comp_name) return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} - result, final_msg, _ = await executor.execute( - mock_invoke, message={"plain_text": "original"} - ) + result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) # non-blocking 的 modified_message 被忽略 assert final_msg["plain_text"] == "original" # 给异步 task 时间完成 @@ -1042,9 +1162,14 @@ class TestWorkflowExecutor: from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() - reg.register_component("help", "command", "p1", { - "command_pattern": r"^/help", - }) + reg.register_component( + "help", + "command", + "p1", + { + "command_pattern": r"^/help", + }, + ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): @@ -1052,9 +1177,7 @@ class TestWorkflowExecutor: return {"output": "帮助信息"} return {"hook_result": "continue"} - result, _, ctx = await executor.execute( - mock_invoke, message={"plain_text": "/help topic"} - ) + result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"}) assert result.status == "completed" assert ctx.matched_command == "p1.help" cmd_result = ctx.get_stage_output("plan", "command_result") @@ -1069,17 +1192,27 @@ class TestWorkflowExecutor: reg = ComponentRegistry() # ingress 阶段写入数据 - reg.register_component("writer", "workflow_step", "p1", { - "stage": "ingress", - "priority": 10, - "blocking": True, - }) + reg.register_component( + "writer", + "workflow_step", + "p1", + { + "stage": "ingress", + "priority": 10, + "blocking": True, + }, + ) # pre_process 阶段读取数据 - reg.register_component("reader", "workflow_step", "p2", { - "stage": "pre_process", - "priority": 10, - "blocking": True, - }) + reg.register_component( + "reader", + "workflow_step", + "p2", + { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }, + ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): @@ -1096,9 +1229,7 @@ class TestWorkflowExecutor: return {"hook_result": "continue"} return {"hook_result": "continue"} - result, _, ctx = await executor.execute( - mock_invoke, message={"plain_text": "hi"} - ) + result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"}) assert result.status == "completed" assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" @@ -1324,10 +1455,15 @@ class TestIntegration: async def stop(self): self.stopped = True - monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])) - monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])) + monkeypatch.setattr( + integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]) + ) + monkeypatch.setattr( + integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]) + ) import src.plugin_runtime.host.supervisor as supervisor_module + monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) manager = integration_module.PluginRuntimeManager() diff --git a/pytests/utils_test/message_utils_test.py b/pytests/utils_test/message_utils_test.py index 60c32cbb..fba99324 100644 --- a/pytests/utils_test/message_utils_test.py +++ b/pytests/utils_test/message_utils_test.py @@ -9,16 +9,12 @@ from datetime import datetime from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent + from src.common.data_models.message_component_data_model import MessageSequence from src.chat.message_receive.message import ( SessionMessage, TextComponent, ImageComponent, - EmojiComponent, - VoiceComponent, AtComponent, - ReplyComponent, - ForwardNodeComponent, ) @@ -203,17 +199,23 @@ def load_message_via_file(monkeypatch): def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str: return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为 + def dummy_is_bot_self(user_id: str, platform) -> bool: return user_id == "bot_self" + def load_utils_via_file(monkeypatch): setup_mocks(monkeypatch) # Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用 math_utils_mod = ModuleType("src.common.utils.math_utils") math_utils_mod.number_to_short_id = dummy_number_to_short_id - math_utils_mod.TimestampMode = type("TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}) - math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: "2024-01-01 12:00:00" # 返回固定的时间字符串 + math_utils_mod.TimestampMode = type( + "TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"} + ) + math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: ( + "2024-01-01 12:00:00" + ) # 返回固定的时间字符串 monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod) # 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析 @@ -349,6 +351,7 @@ async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno( assert "u_comb" in mapping assert mapping["u_comb"][0] == "XXXXXX" + @pytest.mark.asyncio async def test_build_readable_message_with_at(monkeypatch): """包含@组件的消息:验证@组件中的用户信息也被匿名化和替换""" @@ -363,7 +366,9 @@ async def test_build_readable_message_with_at(monkeypatch): msg.session_id = "s_at" msg.raw_message = MessageSequence([at_comp]) msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser")) - text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot") + text, mapping, _ = await MessageUtils.build_readable_message( + [msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot" + ) # 验证主消息和@组件中的用户信息都被处理 assert "XXXXXX说:" in text # 主消息用户被匿名化 - assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化 \ No newline at end of file + assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化 diff --git a/scripts/i18n_extract_candidates.py b/scripts/i18n_extract_candidates.py index 855b5429..830dba2f 100644 --- a/scripts/i18n_extract_candidates.py +++ b/scripts/i18n_extract_candidates.py @@ -22,11 +22,7 @@ def should_skip(path: Path) -> bool: def iter_python_files(root: Path) -> list[Path]: - return sorted( - path - for path in root.rglob("*.py") - if path.is_file() and not should_skip(path.relative_to(root)) - ) + return sorted(path for path in root.rglob("*.py") if path.is_file() and not should_skip(path.relative_to(root))) class CandidateExtractor(ast.NodeVisitor): diff --git a/scripts/test_memory_retrieval.py b/scripts/test_memory_retrieval.py index 63bc3bd4..40aff8aa 100644 --- a/scripts/test_memory_retrieval.py +++ b/scripts/test_memory_retrieval.py @@ -23,7 +23,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import initialize_logging, get_logger from src.common.database.database import db from src.common.database.database_model import LLMUsage -from src.chat.message_receive.chat_manager import BotChatSession from maim_message import UserInfo, GroupInfo logger = get_logger("test_memory_retrieval") diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 6944e1f0..c6cfe469 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -10,7 +10,6 @@ from src.common.logger import get_logger from src.common.database.database_model import Expression from src.prompt.prompt_manager import prompt_manager from src.bw_learner.learner_utils_old import weighted_sample -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") diff --git a/src/bw_learner/learner_utils_old.py b/src/bw_learner/learner_utils_old.py index 42fda4b3..3f21c55d 100644 --- a/src/bw_learner/learner_utils_old.py +++ b/src/bw_learner/learner_utils_old.py @@ -1,22 +1,14 @@ -import re -import difflib import random import json -from typing import Optional, List, Dict, Any, Tuple +from typing import Optional, List, Dict, Any from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - build_readable_messages, -) -from src.chat.utils.utils import parse_platform_accounts -from json_repair import repair_json logger = get_logger("learner_utils") - def _compute_weights(population: List[Dict]) -> List[float]: """ 根据表达的count计算权重,范围限定在1~5之间。 diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py index 5b9b60ac..119c4ecd 100644 --- a/src/chat/brain_chat/PFC/conversation.py +++ b/src/chat/brain_chat/PFC/conversation.py @@ -16,7 +16,7 @@ from .action_planner import ActionPlanner from .observation_info import ObservationInfo from .conversation_info import ConversationInfo # 确保导入 ConversationInfo from .reply_generator import ReplyGenerator -from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from maim_message import UserInfo from .pfc_KnowledgeFetcher import KnowledgeFetcher from .waiter import Waiter diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 74da0b89..9c6e5847 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -9,7 +9,7 @@ from src.config.config import global_config from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.message_data_model import ReplyContentType -from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer from src.chat.brain_chat.brain_planner import BrainPlanner @@ -22,7 +22,12 @@ from src.person_info.person_info import Person from src.core.types import ActionInfo, EventType from src.core.event_bus import event_bus from src.chat.event_helpers import build_event_message -from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api +from src.services import ( + generator_service as generator_api, + send_service as send_api, + message_service as message_api, + database_service as database_api, +) from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_before_timestamp_with_chat, @@ -294,10 +299,10 @@ class BrainChatting: message_id_list=message_id_list, prompt_key="brain_planner", ) - _event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id) - continue_flag, modified_message = await event_bus.emit( - EventType.ON_PLAN, _event_msg + _event_msg = build_event_message( + EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id ) + continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg) if not continue_flag: return False if modified_message and modified_message._modify_flags.modify_llm_prompt: diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 40f2d91d..3778a62f 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -1,5 +1,5 @@ from rich.traceback import install -from typing import Optional, List, TYPE_CHECKING, Tuple, Dict +from typing import Optional, List, TYPE_CHECKING import asyncio import time @@ -14,7 +14,7 @@ from src.chat.message_receive.chat_manager import chat_manager from src.bw_learner.expression_learner import ExpressionLearner from src.bw_learner.jargon_miner import JargonMiner -from .heartFC_utils import CycleDetail, CycleActionInfo, CyclePlanInfo +from .heartFC_utils import CycleDetail if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 128ba3fe..5b3ece9b 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Dict import traceback @@ -9,6 +9,7 @@ from src.chat.heart_flow.heartFC_chat import HeartFChatting logger = get_logger("heartflow") + # TODO: 恢复PFC,现在暂时禁用 class HeartflowManager: """主心流协调器,负责初始化并协调聊天,控制聊天属性""" @@ -17,7 +18,7 @@ class HeartflowManager: # self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {} self.heartflow_chat_list: Dict[str, HeartFChatting] = {} - async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]: + async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]: """获取或创建一个新的HeartFChatting实例""" try: if chat := self.heartflow_chat_list.get(session_id): diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index f79ca12d..b02e037d 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -9,11 +9,10 @@ from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver + # from src.chat.brain_chat.PFC.pfc_manager import PFCManager -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.core.announcement_manager import global_announcement_manager from src.core.component_registry import component_registry -from src.core.types import EventType from .message import SessionMessage from .chat_manager import chat_manager @@ -391,6 +390,7 @@ class ChatBot: else: logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统") await self.heartflow_message_receiver.process_message(message) + await preprocess() except Exception as e: diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index c23921c1..be2ef026 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -8,7 +8,7 @@ import asyncio from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Messages -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.common.data_models.message_component_data_model import ( TextComponent, ImageComponent, diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 1b4db86d..075a695d 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -22,6 +22,7 @@ _webui_chat_broadcaster = None # 虚拟群 ID 前缀(与 chat_routes.py 保持一致) VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" + # TODO: 重构完成后完成webui相关 def get_webui_chat_broadcaster(): """获取 WebUI 聊天室广播器""" diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 05673778..a6ffec2b 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -4,7 +4,7 @@ from src.chat.message_receive.chat_manager import BotChatSession from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages from src.core.component_registry import component_registry, ActionExecutor -from src.core.types import ActionInfo, ComponentType +from src.core.types import ActionInfo logger = get_logger("action_manager") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index a34f42d4..fa8635ee 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -1,6 +1,6 @@ import random import time -from typing import List, Dict, TYPE_CHECKING, Tuple +from typing import List, Dict, Tuple from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 735e0afe..39e66e91 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -119,9 +119,7 @@ class PrivateReplyer: if not from_plugin: _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) - continue_flag, modified_message = await event_bus.emit( - EventType.POST_LLM, _event_msg - ) + continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg) if not continue_flag: raise UserWarning("插件于请求前中断了内容生成") if modified_message and modified_message._modify_flags.modify_llm_prompt: @@ -140,10 +138,10 @@ class PrivateReplyer: llm_response.reasoning = reasoning_content llm_response.model = model_name llm_response.tool_calls = tool_call - _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id) - continue_flag, modified_message = await event_bus.emit( - EventType.AFTER_LLM, _event_msg + _event_msg = build_event_message( + EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id ) + continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg) if not from_plugin and not continue_flag: raise UserWarning("插件于请求后取消了内容生成") if modified_message: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index f7566630..8fe1cbef 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -817,6 +817,8 @@ def assign_message_ids(messages: List[SessionMessage]) -> List[Tuple[str, Sessio result.append((message_id, message)) return result + + # break # result.append((message_id, message)) diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 672953b0..83636e85 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -25,4 +25,4 @@ # available_actions: Optional[Dict[str, "ActionInfo"]] = None # loop_start_time: Optional[float] = None # action_reasoning: Optional[str] = None -# TODO: 重构 \ No newline at end of file +# TODO: 重构 diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 567eefc6..cb32702e 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -20,4 +20,4 @@ # timing: Optional[Dict[str, Any]] = None # processed_output: Optional[List[str]] = None # timing_logs: Optional[List[str]] = None -# TODO: 重构 \ No newline at end of file +# TODO: 重构 diff --git a/src/common/data_models/message_component_data_model.py b/src/common/data_models/message_component_data_model.py index d7b5c287..995e54ce 100644 --- a/src/common/data_models/message_component_data_model.py +++ b/src/common/data_models/message_component_data_model.py @@ -13,8 +13,10 @@ from src.common.logger import get_logger logger = get_logger("base_message_component_model") + class UnknownUser(str): ... + class BaseMessageComponentModel(ABC): @property @abstractmethod diff --git a/src/common/message_server/universal_message_sender.py b/src/common/message_server/universal_message_sender.py index 4098f057..b022b9db 100644 --- a/src/common/message_server/universal_message_sender.py +++ b/src/common/message_server/universal_message_sender.py @@ -107,7 +107,7 @@ class UniversalMessageSender: logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'") return True - except Exception as legacy_e: + except Exception: # # Legacy API 抛出异常,尝试 Fallback # return await self._send_with_fallback( # message, message_preview, platform, show_log, legacy_exception=legacy_e diff --git a/src/common/utils/port_checker.py b/src/common/utils/port_checker.py index 8ff3c598..fc5acd64 100644 --- a/src/common/utils/port_checker.py +++ b/src/common/utils/port_checker.py @@ -76,4 +76,4 @@ def assert_port_available( port=port, config_hint=config_hint, ) - raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port)) \ No newline at end of file + raise OSError(build_port_conflict_message(service_name=service_name, host=host, port=port)) diff --git a/src/config/config_base.py b/src/config/config_base.py index 8661adab..533d1e8b 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -133,8 +133,8 @@ class ConfigBase(BaseModel, AttrDocBase): # UI 分组元数据:子类可覆盖以声明所属 Tab 分组 __ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab - __ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc - __ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名) + __ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc + __ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名) @classmethod def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]): diff --git a/src/config/file_watcher.py b/src/config/file_watcher.py index ee7ac076..ed86b7ee 100644 --- a/src/config/file_watcher.py +++ b/src/config/file_watcher.py @@ -182,7 +182,9 @@ class FileWatcher: self._stats.callbacks_skipped_cooldown += 1 continue try: - await asyncio.wait_for(self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s) + await asyncio.wait_for( + self._invoke_callback(subscription.callback, matched_changes), timeout=self._callback_timeout_s + ) state.consecutive_failures = 0 self._stats.callbacks_succeeded += 1 except asyncio.TimeoutError: diff --git a/src/core/component_registry.py b/src/core/component_registry.py index b9c7af73..bb58682a 100644 --- a/src/core/component_registry.py +++ b/src/core/component_registry.py @@ -10,11 +10,10 @@ """ import re -from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Pattern, Tuple from src.common.logger import get_logger from src.core.types import ( - ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, @@ -130,9 +129,7 @@ class ComponentRegistry: logger.debug(f"注册 Command: {name}") return True - def find_command_by_text( - self, text: str - ) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: + def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]: """根据文本查找匹配的命令 Returns: diff --git a/src/core/event_bus.py b/src/core/event_bus.py index c84a86b0..7d93d04b 100644 --- a/src/core/event_bus.py +++ b/src/core/event_bus.py @@ -117,9 +117,7 @@ class EventBus: self._fire_and_forget(entry, event_type, current_message) # 桥接到 IPC 插件运行时 - continue_flag, current_message = await self._bridge_to_ipc_runtime( - event_type, continue_flag, current_message - ) + continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message) return continue_flag, current_message diff --git a/src/maisaka/builtin_tools.py b/src/maisaka/builtin_tools.py index 8642a7e5..0260f778 100644 --- a/src/maisaka/builtin_tools.py +++ b/src/maisaka/builtin_tools.py @@ -7,6 +7,7 @@ MaiSaka - 内置工具定义 from typing import List, Dict, Any from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType + # 内置工具定义 def create_builtin_tools() -> List[ToolOption]: """创建内置工具列表""" @@ -17,57 +18,66 @@ def create_builtin_tools() -> List[ToolOption]: # say 工具 say_builder = ToolOptionBuilder() say_builder.set_name("say") - say_builder.set_description("对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考,用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。") + say_builder.set_description( + "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。直接输出的文本会被视为你的内心思考,用户无法阅读。reason 参数描述你想要回复的方式、想法和内容,系统会根据你的想法和对话上下文生成具体的回复。" + ) say_builder.add_param( name="reason", param_type=ToolParamType.STRING, description="描述你想要回复的方式、想法和内容。例如:'同意对方的看法,并分享自己的经历' 或 '礼貌地拒绝,表示现在不方便聊天'", required=True, - enum_values=None + enum_values=None, ) tools.append(say_builder.build()) # wait 工具 wait_builder = ToolOptionBuilder() wait_builder.set_name("wait") - wait_builder.set_description("暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。") + wait_builder.set_description( + "暂时结束你的发言,把话语权交给用户,等待对方说话。这就像现实对话中你说完一句话后停下来等对方回应。如果用户在等待期间说了话,你会通过工具返回结果收到内容。如果超时没有回复,你也会收到超时通知。" + ) wait_builder.add_param( name="seconds", param_type=ToolParamType.INTEGER, description="等待的秒数。建议 3-10 秒。超过这个时间用户没有回复会显示超时提示。", required=True, - enum_values=None + enum_values=None, ) tools.append(wait_builder.build()) # stop 工具 stop_builder = ToolOptionBuilder() stop_builder.set_name("stop") - stop_builder.set_description("结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。") + stop_builder.set_description( + "结束当前对话循环,进入待机状态,直到用户下次输入新内容时再唤醒你。当对话自然结束、用户表示不想继续聊、或连续多次等待超时用户没有回复时使用。" + ) tools.append(stop_builder.build()) # store_context 工具 store_context_builder = ToolOptionBuilder() store_context_builder.set_name("store_context") - store_context_builder.set_description("将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。") + store_context_builder.set_description( + "将指定范围的对话上下文存入记忆系统,然后从当前对话中移除这些内容。适合在对话上下文过长、话题转换、或遇到重要内容需要保存时使用。" + ) store_context_builder.add_param( name="count", param_type=ToolParamType.INTEGER, description="要保存的消息条数(从最早的对话开始计数)。建议 5-20 条。", required=True, - enum_values=None + enum_values=None, ) store_context_builder.add_param( name="reason", param_type=ToolParamType.STRING, description="保存原因,用于后续检索。例如:'讨论了用户的工作情况' 或 '用户分享了对电影的看法'", required=True, - enum_values=None + enum_values=None, ) tools.append(store_context_builder.build()) return tools + # 为了兼容性,创建一个函数来将工具转换为 dict 格式(用于调试显示) def builtin_tools_as_dicts() -> List[Dict[str, Any]]: """将内置工具转换为 dict 格式(用于调试)""" @@ -77,31 +87,23 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]: "description": "对用户说话。你所有想让用户看到的正式发言都必须通过此工具输出。", "parameters": { "type": "object", - "properties": { - "reason": {"type": "string", "description": "回复的想法和内容"} - }, - "required": ["reason"] - } + "properties": {"reason": {"type": "string", "description": "回复的想法和内容"}}, + "required": ["reason"], + }, }, { "name": "wait", "description": "暂时结束发言,等待用户回应", "parameters": { "type": "object", - "properties": { - "seconds": {"type": "number", "description": "等待秒数"} - }, - "required": ["seconds"] - } + "properties": {"seconds": {"type": "number", "description": "等待秒数"}}, + "required": ["seconds"], + }, }, { "name": "stop", "description": "结束对话循环", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } + "parameters": {"type": "object", "properties": {}, "required": []}, }, { "name": "store_context", @@ -110,17 +112,19 @@ def builtin_tools_as_dicts() -> List[Dict[str, Any]]: "type": "object", "properties": { "count": {"type": "number", "description": "保存的消息条数"}, - "reason": {"type": "string", "description": "保存原因"} + "reason": {"type": "string", "description": "保存原因"}, }, - "required": ["count", "reason"] - } - } + "required": ["count", "reason"], + }, + }, ] + # 导出工具创建函数和列表 def get_builtin_tools() -> List[ToolOption]: """获取内置工具列表""" return create_builtin_tools() + # 为了向后兼容,也导出 dict 格式 BUILTIN_TOOLS_DICTS = builtin_tools_as_dicts() diff --git a/src/maisaka/cli.py b/src/maisaka/cli.py index 1edf1483..e4dc0c0f 100644 --- a/src/maisaka/cli.py +++ b/src/maisaka/cli.py @@ -13,10 +13,17 @@ from rich.markdown import Markdown from rich.text import Text from rich import box -from config import console, ENABLE_EMOTION_MODULE, ENABLE_COGNITION_MODULE, ENABLE_TIMING_MODULE, ENABLE_KNOWLEDGE_MODULE, ENABLE_MCP +from config import ( + console, + ENABLE_EMOTION_MODULE, + ENABLE_COGNITION_MODULE, + ENABLE_TIMING_MODULE, + ENABLE_KNOWLEDGE_MODULE, + ENABLE_MCP, +) from input_reader import InputReader from timing import build_timing_info -from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge, build_knowledge_summary +from knowledge import store_knowledge_from_context, retrieve_relevant_knowledge from knowledge_store import get_knowledge_store from llm_service import MaiSakaLLMService, build_message, remove_last_perception from mcp_client import MCPManager @@ -64,11 +71,7 @@ class BufferCLI: def _init_llm(self): """初始化 LLM 服务 - 使用主项目配置系统""" thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower() - enable_thinking: Optional[bool] = ( - True if thinking_env == "true" - else False if thinking_env == "false" - else None - ) + enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None # MaiSakaLLMService 现在使用主项目的配置系统 # 参数仅为兼容性保留,实际从 config_manager 读取配置 @@ -210,7 +213,7 @@ class BufferCLI: to_compress, store_result_callback=lambda cat_id, cat_name, content: console.print( f"[muted] [OK] 存储了解信息: {cat_name}[/muted]" - ) + ), ) if knowledge_count > 0: console.print(f"[success][OK] 了解模块: 存储{knowledge_count}条特征信息[/success]") @@ -272,10 +275,12 @@ class BufferCLI: self._chat_history = self.llm_service.build_chat_context(user_text) else: # 后续对话:追加用户消息到已有上下文 - self._chat_history.append({ - "role": "user", - "content": user_text, - }) + self._chat_history.append( + { + "role": "user", + "content": user_text, + } + ) await self._run_llm_loop(self._chat_history) @@ -436,16 +441,17 @@ class BufferCLI: if perception_parts: # 添加感知消息(AI 的感知能力结果) - chat_history.append(build_message( - role="assistant", - content="\n\n".join(perception_parts), - msg_type="perception", - )) + chat_history.append( + build_message( + role="assistant", + content="\n\n".join(perception_parts), + msg_type="perception", + ) + ) else: # 上次没有调用工具,跳过模块分析 console.print("[muted]ℹ️ 上次未调用工具,跳过模块分析[/muted]") - # ── 调用 LLM ── with console.status("[info]💬 AI 正在思考...[/info]", spinner="dots"): try: @@ -540,7 +546,8 @@ class BufferCLI: async def _init_mcp(self): """初始化 MCP 服务器连接,发现并注册外部工具。""" config_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "mcp_config.json", + os.path.dirname(os.path.abspath(__file__)), + "mcp_config.json", ) self._mcp_manager = await MCPManager.from_config(config_path) diff --git a/src/maisaka/config.py b/src/maisaka/config.py index 486f1388..c9c95d8a 100644 --- a/src/maisaka/config.py +++ b/src/maisaka/config.py @@ -15,16 +15,20 @@ if str(_root) not in sys.path: # ──────────────────── 从主配置读取 ──────────────────── + def _get_maisaka_config(): """获取 MaiSaka 配置""" try: from src.config.config import config_manager + return config_manager.config.maisaka except Exception: # 如果配置加载失败,返回默认值 from src.config.official_configs import MaiSakaConfig + return MaiSakaConfig() + _maisaka_config = _get_maisaka_config() # ──────────────────── 模块开关配置 ──────────────────── diff --git a/src/maisaka/debug_client.py b/src/maisaka/debug_client.py index f8b5dec3..1488de6b 100644 --- a/src/maisaka/debug_client.py +++ b/src/maisaka/debug_client.py @@ -48,9 +48,7 @@ class DebugViewer: conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect(("127.0.0.1", self._port)) self._conn = conn - console.print( - f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]" - ) + console.print(f"[success]✓ 调试窗口已启动[/success] [muted](port {self._port})[/muted]") return except ConnectionRefusedError: conn.close() diff --git a/src/maisaka/debug_viewer.py b/src/maisaka/debug_viewer.py index 3732c8d3..0a11f84f 100644 --- a/src/maisaka/debug_viewer.py +++ b/src/maisaka/debug_viewer.py @@ -16,10 +16,10 @@ from rich import box console = Console() ROLE_STYLES = { - "system": ("📋", "bold blue"), - "user": ("👤", "bold green"), + "system": ("📋", "bold blue"), + "user": ("👤", "bold green"), "assistant": ("🤖", "bold magenta"), - "tool": ("🔧", "bold yellow"), + "tool": ("🔧", "bold yellow"), } @@ -54,8 +54,10 @@ def format_message(idx: int, msg: dict) -> str: # 正文 if content: - display = content if len(content) <= 3000 else ( - content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]" + display = ( + content + if len(content) <= 3000 + else (content[:3000] + f"\n[dim]... (截断, 共 {len(content)} 字符)[/dim]") ) parts.append(display) @@ -88,8 +90,7 @@ def main(): console.print( Panel( - f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n" - f"[dim]监听端口: {port} 等待主进程连接...[/dim]", + f"[bold cyan]MaiSaka Debug Viewer[/bold cyan]\n[dim]监听端口: {port} 等待主进程连接...[/dim]", box=box.DOUBLE_EDGE, border_style="cyan", ) @@ -131,8 +132,7 @@ def main(): # ── 标题栏 ── console.print(f"\n{'═' * 90}") console.print( - f"[bold yellow]#{call_count} {label}[/bold yellow] " - f"[dim]({len(messages)} messages)[/dim]" + f"[bold yellow]#{call_count} {label}[/bold yellow] [dim]({len(messages)} messages)[/dim]" ) console.print(f"{'═' * 90}") @@ -144,12 +144,8 @@ def main(): # ── tools 信息 ── if tools: - tool_names = [ - t.get("function", {}).get("name", "?") for t in tools - ] - console.print( - f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]" - ) + tool_names = [t.get("function", {}).get("name", "?") for t in tools] + console.print(f"\n[dim]可用工具: {', '.join(tool_names)}[/dim]") except Exception as e: console.print(f"\n[red]数据处理错误: {e}[/red]") console.print(f"[dim]Payload: {payload}[/dim]") @@ -161,8 +157,12 @@ def main(): console.print("\n[bold cyan]📤 LLM 响应:[/bold cyan]") resp_content = response.get("content", "") if resp_content: - display = resp_content if len(str(resp_content)) <= 3000 else ( - str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]" + display = ( + resp_content + if len(str(resp_content)) <= 3000 + else ( + str(resp_content)[:3000] + f"\n[dim]... (截断, 共 {len(str(resp_content))} 字符)[/dim]" + ) ) console.print(Panel(display, border_style="cyan", padding=(0, 1))) resp_tool_calls = response.get("tool_calls", []) diff --git a/src/maisaka/knowledge.py b/src/maisaka/knowledge.py index 28aa1993..0352fd14 100644 --- a/src/maisaka/knowledge.py +++ b/src/maisaka/knowledge.py @@ -3,7 +3,7 @@ MaiSaka - 了解模块 负责从对话中提取和存储用户个人特征信息。 """ -from typing import List, Optional +from typing import List from knowledge_store import get_knowledge_store, KNOWLEDGE_CATEGORIES @@ -100,9 +100,7 @@ async def store_knowledge_from_context( try: # 第一步:分析涉及哪些分类 - category_ids = await llm_service.analyze_knowledge_categories( - context_messages, categories_summary - ) + category_ids = await llm_service.analyze_knowledge_categories(context_messages, categories_summary) if not category_ids: return 0 @@ -119,25 +117,19 @@ async def store_knowledge_from_context( if extracted_content: # 存储到了解列表 success = store.add_knowledge( - category_id=category_id, - content=extracted_content, - metadata={"source": "context_compression"} + category_id=category_id, content=extracted_content, metadata={"source": "context_compression"} ) if success: stored_count += 1 if store_result_callback: - store_result_callback( - category_id, - store.get_category_name(category_id), - extracted_content - ) - except Exception as e: + store_result_callback(category_id, store.get_category_name(category_id), extracted_content) + except Exception: # 单个分类失败不影响其他分类 continue return stored_count - except Exception as e: + except Exception: return 0 @@ -165,9 +157,7 @@ async def retrieve_relevant_knowledge( try: # 分析需要哪些分类 - category_ids = await llm_service.analyze_knowledge_need( - chat_history, categories_summary - ) + category_ids = await llm_service.analyze_knowledge_need(chat_history, categories_summary) if not category_ids: return "" diff --git a/src/maisaka/knowledge_store.py b/src/maisaka/knowledge_store.py index 57d7b9bb..f91573d2 100644 --- a/src/maisaka/knowledge_store.py +++ b/src/maisaka/knowledge_store.py @@ -45,9 +45,7 @@ class KnowledgeStore: def __init__(self): """初始化了解存储""" - self._knowledge: Dict[str, List[Dict[str, Any]]] = { - category_id: [] for category_id in KNOWLEDGE_CATEGORIES - } + self._knowledge: Dict[str, List[Dict[str, Any]]] = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} self._ensure_data_dir() self._load() @@ -58,9 +56,7 @@ class KnowledgeStore: def _load(self): """从文件加载了解数据""" if not KNOWLEDGE_FILE.exists(): - self._knowledge = { - category_id: [] for category_id in KNOWLEDGE_CATEGORIES - } + self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} return try: @@ -73,9 +69,7 @@ class KnowledgeStore: self._knowledge = loaded except Exception as e: print(f"[warning]加载了解数据失败: {e}[/warning]") - self._knowledge = { - category_id: [] for category_id in KNOWLEDGE_CATEGORIES - } + self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} def _save(self): """保存了解数据到文件""" diff --git a/src/maisaka/llm_service.py b/src/maisaka/llm_service.py index c2c183bb..60a28f07 100644 --- a/src/maisaka/llm_service.py +++ b/src/maisaka/llm_service.py @@ -4,7 +4,6 @@ MaiSaka LLM 服务 - 使用主项目 LLM 系统 """ import json -import os from dataclasses import dataclass from typing import List, Optional, Literal @@ -34,6 +33,7 @@ MSG_TYPE_FIELD = "_type" @dataclass class ToolCall: """工具调用信息""" + id: str name: str arguments: dict @@ -42,6 +42,7 @@ class ToolCall: @dataclass class ChatResponse: """LLM 对话循环单步响应""" + content: Optional[str] tool_calls: List[ToolCall] raw_message: dict # 可直接追加到对话历史的消息字典 @@ -49,6 +50,7 @@ class ChatResponse: # ──────────────────── 工具函数 ──────────────────── + def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict: """构建消息字典,包含消息类型标记。""" msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs} @@ -93,23 +95,18 @@ class MaiSakaLLMService: except Exception: # 如果配置加载失败,使用默认配置 from src.config.model_configs import ModelTaskConfig + self._model_configs = ModelTaskConfig() logger.warning("无法加载主项目模型配置,使用默认配置") # 初始化 LLMRequest 实例(只使用 tool_use 和 replyer) - self._llm_tool_use = LLMRequest( - model_set=self._model_configs.tool_use, - request_type="maisaka_tool_use" - ) + self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use") # 主对话也使用 tool_use 模型(因为需要工具调用支持) self._llm_chat = self._llm_tool_use # 分析模块也使用 tool_use 模型 self._llm_utils = self._llm_tool_use # 回复生成使用 replyer 模型 - self._llm_replyer = LLMRequest( - model_set=self._model_configs.replyer, - request_type="maisaka_replyer" - ) + self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer") # 尝试修复数据库 schema(忽略错误) self._try_fix_database_schema() @@ -133,6 +130,7 @@ class MaiSakaLLMService: chat_prompt.add_context("file_tools_section", tools_section if tools_section else "") import asyncio + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: @@ -147,7 +145,9 @@ class MaiSakaLLMService: self._chat_system_prompt = chat_system_prompt # 获取模型名称用于显示 - self._model_name = self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置" + self._model_name = ( + self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置" + ) # 加载子模块提示词 self._emotion_prompt: Optional[str] = None @@ -157,21 +157,22 @@ class MaiSakaLLMService: try: import asyncio + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - self._emotion_prompt = loop.run_until_complete(prompt_manager.render_prompt( - prompt_manager.get_prompt("maidairy_emotion") - )) - self._cognition_prompt = loop.run_until_complete(prompt_manager.render_prompt( - prompt_manager.get_prompt("maidairy_cognition") - )) - self._timing_prompt = loop.run_until_complete(prompt_manager.render_prompt( - prompt_manager.get_prompt("maidairy_timing") - )) - self._context_summarize_prompt = loop.run_until_complete(prompt_manager.render_prompt( - prompt_manager.get_prompt("maidairy_context_summarize") - )) + self._emotion_prompt = loop.run_until_complete( + prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_emotion")) + ) + self._cognition_prompt = loop.run_until_complete( + prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_cognition")) + ) + self._timing_prompt = loop.run_until_complete( + prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_timing")) + ) + self._context_summarize_prompt = loop.run_until_complete( + prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_context_summarize")) + ) logger.info("成功加载 MaiSaka 子模块提示词") finally: loop.close() @@ -191,9 +192,7 @@ class MaiSakaLLMService: if "model_api_provider_name" not in columns: # 添加缺失的列 - session.execute(text( - "ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)" - )) + session.execute(text("ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)")) session.commit() logger.info("数据库 schema 已修复:添加 model_api_provider_name 列") except Exception: @@ -205,7 +204,7 @@ class MaiSakaLLMService: self._extra_tools = list(tools) @staticmethod - def _tool_option_to_dict(tool: 'ToolOption') -> dict: + def _tool_option_to_dict(tool: "ToolOption") -> dict: """将 ToolOption 对象转换为主项目期望的 dict 格式 主项目的 _build_tool_options() 期望的格式: @@ -218,18 +217,8 @@ class MaiSakaLLMService: params = [] if tool.params: for param in tool.params: - params.append(( - param.name, - param.param_type, - param.description, - param.required, - param.enum_values - )) - return { - "name": tool.name, - "description": tool.description, - "parameters": params - } + params.append((param.name, param.param_type, param.description, param.required, param.enum_values)) + return {"name": tool.name, "description": tool.description, "parameters": params} async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse: """执行对话循环的一步 - 使用 tool_use 模型""" @@ -271,11 +260,13 @@ class MaiSakaLLMService: for tc in msg["tool_calls"]: tc_func = tc.get("function", {}) # 主项目的 ToolCall: call_id, func_name, args - tool_calls_list.append(ToolCallOption( - call_id=tc.get("id", ""), - func_name=tc_func.get("name", ""), - args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {} - )) + tool_calls_list.append( + ToolCallOption( + call_id=tc.get("id", ""), + func_name=tc_func.get("name", ""), + args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {}, + ) + ) builder.set_tool_calls(tool_calls_list) elif role == "tool" and "tool_call_id" in msg: builder.add_tool_call(msg["tool_call_id"]) @@ -290,15 +281,17 @@ class MaiSakaLLMService: # 调用 LLM(使用带消息的接口) # 合并内置工具和额外工具(将 ToolOption 对象转换为 dict) - all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (self._extra_tools if self._extra_tools else []) + all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + ( + self._extra_tools if self._extra_tools else [] + ) # 打印消息列表 built_messages = message_factory(None) - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - chat_loop_step:") for msg in built_messages: print(f" {msg}") - print("="*60 + "\n") + print("=" * 60 + "\n") response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async( message_factory=message_factory, @@ -312,15 +305,17 @@ class MaiSakaLLMService: if tool_calls: for tc in tool_calls: # 主项目的 ToolCall 有 call_id, func_name, args - call_id = tc.call_id if hasattr(tc, 'call_id') else "" - func_name = tc.func_name if hasattr(tc, 'func_name') else "" - args = tc.args if hasattr(tc, 'args') else {} + call_id = tc.call_id if hasattr(tc, "call_id") else "" + func_name = tc.func_name if hasattr(tc, "func_name") else "" + args = tc.args if hasattr(tc, "args") else {} - converted_tool_calls.append(ToolCall( - id=call_id, - name=func_name, - arguments=args, - )) + converted_tool_calls.append( + ToolCall( + id=call_id, + name=func_name, + arguments=args, + ) + ) # 构建原始消息格式(MaiSaka 风格) raw_message = {"role": "assistant", "content": response} @@ -394,10 +389,10 @@ class MaiSakaLLMService: prompt = "\n".join(prompt_parts) - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - analyze_emotion:") print(f" {prompt}") - print("="*60 + "\n") + print("=" * 60 + "\n") try: response, _ = await self._llm_utils.generate_response_async( @@ -428,10 +423,10 @@ class MaiSakaLLMService: prompt = "\n".join(prompt_parts) - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - analyze_cognition:") print(f" {prompt}") - print("="*60 + "\n") + print("=" * 60 + "\n") try: response, _ = await self._llm_utils.generate_response_async( @@ -463,10 +458,10 @@ class MaiSakaLLMService: prompt = "\n".join(prompt_parts) - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - analyze_timing:") print(f" {prompt}") - print("="*60 + "\n") + print("=" * 60 + "\n") try: response, _ = await self._llm_utils.generate_response_async( @@ -498,10 +493,10 @@ class MaiSakaLLMService: prompt = "\n".join(prompt_parts) - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - summarize_context:") print(f" {prompt}") - print("="*60 + "\n") + print("=" * 60 + "\n") try: response, _ = await self._llm_utils.generate_response_async( @@ -529,8 +524,7 @@ class MaiSakaLLMService: # 格式化对话历史 filtered_history = [ - msg for msg in chat_history - if msg.get("role") != "system" and msg.get("_type") != "perception" + msg for msg in chat_history if msg.get("role") != "system" and msg.get("_type") != "perception" ] formatted_history = format_chat_history(filtered_history) @@ -542,18 +536,15 @@ class MaiSakaLLMService: system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。" user_prompt = ( - f"当前时间:{current_time}\n\n" - f"【聊天记录】\n{formatted_history}\n\n" - f"【你的想法】\n{reason}\n\n" - f"现在,你说:" + f"当前时间:{current_time}\n\n【聊天记录】\n{formatted_history}\n\n【你的想法】\n{reason}\n\n现在,你说:" ) messages = f"System: {system_prompt}\n\nUser: {user_prompt}" - print("\n" + "="*60) + print("\n" + "=" * 60) print("MaiSaka LLM Request - generate_reply:") print(f" {messages}") - print("="*60 + "\n") + print("=" * 60 + "\n") try: response, _ = await self._llm_replyer.generate_response_async( diff --git a/src/maisaka/mcp_client/config.py b/src/maisaka/mcp_client/config.py index 0ed45de5..5803d557 100644 --- a/src/maisaka/mcp_client/config.py +++ b/src/maisaka/mcp_client/config.py @@ -95,9 +95,7 @@ def load_mcp_config(config_path: str = "mcp_config.json") -> list[MCPServerConfi ) if server.transport_type == "unknown": - console.print( - f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]" - ) + console.print(f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]") continue configs.append(server) diff --git a/src/maisaka/mcp_client/connection.py b/src/maisaka/mcp_client/connection.py index d3e3999e..d7d92df7 100644 --- a/src/maisaka/mcp_client/connection.py +++ b/src/maisaka/mcp_client/connection.py @@ -63,9 +63,7 @@ class MCPConnection: True 表示连接成功,False 表示失败。 """ if not MCP_AVAILABLE: - console.print( - "[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]" - ) + console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]") return False try: @@ -76,15 +74,11 @@ class MCPConnection: elif self.config.transport_type == "sse": read_stream, write_stream = await self._connect_sse() else: - console.print( - f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]" - ) + console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]") return False # 创建并初始化 MCP 会话 - self.session = await self._exit_stack.enter_async_context( - ClientSession(read_stream, write_stream) - ) + self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream)) await self.session.initialize() # 发现工具 @@ -94,9 +88,7 @@ class MCPConnection: return True except Exception as e: - console.print( - f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]" - ) + console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]") await self.close() return False @@ -107,19 +99,13 @@ class MCPConnection: args=self.config.args, env=self.config.env, ) - return await self._exit_stack.enter_async_context( - stdio_client(params) - ) + return await self._exit_stack.enter_async_context(stdio_client(params)) async def _connect_sse(self): """建立 SSE 传输连接。""" if not SSE_AVAILABLE: - raise ImportError( - "SSE 传输需要额外依赖,请运行: pip install mcp[sse]" - ) - return await self._exit_stack.enter_async_context( - sse_client(url=self.config.url, headers=self.config.headers) - ) + raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]") + return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers)) async def call_tool(self, tool_name: str, arguments: dict) -> str: """ diff --git a/src/maisaka/mcp_client/manager.py b/src/maisaka/mcp_client/manager.py index 42f2257f..ba46c707 100644 --- a/src/maisaka/mcp_client/manager.py +++ b/src/maisaka/mcp_client/manager.py @@ -10,10 +10,16 @@ from .config import MCPServerConfig, load_mcp_config from .connection import MCPConnection, MCP_AVAILABLE # 内置工具名称集合 —— MCP 工具不允许与这些名称冲突 -BUILTIN_TOOL_NAMES = frozenset({ - "say", "wait", "stop", - "create_table", "list_tables", "view_table", -}) +BUILTIN_TOOL_NAMES = frozenset( + { + "say", + "wait", + "stop", + "create_table", + "list_tables", + "view_table", + } +) class MCPManager: @@ -35,7 +41,8 @@ class MCPManager: @classmethod async def from_config( - cls, config_path: str = "mcp_config.json", + cls, + config_path: str = "mcp_config.json", ) -> Optional["MCPManager"]: """ 从配置文件创建并初始化 MCPManager。 @@ -51,10 +58,7 @@ class MCPManager: return None if not MCP_AVAILABLE: - console.print( - "[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK," - "请运行: pip install mcp[/warning]" - ) + console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,请运行: pip install mcp[/warning]") return None manager = cls() @@ -85,8 +89,7 @@ class MCPManager: if tool_name in BUILTIN_TOOL_NAMES: console.print( - f"[warning]⚠️ MCP 工具 '{tool_name}' " - f"(来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]" + f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]" ) continue @@ -102,8 +105,7 @@ class MCPManager: registered += 1 console.print( - f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] " - f"[muted]({registered} 个工具已注册)[/muted]" + f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]" ) # ──────── 工具发现 ──────── @@ -134,17 +136,16 @@ class MCPManager: # 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受) parameters.pop("$schema", None) - tools.append({ - "type": "function", - "function": { - "name": tool.name, - "description": ( - tool.description - or f"MCP tool from {server_name}" - ), - "parameters": parameters, - }, - }) + tools.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": (tool.description or f"MCP tool from {server_name}"), + "parameters": parameters, + }, + } + ) return tools @@ -184,9 +185,9 @@ class MCPManager: parts: list[str] = [] for server_name, conn in self._connections.items(): tool_names = [ - t.name for t in conn.tools - if t.name in self._tool_to_server - and self._tool_to_server[t.name] == server_name + t.name + for t in conn.tools + if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name ] if tool_names: parts.append(f" • {server_name}: {', '.join(tool_names)}") diff --git a/src/maisaka/timing.py b/src/maisaka/timing.py index 63d0a167..522511de 100644 --- a/src/maisaka/timing.py +++ b/src/maisaka/timing.py @@ -49,8 +49,7 @@ def build_timing_info( if len(user_input_times) >= 2: intervals = [ - (user_input_times[i] - user_input_times[i - 1]).total_seconds() - for i in range(1, len(user_input_times)) + (user_input_times[i] - user_input_times[i - 1]).total_seconds() for i in range(1, len(user_input_times)) ] avg_interval = sum(intervals) / len(intervals) parts.append(f"用户平均回复间隔: {int(avg_interval)}秒") diff --git a/src/maisaka/tool_handlers.py b/src/maisaka/tool_handlers.py index ff4220d7..210bfc97 100644 --- a/src/maisaka/tool_handlers.py +++ b/src/maisaka/tool_handlers.py @@ -4,7 +4,6 @@ MaiSaka - 工具调用处理器 """ import json as _json -import asyncio import os from datetime import datetime from typing import TYPE_CHECKING, Optional @@ -92,27 +91,33 @@ async def handle_say(tc, chat_history: list, ctx: ToolHandlerContext): ) ) # 生成的回复作为 tool 结果写入上下文 - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"已向用户展示(实际输出):{reply}", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": f"已向用户展示(实际输出):{reply}", + } + ) else: - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": "reason 内容为空,未展示", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": "reason 内容为空,未展示", + } + ) async def handle_stop(tc, chat_history: list): """处理 stop 工具:结束对话循环。""" console.print("[accent]🔧 调用工具: stop()[/accent]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": "对话循环已停止,等待用户下次输入。", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": "对话循环已停止,等待用户下次输入。", + } + ) async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str: @@ -128,11 +133,13 @@ async def handle_wait(tc, chat_history: list, ctx: ToolHandlerContext) -> str: tool_result = await _do_wait(seconds, ctx) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": tool_result, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": tool_result, + } + ) return tool_result @@ -193,28 +200,32 @@ async def handle_mcp_tool(tc, chat_history: list, mcp_manager: "MCPManager"): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result, + } + ) async def handle_unknown_tool(tc, chat_history: list): """处理未知工具调用。""" console.print(f"[accent]🔧 调用工具: {tc.name}({tc.arguments})[/accent]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"未知工具: {tc.name}", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": f"未知工具: {tc.name}", + } + ) async def handle_write_file(tc, chat_history: list): """处理 write_file 工具:在 mai_files 目录下写入文件。""" filename = tc.arguments.get("filename", "") content = tc.arguments.get("content", "") - console.print(f"[accent]🔧 调用工具: write_file(\"{filename}\")[/accent]") + console.print(f'[accent]🔧 调用工具: write_file("{filename}")[/accent]') # 确保目录存在 MAI_FILES_DIR.mkdir(parents=True, exist_ok=True) @@ -242,25 +253,29 @@ async def handle_write_file(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": f"文件「{filename}」已成功写入,共 {file_size} 个字符。", + } + ) except Exception as e: error_msg = f"写入文件失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) async def handle_read_file(tc, chat_history: list): """处理 read_file 工具:读取 mai_files 目录下的文件。""" filename = tc.arguments.get("filename", "") - console.print(f"[accent]🔧 调用工具: read_file(\"{filename}\")[/accent]") + console.print(f'[accent]🔧 调用工具: read_file("{filename}")[/accent]') # 构建完整文件路径 file_path = MAI_FILES_DIR / filename @@ -269,21 +284,25 @@ async def handle_read_file(tc, chat_history: list): if not file_path.exists(): error_msg = f"文件「{filename}」不存在。" console.print(f"[warning]{error_msg}[/warning]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return if not file_path.is_file(): error_msg = f"「{filename}」不是一个文件。" console.print(f"[warning]{error_msg}[/warning]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return # 读取文件内容 @@ -304,19 +323,23 @@ async def handle_read_file(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"文件「{filename}」内容:\n{file_content}", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": f"文件「{filename}」内容:\n{file_content}", + } + ) except Exception as e: error_msg = f"读取文件失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) async def handle_list_files(tc, chat_history: list): @@ -334,11 +357,13 @@ async def handle_list_files(tc, chat_history: list): # 获取相对路径 rel_path = item.relative_to(MAI_FILES_DIR) stat = item.stat() - files_info.append({ - "name": str(rel_path), - "size": stat.st_size, - "modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"), - }) + files_info.append( + { + "name": str(rel_path), + "size": stat.st_size, + "modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"), + } + ) if not files_info: result_text = "mai_files 目录为空,没有任何文件。" @@ -360,19 +385,23 @@ async def handle_list_files(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result_text, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result_text, + } + ) except Exception as e: error_msg = f"获取文件列表失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): @@ -385,16 +414,18 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): """ count = tc.arguments.get("count", 0) reason = tc.arguments.get("reason", "") - console.print(f"[accent]🔧 调用工具: store_context(count={count}, reason=\"{reason}\")[/accent]") + console.print(f'[accent]🔧 调用工具: store_context(count={count}, reason="{reason}")[/accent]') if count <= 0: error_msg = "count 参数必须大于 0" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return # 计算实际消息数量(排除 role=tool 的工具返回消息) @@ -423,9 +454,7 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): if role == "assistant" and "tool_calls" in msg: # 检查这个消息是否包含当前的 tool_call(store_context 自己) # 如果包含,跳过不删除(否则会导致 tool 响应孤儿) - contains_current_call = any( - tc.get("id") == tc.id for tc in msg.get("tool_calls", []) - ) + contains_current_call = any(tc.get("id") == tc.id for tc in msg.get("tool_calls", [])) if contains_current_call: i += 1 continue @@ -453,11 +482,13 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): if not indices_to_remove: result_msg = "没有找到可存入记忆的消息" - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result_msg, + } + ) return # 收集要总结的消息(在删除前) @@ -516,38 +547,45 @@ async def handle_store_context(tc, chat_history: list, ctx: ToolHandlerContext): chat_history.pop(i) i -= 1 - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result_msg, + } + ) async def handle_get_qq_chat_info(tc, chat_history: list): """处理 get_qq_chat_info 工具:通过 HTTP 获取 QQ 聊天内容。""" chat = tc.arguments.get("chat", "") limit = tc.arguments.get("limit", 20) - console.print(f"[accent]🔧 调用工具: get_qq_chat_info(\"{chat}\", limit={limit})[/accent]") + console.print(f'[accent]🔧 调用工具: get_qq_chat_info("{chat}", limit={limit})[/accent]') if not AIOHTTP_AVAILABLE: error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return from config import QQ_API_BASE_URL, QQ_API_KEY + if not QQ_API_BASE_URL: error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return try: @@ -577,55 +615,66 @@ async def handle_get_qq_chat_info(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": text if text.strip() else "暂无聊天记录", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": text if text.strip() else "暂无聊天记录", + } + ) else: error_text = await response.text() error_msg = f"HTTP 请求失败 (状态码 {response.status}): {error_text}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) except Exception as e: error_msg = f"获取 QQ 聊天记录失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) async def handle_send_info(tc, chat_history: list): """处理 send_info 工具:通过 HTTP 发送消息到 QQ。""" chat = tc.arguments.get("chat", "") message = tc.arguments.get("message", "") - console.print(f"[accent]🔧 调用工具: send_info(\"{chat}\")[/accent]") + console.print(f'[accent]🔧 调用工具: send_info("{chat}")[/accent]') if not AIOHTTP_AVAILABLE: error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return from config import QQ_API_BASE_URL, QQ_API_KEY + if not QQ_API_BASE_URL: error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return try: @@ -654,27 +703,33 @@ async def handle_send_info(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"消息发送成功: {data.get('message', '发送成功')}", - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": f"消息发送成功: {data.get('message', '发送成功')}", + } + ) else: error_msg = f"发送失败: {data.get('message', '未知错误')}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) except Exception as e: error_msg = f"发送消息失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) async def handle_list_qq_chats(tc, chat_history: list): @@ -684,22 +739,27 @@ async def handle_list_qq_chats(tc, chat_history: list): if not AIOHTTP_AVAILABLE: error_msg = "aiohttp 模块未安装,请运行: pip install aiohttp" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return from config import QQ_API_BASE_URL, QQ_API_KEY + if not QQ_API_BASE_URL: error_msg = "QQ_API_BASE_URL 未配置,请在 .env 中设置" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) return try: @@ -721,10 +781,12 @@ async def handle_list_qq_chats(tc, chat_history: list): # 格式化聊天列表 if chats: - chat_list_text = "\n".join([ - f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})" - for c in chats - ]) + chat_list_text = "\n".join( + [ + f" • [{c.get('platform', 'qq')}] {c.get('name', '未知')} (chat: {c.get('chat', 'N/A')})" + for c in chats + ] + ) result_text = f"可用的聊天 (共 {len(chats)} 个):\n{chat_list_text}" else: result_text = "没有可用的聊天" @@ -738,27 +800,33 @@ async def handle_list_qq_chats(tc, chat_history: list): ) ) - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": result_text, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result_text, + } + ) else: error_msg = f"获取失败: {data.get('message', '未知错误')}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) except Exception as e: error_msg = f"获取聊天列表失败: {e}" console.print(f"[error]{error_msg}[/error]") - chat_history.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": error_msg, - }) + chat_history.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": error_msg, + } + ) # ──────────────────── 初始化 mai_files 目录 ──────────────────── diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 46c56f0a..4193a16a 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -847,11 +847,7 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) if not records: return [] - return [ - f"问题:{record.question}\n答案:{record.answer}" - for record in records - if record.answer - ] + return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer] except Exception as e: logger.error(f"获取最近已找到答案的记录失败: {e}") diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index a4a6768b..940a4f2b 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -13,4 +13,3 @@ ENV_SESSION_TOKEN = "MAIBOT_SESSION_TOKEN" ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS" """Runner 需要加载的插件目录列表(os.pathsep 分隔)""" - diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 3f3a000c..6685ff60 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -67,7 +67,9 @@ class CapabilityService: # 1. 权限校验 allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation) if not allowed: - error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED + error_code = ( + ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED + ) return envelope.make_error_response( error_code.value, reason, diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 4abbd08e..220a19c0 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -22,8 +22,13 @@ class RegisteredComponent: """已注册的组件条目""" __slots__ = ( - "name", "full_name", "component_type", "plugin_id", - "metadata", "enabled", "_compiled_pattern", + "name", + "full_name", + "component_type", + "plugin_id", + "metadata", + "enabled", + "_compiled_pattern", ) def __init__( @@ -165,18 +170,14 @@ class ComponentRegistry: """按全名查询。""" return self._components.get(full_name) - def get_components_by_type( - self, component_type: str, *, enabled_only: bool = True - ) -> List[RegisteredComponent]: + def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: """按类型查询。""" type_dict = self._by_type.get(component_type, {}) if enabled_only: return [c for c in type_dict.values() if c.enabled] return list(type_dict.values()) - def get_components_by_plugin( - self, plugin_id: str, *, enabled_only: bool = True - ) -> List[RegisteredComponent]: + def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: """按插件查询。""" comps = self._by_plugin.get(plugin_id, []) return [c for c in comps if c.enabled] if enabled_only else list(comps) @@ -200,9 +201,7 @@ class ComponentRegistry: return comp, {} return None - def get_event_handlers( - self, event_type: str, *, enabled_only: bool = True - ) -> List[RegisteredComponent]: + def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: """获取特定事件类型的所有 event_handler,按 weight 降序排列。""" handlers = [] for comp in self._by_type.get("event_handler", {}).values(): @@ -213,9 +212,7 @@ class ComponentRegistry: handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True) return handlers - def get_workflow_steps( - self, stage: str, *, enabled_only: bool = True - ) -> List[RegisteredComponent]: + def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]: """获取特定 workflow 阶段的所有步骤,按 priority 降序。""" steps = [] for comp in self._by_type.get("workflow_step", {}).values(): diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index a9b5793c..48dfcedd 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -22,6 +22,7 @@ InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]] class EventResult: """单个 EventHandler 的执行结果""" + __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result") def __init__( @@ -107,9 +108,7 @@ class EventDispatcher: modified_message = result.modified_message else: # 非阻塞:保持实例级强引用,防止 task 被 GC 回收 - task = asyncio.create_task( - self._invoke_handler(invoke_fn, handler, args, event_type) - ) + task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py index 9ce7ccde..c0425cbb 100644 --- a/src/plugin_runtime/host/policy_engine.py +++ b/src/plugin_runtime/host/policy_engine.py @@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Set, Tuple @dataclass class CapabilityToken: """能力令牌""" + plugin_id: str generation: int capabilities: Set[str] = field(default_factory=set) diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 687f35f6..6b2933d5 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -231,9 +231,7 @@ class RPCServer: stale_count = 0 for _req_id, future in list(self._pending_requests.items()): if not future.done(): - future.set_exception( - RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管") - ) + future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")) stale_count += 1 self._pending_requests.clear() if stale_count: @@ -399,9 +397,7 @@ class RPCServer: result = await handler(envelope) # 检查 handler 返回的信封是否包含错误信息 if result is not None and isinstance(result, Envelope) and result.error: - logger.warning( - f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}" - ) + logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}") except Exception as e: logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index e65a9f66..164d2641 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -39,6 +39,7 @@ logger = get_logger("plugin_runtime.host.supervisor") # ─── 日志桥 ────────────────────────────────────────────────────── + class RunnerLogBridge: """将 Runner 进程上报的批量日志重放到主进程的 Logger 中。 @@ -80,9 +81,7 @@ class RunnerLogBridge: stdlib_logging.getLogger(entry.logger_name).handle(record) - return envelope.make_response( - payload={"accepted": True, "count": len(batch.entries)} - ) + return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)}) class PluginSupervisor: @@ -101,8 +100,12 @@ class PluginSupervisor: ): _cfg = global_config.plugin_runtime self._plugin_dirs = plugin_dirs or [] - self._health_interval = health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec - self._runner_spawn_timeout = runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec + self._health_interval = ( + health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec + ) + self._runner_spawn_timeout = ( + runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec + ) # 基础设施 self._transport = create_transport_server(socket_path=socket_path) @@ -114,6 +117,7 @@ class PluginSupervisor: # 编解码 from src.plugin_runtime.protocol.codec import MsgPackCodec + codec = MsgPackCodec() self._rpc_server = RPCServer( @@ -124,7 +128,9 @@ class PluginSupervisor: # Runner 子进程 self._runner_process: Optional[asyncio.subprocess.Process] = None self._runner_generation: int = 0 - self._max_restart_attempts: int = max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts + self._max_restart_attempts: int = ( + max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts + ) self._restart_count: int = 0 # 已注册的插件组件信息 @@ -173,6 +179,7 @@ class PluginSupervisor: extra_args: Optional[Dict[str, Any]] = None, ) -> Tuple[bool, Optional[Dict[str, Any]]]: """分发事件到所有对应 handler 的快捷方法。""" + async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: resp = await self.invoke_plugin( method="plugin.emit_event", @@ -196,6 +203,7 @@ class PluginSupervisor: context: Optional[WorkflowContext] = None, ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]: """执行 Workflow Pipeline 的快捷方法。""" + async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]: resp = await self.invoke_plugin( method="plugin.invoke_workflow_step", @@ -415,7 +423,9 @@ class PluginSupervisor: env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs) self._runner_process = await asyncio.create_subprocess_exec( - sys.executable, "-m", runner_module, + sys.executable, + "-m", + runner_module, env=env, # stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler) stdout=None, @@ -557,9 +567,7 @@ class PluginSupervisor: ) self._stderr_drain_task = task task.add_done_callback( - lambda done_task: None - if self._stderr_drain_task is not done_task - else self._clear_stderr_drain_task() + lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task() ) def _clear_stderr_drain_task(self) -> None: diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py index 60fe7d07..732a3888 100644 --- a/src/plugin_runtime/host/workflow_executor.py +++ b/src/plugin_runtime/host/workflow_executor.py @@ -44,6 +44,7 @@ HOOK_CONTINUE = "continue" HOOK_SKIP_STAGE = "skip_stage" HOOK_ABORT = "abort" + # blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待 # 从配置文件读取,允许用户调整 def _get_blocking_timeout() -> float: @@ -52,6 +53,7 @@ def _get_blocking_timeout() -> float: class ModificationRecord: """消息修改记录""" + __slots__ = ("stage", "hook_name", "timestamp", "fields_changed") def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None: @@ -141,9 +143,7 @@ class WorkflowExecutor: try: # PLAN 阶段: 先做 Command 路由 if stage == "plan" and current_message: - cmd_result = await self._route_command( - command_invoke_fn or invoke_fn, current_message, ctx - ) + cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx) if cmd_result is not None: # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs ctx.set_stage_output("plan", "command_result", cmd_result) @@ -195,10 +195,10 @@ class WorkflowExecutor: # 更新消息(仅 blocking hook 有权修改) if modified: - changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys()) - ctx.modification_log.append( - ModificationRecord(stage, step.full_name, changed_fields) + changed_fields = ( + _diff_keys(current_message, modified) if current_message else list(modified.keys()) ) + ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields)) current_message = modified if hook_result == HOOK_ABORT: @@ -222,9 +222,7 @@ class WorkflowExecutor: if nonblocking_steps and not skip_stage: nb_tasks = [ asyncio.create_task( - self._invoke_step_fire_and_forget( - invoke_fn, step, stage, ctx, current_message - ) + self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) ) for step in nonblocking_steps ] @@ -314,12 +312,16 @@ class WorkflowExecutor: step_start = time.perf_counter() try: - coro = invoke_fn(step.plugin_id, step.name, { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }) + coro = invoke_fn( + step.plugin_id, + step.name, + { + "stage": stage, + "trace_id": ctx.trace_id, + "message": message, + "stage_outputs": ctx.stage_outputs, + }, + ) resp = await asyncio.wait_for(coro, timeout=timeout_sec) ctx.timings[step_key] = time.perf_counter() - step_start @@ -355,12 +357,16 @@ class WorkflowExecutor: timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout() try: - coro = invoke_fn(step.plugin_id, step.name, { - "stage": stage, - "trace_id": ctx.trace_id, - "message": message, - "stage_outputs": ctx.stage_outputs, - }) + coro = invoke_fn( + step.plugin_id, + step.name, + { + "stage": stage, + "trace_id": ctx.trace_id, + "message": message, + "stage_outputs": ctx.stage_outputs, + }, + ) await asyncio.wait_for(coro, timeout=timeout_sec) except asyncio.TimeoutError: logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)") @@ -393,12 +399,16 @@ class WorkflowExecutor: logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") try: - return await invoke_fn(matched.plugin_id, matched.name, { - "text": plain_text, - "message": message, - "trace_id": ctx.trace_id, - "matched_groups": matched_groups, - }) + return await invoke_fn( + matched.plugin_id, + matched.name, + { + "text": plain_text, + "message": message, + "trace_id": ctx.trace_id, + "matched_groups": matched_groups, + }, + ) except Exception as e: logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) ctx.errors.append(f"command:{matched.full_name}: {e}") diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index d070ef15..4272e315 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -113,9 +113,7 @@ class PluginRuntimeManager: await self._thirdparty_supervisor.start() started_supervisors.append(self._thirdparty_supervisor) self._started = True - logger.info( - f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}" - ) + logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {thirdparty_dirs or '无'}") except Exception as e: logger.error(f"插件运行时启动失败: {e}", exc_info=True) await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) @@ -303,7 +301,9 @@ class PluginRuntimeManager: cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins) cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info) cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins) - cap_service.register_capability("component.list_registered_plugins", self._cap_component_list_registered_plugins) + cap_service.register_capability( + "component.list_registered_plugins", self._cap_component_list_registered_plugins + ) cap_service.register_capability("component.enable", self._cap_component_enable) cap_service.register_capability("component.disable", self._cap_component_disable) cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin) @@ -1232,9 +1232,7 @@ class PluginRuntimeManager: count: int = args.get("count", 1) try: results = await emoji_api.get_random(count=count) - emojis = [ - {"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results - ] + emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] return {"success": True, "emojis": emojis} except Exception as e: logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) @@ -1269,9 +1267,9 @@ class PluginRuntimeManager: try: results = await emoji_api.get_all() - emojis = [ - {"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results - ] if results else [] + emojis = ( + [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else [] + ) return {"success": True, "emojis": emojis} except Exception as e: logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) diff --git a/src/plugin_runtime/protocol/codec.py b/src/plugin_runtime/protocol/codec.py index 578aad2a..83239348 100644 --- a/src/plugin_runtime/protocol/codec.py +++ b/src/plugin_runtime/protocol/codec.py @@ -12,20 +12,16 @@ class Codec(ABC): """消息编解码器基类""" @abstractmethod - def encode_envelope(self, envelope: Envelope) -> bytes: - ... + def encode_envelope(self, envelope: Envelope) -> bytes: ... @abstractmethod - def decode_envelope(self, data: bytes) -> Envelope: - ... + def decode_envelope(self, data: bytes) -> Envelope: ... @abstractmethod - def encode(self, obj: Dict[str, Any]) -> bytes: - ... + def encode(self, obj: Dict[str, Any]) -> bytes: ... @abstractmethod - def decode(self, data: bytes) -> Dict[str, Any]: - ... + def decode(self, data: bytes) -> Dict[str, Any]: ... class MsgPackCodec(Codec): diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 0b276c7c..61a45ed8 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -24,8 +24,10 @@ MAX_SDK_VERSION = "1.99.99" # ─── 消息类型 ────────────────────────────────────────────────────── + class MessageType(str, Enum): """RPC 消息类型""" + REQUEST = "request" RESPONSE = "response" EVENT = "event" @@ -33,6 +35,7 @@ class MessageType(str, Enum): # ─── 请求 ID 生成器 ─────────────────────────────────────────────── + class RequestIdGenerator: """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)""" @@ -47,6 +50,7 @@ class RequestIdGenerator: # ─── Envelope 模型 ───────────────────────────────────────────────── + class Envelope(BaseModel): """RPC 统一信封 @@ -75,7 +79,9 @@ class Envelope(BaseModel): def is_event(self) -> bool: return self.message_type == MessageType.EVENT - def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope": + def make_response( + self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None + ) -> "Envelope": """基于当前请求创建对应的响应信封""" return Envelope( protocol_version=self.protocol_version, @@ -101,8 +107,10 @@ class Envelope(BaseModel): # ─── 握手消息 ────────────────────────────────────────────────────── + class HelloPayload(BaseModel): """runner.hello 握手请求 payload""" + runner_id: str = Field(description="Runner 进程唯一标识") sdk_version: str = Field(description="SDK 版本号") session_token: str = Field(description="一次性会话令牌") @@ -110,6 +118,7 @@ class HelloPayload(BaseModel): class HelloResponsePayload(BaseModel): """runner.hello 握手响应 payload""" + accepted: bool = Field(description="是否接受连接") host_version: str = Field(default="", description="Host 版本号") assigned_generation: int = Field(default=0, description="分配的 generation 编号") @@ -118,8 +127,10 @@ class HelloResponsePayload(BaseModel): # ─── 组件注册消息 ────────────────────────────────────────────────── + class ComponentDeclaration(BaseModel): """单个组件声明""" + name: str = Field(description="组件名称") component_type: str = Field(description="组件类型: action/command/tool/event_handler") plugin_id: str = Field(description="所属插件 ID") @@ -128,6 +139,7 @@ class ComponentDeclaration(BaseModel): class RegisterComponentsPayload(BaseModel): """plugin.register_components 请求 payload""" + plugin_id: str = Field(description="插件 ID") plugin_version: str = Field(default="1.0.0", description="插件版本") components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") @@ -136,36 +148,44 @@ class RegisterComponentsPayload(BaseModel): # ─── 调用消息 ────────────────────────────────────────────────────── + class InvokePayload(BaseModel): """plugin.invoke_* 请求 payload""" + component_name: str = Field(description="要调用的组件名称") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") class InvokeResultPayload(BaseModel): """plugin.invoke_* 响应 payload""" + success: bool = Field(description="是否成功") result: Any = Field(default=None, description="返回值") # ─── 能力调用消息 ────────────────────────────────────────────────── + class CapabilityRequestPayload(BaseModel): """cap.* 请求 payload(插件 -> Host 能力调用)""" + capability: str = Field(description="能力名称,如 send.text, db.query") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数") class CapabilityResponsePayload(BaseModel): """cap.* 响应 payload""" + success: bool = Field(description="是否成功") result: Any = Field(default=None, description="返回值") # ─── 健康检查 ────────────────────────────────────────────────────── + class HealthPayload(BaseModel): """plugin.health 响应 payload""" + healthy: bool = Field(description="是否健康") loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表") uptime_ms: int = Field(default=0, description="运行时长(ms)") @@ -173,11 +193,13 @@ class HealthPayload(BaseModel): # ─── 配置更新 ────────────────────────────────────────────────────── + # TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated # 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。 # 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。 class ConfigUpdatedPayload(BaseModel): """plugin.config_updated 事件 payload""" + plugin_id: str = Field(description="插件 ID") config_version: str = Field(description="新配置版本") config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容") @@ -185,14 +207,17 @@ class ConfigUpdatedPayload(BaseModel): # ─── 关停 ────────────────────────────────────────────────────────── + class ShutdownPayload(BaseModel): """plugin.shutdown / plugin.prepare_shutdown payload""" + reason: str = Field(default="normal", description="关停原因") drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)") # ─── 日志传输 ────────────────────────────────────────────────────── + class LogEntry(BaseModel): """单条日志记录(Runner → Host 传输格式)""" @@ -200,10 +225,7 @@ class LogEntry(BaseModel): description="日志时间戳,Unix epoch 毫秒", ) level: int = Field( - description=( - "stdlib logging 整数级别:" - " 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL" - ), + description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"), ) logger_name: str = Field( description="Logger 名称,如 plugin.my_plugin.submodule", diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py index 74b4a3e9..b5a0a328 100644 --- a/src/plugin_runtime/runner/log_handler.py +++ b/src/plugin_runtime/runner/log_handler.py @@ -22,6 +22,7 @@ Host 端将其重放到主进程的 Logger(以 plugin. 为名)中, - 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送 - IPC 发送失败时静默忽略;stderr fallback 由 supervisor 的 drain task 覆盖 """ + from __future__ import annotations from typing import TYPE_CHECKING, List, Optional @@ -203,6 +204,7 @@ class RunnerIPCLogHandler(logging.Handler): # IPC 连接断开时回退到 stderr,避免日志静默丢失 if not self._rpc_client.is_connected: import sys + for entry in entries: print( f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}", @@ -218,6 +220,7 @@ class RunnerIPCLogHandler(logging.Handler): ) except Exception: import sys + for entry in entries: print( f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}", diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index bc407e6d..b6990850 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -105,9 +105,7 @@ class ManifestValidator: def _check_manifest_version(self, manifest: Dict[str, Any]) -> None: mv = manifest.get("manifest_version") if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: - self.errors.append( - f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}" - ) + self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}") def _check_author(self, manifest: Dict[str, Any]) -> None: author = manifest.get("author") diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 50767363..1ddc55b7 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -240,8 +240,7 @@ class PluginLoader: instance = self._try_load_legacy_plugin(module, plugin_id) if instance is not None: logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} " - f"通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" ) return PluginMeta( plugin_id=plugin_id, @@ -261,6 +260,7 @@ class PluginLoader: return try: from maibot_sdk.compat._import_hook import install_hook + install_hook() self._compat_hook_installed = True except ImportError: @@ -281,11 +281,7 @@ class PluginLoader: for attr_name in dir(module): obj = getattr(module, attr_name, None) - if ( - isinstance(obj, type) - and issubclass(obj, LegacyBasePlugin) - and obj is not LegacyBasePlugin - ): + if isinstance(obj, type) and issubclass(obj, LegacyBasePlugin) and obj is not LegacyBasePlugin: legacy_cls = obj break @@ -294,6 +290,7 @@ class PluginLoader: try: from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter + legacy_instance = legacy_cls() return LegacyPluginAdapter(legacy_instance) except Exception as e: diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 1ef39946..46a9e9af 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -37,6 +37,7 @@ def _get_sdk_version() -> str: """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。""" try: from importlib.metadata import version + return version("maibot-plugin-sdk") except Exception: return "1.0.0" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 46e9b641..f881ab25 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -133,7 +133,9 @@ class PluginRunner: self._suspend_console_handlers() stdlib_logging.root.addHandler(handler) self._log_handler = handler - logger.debug("RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b") + logger.debug( + "RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b" + ) async def _uninstall_log_handler(self) -> None: """关停前从 logging.root 移除 Handler 并刷空缓冲。 @@ -291,7 +293,11 @@ class PluginRunner: ) try: - result = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) + result = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) resp_payload = InvokeResultPayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except Exception as e: @@ -332,7 +338,11 @@ class PluginRunner: ) try: - raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) + raw = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) # 规范化返回值:将 EventHandler 返回展平到 payload 顶层 if raw is None: @@ -341,7 +351,9 @@ class PluginRunner: result = { "success": True, # 兼容 guide.md 中文档的 {"blocked": True} 写法 - "continue_processing": not raw.get("blocked", False) if "blocked" in raw else raw.get("continue_processing", True), + "continue_processing": not raw.get("blocked", False) + if "blocked" in raw + else raw.get("continue_processing", True), "modified_message": raw.get("modified_message"), "custom_result": raw.get("custom_result"), } @@ -383,7 +395,11 @@ class PluginRunner: ) try: - raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args) + raw = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) # 规范化返回值 if isinstance(raw, str): @@ -455,6 +471,7 @@ class PluginRunner: # ─── sys.path 隔离 ──────────────────────────────────────── + def _isolate_sys_path(plugin_dirs: List[str]) -> None: """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 @@ -504,9 +521,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: return self if self._should_block(fullname) else None def load_module(self, fullname): - raise ImportError( - f"Runner 子进程不允许导入主程序模块: {fullname}" - ) + raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") def _should_block(self, fullname: str) -> bool: # 放行非 src.* 的导入、以及 "src" 本身 @@ -514,8 +529,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: return False # 放行白名单前缀 return not any( - fullname == prefix or fullname.startswith(f"{prefix}.") - for prefix in self._ALLOWED_SRC_PREFIXES + fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES ) sys.meta_path.insert(0, _PluginImportBlocker()) @@ -523,6 +537,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: # ─── 进程入口 ────────────────────────────────────────────── + async def _async_main() -> None: """异步主入口""" host_address = os.environ.get(ENV_IPC_ADDRESS, "") diff --git a/src/plugin_runtime/transport/base.py b/src/plugin_runtime/transport/base.py index 229de0b1..ca14e6df 100644 --- a/src/plugin_runtime/transport/base.py +++ b/src/plugin_runtime/transport/base.py @@ -20,6 +20,7 @@ MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小 class ConnectionClosed(Exception): """连接已关闭""" + pass diff --git a/src/plugin_runtime/transport/factory.py b/src/plugin_runtime/transport/factory.py index c8de493c..80a5c625 100644 --- a/src/plugin_runtime/transport/factory.py +++ b/src/plugin_runtime/transport/factory.py @@ -20,10 +20,12 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe """ if sys.platform != "win32": from .uds import UDSTransportServer + return UDSTransportServer(socket_path=socket_path) else: # Windows 回退到 TCP(后续可改为 Named Pipe) from .tcp import TCPTransportServer + return TCPTransportServer() @@ -39,9 +41,11 @@ def create_transport_client(address: str) -> TransportClient: """ if "/" in address or address.endswith(".sock"): from .uds import UDSTransportClient + return UDSTransportClient(socket_path=address) elif ":" in address: from .tcp import TCPTransportClient + host, port_str = address.rsplit(":", 1) return TCPTransportClient(host=host, port=int(port_str)) else: diff --git a/src/plugin_runtime/transport/tcp.py b/src/plugin_runtime/transport/tcp.py index 870d33b3..20b708ea 100644 --- a/src/plugin_runtime/transport/tcp.py +++ b/src/plugin_runtime/transport/tcp.py @@ -13,6 +13,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe class TCPConnection(Connection): """基于 TCP 的连接""" + pass diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 0ea885e7..da10f0d4 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -15,6 +15,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe class UDSConnection(Connection): """基于 UDS 的连接""" + pass # 直接复用 Connection 基类的分帧读写 @@ -30,16 +31,17 @@ class UDSTransportServer(TransportServer): if socket_path is None: # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 import uuid - socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") + + socket_path = os.path.join( + tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock" + ) # 如果路径超出 UDS 限制,回退到更短的路径 if len(socket_path.encode()) > _UDS_PATH_MAX: socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") if len(socket_path.encode()) > _UDS_PATH_MAX: - raise OSError( - f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}" - ) + raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}") self._socket_path = socket_path self._server: Optional[asyncio.AbstractServer] = None diff --git a/src/plugins/built_in/knowledge/plugin.py b/src/plugins/built_in/knowledge/plugin.py index 92185521..665029a8 100644 --- a/src/plugins/built_in/knowledge/plugin.py +++ b/src/plugins/built_in/knowledge/plugin.py @@ -14,8 +14,16 @@ class KnowledgePlugin(MaiBotPlugin): "lpmm_search_knowledge", description="从知识库中搜索相关信息,如果你需要知识,就使用这个工具", parameters=[ - ToolParameterInfo(name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True), - ToolParameterInfo(name="limit", param_type=ToolParamType.INTEGER, description="希望返回的相关知识条数,默认5", required=False, default=5), + ToolParameterInfo( + name="query", param_type=ToolParamType.STRING, description="搜索查询关键词", required=True + ), + ToolParameterInfo( + name="limit", + param_type=ToolParamType.INTEGER, + description="希望返回的相关知识条数,默认5", + required=False, + default=5, + ), ], ) async def handle_lpmm_search_knowledge(self, query: str = "", limit: int = 5, **kwargs): diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 81c77faa..fe0888c6 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -231,9 +231,7 @@ class PluginManagementPlugin(MaiBotPlugin): text = ", ".join(f"{c['name']} ({c['type']})" for c in filtered) await self.ctx.send.text(f"满足条件的{label}{scope_label}组件: {text}", stream_id) - async def _handle_component_toggle( - self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str - ): + async def _handle_component_toggle(self, action: str, scope: str, comp_name: str, comp_type: str, stream_id: str): if action not in ("enable", "disable"): await self.ctx.send.text("插件管理命令不合法", stream_id) return @@ -245,13 +243,9 @@ class PluginManagementPlugin(MaiBotPlugin): return if action == "enable": - result = await self.ctx.component.enable_component( - comp_name, comp_type, scope=scope, stream_id=stream_id - ) + result = await self.ctx.component.enable_component(comp_name, comp_type, scope=scope, stream_id=stream_id) else: - result = await self.ctx.component.disable_component( - comp_name, comp_type, scope=scope, stream_id=stream_id - ) + result = await self.ctx.component.disable_component(comp_name, comp_type, scope=scope, stream_id=stream_id) ok = result.get("success", False) if isinstance(result, dict) else bool(result) scope_label = "全局" if scope == "global" else "本地" diff --git a/src/services/chat_service.py b/src/services/chat_service.py index ea6e6541..35324bd9 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -142,13 +142,21 @@ class ChatManager: if chat_stream.is_group_session: info["group_id"] = chat_stream.group_id - if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info: + if ( + chat_stream.context + and chat_stream.context.message + and chat_stream.context.message.message_info.group_info + ): info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" else: info["group_name"] = "未知群聊" else: info["user_id"] = chat_stream.user_id - if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info: + if ( + chat_stream.context + and chat_stream.context.message + and chat_stream.context.message.message_info.user_info + ): info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname else: info["user_name"] = "未知用户" diff --git a/src/services/generator_service.py b/src/services/generator_service.py index 8b5c5152..7a84b911 100644 --- a/src/services/generator_service.py +++ b/src/services/generator_service.py @@ -44,7 +44,9 @@ def get_replyer( if not chat_id and not chat_stream: raise ValueError("chat_stream 和 chat_id 不可均为空") try: - logger.debug(f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") + logger.debug( + f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}" + ) return replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, diff --git a/src/services/send_service.py b/src/services/send_service.py index 0f2eac85..29a79ced 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -70,10 +70,10 @@ async def _send_to_target( if reply_message: anchor_message = db_message_to_mai_message(reply_message) if anchor_message: - logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") - reply_to_platform_id = ( - f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" + logger.debug( + f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" ) + reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" sender_info = None if target_stream.context and target_stream.context.message: diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index f7fa86da..1d43f306 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -174,4 +174,4 @@ async def broadcast_log(log_data: dict): # 清理断开的连接 if disconnected: active_connections.difference_update(disconnected) - logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接") \ No newline at end of file + logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接") diff --git a/src/webui/routers/annual_report.py b/src/webui/routers/annual_report.py index b9abc54d..2eeb8165 100644 --- a/src/webui/routers/annual_report.py +++ b/src/webui/routers/annual_report.py @@ -551,9 +551,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: try: # 1. 表情包之王 - 使用次数最多的表情包 with get_db_session() as session: - statement = ( - select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5) - ) + statement = select(Images).where(col(Images.is_registered)).order_by(desc(col(Images.query_count))).limit(5) top_emojis = session.exec(statement).all() if top_emojis: data.top_emoji = { diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index bee2d276..02414109 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -314,17 +314,13 @@ async def get_jargon_stats(): with get_db_session() as session: total = session.exec(select(fn.count()).select_from(Jargon)).one() - confirmed_jargon = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon)) - ).one() + confirmed_jargon = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon))).one() confirmed_not_jargon = session.exec( select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(False)) ).one() pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() - complete_count = session.exec( - select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete)) - ).one() + complete_count = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete))).one() chat_count = session.exec( select(fn.count()).select_from(