diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index b4a43b23..28054717 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -225,9 +225,7 @@ class BrainChatting: await database_api.store_action_info( chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=action_prompt_display, - action_done=True, + display_prompt=action_prompt_display, thinking_id=thinking_id, action_data={"reply_text": reply_text}, action_name="reply", @@ -573,9 +571,7 @@ class BrainChatting: # 存储complete_talk信息到数据库 await database_api.store_action_info( chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, + display_prompt=reason, thinking_id=thinking_id, action_data={"reason": reason}, action_name="complete_talk", @@ -678,9 +674,7 @@ class BrainChatting: # 记录动作信息 await database_api.store_action_info( chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason or f"等待 {wait_seconds} 秒", - action_done=True, + display_prompt=reason or f"等待 {wait_seconds} 秒", thinking_id=thinking_id, action_data={"reason": reason, "wait_seconds": wait_seconds}, action_name="wait", @@ -722,9 +716,7 @@ class BrainChatting: # 记录动作信息 await database_api.store_action_info( chat_stream=self.chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason or f"倾听并等待 {wait_seconds} 秒", - action_done=True, + display_prompt=reason or f"倾听并等待 {wait_seconds} 秒", thinking_id=thinking_id, action_data={"reason": reason, "wait_seconds": wait_seconds}, action_name="listening", diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index 7613de0a..9a46f340 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -8,6 +8,7 @@ import json import time import re import difflib +import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Set from dataclasses import dataclass, field @@ -917,21 +918,18 @@ class ChatHistorySummarizer: # 准备数据 data = { - "chat_id": self.session_id, - "start_time": start_time, - "end_time": end_time, - "original_text": original_text, + "session_id": self.session_id, + "start_timestamp": datetime.fromtimestamp(start_time), + "end_timestamp": datetime.fromtimestamp(end_time), + "original_messages": original_text, "participants": json.dumps(participants, ensure_ascii=False), "theme": theme, "keywords": json.dumps(keywords, ensure_ascii=False), "summary": summary, - "count": 0, + "query_count": 0, + "query_forget_count": 0, } - # 使用db_save存储(使用start_time和chat_id作为唯一标识) - # 由于可能有多条记录,我们使用组合键,但peewee不支持,所以使用start_time作为唯一标识 - # 但为了避免冲突,我们使用组合键:chat_id + start_time - # 由于peewee不支持组合键,我们直接创建新记录(不提供key_field和key_value) saved_record = await database_api.db_save( ChatHistory, data=data, diff --git a/src/plugin_runtime/capabilities/__init__.py b/src/plugin_runtime/capabilities/__init__.py new file mode 100644 index 00000000..3ade3886 --- /dev/null +++ b/src/plugin_runtime/capabilities/__init__.py @@ -0,0 +1,9 @@ +from .components import RuntimeComponentCapabilityMixin +from .core import RuntimeCoreCapabilityMixin +from .data import RuntimeDataCapabilityMixin + +__all__ = [ + "RuntimeComponentCapabilityMixin", + "RuntimeCoreCapabilityMixin", + "RuntimeDataCapabilityMixin", +] diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py new file mode 100644 index 00000000..c5d3d20f --- /dev/null +++ b/src/plugin_runtime/capabilities/components.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.integration") + + +class RuntimeComponentCapabilityMixin: + async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + result: Dict[str, Any] = {} + for sv in self.supervisors: + for pid, reg in sv._registered_plugins.items(): + if pid in result: + logger.error(f"检测到重复插件 ID {pid},component.get_all_plugins 结果已拒绝聚合") + return {"success": False, "error": f"检测到重复插件 ID: {pid}"} + comps = sv.component_registry.get_components_by_plugin(pid, enabled_only=False) + components_list = [ + { + "name": component.name, + "full_name": component.full_name, + "type": component.component_type, + "enabled": component.enabled, + "metadata": component.metadata, + } + for component in comps + ] + result[pid] = { + "name": pid, + "version": reg.plugin_version, + "description": "", + "author": "", + "enabled": True, + "components": components_list, + } + return {"success": True, "plugins": result} + + async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + plugin_name: str = args.get("plugin_name", plugin_id) + try: + sv = self._get_supervisor_for_plugin(plugin_name) + except RuntimeError as exc: + return {"success": False, "error": str(exc)} + + if sv is not None and (reg := sv._registered_plugins.get(plugin_name)) is not None: + return { + "success": True, + "plugin": { + "name": plugin_name, + "version": reg.plugin_version, + "description": "", + "author": "", + "enabled": True, + }, + } + return {"success": False, "error": f"未找到插件: {plugin_name}"} + + async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + plugins: List[str] = [] + for sv in self.supervisors: + plugins.extend(sv._registered_plugins.keys()) + return {"success": True, "plugins": plugins} + + async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + plugins: List[str] = [] + for sv in self.supervisors: + plugins.extend(sv._registered_plugins.keys()) + return {"success": True, "plugins": plugins} + + def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]: + short_name_matches: List[Any] = [] + for sv in self.supervisors: + comp = sv.component_registry.get_component(name) + if comp is not None and comp.component_type == component_type: + return comp, None + + short_name_matches.extend( + candidate + for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False) + if candidate.name == name + ) + + if len(short_name_matches) == 1: + return short_name_matches[0], None + if len(short_name_matches) > 1: + return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" + return None, f"未找到组件: {name} ({component_type})" + + async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + name: str = args.get("name", "") + component_type: str = args.get("component_type", "") + scope: str = args.get("scope", "global") + stream_id: str = args.get("stream_id", "") + if not name or not component_type: + return {"success": False, "error": "缺少必要参数 name 或 component_type"} + if scope != "global" or stream_id: + return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} + + comp, error = self._resolve_component_toggle_target(name, component_type) + if comp is None: + return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} + + comp.enabled = True + return {"success": True} + + async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + name: str = args.get("name", "") + component_type: str = args.get("component_type", "") + scope: str = args.get("scope", "global") + stream_id: str = args.get("stream_id", "") + if not name or not component_type: + return {"success": False, "error": "缺少必要参数 name 或 component_type"} + if scope != "global" or stream_id: + return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} + + comp, error = self._resolve_component_toggle_target(name, component_type) + if comp is None: + return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} + + comp.enabled = False + return {"success": True} + + async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + plugin_name: str = args.get("plugin_name", "") + if not plugin_name: + return {"success": False, "error": "缺少必要参数 plugin_name"} + + import os + + if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())): + details = "; ".join( + f"{conflict_plugin_id}: {', '.join(paths)}" + for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items()) + ) + return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} + + try: + registered_supervisor = self._get_supervisor_for_plugin(plugin_name) + except RuntimeError as exc: + return {"success": False, "error": str(exc)} + + if registered_supervisor is not None: + try: + reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}") + if reloaded: + return {"success": True, "count": 1} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} + + for sv in self.supervisors: + for pdir in sv._plugin_dirs: + if os.path.isdir(os.path.join(pdir, plugin_name)): + try: + reloaded = await sv.reload_plugins(reason=f"load {plugin_name}") + if reloaded: + return {"success": True, "count": 1} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} + + return {"success": False, "error": f"未找到插件: {plugin_name}"} + + async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} + + async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + plugin_name: str = args.get("plugin_name", "") + if not plugin_name: + return {"success": False, "error": "缺少必要参数 plugin_name"} + + if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())): + details = "; ".join( + f"{conflict_plugin_id}: {', '.join(paths)}" + for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items()) + ) + return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} + + try: + sv = self._get_supervisor_for_plugin(plugin_name) + except RuntimeError as exc: + return {"success": False, "error": str(exc)} + + if sv is not None: + try: + reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}") + if reloaded: + return {"success": True} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} + except Exception as e: + logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} + return {"success": False, "error": f"未找到插件: {plugin_name}"} diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py new file mode 100644 index 00000000..bea5557b --- /dev/null +++ b/src/plugin_runtime/capabilities/core.py @@ -0,0 +1,283 @@ +from typing import Any, Dict + +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.integration") + + +class RuntimeCoreCapabilityMixin: + async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + text: str = args.get("text", "") + stream_id: str = args.get("stream_id", "") + if not text or not stream_id: + return {"success": False, "error": "缺少必要参数 text 或 stream_id"} + + try: + result = await send_api.text_to_stream( + text=text, + stream_id=stream_id, + typing=args.get("typing", False), + set_reply=args.get("set_reply", False), + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + emoji_base64: str = args.get("emoji_base64", "") + stream_id: str = args.get("stream_id", "") + if not emoji_base64 or not stream_id: + return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} + + try: + result = await send_api.emoji_to_stream( + emoji_base64=emoji_base64, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + image_base64: str = args.get("image_base64", "") + stream_id: str = args.get("stream_id", "") + if not image_base64 or not stream_id: + return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} + + try: + result = await send_api.image_to_stream( + image_base64=image_base64, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + command = args.get("command", "") + stream_id: str = args.get("stream_id", "") + if not command or not stream_id: + return {"success": False, "error": "缺少必要参数 command 或 stream_id"} + + try: + result = await send_api.command_to_stream( + command=command, + stream_id=stream_id, + storage_message=args.get("storage_message", True), + display_message=args.get("display_message", ""), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + message_type: str = args.get("message_type", "") or args.get("custom_type", "") + content = args.get("content") + if content is None: + content = args.get("data", "") + stream_id: str = args.get("stream_id", "") + if not message_type or not stream_id: + return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} + + try: + result = await send_api.custom_to_stream( + message_type=message_type, + content=content, + stream_id=stream_id, + display_message=args.get("display_message", ""), + typing=args.get("typing", False), + storage_message=args.get("storage_message", True), + ) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_forward(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + messages = args.get("messages", []) + stream_id: str = args.get("stream_id", "") + if not messages or not stream_id: + return {"success": False, "error": "缺少必要参数 messages 或 stream_id"} + + try: + result = await send_api.forward_to_stream(messages=messages, stream_id=stream_id) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.forward] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_send_hybrid(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import send_service as send_api + + segments = args.get("segments", []) + stream_id: str = args.get("stream_id", "") + if not segments or not stream_id: + return {"success": False, "error": "缺少必要参数 segments 或 stream_id"} + + try: + result = await send_api.hybrid_to_stream(segments=segments, stream_id=stream_id) + return {"success": result} + except Exception as e: + logger.error(f"[cap.send.hybrid] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import llm_service as llm_api + + prompt: str = args.get("prompt", "") + if not prompt: + return {"success": False, "error": "缺少必要参数 prompt"} + + model_name: str = args.get("model", "") or args.get("model_name", "") + temperature = args.get("temperature") + max_tokens = args.get("max_tokens") + + try: + models = llm_api.get_available_models() + if model_name and model_name in models: + model_config = models[model_name] + else: + if not models: + return {"success": False, "error": "没有可用的模型配置"} + model_config = next(iter(models.values())) + + success, response, reasoning, used_model = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type=f"plugin.{plugin_id}", + temperature=temperature, + max_tokens=max_tokens, + ) + return { + "success": success, + "response": response, + "reasoning": reasoning, + "model_name": used_model, + } + except Exception as e: + logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import llm_service as llm_api + + prompt: str = args.get("prompt", "") + if not prompt: + return {"success": False, "error": "缺少必要参数 prompt"} + + model_name: str = args.get("model", "") or args.get("model_name", "") + tool_options = args.get("tools") or args.get("tool_options") + temperature = args.get("temperature") + max_tokens = args.get("max_tokens") + + try: + models = llm_api.get_available_models() + if model_name and model_name in models: + model_config = models[model_name] + else: + if not models: + return {"success": False, "error": "没有可用的模型配置"} + model_config = next(iter(models.values())) + + success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools( + prompt=prompt, + model_config=model_config, + tool_options=tool_options, + request_type=f"plugin.{plugin_id}", + temperature=temperature, + max_tokens=max_tokens, + ) + serialized_tool_calls = None + if tool_calls: + serialized_tool_calls = [ + {"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + for tc in tool_calls + if hasattr(tc, "function") + ] + return { + "success": success, + "response": response, + "reasoning": reasoning, + "model_name": used_model, + "tool_calls": serialized_tool_calls, + } + except Exception as e: + logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import llm_service as llm_api + + try: + models = llm_api.get_available_models() + return {"success": True, "models": list(models.keys())} + except Exception as e: + logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import config_service as config_api + + key: str = args.get("key", "") + default = args.get("default") + if not key: + return {"success": False, "value": None, "error": "缺少必要参数 key"} + + try: + value = config_api.get_global_config(key, default) + return {"success": True, "value": value} + except Exception as e: + return {"success": False, "value": None, "error": str(e)} + + async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.core.component_registry import component_registry as core_registry + + plugin_name: str = args.get("plugin_name", plugin_id) + key: str = args.get("key", "") + default = args.get("default") + + try: + config = core_registry.get_plugin_config(plugin_name) + if config is None: + return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} + + if key: + from src.services import config_service as config_api + + value = config_api.get_plugin_config(config, key, default) + return {"success": True, "value": value} + + return {"success": True, "value": config} + except Exception as e: + return {"success": False, "value": default, "error": str(e)} + + async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.core.component_registry import component_registry as core_registry + + plugin_name: str = args.get("plugin_name", plugin_id) + try: + config = core_registry.get_plugin_config(plugin_name) + if config is None: + return {"success": True, "value": {}} + return {"success": True, "value": config} + except Exception as e: + return {"success": False, "value": {}, "error": str(e)} diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py new file mode 100644 index 00000000..e5b183cd --- /dev/null +++ b/src/plugin_runtime/capabilities/data.py @@ -0,0 +1,582 @@ +from typing import Any, Dict, List, Optional + +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.integration") + + +class RuntimeDataCapabilityMixin: + async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + query_type = args.get("query_type", "get") + if query_type == "get": + result = await database_api.db_get( + model_class=model_class, + filters=args.get("filters"), + limit=args.get("limit"), + order_by=args.get("order_by"), + single_result=args.get("single_result", False), + ) + elif query_type == "create": + if not (data := args.get("data")): + return {"success": False, "error": "create 需要 data"} + result = await database_api.db_save(model_class=model_class, data=data) + elif query_type == "update": + if not (data := args.get("data")): + return {"success": False, "error": "update 需要 data"} + result = await database_api.db_update( + model_class=model_class, + data=data, + filters=args.get("filters"), + ) + elif query_type == "delete": + result = await database_api.db_delete(model_class=model_class, filters=args.get("filters")) + elif query_type == "count": + result = await database_api.db_count(model_class=model_class, filters=args.get("filters")) + else: + return {"success": False, "error": f"不支持的 query_type: {query_type}"} + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_database_save(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + data: Optional[Dict[str, Any]] = args.get("data") + if not model_name or not data: + return {"success": False, "error": "缺少必要参数 model_name 或 data"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_save( + model_class=model_class, + data=data, + key_field=args.get("key_field"), + key_value=args.get("key_value"), + ) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_database_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_get( + model_class=model_class, + filters=args.get("filters"), + limit=args.get("limit"), + order_by=args.get("order_by"), + single_result=args.get("single_result", False), + ) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_database_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + filters = args.get("filters", {}) + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + if not filters: + return {"success": False, "error": "缺少必要参数 filters(不允许无条件删除)"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_delete(model_class=model_class, filters=filters) + return {"success": True, "result": result} + except Exception as e: + logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_database_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import database_service as database_api + + model_name: str = args.get("model_name", "") + if not model_name: + return {"success": False, "error": "缺少必要参数 model_name"} + + try: + import src.common.database.database_model as db_models + + model_class = getattr(db_models, model_name, None) + if model_class is None: + return {"success": False, "error": f"未找到数据模型: {model_name}"} + + result = await database_api.db_count(model_class=model_class, filters=args.get("filters")) + return {"success": True, "count": result} + except Exception as e: + logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + def _list_sessions(self, platform: str, is_group_session: Optional[bool] = None) -> List[BotChatSession]: + return [ + session + for session in chat_manager.sessions.values() + if (platform == "all_platforms" or session.platform == platform) + and (is_group_session is None or session.is_group_session == is_group_session) + ] + + @staticmethod + def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]: + return { + "session_id": stream.session_id, + "platform": stream.platform, + "user_id": stream.user_id, + "group_id": stream.group_id, + "is_group_session": stream.is_group_session, + } + + async def _cap_chat_get_all_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + platform: str = args.get("platform", "qq") + try: + streams = self._list_sessions(platform=platform) + return {"success": True, "streams": [self._serialize_stream(item) for item in streams]} + except Exception as e: + logger.error(f"[cap.chat.get_all_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_chat_get_group_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + platform: str = args.get("platform", "qq") + try: + streams = self._list_sessions(platform=platform, is_group_session=True) + return {"success": True, "streams": [self._serialize_stream(item) for item in streams]} + except Exception as e: + logger.error(f"[cap.chat.get_group_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_chat_get_private_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + platform: str = args.get("platform", "qq") + try: + streams = self._list_sessions(platform=platform, is_group_session=False) + return {"success": True, "streams": [self._serialize_stream(item) for item in streams]} + except Exception as e: + logger.error(f"[cap.chat.get_private_streams] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_chat_get_stream_by_group_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + group_id: str = args.get("group_id", "") + if not group_id: + return {"success": False, "error": "缺少必要参数 group_id"} + + platform: str = args.get("platform", "qq") + try: + stream = next( + ( + item + for item in self._list_sessions(platform=platform, is_group_session=True) + if str(item.group_id) == str(group_id) + ), + None, + ) + return {"success": True, "stream": None if stream is None else self._serialize_stream(stream)} + except Exception as e: + logger.error(f"[cap.chat.get_stream_by_group_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_chat_get_stream_by_user_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + user_id: str = args.get("user_id", "") + if not user_id: + return {"success": False, "error": "缺少必要参数 user_id"} + + platform: str = args.get("platform", "qq") + try: + stream = next( + ( + item + for item in self._list_sessions(platform=platform, is_group_session=False) + if str(item.user_id) == str(user_id) + ), + None, + ) + return {"success": True, "stream": None if stream is None else self._serialize_stream(stream)} + except Exception as e: + logger.error(f"[cap.chat.get_stream_by_user_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + def _serialize_messages(messages: list) -> List[Dict[str, Any]]: + result: List[Dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "model_dump"): + result.append(msg.model_dump()) + elif hasattr(msg, "__dict__"): + result.append(dict(msg.__dict__)) + else: + result.append(str(msg)) + return result + + async def _cap_message_get_by_time(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import message_service as message_api + + try: + messages = message_api.get_messages_by_time( + start_time=float(args.get("start_time", 0.0)), + end_time=float(args.get("end_time", 0.0)), + limit=args.get("limit", 0), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + ) + return {"success": True, "messages": self._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_message_get_by_time_in_chat(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + messages = message_api.get_messages_by_time_in_chat( + chat_id=chat_id, + start_time=float(args.get("start_time", 0.0)), + end_time=float(args.get("end_time", 0.0)), + limit=args.get("limit", 0), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + filter_command=args.get("filter_command", False), + ) + return {"success": True, "messages": self._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_message_get_recent(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + messages = message_api.get_recent_messages( + chat_id=chat_id, + hours=float(args.get("hours", 24.0)), + limit=args.get("limit", 100), + limit_mode=args.get("limit_mode", "latest"), + filter_mai=args.get("filter_mai", False), + ) + return {"success": True, "messages": self._serialize_messages(messages)} + except Exception as e: + logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_message_count_new(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import message_service as message_api + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + since = args.get("since") + start_time = float(since) if since is not None else float(args.get("start_time", 0.0)) + count = message_api.count_new_messages( + chat_id=chat_id, + start_time=start_time, + end_time=args.get("end_time"), + ) + return {"success": True, "count": count} + except Exception as e: + logger.error(f"[cap.message.count_new] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_message_build_readable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import message_service as message_api + + try: + messages = args.get("messages") + if messages is None: + if not (chat_id := args.get("chat_id", "")): + return {"success": False, "error": "缺少必要参数: messages 或 chat_id"} + messages = message_api.get_messages_by_time_in_chat( + chat_id=chat_id, + start_time=float(args.get("start_time", 0.0)), + end_time=float(args.get("end_time", 0.0)), + limit=args.get("limit", 0), + ) + + readable = message_api.build_readable_messages_to_str( + messages=messages, + replace_bot_name=args.get("replace_bot_name", True), + timestamp_mode=args.get("timestamp_mode", "relative"), + truncate=args.get("truncate", False), + ) + return {"success": True, "text": readable} + except Exception as e: + logger.error(f"[cap.message.build_readable] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_person_get_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.person_info.person_info import Person + + platform: str = args.get("platform", "") + user_id = args.get("user_id", "") + if not platform or not user_id: + return {"success": False, "error": "缺少必要参数 platform 或 user_id"} + + try: + pid = Person(platform=platform, user_id=str(user_id)).person_id + return {"success": True, "person_id": pid} + except Exception as e: + logger.error(f"[cap.person.get_id] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_person_get_value(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.person_info.person_info import Person + + person_id: str = args.get("person_id", "") + field_name: str = args.get("field_name", "") + if not person_id or not field_name: + return {"success": False, "error": "缺少必要参数 person_id 或 field_name"} + + try: + person = Person(person_id=person_id) + value = getattr(person, field_name) + if value is None: + value = args.get("default") + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.person.get_value] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_person_get_id_by_name(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.person_info.person_info import Person + + person_name: str = args.get("person_name", "") + if not person_name: + return {"success": False, "error": "缺少必要参数 person_name"} + + try: + pid = Person(person_name=person_name).person_id + return {"success": True, "person_id": pid} + except Exception as e: + logger.error(f"[cap.person.get_id_by_name] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + description: str = args.get("description", "") + if not description: + return {"success": False, "error": "缺少必要参数 description"} + + try: + result = await emoji_api.get_by_description(description=description) + if result is None: + return {"success": True, "emoji": None} + emoji_base64, emoji_desc, matched_emotion = result + return { + "success": True, + "emoji": { + "base64": emoji_base64, + "description": emoji_desc, + "emotion": matched_emotion, + }, + } + except Exception as e: + logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + count: int = args.get("count", 1) + try: + results = await emoji_api.get_random(count=count) + emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] + return {"success": True, "emojis": emojis} + except Exception as e: + logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "count": emoji_api.get_count()} + except Exception as e: + logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "emotions": emoji_api.get_emotions()} + except Exception as e: + logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + try: + results = await emoji_api.get_all() + emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else [] + return {"success": True, "emojis": emojis} + except Exception as e: + logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + try: + return {"success": True, "info": emoji_api.get_info()} + except Exception as e: + logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + emoji_base64: str = args.get("emoji_base64", "") + if not emoji_base64: + return {"success": False, "error": "缺少必要参数 emoji_base64"} + + try: + return await emoji_api.register_emoji(emoji_base64) + except Exception as e: + logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.services import emoji_service as emoji_api + + emoji_hash: str = args.get("emoji_hash", "") + if not emoji_hash: + return {"success": False, "error": "缺少必要参数 emoji_hash"} + + try: + return await emoji_api.delete_emoji(emoji_hash) + except Exception as e: + logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + @staticmethod + def _get_frequency_adjust_value(chat_id: str) -> float: + from src.chat.heart_flow.heartflow_manager import heartflow_manager + + heartflow_chat = heartflow_manager.heartflow_chat_list.get(chat_id) + return 1.0 if heartflow_chat is None else heartflow_chat._talk_frequency_adjust + + async def _cap_frequency_get_current_talk_value(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.common.utils.utils_config import ChatConfigUtils + + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + value = self._get_frequency_adjust_value(chat_id) * ChatConfigUtils.get_talk_value(chat_id) + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.frequency.get_current_talk_value] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_frequency_set_adjust(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.chat.heart_flow.heartflow_manager import heartflow_manager + + chat_id: str = args.get("chat_id", "") + value = args.get("value") + if not chat_id or value is None: + return {"success": False, "error": "缺少必要参数 chat_id 或 value"} + + try: + heartflow_manager.adjust_talk_frequency(chat_id, float(value)) + return {"success": True} + except Exception as e: + logger.error(f"[cap.frequency.set_adjust] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_frequency_get_adjust(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + chat_id: str = args.get("chat_id", "") + if not chat_id: + return {"success": False, "error": "缺少必要参数 chat_id"} + + try: + value = self._get_frequency_adjust_value(chat_id) + return {"success": True, "value": value} + except Exception as e: + logger.error(f"[cap.frequency.get_adjust] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + from src.core.component_registry import component_registry as core_registry + + try: + tools = core_registry.get_llm_available_tools() + return { + "success": True, + "tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()], + } + except Exception as e: + logger.error(f"[cap.tool.get_definitions] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def _cap_knowledge_search(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + query: str = args.get("query", "") + if not query: + return {"success": False, "error": "缺少必要参数 query"} + + limit = args.get("limit", 5) + try: + limit_value = max(1, int(limit)) + except (TypeError, ValueError): + limit_value = 5 + + try: + from src.chat.knowledge import qa_manager + + if qa_manager is None: + return {"success": True, "content": "LPMM知识库已禁用"} + + knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) + content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识" + return {"success": True, "content": content} + except Exception as e: + logger.error(f"[cap.knowledge.search] 执行失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py new file mode 100644 index 00000000..7890b7e4 --- /dev/null +++ b/src/plugin_runtime/capabilities/registry.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +from src.common.logger import get_logger +from src.plugin_runtime.host.supervisor import PluginSupervisor + +if TYPE_CHECKING: + from src.plugin_runtime.integration import PluginRuntimeManager + +logger = get_logger("plugin_runtime.integration") + + +def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None: + """向指定 Supervisor 注册主程序提供的能力实现。""" + cap_service = supervisor.capability_service + + cap_service.register_capability("send.text", manager._cap_send_text) + cap_service.register_capability("send.emoji", manager._cap_send_emoji) + cap_service.register_capability("send.image", manager._cap_send_image) + cap_service.register_capability("send.command", manager._cap_send_command) + cap_service.register_capability("send.custom", manager._cap_send_custom) + cap_service.register_capability("send.forward", manager._cap_send_forward) + cap_service.register_capability("send.hybrid", manager._cap_send_hybrid) + + cap_service.register_capability("llm.generate", manager._cap_llm_generate) + cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools) + cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models) + + cap_service.register_capability("config.get", manager._cap_config_get) + cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin) + cap_service.register_capability("config.get_all", manager._cap_config_get_all) + + cap_service.register_capability("database.query", manager._cap_database_query) + cap_service.register_capability("database.save", manager._cap_database_save) + cap_service.register_capability("database.get", manager._cap_database_get) + cap_service.register_capability("database.delete", manager._cap_database_delete) + cap_service.register_capability("database.count", manager._cap_database_count) + + cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams) + cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams) + cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams) + cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) + cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) + + cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time) + cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) + cap_service.register_capability("message.get_recent", manager._cap_message_get_recent) + cap_service.register_capability("message.count_new", manager._cap_message_count_new) + cap_service.register_capability("message.build_readable", manager._cap_message_build_readable) + + cap_service.register_capability("person.get_id", manager._cap_person_get_id) + cap_service.register_capability("person.get_value", manager._cap_person_get_value) + cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name) + + cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description) + cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random) + cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count) + cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions) + cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all) + cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info) + cap_service.register_capability("emoji.register", manager._cap_emoji_register) + cap_service.register_capability("emoji.delete", manager._cap_emoji_delete) + + cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) + cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust) + cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust) + + cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions) + + cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins) + cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info) + cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) + cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins) + cap_service.register_capability("component.enable", manager._cap_component_enable) + cap_service.register_capability("component.disable", manager._cap_component_disable) + cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin) + cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin) + cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin) + + cap_service.register_capability("knowledge.search", manager._cap_knowledge_search) + logger.debug("已注册全部主程序能力实现") diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 142aca43..a6d04d95 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -7,17 +7,25 @@ 4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 """ -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple import asyncio import json import os from pathlib import Path -from src.chat.message_receive.chat_manager import BotChatSession from src.common.logger import get_logger from src.config.config import config_manager, global_config from src.config.file_watcher import FileChange, FileWatcher +from src.plugin_runtime.capabilities import ( + RuntimeComponentCapabilityMixin, + RuntimeCoreCapabilityMixin, + RuntimeDataCapabilityMixin, +) +from src.plugin_runtime.capabilities.registry import register_capability_impls + +if TYPE_CHECKING: + from src.plugin_runtime.host.supervisor import PluginSupervisor logger = get_logger("plugin_runtime.integration") @@ -36,7 +44,11 @@ _EVENT_TYPE_MAP: Dict[str, str] = { } -class PluginRuntimeManager: +class PluginRuntimeManager( + RuntimeCoreCapabilityMixin, + RuntimeDataCapabilityMixin, + RuntimeComponentCapabilityMixin, +): """插件运行时管理器(单例) 内置插件与第三方插件分别运行在各自的 Supervisor / Runner 子进程中。 @@ -463,1467 +475,8 @@ class PluginRuntimeManager: # ─── 能力实现注册 ────────────────────────────────────────── - def _register_capability_impls(self, supervisor: Any) -> None: - """向指定 Supervisor 注册主程序提供的能力实现""" - cap_service = supervisor.capability_service - - # ── send.* ───────────────────────────────────────── - cap_service.register_capability("send.text", self._cap_send_text) - cap_service.register_capability("send.emoji", self._cap_send_emoji) - cap_service.register_capability("send.image", self._cap_send_image) - cap_service.register_capability("send.command", self._cap_send_command) - cap_service.register_capability("send.custom", self._cap_send_custom) - cap_service.register_capability("send.forward", self._cap_send_forward) - cap_service.register_capability("send.hybrid", self._cap_send_hybrid) - - # ── llm.* ───────────────────────────────────────── - cap_service.register_capability("llm.generate", self._cap_llm_generate) - cap_service.register_capability("llm.generate_with_tools", self._cap_llm_generate_with_tools) - cap_service.register_capability("llm.get_available_models", self._cap_llm_get_available_models) - - # ── config.* ────────────────────────────────────── - cap_service.register_capability("config.get", self._cap_config_get) - cap_service.register_capability("config.get_plugin", self._cap_config_get_plugin) - cap_service.register_capability("config.get_all", self._cap_config_get_all) - - # ── database.* ──────────────────────────────────── - cap_service.register_capability("database.query", self._cap_database_query) - cap_service.register_capability("database.save", self._cap_database_save) - cap_service.register_capability("database.get", self._cap_database_get) - cap_service.register_capability("database.delete", self._cap_database_delete) - cap_service.register_capability("database.count", self._cap_database_count) - - # ── chat.* ──────────────────────────────────────── - cap_service.register_capability("chat.get_all_streams", self._cap_chat_get_all_streams) - cap_service.register_capability("chat.get_group_streams", self._cap_chat_get_group_streams) - cap_service.register_capability("chat.get_private_streams", self._cap_chat_get_private_streams) - cap_service.register_capability("chat.get_stream_by_group_id", self._cap_chat_get_stream_by_group_id) - cap_service.register_capability("chat.get_stream_by_user_id", self._cap_chat_get_stream_by_user_id) - - # ── message.* ───────────────────────────────────── - cap_service.register_capability("message.get_by_time", self._cap_message_get_by_time) - cap_service.register_capability("message.get_by_time_in_chat", self._cap_message_get_by_time_in_chat) - cap_service.register_capability("message.get_recent", self._cap_message_get_recent) - cap_service.register_capability("message.count_new", self._cap_message_count_new) - cap_service.register_capability("message.build_readable", self._cap_message_build_readable) - - # ── person.* ────────────────────────────────────── - cap_service.register_capability("person.get_id", self._cap_person_get_id) - cap_service.register_capability("person.get_value", self._cap_person_get_value) - cap_service.register_capability("person.get_id_by_name", self._cap_person_get_id_by_name) - - # ── emoji.* ─────────────────────────────────────── - cap_service.register_capability("emoji.get_by_description", self._cap_emoji_get_by_description) - cap_service.register_capability("emoji.get_random", self._cap_emoji_get_random) - cap_service.register_capability("emoji.get_count", self._cap_emoji_get_count) - cap_service.register_capability("emoji.get_emotions", self._cap_emoji_get_emotions) - cap_service.register_capability("emoji.get_all", self._cap_emoji_get_all) - cap_service.register_capability("emoji.get_info", self._cap_emoji_get_info) - cap_service.register_capability("emoji.register", self._cap_emoji_register) - cap_service.register_capability("emoji.delete", self._cap_emoji_delete) - - # ── frequency.* ─────────────────────────────────── - cap_service.register_capability("frequency.get_current_talk_value", self._cap_frequency_get_current_talk_value) - cap_service.register_capability("frequency.set_adjust", self._cap_frequency_set_adjust) - cap_service.register_capability("frequency.get_adjust", self._cap_frequency_get_adjust) - - # ── tool.* ──────────────────────────────────────── - cap_service.register_capability("tool.get_definitions", self._cap_tool_get_definitions) - - # ── component.* ─────────────────────────────────── - cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins) - cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info) - cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins) - cap_service.register_capability( - "component.list_registered_plugins", self._cap_component_list_registered_plugins - ) - cap_service.register_capability("component.enable", self._cap_component_enable) - cap_service.register_capability("component.disable", self._cap_component_disable) - cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin) - cap_service.register_capability("component.unload_plugin", self._cap_component_unload_plugin) - cap_service.register_capability("component.reload_plugin", self._cap_component_reload_plugin) - - # ── knowledge.* ─────────────────────────────────── - cap_service.register_capability("knowledge.search", self._cap_knowledge_search) - - # 注意:logging.* 能力已移除——Runner 端通过 RunnerIPCLogHandler 将 stdlib - # logging 日志批量发送到 Host,由 RunnerLogBridge 重放到主进程 Logger, - # 不再需要单独的 logging.log RPC 能力。 - - logger.debug("已注册全部主程序能力实现") - - # ═════════════════════════════════════════════════════════ - # send.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_send_text(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送文本消息 - - args: text, stream_id, typing?, set_reply?, storage_message? - """ - from src.services import send_service as send_api - - text: str = args.get("text", "") - stream_id: str = args.get("stream_id", "") - if not text or not stream_id: - return {"success": False, "error": "缺少必要参数 text 或 stream_id"} - - try: - result = await send_api.text_to_stream( - text=text, - stream_id=stream_id, - typing=args.get("typing", False), - set_reply=args.get("set_reply", False), - storage_message=args.get("storage_message", True), - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_emoji(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送表情 - - args: emoji_base64, stream_id, storage_message? - """ - from src.services import send_service as send_api - - emoji_base64: str = args.get("emoji_base64", "") - stream_id: str = args.get("stream_id", "") - if not emoji_base64 or not stream_id: - return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} - - try: - result = await send_api.emoji_to_stream( - emoji_base64=emoji_base64, - stream_id=stream_id, - storage_message=args.get("storage_message", True), - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_image(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送图片 - - args: image_base64, stream_id, storage_message? - """ - from src.services import send_service as send_api - - image_base64: str = args.get("image_base64", "") - stream_id: str = args.get("stream_id", "") - if not image_base64 or not stream_id: - return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} - - try: - result = await send_api.image_to_stream( - image_base64=image_base64, - stream_id=stream_id, - storage_message=args.get("storage_message", True), - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_command(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送命令 - - args: command, stream_id, storage_message?, display_message? - """ - from src.services import send_service as send_api - - command = args.get("command", "") - stream_id: str = args.get("stream_id", "") - if not command or not stream_id: - return {"success": False, "error": "缺少必要参数 command 或 stream_id"} - - try: - result = await send_api.command_to_stream( - command=command, - stream_id=stream_id, - storage_message=args.get("storage_message", True), - display_message=args.get("display_message", ""), - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_custom(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送自定义类型消息 - - args: message_type, content, stream_id, display_message?, typing?, storage_message? - """ - from src.services import send_service as send_api - - message_type: str = args.get("message_type", "") or args.get("custom_type", "") - content = args.get("content") - if content is None: - content = args.get("data", "") - stream_id: str = args.get("stream_id", "") - if not message_type or not stream_id: - return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} - - try: - result = await send_api.custom_to_stream( - message_type=message_type, - content=content, - stream_id=stream_id, - display_message=args.get("display_message", ""), - typing=args.get("typing", False), - storage_message=args.get("storage_message", True), - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_forward(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送转发消息 - - args: messages, stream_id - """ - from src.services import send_service as send_api - - messages = args.get("messages", []) - stream_id: str = args.get("stream_id", "") - if not messages or not stream_id: - return {"success": False, "error": "缺少必要参数 messages 或 stream_id"} - - try: - result = await send_api.forward_to_stream( - messages=messages, - stream_id=stream_id, - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.forward] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_send_hybrid(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """发送混合消息(图文混合) - - args: segments, stream_id - """ - from src.services import send_service as send_api - - segments = args.get("segments", []) - stream_id: str = args.get("stream_id", "") - if not segments or not stream_id: - return {"success": False, "error": "缺少必要参数 segments 或 stream_id"} - - try: - result = await send_api.hybrid_to_stream( - segments=segments, - stream_id=stream_id, - ) - return {"success": result} - except Exception as e: - logger.error(f"[cap.send.hybrid] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # llm.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_llm_generate(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """LLM 生成 - - args: prompt, model|model_name?, temperature?, max_tokens? - """ - from src.services import llm_service as llm_api - - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - # 兼容 SDK 发送的 "model" 和旧版的 "model_name" - model_name: str = args.get("model", "") or args.get("model_name", "") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") - - try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - # 选取第一个可用模型配置 - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, - ) - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - } - except Exception as e: - logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """LLM 带工具生成 - - args: prompt, model|model_name?, tools|tool_options?, temperature?, max_tokens? - """ - from src.services import llm_service as llm_api - - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - # 兼容 SDK 发送的 "model"/"tools" 和旧版的 "model_name"/"tool_options" - model_name: str = args.get("model", "") or args.get("model_name", "") - tool_options = args.get("tools") or args.get("tool_options") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") - - try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools( - prompt=prompt, - model_config=model_config, - tool_options=tool_options, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, - ) - # 将 ToolCall 对象序列化为 dict - serialized_tool_calls = None - if tool_calls: - serialized_tool_calls = [ - {"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} - for tc in tool_calls - if hasattr(tc, "function") - ] - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - "tool_calls": serialized_tool_calls, - } - except Exception as e: - logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取可用模型列表""" - from src.services import llm_service as llm_api - - try: - models = llm_api.get_available_models() - return {"success": True, "models": list(models.keys())} - except Exception as e: - logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # config.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_config_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """读取全局配置 - - args: key, default? - """ - from src.services import config_service as config_api - - key: str = args.get("key", "") - default = args.get("default") - - if not key: - return {"success": False, "value": None, "error": "缺少必要参数 key"} - - try: - value = config_api.get_global_config(key, default) - return {"success": True, "value": value} - except Exception as e: - return {"success": False, "value": None, "error": str(e)} - - @staticmethod - async def _cap_config_get_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """读取插件配置 - - args: key, default?, plugin_name? - """ - from src.core.component_registry import component_registry as core_registry - - plugin_name: str = args.get("plugin_name", plugin_id) - key: str = args.get("key", "") - default = args.get("default") - - try: - config = core_registry.get_plugin_config(plugin_name) - if config is None: - return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} - - if key: - from src.services import config_service as config_api - - value = config_api.get_plugin_config(config, key, default) - return {"success": True, "value": value} - - return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": default, "error": str(e)} - - @staticmethod - async def _cap_config_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取当前插件的全部配置""" - from src.core.component_registry import component_registry as core_registry - - plugin_name: str = args.get("plugin_name", plugin_id) - - try: - config = core_registry.get_plugin_config(plugin_name) - if config is None: - return {"success": True, "value": {}} - return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": {}, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # database.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_database_query(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """数据库查询 - - args: model_name|table, query_type?, filters?, limit?, order_by?, data?, single_result? - model_name/table 应为 src.common.database.database_model 中的类名。 - """ - from src.services import database_service as database_api - - # 兼容 SDK 发送的 "table" 和旧版的 "model_name" - model_name: str = args.get("model_name", "") or args.get("table", "") - if not model_name: - return {"success": False, "error": "缺少必要参数 model_name 或 table"} - - try: - import src.common.database.database_model as db_models - - model_class = getattr(db_models, model_name, None) - if model_class is None: - return {"success": False, "error": f"未找到数据模型: {model_name}"} - - result = await database_api.db_query( - model_class=model_class, - data=args.get("data"), - query_type=args.get("query_type", "get"), - filters=args.get("filters"), - limit=args.get("limit"), - order_by=args.get("order_by"), - single_result=args.get("single_result", False), - ) - return {"success": True, "result": result} - except Exception as e: - logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_database_save(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """数据库保存 - - args: model_name|table, data, key_field?, key_value? - """ - from src.services import database_service as database_api - - # 兼容 SDK 发送的 "table" 和旧版的 "model_name" - model_name: str = args.get("model_name", "") or args.get("table", "") - data: Optional[Dict[str, Any]] = args.get("data") - if not model_name or not data: - return {"success": False, "error": "缺少必要参数 model_name/table 或 data"} - - try: - import src.common.database.database_model as db_models - - model_class = getattr(db_models, model_name, None) - if model_class is None: - return {"success": False, "error": f"未找到数据模型: {model_name}"} - - result = await database_api.db_save( - model_class=model_class, - data=data, - key_field=args.get("key_field"), - key_value=args.get("key_value"), - ) - return {"success": True, "result": result} - except Exception as e: - logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_database_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """数据库简单查询 - - args: model_name|table, filters?, key_field?, key_value?, limit?, order_by?, single_result? - """ - from src.services import database_service as database_api - - # 兼容 SDK 发送的 "table" 和旧版的 "model_name" - model_name: str = args.get("model_name", "") or args.get("table", "") - if not model_name: - return {"success": False, "error": "缺少必要参数 model_name 或 table"} - - try: - import src.common.database.database_model as db_models - - model_class = getattr(db_models, model_name, None) - if model_class is None: - return {"success": False, "error": f"未找到数据模型: {model_name}"} - - # 兼容 SDK 的 key_field/key_value 参数,自动转换为 filters - filters = args.get("filters") - key_value = args.get("key_value") - if not filters: - key_field = args.get("key_field", "id") - if key_value is not None: - filters = {key_field: key_value} - - result = await database_api.db_get( - model_class=model_class, - filters=filters, - limit=args.get("limit"), - order_by=args.get("order_by"), - single_result=args.get("single_result", key_value is not None), - ) - return {"success": True, "result": result} - except Exception as e: - logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_database_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """数据库删除 - - args: model_name|table, filters - """ - from src.services import database_service as database_api - - # 兼容 SDK 发送的 "table" 和旧版的 "model_name" - model_name: str = args.get("model_name", "") or args.get("table", "") - filters = args.get("filters", {}) - if not model_name: - return {"success": False, "error": "缺少必要参数 model_name 或 table"} - if not filters: - return {"success": False, "error": "缺少必要参数 filters(不允许无条件删除)"} - - try: - import src.common.database.database_model as db_models - - model_class = getattr(db_models, model_name, None) - if model_class is None: - return {"success": False, "error": f"未找到数据模型: {model_name}"} - - result = await database_api.db_delete( - model_class=model_class, - filters=filters, - ) - return {"success": True, "result": result} - except Exception as e: - logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_database_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """数据库计数 - - args: model_name|table, filters? - """ - from src.services import database_service as database_api - - # 兼容 SDK 发送的 "table" 和旧版的 "model_name" - model_name: str = args.get("model_name", "") or args.get("table", "") - if not model_name: - return {"success": False, "error": "缺少必要参数 model_name 或 table"} - - try: - import src.common.database.database_model as db_models - - model_class = getattr(db_models, model_name, None) - if model_class is None: - return {"success": False, "error": f"未找到数据模型: {model_name}"} - - result = await database_api.db_count( - model_class=model_class, - filters=args.get("filters"), - ) - return {"success": True, "count": result} - except Exception as e: - logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # chat.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]: - """将 BotChatSession 序列化为可通过 RPC 传输的字典""" - return { - "session_id": stream.session_id, - "platform": stream.platform, - "user_id": stream.user_id, - "group_id": stream.group_id, - "is_group_session": stream.is_group_session, - } - - @staticmethod - async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有聊天流 - - args: platform? - """ - from src.services import chat_service as chat_api - - platform: str = args.get("platform", "qq") - try: - streams = chat_api.ChatManager.get_all_streams(platform=platform) - return { - "success": True, - "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], - } - except Exception as e: - logger.error(f"[cap.chat.get_all_streams] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有群聊流 - - args: platform? - """ - from src.services import chat_service as chat_api - - platform: str = args.get("platform", "qq") - try: - streams = chat_api.ChatManager.get_group_streams(platform=platform) - return { - "success": True, - "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], - } - except Exception as e: - logger.error(f"[cap.chat.get_group_streams] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有私聊流 - - args: platform? - """ - from src.services import chat_service as chat_api - - platform: str = args.get("platform", "qq") - try: - streams = chat_api.ChatManager.get_private_streams(platform=platform) - return { - "success": True, - "streams": [PluginRuntimeManager._serialize_stream(s) for s in streams], - } - except Exception as e: - logger.error(f"[cap.chat.get_private_streams] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """按群 ID 查找聊天流 - - args: group_id, platform? - """ - from src.services import chat_service as chat_api - - group_id: str = args.get("group_id", "") - if not group_id: - return {"success": False, "error": "缺少必要参数 group_id"} - - platform: str = args.get("platform", "qq") - try: - stream = chat_api.ChatManager.get_group_stream_by_group_id(group_id=group_id, platform=platform) - if stream is None: - return {"success": True, "stream": None} - return {"success": True, "stream": PluginRuntimeManager._serialize_stream(stream)} - except Exception as e: - logger.error(f"[cap.chat.get_stream_by_group_id] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """按用户 ID 查找私聊流 - - args: user_id, platform? - """ - from src.services import chat_service as chat_api - - user_id: str = args.get("user_id", "") - if not user_id: - return {"success": False, "error": "缺少必要参数 user_id"} - - platform: str = args.get("platform", "qq") - try: - stream = chat_api.ChatManager.get_private_stream_by_user_id(user_id=user_id, platform=platform) - if stream is None: - return {"success": True, "stream": None} - return {"success": True, "stream": PluginRuntimeManager._serialize_stream(stream)} - except Exception as e: - logger.error(f"[cap.chat.get_stream_by_user_id] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # message.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - def _serialize_messages(messages: list) -> List[Dict[str, Any]]: - """将 DatabaseMessages 列表序列化为 dict 列表""" - result: List[Dict[str, Any]] = [] - for msg in messages: - if hasattr(msg, "model_dump"): - result.append(msg.model_dump()) - elif hasattr(msg, "__dict__"): - result.append(dict(msg.__dict__)) - else: - result.append(str(msg)) - return result - - @staticmethod - async def _cap_message_get_by_time(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """按时间范围查询消息 - - args: start_time, end_time, limit?, filter_mai? - """ - from src.services import message_service as message_api - - start_time = args.get("start_time", 0.0) - end_time = args.get("end_time", 0.0) - - try: - messages = message_api.get_messages_by_time( - start_time=float(start_time), - end_time=float(end_time), - limit=args.get("limit", 0), - limit_mode=args.get("limit_mode", "latest"), - filter_mai=args.get("filter_mai", False), - ) - return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} - except Exception as e: - logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """按时间范围查询指定聊天消息 - - args: chat_id, start_time, end_time, limit?, filter_mai?, filter_command? - """ - from src.services import message_service as message_api - - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数 chat_id"} - - try: - messages = message_api.get_messages_by_time_in_chat( - chat_id=chat_id, - start_time=float(args.get("start_time", 0.0)), - end_time=float(args.get("end_time", 0.0)), - limit=args.get("limit", 0), - limit_mode=args.get("limit_mode", "latest"), - filter_mai=args.get("filter_mai", False), - filter_command=args.get("filter_command", False), - ) - return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} - except Exception as e: - logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_message_get_recent(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取最近的消息 - - args: chat_id, hours?, limit?, filter_mai? - """ - from src.services import message_service as message_api - - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数 chat_id"} - - try: - messages = message_api.get_recent_messages( - chat_id=chat_id, - hours=float(args.get("hours", 24.0)), - limit=args.get("limit", 100), - limit_mode=args.get("limit_mode", "latest"), - filter_mai=args.get("filter_mai", False), - ) - return {"success": True, "messages": PluginRuntimeManager._serialize_messages(messages)} - except Exception as e: - logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_message_count_new(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """统计新消息数量 - - args: chat_id, since? | start_time?, end_time? - """ - from src.services import message_service as message_api - - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数 chat_id"} - - try: - # 兼容 SDK 传 since 和直接传 start_time 两种方式 - since = args.get("since") - start_time = float(since) if since is not None else float(args.get("start_time", 0.0)) - count = message_api.count_new_messages( - chat_id=chat_id, - start_time=start_time, - end_time=args.get("end_time"), - ) - return {"success": True, "count": count} - except Exception as e: - logger.error(f"[cap.message.count_new] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_message_build_readable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """将消息列表构建成可读字符串 - - 支持两种调用方式: - 1. SDK 方式: 传入 messages(已查询的消息列表) - 2. 直接方式: 传入 chat_id + start_time + end_time(Host 端查询) - - args: messages? | chat_id?, start_time?, end_time?, limit?, replace_bot_name?, timestamp_mode? - """ - from src.services import message_service as message_api - - try: - # 优先使用调用方已提供的消息列表 - messages = args.get("messages") - if messages is None: - # 回退到 chat_id + 时间范围查询 - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数: messages 或 chat_id"} - messages = message_api.get_messages_by_time_in_chat( - chat_id=chat_id, - start_time=float(args.get("start_time", 0.0)), - end_time=float(args.get("end_time", 0.0)), - limit=args.get("limit", 0), - ) - - readable = message_api.build_readable_messages_to_str( - messages=messages, - replace_bot_name=args.get("replace_bot_name", True), - timestamp_mode=args.get("timestamp_mode", "relative"), - truncate=args.get("truncate", False), - ) - return {"success": True, "text": readable} - except Exception as e: - logger.error(f"[cap.message.build_readable] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # person.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_person_get_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取 person_id - - args: platform, user_id - """ - from src.services import person_service as person_api - - platform: str = args.get("platform", "") - user_id = args.get("user_id", "") - if not platform or not user_id: - return {"success": False, "error": "缺少必要参数 platform 或 user_id"} - - try: - pid = person_api.get_person_id(platform=platform, user_id=user_id) - return {"success": True, "person_id": pid} - except Exception as e: - logger.error(f"[cap.person.get_id] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_person_get_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取用户字段值 - - args: person_id, field_name, default? - """ - from src.services import person_service as person_api - - person_id: str = args.get("person_id", "") - field_name: str = args.get("field_name", "") - if not person_id or not field_name: - return {"success": False, "error": "缺少必要参数 person_id 或 field_name"} - - try: - value = await person_api.get_person_value( - person_id=person_id, - field_name=field_name, - default=args.get("default"), - ) - return {"success": True, "value": value} - except Exception as e: - logger.error(f"[cap.person.get_value] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """根据用户名获取 person_id - - args: person_name - """ - from src.services import person_service as person_api - - person_name: str = args.get("person_name", "") - if not person_name: - return {"success": False, "error": "缺少必要参数 person_name"} - - try: - pid = person_api.get_person_id_by_name(person_name=person_name) - return {"success": True, "person_id": pid} - except Exception as e: - logger.error(f"[cap.person.get_id_by_name] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # emoji.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """根据描述获取表情包 - - args: description - """ - from src.services import emoji_service as emoji_api - - description: str = args.get("description", "") - if not description: - return {"success": False, "error": "缺少必要参数 description"} - - try: - result = await emoji_api.get_by_description(description=description) - if result is None: - return {"success": True, "emoji": None} - emoji_base64, emoji_desc, matched_emotion = result - return { - "success": True, - "emoji": { - "base64": emoji_base64, - "description": emoji_desc, - "emotion": matched_emotion, - }, - } - except Exception as e: - logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_get_random(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """随机获取表情包 - - args: count? - """ - from src.services import emoji_service as emoji_api - - count: int = args.get("count", 1) - try: - results = await emoji_api.get_random(count=count) - emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] - return {"success": True, "emojis": emojis} - except Exception as e: - logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_get_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取表情包数量""" - from src.services import emoji_service as emoji_api - - try: - return {"success": True, "count": emoji_api.get_count()} - except Exception as e: - logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有情绪标签""" - from src.services import emoji_service as emoji_api - - try: - return {"success": True, "emotions": emoji_api.get_emotions()} - except Exception as e: - logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有表情包""" - from src.services import emoji_service as emoji_api - - try: - results = await emoji_api.get_all() - emojis = ( - [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else [] - ) - return {"success": True, "emojis": emojis} - except Exception as e: - logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_get_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取表情包统计信息""" - from src.services import emoji_service as emoji_api - - try: - return {"success": True, "info": emoji_api.get_info()} - except Exception as e: - logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_register(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """注册表情包 - - args: emoji_base64 - """ - from src.services import emoji_service as emoji_api - - emoji_base64: str = args.get("emoji_base64", "") - if not emoji_base64: - return {"success": False, "error": "缺少必要参数 emoji_base64"} - - try: - result = await emoji_api.register_emoji(emoji_base64) - return result - except Exception as e: - logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_emoji_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """删除表情包 - - args: emoji_hash - """ - from src.services import emoji_service as emoji_api - - emoji_hash: str = args.get("emoji_hash", "") - if not emoji_hash: - return {"success": False, "error": "缺少必要参数 emoji_hash"} - - try: - result = await emoji_api.delete_emoji(emoji_hash) - return result - except Exception as e: - logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # frequency.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取当前说话频率值 - - args: chat_id - """ - from src.services import frequency_service as frequency_api - - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数 chat_id"} - - try: - value = frequency_api.get_current_talk_value(chat_id) - return {"success": True, "value": value} - except Exception as e: - logger.error(f"[cap.frequency.get_current_talk_value] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """设置说话频率调整值 - - args: chat_id, value - """ - from src.services import frequency_service as frequency_api - - chat_id: str = args.get("chat_id", "") - value = args.get("value") - if not chat_id or value is None: - return {"success": False, "error": "缺少必要参数 chat_id 或 value"} - - try: - frequency_api.set_talk_frequency_adjust(chat_id, float(value)) - return {"success": True} - except Exception as e: - logger.error(f"[cap.frequency.set_adjust] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - @staticmethod - async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取说话频率调整值 - - args: chat_id - """ - from src.services import frequency_service as frequency_api - - chat_id: str = args.get("chat_id", "") - if not chat_id: - return {"success": False, "error": "缺少必要参数 chat_id"} - - try: - value = frequency_api.get_talk_frequency_adjust(chat_id) - return {"success": True, "value": value} - except Exception as e: - logger.error(f"[cap.frequency.get_adjust] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # tool.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取 LLM 可用的工具定义列表""" - from src.core.component_registry import component_registry as core_registry - - try: - tools = core_registry.get_llm_available_tools() - return { - "success": True, - "tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()], - } - except Exception as e: - logger.error(f"[cap.tool.get_definitions] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - # ═════════════════════════════════════════════════════════ - # component.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取所有插件信息(汇总所有 Supervisor 的注册信息,包含组件列表)""" - mgr = get_plugin_runtime_manager() - result: Dict[str, Any] = {} - for sv in mgr.supervisors: - for pid, reg in sv._registered_plugins.items(): - if pid in result: - logger.error(f"检测到重复插件 ID {pid},component.get_all_plugins 结果已拒绝聚合") - return {"success": False, "error": f"检测到重复插件 ID: {pid}"} - # 从 ComponentRegistry 中获取该插件的所有组件 - comps = sv.component_registry.get_components_by_plugin(pid, enabled_only=False) - components_list = [ - { - "name": c.name, - "full_name": c.full_name, - "type": c.component_type, - "enabled": c.enabled, - "metadata": c.metadata, - } - for c in comps - ] - result[pid] = { - "name": pid, - "version": reg.plugin_version, - "description": "", - "author": "", - "enabled": True, - "components": components_list, - } - return {"success": True, "plugins": result} - - @staticmethod - async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """获取指定插件信息 - - args: plugin_name - """ - plugin_name: str = args.get("plugin_name", plugin_id) - mgr = get_plugin_runtime_manager() - try: - sv = mgr._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} - - if sv is not None: - reg = sv._registered_plugins.get(plugin_name) - if reg is not None: - return { - "success": True, - "plugin": { - "name": plugin_name, - "version": reg.plugin_version, - "description": "", - "author": "", - "enabled": True, - }, - } - return {"success": False, "error": f"未找到插件: {plugin_name}"} - - @staticmethod - async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """列出已加载的插件""" - mgr = get_plugin_runtime_manager() - plugins: List[str] = [] - for sv in mgr.supervisors: - plugins.extend(sv._registered_plugins.keys()) - return {"success": True, "plugins": plugins} - - @staticmethod - async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """列出已注册的插件(同 list_loaded)""" - mgr = get_plugin_runtime_manager() - plugins: List[str] = [] - for sv in mgr.supervisors: - plugins.extend(sv._registered_plugins.keys()) - return {"success": True, "plugins": plugins} - - @staticmethod - def _resolve_component_toggle_target(name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]: - """解析组件启停目标。 - - 支持全名 plugin_id.component_name;短名仅在全局唯一时允许, - 否则返回歧义错误,避免跨 Supervisor 误操作。 - """ - mgr = get_plugin_runtime_manager() - short_name_matches: List[Any] = [] - - for sv in mgr.supervisors: - comp = sv.component_registry.get_component(name) - if comp is not None and comp.component_type == component_type: - return comp, None - - short_name_matches.extend( - candidate - for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False) - if candidate.name == name - ) - - if len(short_name_matches) == 1: - return short_name_matches[0], None - if len(short_name_matches) > 1: - return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" - return None, f"未找到组件: {name} ({component_type})" - - @staticmethod - async def _cap_component_enable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """启用组件 - - args: name, component_type, scope, stream_id - """ - name: str = args.get("name", "") - component_type: str = args.get("component_type", "") - scope: str = args.get("scope", "global") - stream_id: str = args.get("stream_id", "") - if not name or not component_type: - return {"success": False, "error": "缺少必要参数 name 或 component_type"} - if scope != "global" or stream_id: - return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} - - comp, error = PluginRuntimeManager._resolve_component_toggle_target(name, component_type) - if comp is None: - return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} - - comp.enabled = True - return {"success": True} - - @staticmethod - async def _cap_component_disable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """禁用组件 - - args: name, component_type, scope, stream_id - """ - name: str = args.get("name", "") - component_type: str = args.get("component_type", "") - scope: str = args.get("scope", "global") - stream_id: str = args.get("stream_id", "") - if not name or not component_type: - return {"success": False, "error": "缺少必要参数 name 或 component_type"} - if scope != "global" or stream_id: - return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} - - comp, error = PluginRuntimeManager._resolve_component_toggle_target(name, component_type) - if comp is None: - return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} - - comp.enabled = False - return {"success": True} - - @staticmethod - async def _cap_component_load_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """加载插件(在新运行时中通过热重载实现) - - 先验证目标插件是否已注册或插件目录是否存在于某个 Supervisor, - 然后只对拥有该插件的 Supervisor 执行热重载。 - - args: plugin_name - """ - plugin_name: str = args.get("plugin_name", "") - if not plugin_name: - return {"success": False, "error": "缺少必要参数 plugin_name"} - - import os - - mgr = get_plugin_runtime_manager() - if duplicate_plugin_ids := mgr._find_duplicate_plugin_ids(list(mgr._iter_plugin_dirs())): - details = "; ".join( - f"{conflict_plugin_id}: {', '.join(paths)}" - for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items()) - ) - return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} - - # 优先查找已注册该插件的 Supervisor - try: - registered_supervisor = mgr._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} - - if registered_supervisor is not None: - try: - reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}") - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - # 插件尚未注册,检查是否有 Supervisor 的 plugin_dirs 下包含该插件目录 - for sv in mgr.supervisors: - for pdir in sv._plugin_dirs: - if os.path.isdir(os.path.join(pdir, plugin_name)): - try: - reloaded = await sv.reload_plugins(reason=f"load {plugin_name}") - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - return {"success": False, "error": f"未找到插件: {plugin_name}"} - - @staticmethod - async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """卸载插件(在新运行时中不支持单独卸载) - - args: plugin_name - """ - return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} - - @staticmethod - async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """重新加载插件(触发对应 Supervisor 的热重载) - - args: plugin_name - """ - plugin_name: str = args.get("plugin_name", "") - if not plugin_name: - return {"success": False, "error": "缺少必要参数 plugin_name"} - - mgr = get_plugin_runtime_manager() - if duplicate_plugin_ids := mgr._find_duplicate_plugin_ids(list(mgr._iter_plugin_dirs())): - details = "; ".join( - f"{conflict_plugin_id}: {', '.join(paths)}" - for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items()) - ) - return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} - - try: - sv = mgr._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} - - if sv is not None: - try: - reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}") - if reloaded: - return {"success": True} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - return {"success": False, "error": f"未找到插件: {plugin_name}"} - - # ═════════════════════════════════════════════════════════ - # knowledge.* 能力实现 - # ═════════════════════════════════════════════════════════ - - @staticmethod - async def _cap_knowledge_search(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - """从 LPMM 知识库搜索知识 - - args: query, limit? - """ - query: str = args.get("query", "") - if not query: - return {"success": False, "error": "缺少必要参数 query"} - - limit = args.get("limit", 5) - try: - limit_value = max(1, int(limit)) - except (TypeError, ValueError): - limit_value = 5 - - try: - from src.chat.knowledge import qa_manager - - if qa_manager is None: - return {"success": True, "content": "LPMM知识库已禁用"} - - knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"success": True, "content": content} - except Exception as e: - logger.error(f"[cap.knowledge.search] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None: + register_capability_impls(self, supervisor) # ─── 单例 ────────────────────────────────────────────────── diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 35324bd9..effb0b93 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -4,7 +4,7 @@ 提供聊天信息查询和管理的核心功能。 """ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from enum import Enum from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager @@ -23,50 +23,58 @@ class ChatManager: """聊天管理器 - 负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend + def _validate_platform(platform: Optional[str] | SpecialTypes) -> None: if not isinstance(platform, (str, SpecialTypes)): raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] + + @staticmethod + def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool: + return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform + + @staticmethod + def _get_streams( + platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None + ) -> List[BotChatSession]: + ChatManager._validate_platform(platform) + try: - for _, stream in _chat_manager.sessions.items(): - if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: - streams.append(stream) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流") + streams = [ + stream + for stream in _chat_manager.sessions.values() + if ChatManager._match_platform(stream, platform) + and (is_group_session is None or stream.is_group_session == is_group_session) + ] + return streams except Exception as e: logger.error(f"[ChatService] 获取聊天流失败: {e}") + return [] + + @staticmethod + def _find_stream( + predicate: Callable[[BotChatSession], bool], + platform: Optional[str] | SpecialTypes = "qq", + ) -> Optional[BotChatSession]: + for stream in ChatManager._get_streams(platform=platform): + if predicate(stream): + return stream + return None + + @staticmethod + def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: + streams = ChatManager._get_streams(platform=platform) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流") return streams @staticmethod def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] - try: - for _, stream in _chat_manager.sessions.items(): - if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session: - streams.append(stream) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流") - except Exception as e: - logger.error(f"[ChatService] 获取群聊流失败: {e}") + streams = ChatManager._get_streams(platform=platform, is_group_session=True) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流") return streams @staticmethod def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - # sourcery skip: for-append-to-extend - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - streams = [] - try: - for _, stream in _chat_manager.sessions.items(): - if ( - platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform - ) and not stream.is_group_session: - streams.append(stream) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流") - except Exception as e: - logger.error(f"[ChatService] 获取私聊流失败: {e}") + streams = ChatManager._get_streams(platform=platform, is_group_session=False) + logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流") return streams @staticmethod @@ -75,19 +83,17 @@ class ChatManager: ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast if not isinstance(group_id, str): raise TypeError("group_id 必须是字符串类型") - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + ChatManager._validate_platform(platform) if not group_id: raise ValueError("group_id 不能为空") try: - for _, stream in _chat_manager.sessions.items(): - if ( - stream.is_group_session - and str(stream.group_id) == str(group_id) - and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) - ): - logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流") - return stream + stream = ChatManager._find_stream( + lambda item: item.is_group_session and str(item.group_id) == str(group_id), + platform=platform, + ) + if stream is not None: + logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流") + return stream logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流") except Exception as e: logger.error(f"[ChatService] 查找群聊流失败: {e}") @@ -99,19 +105,17 @@ class ChatManager: ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast if not isinstance(user_id, str): raise TypeError("user_id 必须是字符串类型") - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + ChatManager._validate_platform(platform) if not user_id: raise ValueError("user_id 不能为空") try: - for _, stream in _chat_manager.sessions.items(): - if ( - not stream.is_group_session - and str(stream.user_id) == str(user_id) - and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) - ): - logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流") - return stream + stream = ChatManager._find_stream( + lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id), + platform=platform, + ) + if stream is not None: + logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流") + return stream logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流") except Exception as e: logger.error(f"[ChatService] 查找私聊流失败: {e}") diff --git a/src/services/database_service.py b/src/services/database_service.py index 31705745..8d9192d2 100644 --- a/src/services/database_service.py +++ b/src/services/database_service.py @@ -1,13 +1,16 @@ -"""数据库服务模块 - -提供数据库操作相关的核心功能。 -""" +"""数据库服务模块。""" import json import time import traceback +from datetime import datetime from typing import Any, Optional +from sqlalchemy import delete, func, select +from sqlmodel import SQLModel + +from src.common.database.database import get_db_session +from src.common.database.database_model import ActionRecord from src.common.logger import get_logger logger = get_logger("database_service") @@ -25,73 +28,57 @@ def _to_dict(record: Any) -> dict[str, Any]: return {} -async def db_query( - model_class, - data: Optional[dict[str, Any]] = None, - query_type: str = "get", - filters: Optional[dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[list[str]] = None, - single_result: bool = False, +def _get_model_field(model_class: type[SQLModel], field_name: str): + field = getattr(model_class, field_name, None) + if field is None: + raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}") + return field + + +def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> list[Any]: + if not filters: + return [] + return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()] + + +def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None): + if not order_by: + return statement + + order_fields = [order_by] if isinstance(order_by, str) else order_by + clauses = [] + for item in order_fields: + descending = item.startswith("-") + field_name = item[1:] if descending else item + field = _get_model_field(model_class, field_name) + clauses.append(field.desc() if descending else field.asc()) + return statement.order_by(*clauses) + + +async def db_save( + model_class: type[SQLModel], + data: dict[str, Any], + key_field: Optional[str] = None, + key_value: Optional[Any] = None, ): try: - if query_type not in ["get", "create", "update", "delete", "count"]: - raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'") + with get_db_session() as session: + record = None + if key_field and key_value is not None: + key_column = _get_model_field(model_class, key_field) + record = session.exec(select(model_class).where(key_column == key_value)).first() - if query_type == "get": - query = model_class.select() - if filters: - for field, value in filters.items(): - query = query.where(getattr(model_class, field) == value) - if order_by: - query = query.order_by(*order_by) - if limit: - query = query.limit(limit) - results = list(query.dicts()) - if single_result: - return results[0] if results else None - return results + if record is None: + record = model_class(**data) + else: + for field_name, value in data.items(): + _get_model_field(model_class, field_name) + setattr(record, field_name, value) - if query_type == "create": - if not data: - raise ValueError("创建记录需要提供data参数") - record = model_class.create(**data) + session.add(record) + session.flush() + session.refresh(record) return _to_dict(record) - - query = model_class.select() - if filters: - for field, value in filters.items(): - query = query.where(getattr(model_class, field) == value) - - if query_type == "update": - if not data: - raise ValueError("更新记录需要提供data参数") - return query.model_class.update(**data).where(*query.stmt._where_criteria).execute() - - if query_type == "delete": - return model_class.delete().where(*query.stmt._where_criteria).execute() - - return query.count() - except Exception as e: - logger.error(f"[DatabaseService] 数据库操作出错: {e}") - traceback.print_exc() - if query_type == "get": - return None if single_result else [] - return None - - -async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None): - try: - if key_field and key_value is not None: - record = model_class.get_or_none(getattr(model_class, key_field) == key_value) - if record is not None: - for field, value in data.items(): - setattr(record, field, value) - record.save() - return _to_dict(record) - - new_record = model_class.create(**data) - return _to_dict(new_record) except Exception as e: logger.error(f"[DatabaseService] 保存数据库记录出错: {e}") traceback.print_exc() @@ -99,68 +86,108 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = async def db_get( - model_class, + model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None, limit: Optional[int] = None, - order_by: Optional[str] = None, + order_by: Optional[str | list[str]] = None, single_result: bool = False, ): try: - query = model_class.select() - if filters: - for field, value in filters.items(): - query = query.where(getattr(model_class, field) == value) - if order_by: - query = query.order_by(order_by) - if limit: - query = query.limit(limit) - results = list(query.dicts()) - if single_result: - return results[0] if results else None - return results + with get_db_session(auto_commit=False) as session: + statement = select(model_class) + conditions = _build_filters(model_class, filters) + if conditions: + statement = statement.where(*conditions) + statement = _apply_order_by(statement, model_class, order_by) + if limit: + statement = statement.limit(limit) + results = session.exec(statement).all() + data = [_to_dict(item) for item in results] + if single_result: + return data[0] if data else None + return data except Exception as e: logger.error(f"[DatabaseService] 获取数据库记录出错: {e}") traceback.print_exc() return None if single_result else [] +async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: Optional[dict[str, Any]] = None) -> int: + try: + with get_db_session() as session: + statement = select(model_class) + conditions = _build_filters(model_class, filters) + if conditions: + statement = statement.where(*conditions) + records = session.exec(statement).all() + for record in records: + for field_name, value in data.items(): + _get_model_field(model_class, field_name) + setattr(record, field_name, value) + session.add(record) + return len(records) + except Exception as e: + logger.error(f"[DatabaseService] 更新数据库记录出错: {e}") + traceback.print_exc() + return 0 + + +async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int: + try: + with get_db_session() as session: + statement = delete(model_class) + conditions = _build_filters(model_class, filters) + if conditions: + statement = statement.where(*conditions) + result = session.exec(statement) + return result.rowcount or 0 + except Exception as e: + logger.error(f"[DatabaseService] 删除数据库记录出错: {e}") + traceback.print_exc() + return 0 + + +async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int: + try: + with get_db_session(auto_commit=False) as session: + statement = select(func.count()).select_from(model_class) + conditions = _build_filters(model_class, filters) + if conditions: + statement = statement.where(*conditions) + result = session.exec(statement).one() + return int(result or 0) + except Exception as e: + logger.error(f"[DatabaseService] 统计数据库记录出错: {e}") + traceback.print_exc() + return 0 + + async def store_action_info( chat_stream=None, - action_build_into_prompt: bool = False, - action_prompt_display: str = "", - action_done: bool = True, + builtin_prompt: Optional[str] = None, + display_prompt: str = "", thinking_id: str = "", action_data: Optional[dict] = None, action_name: str = "", action_reasoning: str = "", ): try: - from src.common.database.database_model import ActionRecords + if chat_stream is None: + raise ValueError("store_action_info 需要 chat_stream") record_data = { "action_id": thinking_id or str(int(time.time() * 1000000)), - "time": time.time(), + "timestamp": datetime.now(), + "session_id": chat_stream.session_id, "action_name": action_name, "action_data": json.dumps(action_data or {}, ensure_ascii=False), - "action_done": action_done, "action_reasoning": action_reasoning, - "action_build_into_prompt": action_build_into_prompt, - "action_prompt_display": action_prompt_display, + "action_builtin_prompt": builtin_prompt, + "action_display_prompt": display_prompt, } - if chat_stream: - record_data.update( - { - "chat_id": getattr(chat_stream, "stream_id", ""), - "chat_info_stream_id": getattr(chat_stream, "stream_id", ""), - "chat_info_platform": getattr(chat_stream, "platform", ""), - } - ) - else: - record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""}) - saved_record = await db_save( - ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] + ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"] ) if saved_record: logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") diff --git a/src/services/frequency_service.py b/src/services/frequency_service.py deleted file mode 100644 index eceb6b95..00000000 --- a/src/services/frequency_service.py +++ /dev/null @@ -1,21 +0,0 @@ -"""频率控制服务模块 - -提供聊天频率控制的核心功能。 -""" - -from src.chat.heart_flow.frequency_control import frequency_control_manager -from src.config.config import global_config - - -def get_current_talk_value(chat_id: str) -> float: - return frequency_control_manager.get_or_create_frequency_control( - chat_id - ).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id) - - -def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: - frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust) - - -def get_talk_frequency_adjust(chat_id: str) -> float: - return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() diff --git a/src/services/llm_service.py b/src/services/llm_service.py index b267e67c..2927b5c1 100644 --- a/src/services/llm_service.py +++ b/src/services/llm_service.py @@ -16,6 +16,38 @@ from src.llm_models.utils_model import LLMRequest logger = get_logger("llm_service") +async def _generate_response( + model_config: TaskConfig, + request_type: str, + prompt: Optional[str] = None, + message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, + tool_options: Optional[List[Dict[str, Any]]] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, +) -> Tuple[str, str, str, List[ToolCall] | None]: + llm_request = LLMRequest(model_set=model_config, request_type=request_type) + + if message_factory is not None: + response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async( + message_factory=message_factory, + tools=tool_options, + temperature=temperature, + max_tokens=max_tokens, + ) + return response, reasoning_content, model_name, tool_call + + if prompt is None: + raise ValueError("prompt 与 message_factory 不能同时为空") + + response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( + prompt, + tools=tool_options, + temperature=temperature, + max_tokens=max_tokens, + ) + return response, reasoning_content, model_name, tool_call + + def get_available_models() -> Dict[str, TaskConfig]: """获取所有可用的模型配置 @@ -61,11 +93,12 @@ async def generate_with_model( """ try: logger.debug(f"[LLMService] 完整提示词: {prompt}") - - llm_request = LLMRequest(model_set=model_config, request_type=request_type) - - response, (reasoning_content, model_name, _) = await llm_request.generate_response_async( - prompt, temperature=temperature, max_tokens=max_tokens + response, reasoning_content, model_name, _ = await _generate_response( + model_config=model_config, + request_type=request_type, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, ) return True, response, reasoning_content, model_name @@ -101,10 +134,13 @@ async def generate_with_model_with_tools( logger.info(f"使用模型{model_name_list}生成内容") logger.debug(f"完整提示词: {prompt}") - llm_request = LLMRequest(model_set=model_config, request_type=request_type) - - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( - prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens + response, reasoning_content, model_name, tool_call = await _generate_response( + model_config=model_config, + request_type=request_type, + prompt=prompt, + tool_options=tool_options, + temperature=temperature, + max_tokens=max_tokens, ) return True, response, reasoning_content, model_name, tool_call @@ -139,11 +175,11 @@ async def generate_with_model_with_tools_by_message_factory( model_name_list = model_config.model_list logger.info(f"使用模型 {model_name_list} 生成内容") - llm_request = LLMRequest(model_set=model_config, request_type=request_type) - - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async( + response, reasoning_content, model_name, tool_call = await _generate_response( + model_config=model_config, + request_type=request_type, message_factory=message_factory, - tools=tool_options, + tool_options=tool_options, temperature=temperature, max_tokens=max_tokens, ) diff --git a/src/services/person_service.py b/src/services/person_service.py deleted file mode 100644 index 74c02feb..00000000 --- a/src/services/person_service.py +++ /dev/null @@ -1,65 +0,0 @@ -"""个人信息服务模块 - -提供个人信息查询的核心功能。 -""" - -from typing import Any - -from src.common.logger import get_logger -from src.person_info.person_info import Person - -logger = get_logger("person_service") - - -def get_person_id(platform: str, user_id: int | str) -> str: - """根据平台和用户ID获取person_id - - Args: - platform: 平台名称,如 "qq", "telegram" 等 - user_id: 用户ID - - Returns: - str: 唯一的person_id(MD5哈希值) - """ - try: - return Person(platform=platform, user_id=str(user_id)).person_id - except Exception as e: - logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}") - return "" - - -async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any: - """根据person_id和字段名获取某个值 - - Args: - person_id: 用户的唯一标识ID - field_name: 要获取的字段名,如 "nickname", "impression" 等 - default: 当字段不存在或获取失败时返回的默认值 - - Returns: - Any: 字段值或默认值 - """ - try: - person = Person(person_id=person_id) - value = getattr(person, field_name) - return value if value is not None else default - except Exception as e: - logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}") - return default - - -def get_person_id_by_name(person_name: str) -> str: - """根据用户名获取person_id - - Args: - person_name: 用户名 - - Returns: - str: person_id,如果未找到返回空字符串 - """ - try: - person = Person(person_name=person_name) - return person.person_id - except Exception as e: - logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}") - return ""