炸 service 层

This commit is contained in:
DrSmoothl
2026-03-14 00:13:35 +08:00
parent 898fab6de9
commit 43c5b34623
13 changed files with 1408 additions and 1736 deletions

View File

@@ -225,9 +225,7 @@ class BrainChatting:
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=action_prompt_display,
action_done=True,
display_prompt=action_prompt_display,
thinking_id=thinking_id,
action_data={"reply_text": reply_text},
action_name="reply",
@@ -573,9 +571,7 @@ class BrainChatting:
# 存储complete_talk信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
display_prompt=reason,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="complete_talk",
@@ -678,9 +674,7 @@ class BrainChatting:
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason or f"等待 {wait_seconds}",
action_done=True,
display_prompt=reason or f"等待 {wait_seconds}",
thinking_id=thinking_id,
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait",
@@ -722,9 +716,7 @@ class BrainChatting:
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason or f"倾听并等待 {wait_seconds}",
action_done=True,
display_prompt=reason or f"倾听并等待 {wait_seconds}",
thinking_id=thinking_id,
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening",

View File

@@ -8,6 +8,7 @@ import json
import time
import re
import difflib
import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass, field
@@ -917,21 +918,18 @@ class ChatHistorySummarizer:
# 准备数据
data = {
"chat_id": self.session_id,
"start_time": start_time,
"end_time": end_time,
"original_text": original_text,
"session_id": self.session_id,
"start_timestamp": datetime.fromtimestamp(start_time),
"end_timestamp": datetime.fromtimestamp(end_time),
"original_messages": original_text,
"participants": json.dumps(participants, ensure_ascii=False),
"theme": theme,
"keywords": json.dumps(keywords, ensure_ascii=False),
"summary": summary,
"count": 0,
"query_count": 0,
"query_forget_count": 0,
}
# 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time
# 由于peewee不支持组合键我们直接创建新记录不提供key_field和key_value
saved_record = await database_api.db_save(
ChatHistory,
data=data,

View File

@@ -0,0 +1,9 @@
from .components import RuntimeComponentCapabilityMixin
from .core import RuntimeCoreCapabilityMixin
from .data import RuntimeDataCapabilityMixin
__all__ = [
"RuntimeComponentCapabilityMixin",
"RuntimeCoreCapabilityMixin",
"RuntimeDataCapabilityMixin",
]

View File

@@ -0,0 +1,194 @@
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
class RuntimeComponentCapabilityMixin:
async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
result: Dict[str, Any] = {}
for sv in self.supervisors:
for pid, reg in sv._registered_plugins.items():
if pid in result:
logger.error(f"检测到重复插件 ID {pid}component.get_all_plugins 结果已拒绝聚合")
return {"success": False, "error": f"检测到重复插件 ID: {pid}"}
comps = sv.component_registry.get_components_by_plugin(pid, enabled_only=False)
components_list = [
{
"name": component.name,
"full_name": component.full_name,
"type": component.component_type,
"enabled": component.enabled,
"metadata": component.metadata,
}
for component in comps
]
result[pid] = {
"name": pid,
"version": reg.plugin_version,
"description": "",
"author": "",
"enabled": True,
"components": components_list,
}
return {"success": True, "plugins": result}
async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None and (reg := sv._registered_plugins.get(plugin_name)) is not None:
return {
"success": True,
"plugin": {
"name": plugin_name,
"version": reg.plugin_version,
"description": "",
"author": "",
"enabled": True,
},
}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]:
short_name_matches: List[Any] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type:
return comp, None
short_name_matches.extend(
candidate
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})"
async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
return {"success": False, "error": "缺少必要参数 name 或 component_type"}
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
comp.enabled = True
return {"success": True}
async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
return {"success": False, "error": "缺少必要参数 name 或 component_type"}
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
comp.enabled = False
return {"success": True}
async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}
import os
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
details = "; ".join(
f"{conflict_plugin_id}: {', '.join(paths)}"
for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items())
)
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if registered_supervisor is not None:
try:
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
for sv in self.supervisors:
for pdir in sv._plugin_dirs:
if os.path.isdir(os.path.join(pdir, plugin_name)):
try:
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
details = "; ".join(
f"{conflict_plugin_id}: {', '.join(paths)}"
for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items())
)
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None:
try:
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}

View File

@@ -0,0 +1,283 @@
from typing import Any, Dict
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
class RuntimeCoreCapabilityMixin:
async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
text: str = args.get("text", "")
stream_id: str = args.get("stream_id", "")
if not text or not stream_id:
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
try:
result = await send_api.text_to_stream(
text=text,
stream_id=stream_id,
typing=args.get("typing", False),
set_reply=args.get("set_reply", False),
storage_message=args.get("storage_message", True),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
emoji_base64: str = args.get("emoji_base64", "")
stream_id: str = args.get("stream_id", "")
if not emoji_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
try:
result = await send_api.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
image_base64: str = args.get("image_base64", "")
stream_id: str = args.get("stream_id", "")
if not image_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
try:
result = await send_api.image_to_stream(
image_base64=image_base64,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
command = args.get("command", "")
stream_id: str = args.get("stream_id", "")
if not command or not stream_id:
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
try:
result = await send_api.command_to_stream(
command=command,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
display_message=args.get("display_message", ""),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
content = args.get("content")
if content is None:
content = args.get("data", "")
stream_id: str = args.get("stream_id", "")
if not message_type or not stream_id:
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
try:
result = await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=stream_id,
display_message=args.get("display_message", ""),
typing=args.get("typing", False),
storage_message=args.get("storage_message", True),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_forward(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
messages = args.get("messages", [])
stream_id: str = args.get("stream_id", "")
if not messages or not stream_id:
return {"success": False, "error": "缺少必要参数 messages 或 stream_id"}
try:
result = await send_api.forward_to_stream(messages=messages, stream_id=stream_id)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.forward] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_send_hybrid(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import send_service as send_api
segments = args.get("segments", [])
stream_id: str = args.get("stream_id", "")
if not segments or not stream_id:
return {"success": False, "error": "缺少必要参数 segments 或 stream_id"}
try:
result = await send_api.hybrid_to_stream(segments=segments, stream_id=stream_id)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.hybrid] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import llm_service as llm_api
prompt: str = args.get("prompt", "")
if not prompt:
return {"success": False, "error": "缺少必要参数 prompt"}
model_name: str = args.get("model", "") or args.get("model_name", "")
temperature = args.get("temperature")
max_tokens = args.get("max_tokens")
try:
models = llm_api.get_available_models()
if model_name and model_name in models:
model_config = models[model_name]
else:
if not models:
return {"success": False, "error": "没有可用的模型配置"}
model_config = next(iter(models.values()))
success, response, reasoning, used_model = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type=f"plugin.{plugin_id}",
temperature=temperature,
max_tokens=max_tokens,
)
return {
"success": success,
"response": response,
"reasoning": reasoning,
"model_name": used_model,
}
except Exception as e:
logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import llm_service as llm_api
prompt: str = args.get("prompt", "")
if not prompt:
return {"success": False, "error": "缺少必要参数 prompt"}
model_name: str = args.get("model", "") or args.get("model_name", "")
tool_options = args.get("tools") or args.get("tool_options")
temperature = args.get("temperature")
max_tokens = args.get("max_tokens")
try:
models = llm_api.get_available_models()
if model_name and model_name in models:
model_config = models[model_name]
else:
if not models:
return {"success": False, "error": "没有可用的模型配置"}
model_config = next(iter(models.values()))
success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools(
prompt=prompt,
model_config=model_config,
tool_options=tool_options,
request_type=f"plugin.{plugin_id}",
temperature=temperature,
max_tokens=max_tokens,
)
serialized_tool_calls = None
if tool_calls:
serialized_tool_calls = [
{"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
for tc in tool_calls
if hasattr(tc, "function")
]
return {
"success": success,
"response": response,
"reasoning": reasoning,
"model_name": used_model,
"tool_calls": serialized_tool_calls,
}
except Exception as e:
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import llm_service as llm_api
try:
models = llm_api.get_available_models()
return {"success": True, "models": list(models.keys())}
except Exception as e:
logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import config_service as config_api
key: str = args.get("key", "")
default = args.get("default")
if not key:
return {"success": False, "value": None, "error": "缺少必要参数 key"}
try:
value = config_api.get_global_config(key, default)
return {"success": True, "value": value}
except Exception as e:
return {"success": False, "value": None, "error": str(e)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
default = args.get("default")
try:
config = core_registry.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
if key:
from src.services import config_service as config_api
value = config_api.get_plugin_config(config, key, default)
return {"success": True, "value": value}
return {"success": True, "value": config}
except Exception as e:
return {"success": False, "value": default, "error": str(e)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
plugin_name: str = args.get("plugin_name", plugin_id)
try:
config = core_registry.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}
except Exception as e:
return {"success": False, "value": {}, "error": str(e)}

View File

@@ -0,0 +1,582 @@
from typing import Any, Dict, List, Optional
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
class RuntimeDataCapabilityMixin:
async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
model_name: str = args.get("model_name", "")
if not model_name:
return {"success": False, "error": "缺少必要参数 model_name"}
try:
import src.common.database.database_model as db_models
model_class = getattr(db_models, model_name, None)
if model_class is None:
return {"success": False, "error": f"未找到数据模型: {model_name}"}
query_type = args.get("query_type", "get")
if query_type == "get":
result = await database_api.db_get(
model_class=model_class,
filters=args.get("filters"),
limit=args.get("limit"),
order_by=args.get("order_by"),
single_result=args.get("single_result", False),
)
elif query_type == "create":
if not (data := args.get("data")):
return {"success": False, "error": "create 需要 data"}
result = await database_api.db_save(model_class=model_class, data=data)
elif query_type == "update":
if not (data := args.get("data")):
return {"success": False, "error": "update 需要 data"}
result = await database_api.db_update(
model_class=model_class,
data=data,
filters=args.get("filters"),
)
elif query_type == "delete":
result = await database_api.db_delete(model_class=model_class, filters=args.get("filters"))
elif query_type == "count":
result = await database_api.db_count(model_class=model_class, filters=args.get("filters"))
else:
return {"success": False, "error": f"不支持的 query_type: {query_type}"}
return {"success": True, "result": result}
except Exception as e:
logger.error(f"[cap.database.query] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_database_save(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
model_name: str = args.get("model_name", "")
data: Optional[Dict[str, Any]] = args.get("data")
if not model_name or not data:
return {"success": False, "error": "缺少必要参数 model_name 或 data"}
try:
import src.common.database.database_model as db_models
model_class = getattr(db_models, model_name, None)
if model_class is None:
return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_api.db_save(
model_class=model_class,
data=data,
key_field=args.get("key_field"),
key_value=args.get("key_value"),
)
return {"success": True, "result": result}
except Exception as e:
logger.error(f"[cap.database.save] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_database_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
model_name: str = args.get("model_name", "")
if not model_name:
return {"success": False, "error": "缺少必要参数 model_name"}
try:
import src.common.database.database_model as db_models
model_class = getattr(db_models, model_name, None)
if model_class is None:
return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_api.db_get(
model_class=model_class,
filters=args.get("filters"),
limit=args.get("limit"),
order_by=args.get("order_by"),
single_result=args.get("single_result", False),
)
return {"success": True, "result": result}
except Exception as e:
logger.error(f"[cap.database.get] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_database_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
model_name: str = args.get("model_name", "")
filters = args.get("filters", {})
if not model_name:
return {"success": False, "error": "缺少必要参数 model_name"}
if not filters:
return {"success": False, "error": "缺少必要参数 filters不允许无条件删除"}
try:
import src.common.database.database_model as db_models
model_class = getattr(db_models, model_name, None)
if model_class is None:
return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_api.db_delete(model_class=model_class, filters=filters)
return {"success": True, "result": result}
except Exception as e:
logger.error(f"[cap.database.delete] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_database_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
model_name: str = args.get("model_name", "")
if not model_name:
return {"success": False, "error": "缺少必要参数 model_name"}
try:
import src.common.database.database_model as db_models
model_class = getattr(db_models, model_name, None)
if model_class is None:
return {"success": False, "error": f"未找到数据模型: {model_name}"}
result = await database_api.db_count(model_class=model_class, filters=args.get("filters"))
return {"success": True, "count": result}
except Exception as e:
logger.error(f"[cap.database.count] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
def _list_sessions(self, platform: str, is_group_session: Optional[bool] = None) -> List[BotChatSession]:
return [
session
for session in chat_manager.sessions.values()
if (platform == "all_platforms" or session.platform == platform)
and (is_group_session is None or session.is_group_session == is_group_session)
]
@staticmethod
def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]:
return {
"session_id": stream.session_id,
"platform": stream.platform,
"user_id": stream.user_id,
"group_id": stream.group_id,
"is_group_session": stream.is_group_session,
}
async def _cap_chat_get_all_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
platform: str = args.get("platform", "qq")
try:
streams = self._list_sessions(platform=platform)
return {"success": True, "streams": [self._serialize_stream(item) for item in streams]}
except Exception as e:
logger.error(f"[cap.chat.get_all_streams] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_chat_get_group_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
platform: str = args.get("platform", "qq")
try:
streams = self._list_sessions(platform=platform, is_group_session=True)
return {"success": True, "streams": [self._serialize_stream(item) for item in streams]}
except Exception as e:
logger.error(f"[cap.chat.get_group_streams] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_chat_get_private_streams(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
platform: str = args.get("platform", "qq")
try:
streams = self._list_sessions(platform=platform, is_group_session=False)
return {"success": True, "streams": [self._serialize_stream(item) for item in streams]}
except Exception as e:
logger.error(f"[cap.chat.get_private_streams] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_chat_get_stream_by_group_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
group_id: str = args.get("group_id", "")
if not group_id:
return {"success": False, "error": "缺少必要参数 group_id"}
platform: str = args.get("platform", "qq")
try:
stream = next(
(
item
for item in self._list_sessions(platform=platform, is_group_session=True)
if str(item.group_id) == str(group_id)
),
None,
)
return {"success": True, "stream": None if stream is None else self._serialize_stream(stream)}
except Exception as e:
logger.error(f"[cap.chat.get_stream_by_group_id] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_chat_get_stream_by_user_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
user_id: str = args.get("user_id", "")
if not user_id:
return {"success": False, "error": "缺少必要参数 user_id"}
platform: str = args.get("platform", "qq")
try:
stream = next(
(
item
for item in self._list_sessions(platform=platform, is_group_session=False)
if str(item.user_id) == str(user_id)
),
None,
)
return {"success": True, "stream": None if stream is None else self._serialize_stream(stream)}
except Exception as e:
logger.error(f"[cap.chat.get_stream_by_user_id] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
@staticmethod
def _serialize_messages(messages: list) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
for msg in messages:
if hasattr(msg, "model_dump"):
result.append(msg.model_dump())
elif hasattr(msg, "__dict__"):
result.append(dict(msg.__dict__))
else:
result.append(str(msg))
return result
async def _cap_message_get_by_time(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import message_service as message_api
try:
messages = message_api.get_messages_by_time(
start_time=float(args.get("start_time", 0.0)),
end_time=float(args.get("end_time", 0.0)),
limit=args.get("limit", 0),
limit_mode=args.get("limit_mode", "latest"),
filter_mai=args.get("filter_mai", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
except Exception as e:
logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_message_get_by_time_in_chat(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import message_service as message_api
chat_id: str = args.get("chat_id", "")
if not chat_id:
return {"success": False, "error": "缺少必要参数 chat_id"}
try:
messages = message_api.get_messages_by_time_in_chat(
chat_id=chat_id,
start_time=float(args.get("start_time", 0.0)),
end_time=float(args.get("end_time", 0.0)),
limit=args.get("limit", 0),
limit_mode=args.get("limit_mode", "latest"),
filter_mai=args.get("filter_mai", False),
filter_command=args.get("filter_command", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
except Exception as e:
logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_message_get_recent(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import message_service as message_api
chat_id: str = args.get("chat_id", "")
if not chat_id:
return {"success": False, "error": "缺少必要参数 chat_id"}
try:
messages = message_api.get_recent_messages(
chat_id=chat_id,
hours=float(args.get("hours", 24.0)),
limit=args.get("limit", 100),
limit_mode=args.get("limit_mode", "latest"),
filter_mai=args.get("filter_mai", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
except Exception as e:
logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_message_count_new(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import message_service as message_api
chat_id: str = args.get("chat_id", "")
if not chat_id:
return {"success": False, "error": "缺少必要参数 chat_id"}
try:
since = args.get("since")
start_time = float(since) if since is not None else float(args.get("start_time", 0.0))
count = message_api.count_new_messages(
chat_id=chat_id,
start_time=start_time,
end_time=args.get("end_time"),
)
return {"success": True, "count": count}
except Exception as e:
logger.error(f"[cap.message.count_new] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_message_build_readable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import message_service as message_api
try:
messages = args.get("messages")
if messages is None:
if not (chat_id := args.get("chat_id", "")):
return {"success": False, "error": "缺少必要参数: messages 或 chat_id"}
messages = message_api.get_messages_by_time_in_chat(
chat_id=chat_id,
start_time=float(args.get("start_time", 0.0)),
end_time=float(args.get("end_time", 0.0)),
limit=args.get("limit", 0),
)
readable = message_api.build_readable_messages_to_str(
messages=messages,
replace_bot_name=args.get("replace_bot_name", True),
timestamp_mode=args.get("timestamp_mode", "relative"),
truncate=args.get("truncate", False),
)
return {"success": True, "text": readable}
except Exception as e:
logger.error(f"[cap.message.build_readable] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_person_get_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.person_info.person_info import Person
platform: str = args.get("platform", "")
user_id = args.get("user_id", "")
if not platform or not user_id:
return {"success": False, "error": "缺少必要参数 platform 或 user_id"}
try:
pid = Person(platform=platform, user_id=str(user_id)).person_id
return {"success": True, "person_id": pid}
except Exception as e:
logger.error(f"[cap.person.get_id] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_person_get_value(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.person_info.person_info import Person
person_id: str = args.get("person_id", "")
field_name: str = args.get("field_name", "")
if not person_id or not field_name:
return {"success": False, "error": "缺少必要参数 person_id 或 field_name"}
try:
person = Person(person_id=person_id)
value = getattr(person, field_name)
if value is None:
value = args.get("default")
return {"success": True, "value": value}
except Exception as e:
logger.error(f"[cap.person.get_value] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_person_get_id_by_name(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.person_info.person_info import Person
person_name: str = args.get("person_name", "")
if not person_name:
return {"success": False, "error": "缺少必要参数 person_name"}
try:
pid = Person(person_name=person_name).person_id
return {"success": True, "person_id": pid}
except Exception as e:
logger.error(f"[cap.person.get_id_by_name] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
description: str = args.get("description", "")
if not description:
return {"success": False, "error": "缺少必要参数 description"}
try:
result = await emoji_api.get_by_description(description=description)
if result is None:
return {"success": True, "emoji": None}
emoji_base64, emoji_desc, matched_emotion = result
return {
"success": True,
"emoji": {
"base64": emoji_base64,
"description": emoji_desc,
"emotion": matched_emotion,
},
}
except Exception as e:
logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
count: int = args.get("count", 1)
try:
results = await emoji_api.get_random(count=count)
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "count": emoji_api.get_count()}
except Exception as e:
logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "emotions": emoji_api.get_emotions()}
except Exception as e:
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
results = await emoji_api.get_all()
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "info": emoji_api.get_info()}
except Exception as e:
logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
emoji_base64: str = args.get("emoji_base64", "")
if not emoji_base64:
return {"success": False, "error": "缺少必要参数 emoji_base64"}
try:
return await emoji_api.register_emoji(emoji_base64)
except Exception as e:
logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
emoji_hash: str = args.get("emoji_hash", "")
if not emoji_hash:
return {"success": False, "error": "缺少必要参数 emoji_hash"}
try:
return await emoji_api.delete_emoji(emoji_hash)
except Exception as e:
logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
@staticmethod
def _get_frequency_adjust_value(chat_id: str) -> float:
from src.chat.heart_flow.heartflow_manager import heartflow_manager
heartflow_chat = heartflow_manager.heartflow_chat_list.get(chat_id)
return 1.0 if heartflow_chat is None else heartflow_chat._talk_frequency_adjust
async def _cap_frequency_get_current_talk_value(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.common.utils.utils_config import ChatConfigUtils
chat_id: str = args.get("chat_id", "")
if not chat_id:
return {"success": False, "error": "缺少必要参数 chat_id"}
try:
value = self._get_frequency_adjust_value(chat_id) * ChatConfigUtils.get_talk_value(chat_id)
return {"success": True, "value": value}
except Exception as e:
logger.error(f"[cap.frequency.get_current_talk_value] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_frequency_set_adjust(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.chat.heart_flow.heartflow_manager import heartflow_manager
chat_id: str = args.get("chat_id", "")
value = args.get("value")
if not chat_id or value is None:
return {"success": False, "error": "缺少必要参数 chat_id 或 value"}
try:
heartflow_manager.adjust_talk_frequency(chat_id, float(value))
return {"success": True}
except Exception as e:
logger.error(f"[cap.frequency.set_adjust] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_frequency_get_adjust(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
chat_id: str = args.get("chat_id", "")
if not chat_id:
return {"success": False, "error": "缺少必要参数 chat_id"}
try:
value = self._get_frequency_adjust_value(chat_id)
return {"success": True, "value": value}
except Exception as e:
logger.error(f"[cap.frequency.get_adjust] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
try:
tools = core_registry.get_llm_available_tools()
return {
"success": True,
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],
}
except Exception as e:
logger.error(f"[cap.tool.get_definitions] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_knowledge_search(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
query: str = args.get("query", "")
if not query:
return {"success": False, "error": "缺少必要参数 query"}
limit = args.get("limit", 5)
try:
limit_value = max(1, int(limit))
except (TypeError, ValueError):
limit_value = 5
try:
from src.chat.knowledge import qa_manager
if qa_manager is None:
return {"success": True, "content": "LPMM知识库已禁用"}
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识"
return {"success": True, "content": content}
except Exception as e:
logger.error(f"[cap.knowledge.search] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -0,0 +1,80 @@
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_runtime.host.supervisor import PluginSupervisor
if TYPE_CHECKING:
from src.plugin_runtime.integration import PluginRuntimeManager
logger = get_logger("plugin_runtime.integration")
def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None:
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
cap_service.register_capability("send.text", manager._cap_send_text)
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
cap_service.register_capability("send.image", manager._cap_send_image)
cap_service.register_capability("send.command", manager._cap_send_command)
cap_service.register_capability("send.custom", manager._cap_send_custom)
cap_service.register_capability("send.forward", manager._cap_send_forward)
cap_service.register_capability("send.hybrid", manager._cap_send_hybrid)
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
cap_service.register_capability("config.get", manager._cap_config_get)
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
cap_service.register_capability("database.query", manager._cap_database_query)
cap_service.register_capability("database.save", manager._cap_database_save)
cap_service.register_capability("database.get", manager._cap_database_get)
cap_service.register_capability("database.delete", manager._cap_database_delete)
cap_service.register_capability("database.count", manager._cap_database_count)
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
cap_service.register_capability("component.enable", manager._cap_component_enable)
cap_service.register_capability("component.disable", manager._cap_component_disable)
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
logger.debug("已注册全部主程序能力实现")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
@@ -23,50 +23,58 @@ class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
def _validate_platform(platform: Optional[str] | SpecialTypes) -> None:
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
@staticmethod
def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool:
return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform
@staticmethod
def _get_streams(
platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None
) -> List[BotChatSession]:
ChatManager._validate_platform(platform)
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
streams = [
stream
for stream in _chat_manager.sessions.values()
if ChatManager._match_platform(stream, platform)
and (is_group_session is None or stream.is_group_session == is_group_session)
]
return streams
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return []
@staticmethod
def _find_stream(
predicate: Callable[[BotChatSession], bool],
platform: Optional[str] | SpecialTypes = "qq",
) -> Optional[BotChatSession]:
for stream in ChatManager._get_streams(platform=platform):
if predicate(stream):
return stream
return None
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
streams = ChatManager._get_streams(platform=platform)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatService] 获取群聊流失败: {e}")
streams = ChatManager._get_streams(platform=platform, is_group_session=True)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatService] 获取私聊流失败: {e}")
streams = ChatManager._get_streams(platform=platform, is_group_session=False)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
return streams
@staticmethod
@@ -75,19 +83,17 @@ class ChatManager:
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
ChatManager._validate_platform(platform)
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
stream = ChatManager._find_stream(
lambda item: item.is_group_session and str(item.group_id) == str(group_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
@@ -99,19 +105,17 @@ class ChatManager:
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
ChatManager._validate_platform(platform)
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
stream = ChatManager._find_stream(
lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")

View File

@@ -1,13 +1,16 @@
"""数据库服务模块
提供数据库操作相关的核心功能。
"""
"""数据库服务模块"""
import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord
from src.common.logger import get_logger
logger = get_logger("database_service")
@@ -25,73 +28,57 @@ def _to_dict(record: Any) -> dict[str, Any]:
return {}
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: str = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: bool = False,
def _get_model_field(model_class: type[SQLModel], field_name: str):
field = getattr(model_class, field_name, None)
if field is None:
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
return field
def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> list[Any]:
if not filters:
return []
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
if not order_by:
return statement
order_fields = [order_by] if isinstance(order_by, str) else order_by
clauses = []
for item in order_fields:
descending = item.startswith("-")
field_name = item[1:] if descending else item
field = _get_model_field(model_class, field_name)
clauses.append(field.desc() if descending else field.asc())
return statement.order_by(*clauses)
async def db_save(
model_class: type[SQLModel],
data: dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None,
):
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
with get_db_session() as session:
record = None
if key_field and key_value is not None:
key_column = _get_model_field(model_class, key_field)
record = session.exec(select(model_class).where(key_column == key_value)).first()
if query_type == "get":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(*order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
if record is None:
record = model_class(**data)
else:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
if query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
record = model_class.create(**data)
session.add(record)
session.flush()
session.refresh(record)
return _to_dict(record)
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
if query_type == "delete":
return model_class.delete().where(*query.stmt._where_criteria).execute()
return query.count()
except Exception as e:
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
try:
if key_field and key_value is not None:
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
if record is not None:
for field, value in data.items():
setattr(record, field, value)
record.save()
return _to_dict(record)
new_record = model_class.create(**data)
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc()
@@ -99,68 +86,108 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
async def db_get(
model_class,
model_class: type[SQLModel],
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
order_by: Optional[str | list[str]] = None,
single_result: bool = False,
):
try:
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
with get_db_session(auto_commit=False) as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
statement = _apply_order_by(statement, model_class, order_by)
if limit:
statement = statement.limit(limit)
results = session.exec(statement).all()
data = [_to_dict(item) for item in results]
if single_result:
return data[0] if data else None
return data
except Exception as e:
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
records = session.exec(statement).all()
for record in records:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
session.add(record)
return len(records)
except Exception as e:
logger.error(f"[DatabaseService] 更新数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = delete(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement)
return result.rowcount or 0
except Exception as e:
logger.error(f"[DatabaseService] 删除数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session(auto_commit=False) as session:
statement = select(func.count()).select_from(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
traceback.print_exc()
return 0
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
):
try:
from src.common.database.database_model import ActionRecords
if chat_stream is None:
raise ValueError("store_action_info 需要 chat_stream")
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"timestamp": datetime.now(),
"session_id": chat_stream.session_id,
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_reasoning": action_reasoning,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
"action_builtin_prompt": builtin_prompt,
"action_display_prompt": display_prompt,
}
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")

View File

@@ -1,21 +0,0 @@
"""频率控制服务模块
提供聊天频率控制的核心功能。
"""
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.config.config import global_config
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()

View File

@@ -16,6 +16,38 @@ from src.llm_models.utils_model import LLMRequest
logger = get_logger("llm_service")
async def _generate_response(
model_config: TaskConfig,
request_type: str,
prompt: Optional[str] = None,
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, str, List[ToolCall] | None]:
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
if message_factory is not None:
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
message_factory=message_factory,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return response, reasoning_content, model_name, tool_call
if prompt is None:
raise ValueError("prompt 与 message_factory 不能同时为空")
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return response, reasoning_content, model_name, tool_call
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
@@ -61,11 +93,12 @@ async def generate_with_model(
"""
try:
logger.debug(f"[LLMService] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
response, reasoning_content, model_name, _ = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name
@@ -101,10 +134,13 @@ async def generate_with_model_with_tools(
logger.info(f"使用模型{model_name_list}生成内容")
logger.debug(f"完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name, tool_call
@@ -139,11 +175,11 @@ async def generate_with_model_with_tools_by_message_factory(
model_name_list = model_config.model_list
logger.info(f"使用模型 {model_name_list} 生成内容")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
message_factory=message_factory,
tools=tool_options,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)

View File

@@ -1,65 +0,0 @@
"""个人信息服务模块
提供个人信息查询的核心功能。
"""
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import Person
logger = get_logger("person_service")
def get_person_id(platform: str, user_id: int | str) -> str:
"""根据平台和用户ID获取person_id
Args:
platform: 平台名称,如 "qq", "telegram"
user_id: 用户ID
Returns:
str: 唯一的person_idMD5哈希值
"""
try:
return Person(platform=platform, user_id=str(user_id)).person_id
except Exception as e:
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
"""根据person_id和字段名获取某个值
Args:
person_id: 用户的唯一标识ID
field_name: 要获取的字段名,如 "nickname", "impression"
default: 当字段不存在或获取失败时返回的默认值
Returns:
Any: 字段值或默认值
"""
try:
person = Person(person_id=person_id)
value = getattr(person, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
person_name: 用户名
Returns:
str: person_id如果未找到返回空字符串
"""
try:
person = Person(person_name=person_name)
return person.person_id
except Exception as e:
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""