feat:支持msg_id获取消息能力
This commit is contained in:
@@ -19,7 +19,7 @@ def get_tool_spec() -> ToolSpec:
|
|||||||
|
|
||||||
return ToolSpec(
|
return ToolSpec(
|
||||||
name="at",
|
name="at",
|
||||||
brief_description="根据一条已知 msg_id 找到发言用户,并发送一条 @ 该用户的消息。",
|
brief_description="当明确提及某位用户时,发送一条 @ 该用户的消息。",
|
||||||
detailed_description=(
|
detailed_description=(
|
||||||
"参数说明:\n"
|
"参数说明:\n"
|
||||||
"- msg_id:string,必填。要 @ 的目标用户发过的消息编号。\n"
|
"- msg_id:string,必填。要 @ 的目标用户发过的消息编号。\n"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
|||||||
from src.common.data_models.image_data_model import MaiEmoji
|
from src.common.data_models.image_data_model import MaiEmoji
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.utils.utils_image import ImageUtils
|
from src.common.utils.utils_image import ImageUtils
|
||||||
|
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||||
|
|
||||||
logger = get_logger("plugin_runtime.integration")
|
logger = get_logger("plugin_runtime.integration")
|
||||||
|
|
||||||
@@ -298,7 +299,9 @@ class RuntimeDataCapabilityMixin:
|
|||||||
def _serialize_messages(messages: list) -> List[Any]:
|
def _serialize_messages(messages: list) -> List[Any]:
|
||||||
result: List[Any] = []
|
result: List[Any] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if hasattr(msg, "model_dump"):
|
if all(hasattr(msg, attr) for attr in ("message_id", "timestamp", "platform", "message_info", "raw_message")):
|
||||||
|
result.append(dict(PluginMessageUtils._session_message_to_dict(msg)))
|
||||||
|
elif hasattr(msg, "model_dump"):
|
||||||
result.append(msg.model_dump())
|
result.append(msg.model_dump())
|
||||||
elif hasattr(msg, "__dict__"):
|
elif hasattr(msg, "__dict__"):
|
||||||
result.append(dict(msg.__dict__))
|
result.append(dict(msg.__dict__))
|
||||||
@@ -306,6 +309,24 @@ class RuntimeDataCapabilityMixin:
|
|||||||
result.append(str(msg))
|
result.append(str(msg))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def _cap_message_get_by_id(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
|
from src.services import message_service
|
||||||
|
|
||||||
|
message_id = str(args.get("message_id") or args.get("msg_id") or "").strip()
|
||||||
|
if not message_id:
|
||||||
|
return {"success": False, "error": "缺少必要参数 message_id"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = message_service.get_message_by_id(
|
||||||
|
message_id=message_id,
|
||||||
|
chat_id=str(args.get("chat_id") or args.get("stream_id") or "").strip() or None,
|
||||||
|
)
|
||||||
|
serialized_message = self._serialize_messages([message])[0] if message is not None else None
|
||||||
|
return {"success": True, "message": serialized_message}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[cap.message.get_by_id] 执行失败: {e}", exc_info=True)
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_message_get_by_time(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_message_get_by_time(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import message_service
|
from src.services import message_service
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
|||||||
|
|
||||||
_register("message.get_by_time", manager._cap_message_get_by_time)
|
_register("message.get_by_time", manager._cap_message_get_by_time)
|
||||||
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||||
|
_register("message.get_by_id", manager._cap_message_get_by_id)
|
||||||
_register("message.get_recent", manager._cap_message_get_recent)
|
_register("message.get_recent", manager._cap_message_get_recent)
|
||||||
_register("message.count_new", manager._cap_message_count_new)
|
_register("message.count_new", manager._cap_message_count_new)
|
||||||
_register("message.build_readable", manager._cap_message_build_readable)
|
_register("message.build_readable", manager._cap_message_build_readable)
|
||||||
|
|||||||
@@ -97,6 +97,24 @@ def get_messages_by_time_in_chat(
|
|||||||
return _normalize_messages(messages)
|
return _normalize_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_by_id(message_id: str, chat_id: Optional[str] = None) -> Optional[SessionMessage]:
|
||||||
|
"""按消息 ID 查询单条消息,可选限定会话。"""
|
||||||
|
|
||||||
|
normalized_message_id = str(message_id or "").strip()
|
||||||
|
if not normalized_message_id:
|
||||||
|
raise ValueError("message_id 不能为空")
|
||||||
|
|
||||||
|
normalized_chat_id = str(chat_id or "").strip()
|
||||||
|
messages = find_messages(
|
||||||
|
session_id=normalized_chat_id or None,
|
||||||
|
message_id=normalized_message_id,
|
||||||
|
limit=1,
|
||||||
|
limit_mode="latest",
|
||||||
|
)
|
||||||
|
normalized_messages = _normalize_messages(messages)
|
||||||
|
return normalized_messages[0] if normalized_messages else None
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
||||||
if not isinstance(timestamp, (int, float)):
|
if not isinstance(timestamp, (int, float)):
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
raise ValueError("timestamp 必须是数字类型")
|
||||||
|
|||||||
Reference in New Issue
Block a user