炸 service 层 x 2,把能归类为现有重构好的模块的都归类过去
This commit is contained in:
@@ -28,10 +28,7 @@ from src.services import (
|
|||||||
message_service as message_api,
|
message_service as message_api,
|
||||||
database_service as database_api,
|
database_service as database_api,
|
||||||
)
|
)
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
|
||||||
build_readable_messages_with_id,
|
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
@@ -275,7 +272,7 @@ class BrainChatting:
|
|||||||
|
|
||||||
# 一次思考迭代:Think - Act - Observe
|
# 一次思考迭代:Think - Act - Observe
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_messages_before_time_in_chat(
|
||||||
chat_id=self.stream_id,
|
chat_id=self.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.6),
|
limit=int(global_config.chat.max_context_size * 0.6),
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ from src.common.logger import get_logger
|
|||||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.services.message_service import (
|
||||||
build_readable_actions,
|
build_readable_actions,
|
||||||
get_actions_by_timestamp_with_chat,
|
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_actions_by_timestamp_with_chat,
|
||||||
|
get_messages_before_time_in_chat,
|
||||||
)
|
)
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
@@ -163,7 +163,7 @@ class BrainPlanner:
|
|||||||
plan_start = time.perf_counter()
|
plan_start = time.perf_counter()
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_messages_before_time_in_chat(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.6),
|
limit=int(global_config.chat.max_context_size * 0.6),
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||||
from src.core.types import ActionActivationType, ActionInfo
|
from src.core.types import ActionActivationType, ActionInfo
|
||||||
from src.core.announcement_manager import global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class ActionModifier:
|
|||||||
self.action_manager.restore_actions()
|
self.action_manager.restore_actions()
|
||||||
all_actions = self.action_manager.get_using_actions()
|
all_actions = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_messages_before_time_in_chat(
|
||||||
chat_id=self.chat_stream.stream_id,
|
chat_id=self.chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||||
|
|||||||
@@ -15,17 +15,17 @@ from src.common.logger import get_logger
|
|||||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.services.message_service import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
|
||||||
replace_user_references,
|
replace_user_references,
|
||||||
|
get_messages_before_time_in_chat,
|
||||||
|
translate_pid_to_description,
|
||||||
)
|
)
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||||
from src.core.component_registry import component_registry
|
from src.core.component_registry import component_registry
|
||||||
from src.services.message_service import translate_pid_to_description
|
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -389,7 +389,7 @@ class ActionPlanner:
|
|||||||
plan_start = time.perf_counter()
|
plan_start = time.perf_counter()
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_messages_before_time_in_chat(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.6),
|
limit=int(global_config.chat.max_context_size * 0.6),
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
|||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.services.message_service import (
|
from src.services.message_service import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_messages_before_time_in_chat,
|
||||||
replace_user_references,
|
replace_user_references,
|
||||||
translate_pid_to_description,
|
translate_pid_to_description,
|
||||||
)
|
)
|
||||||
@@ -809,14 +809,14 @@ class DefaultReplyer:
|
|||||||
# 将[picid:xxx]替换为具体的图片描述
|
# 将[picid:xxx]替换为具体的图片描述
|
||||||
target = self._replace_picids_with_descriptions(target)
|
target = self._replace_picids_with_descriptions(target)
|
||||||
|
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=reply_time_point,
|
timestamp=reply_time_point,
|
||||||
limit=global_config.chat.max_context_size * 1,
|
limit=global_config.chat.max_context_size * 1,
|
||||||
filter_intercept_message_level=1,
|
filter_intercept_message_level=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=reply_time_point,
|
timestamp=reply_time_point,
|
||||||
limit=int(global_config.chat.max_context_size * 0.33),
|
limit=int(global_config.chat.max_context_size * 0.33),
|
||||||
@@ -1022,7 +1022,7 @@ class DefaultReplyer:
|
|||||||
# 将[picid:xxx]替换为具体的图片描述
|
# 将[picid:xxx]替换为具体的图片描述
|
||||||
target = self._replace_picids_with_descriptions(target)
|
target = self._replace_picids_with_descriptions(target)
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from src.prompt.prompt_manager import prompt_manager
|
|||||||
from src.chat.utils.common_utils import TempMethodsExpression
|
from src.chat.utils.common_utils import TempMethodsExpression
|
||||||
from src.services.message_service import (
|
from src.services.message_service import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_messages_before_time_in_chat,
|
||||||
replace_user_references,
|
replace_user_references,
|
||||||
translate_pid_to_description,
|
translate_pid_to_description,
|
||||||
)
|
)
|
||||||
@@ -650,7 +650,7 @@ class PrivateReplyer:
|
|||||||
# 将[picid:xxx]替换为具体的图片描述
|
# 将[picid:xxx]替换为具体的图片描述
|
||||||
target = self._replace_picids_with_descriptions(target)
|
target = self._replace_picids_with_descriptions(target)
|
||||||
|
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size,
|
limit=global_config.chat.max_context_size,
|
||||||
@@ -666,7 +666,7 @@ class PrivateReplyer:
|
|||||||
long_time_notice=True,
|
long_time_notice=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.33),
|
limit=int(global_config.chat.max_context_size * 0.33),
|
||||||
@@ -857,7 +857,7 @@ class PrivateReplyer:
|
|||||||
# 将[picid:xxx]替换为具体的图片描述
|
# 将[picid:xxx]替换为具体的图片描述
|
||||||
target = self._replace_picids_with_descriptions(target)
|
target = self._replace_picids_with_descriptions(target)
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_messages_before_time_in_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||||
|
|||||||
@@ -1,12 +1,38 @@
|
|||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
||||||
|
from src.common.data_models.image_data_model import MaiEmoji
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.utils.utils_image import ImageUtils
|
||||||
|
|
||||||
logger = get_logger("plugin_runtime.integration")
|
logger = get_logger("plugin_runtime.integration")
|
||||||
|
|
||||||
|
|
||||||
class RuntimeDataCapabilityMixin:
|
class RuntimeDataCapabilityMixin:
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_emoji_payload(emoji: MaiEmoji) -> Optional[Dict[str, str]]:
|
||||||
|
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji.full_path))
|
||||||
|
if not emoji_base64:
|
||||||
|
return None
|
||||||
|
|
||||||
|
matched_emotion = emoji.emotion[0] if emoji.emotion else ""
|
||||||
|
return {
|
||||||
|
"base64": emoji_base64,
|
||||||
|
"description": emoji.description,
|
||||||
|
"emotion": matched_emotion,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_emoji_temp_path() -> Path:
|
||||||
|
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||||
|
|
||||||
|
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
return EMOJI_DIR / f"emoji_cap_{int(time.time() * 1000000)}.png"
|
||||||
|
|
||||||
async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import database_service as database_api
|
from src.services import database_service as database_api
|
||||||
|
|
||||||
@@ -338,7 +364,7 @@ class RuntimeDataCapabilityMixin:
|
|||||||
limit=args.get("limit", 0),
|
limit=args.get("limit", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
readable = message_api.build_readable_messages_to_str(
|
readable = message_api.build_readable_messages(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
replace_bot_name=args.get("replace_bot_name", True),
|
replace_bot_name=args.get("replace_bot_name", True),
|
||||||
timestamp_mode=args.get("timestamp_mode", "relative"),
|
timestamp_mode=args.get("timestamp_mode", "relative"),
|
||||||
@@ -397,101 +423,173 @@ class RuntimeDataCapabilityMixin:
|
|||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
description: str = args.get("description", "")
|
description: str = args.get("description", "")
|
||||||
if not description:
|
if not description:
|
||||||
return {"success": False, "error": "缺少必要参数 description"}
|
return {"success": False, "error": "缺少必要参数 description"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await emoji_api.get_by_description(description=description)
|
emoji = await emoji_manager.get_emoji_for_emotion(description)
|
||||||
if result is None:
|
if emoji is None:
|
||||||
|
return {"success": True, "emoji": None}
|
||||||
|
serialized = self._serialize_emoji_payload(emoji)
|
||||||
|
if serialized is None:
|
||||||
return {"success": True, "emoji": None}
|
return {"success": True, "emoji": None}
|
||||||
emoji_base64, emoji_desc, matched_emotion = result
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"emoji": {
|
"emoji": serialized,
|
||||||
"base64": emoji_base64,
|
|
||||||
"description": emoji_desc,
|
|
||||||
"emotion": matched_emotion,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
count: int = args.get("count", 1)
|
count: int = args.get("count", 1)
|
||||||
try:
|
try:
|
||||||
results = await emoji_api.get_random(count=count)
|
if count < 0:
|
||||||
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
|
return {"success": False, "error": "count 不能为负数"}
|
||||||
|
|
||||||
|
emojis_source = list(emoji_manager.emojis)
|
||||||
|
if count == 0 or not emojis_source:
|
||||||
|
return {"success": True, "emojis": []}
|
||||||
|
|
||||||
|
selected = random.sample(emojis_source, min(count, len(emojis_source)))
|
||||||
|
emojis: List[Dict[str, str]] = []
|
||||||
|
for emoji in selected:
|
||||||
|
emoji_manager.update_emoji_usage(emoji)
|
||||||
|
serialized = self._serialize_emoji_payload(emoji)
|
||||||
|
if serialized is not None:
|
||||||
|
if not serialized["emotion"]:
|
||||||
|
serialized["emotion"] = "随机表情"
|
||||||
|
emojis.append(serialized)
|
||||||
return {"success": True, "emojis": emojis}
|
return {"success": True, "emojis": emojis}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": True, "count": emoji_api.get_count()}
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
|
return {"success": True, "count": len(emoji_manager.emojis)}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": True, "emotions": emoji_api.get_emotions()}
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
|
emotions = sorted({emotion for emoji in emoji_manager.emojis for emotion in emoji.emotion})
|
||||||
|
return {"success": True, "emotions": emotions}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await emoji_api.get_all()
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
|
|
||||||
|
emojis = []
|
||||||
|
for emoji in emoji_manager.emojis:
|
||||||
|
serialized = self._serialize_emoji_payload(emoji)
|
||||||
|
if serialized is not None:
|
||||||
|
if not serialized["emotion"]:
|
||||||
|
serialized["emotion"] = "随机表情"
|
||||||
|
emojis.append(serialized)
|
||||||
return {"success": True, "emojis": emojis}
|
return {"success": True, "emojis": emojis}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": True, "info": emoji_api.get_info()}
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
current_count = len(emoji_manager.emojis)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"info": {
|
||||||
|
"current_count": current_count,
|
||||||
|
"max_count": global_config.emoji.max_reg_num,
|
||||||
|
"available_emojis": current_count,
|
||||||
|
},
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
emoji_base64: str = args.get("emoji_base64", "")
|
emoji_base64: str = args.get("emoji_base64", "")
|
||||||
if not emoji_base64:
|
if not emoji_base64:
|
||||||
return {"success": False, "error": "缺少必要参数 emoji_base64"}
|
return {"success": False, "error": "缺少必要参数 emoji_base64"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await emoji_api.register_emoji(emoji_base64)
|
count_before = len(emoji_manager.emojis)
|
||||||
|
temp_file_path = self._build_emoji_temp_path()
|
||||||
|
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
|
||||||
|
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
|
||||||
|
|
||||||
|
register_success = await emoji_manager.register_emoji_by_filename(temp_file_path)
|
||||||
|
if not register_success:
|
||||||
|
if temp_file_path.exists():
|
||||||
|
temp_file_path.unlink(missing_ok=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": None,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
count_after = len(emoji_manager.emojis)
|
||||||
|
replaced = count_after <= count_before
|
||||||
|
new_emoji = next(
|
||||||
|
(
|
||||||
|
item
|
||||||
|
for item in reversed(emoji_manager.emojis)
|
||||||
|
if temp_file_path.name == item.file_name or temp_file_path.name in str(item.full_path)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||||
|
"description": None if new_emoji is None else new_emoji.description,
|
||||||
|
"emotions": None if new_emoji is None else new_emoji.emotion,
|
||||||
|
"replaced": replaced,
|
||||||
|
"hash": None if new_emoji is None else new_emoji.file_hash,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||||
from src.services import emoji_service as emoji_api
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
|
|
||||||
emoji_hash: str = args.get("emoji_hash", "")
|
emoji_hash: str = args.get("emoji_hash", "")
|
||||||
if not emoji_hash:
|
if not emoji_hash:
|
||||||
return {"success": False, "error": "缺少必要参数 emoji_hash"}
|
return {"success": False, "error": "缺少必要参数 emoji_hash"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await emoji_api.delete_emoji(emoji_hash)
|
emoji = emoji_manager.get_emoji_by_hash(emoji_hash)
|
||||||
|
if emoji is None:
|
||||||
|
return {"success": False, "message": f"未找到表情包: {emoji_hash}", "hash": emoji_hash}
|
||||||
|
|
||||||
|
success = emoji_manager.delete_emoji(emoji, not bool(emoji.description and emoji.description.strip()))
|
||||||
|
if not success:
|
||||||
|
return {"success": False, "message": f"删除表情包失败: {emoji_hash}", "hash": emoji_hash}
|
||||||
|
|
||||||
|
emoji_manager.emojis = [item for item in emoji_manager.emojis if item.file_hash != emoji_hash]
|
||||||
|
emoji_manager._emoji_num = len(emoji_manager.emojis)
|
||||||
|
return {"success": True, "message": f"成功删除表情包: {emoji_hash}", "hash": emoji_hash}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True)
|
logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|||||||
@@ -1,171 +0,0 @@
|
|||||||
"""
|
|
||||||
聊天服务模块
|
|
||||||
|
|
||||||
提供聊天信息查询和管理的核心功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("chat_service")
|
|
||||||
|
|
||||||
|
|
||||||
class SpecialTypes(Enum):
|
|
||||||
"""特殊枚举类型"""
|
|
||||||
|
|
||||||
ALL_PLATFORMS = "all_platforms"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
|
||||||
"""聊天管理器 - 负责聊天信息的查询和管理"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_platform(platform: Optional[str] | SpecialTypes) -> None:
|
|
||||||
if not isinstance(platform, (str, SpecialTypes)):
|
|
||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool:
|
|
||||||
return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_streams(
|
|
||||||
platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None
|
|
||||||
) -> List[BotChatSession]:
|
|
||||||
ChatManager._validate_platform(platform)
|
|
||||||
|
|
||||||
try:
|
|
||||||
streams = [
|
|
||||||
stream
|
|
||||||
for stream in _chat_manager.sessions.values()
|
|
||||||
if ChatManager._match_platform(stream, platform)
|
|
||||||
and (is_group_session is None or stream.is_group_session == is_group_session)
|
|
||||||
]
|
|
||||||
return streams
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatService] 获取聊天流失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_stream(
|
|
||||||
predicate: Callable[[BotChatSession], bool],
|
|
||||||
platform: Optional[str] | SpecialTypes = "qq",
|
|
||||||
) -> Optional[BotChatSession]:
|
|
||||||
for stream in ChatManager._get_streams(platform=platform):
|
|
||||||
if predicate(stream):
|
|
||||||
return stream
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
streams = ChatManager._get_streams(platform=platform)
|
|
||||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
streams = ChatManager._get_streams(platform=platform, is_group_session=True)
|
|
||||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
|
||||||
streams = ChatManager._get_streams(platform=platform, is_group_session=False)
|
|
||||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
|
||||||
return streams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_group_stream_by_group_id(
|
|
||||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
|
||||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
|
||||||
if not isinstance(group_id, str):
|
|
||||||
raise TypeError("group_id 必须是字符串类型")
|
|
||||||
ChatManager._validate_platform(platform)
|
|
||||||
if not group_id:
|
|
||||||
raise ValueError("group_id 不能为空")
|
|
||||||
try:
|
|
||||||
stream = ChatManager._find_stream(
|
|
||||||
lambda item: item.is_group_session and str(item.group_id) == str(group_id),
|
|
||||||
platform=platform,
|
|
||||||
)
|
|
||||||
if stream is not None:
|
|
||||||
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
|
|
||||||
return stream
|
|
||||||
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatService] 查找群聊流失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_private_stream_by_user_id(
|
|
||||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
|
||||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
|
||||||
if not isinstance(user_id, str):
|
|
||||||
raise TypeError("user_id 必须是字符串类型")
|
|
||||||
ChatManager._validate_platform(platform)
|
|
||||||
if not user_id:
|
|
||||||
raise ValueError("user_id 不能为空")
|
|
||||||
try:
|
|
||||||
stream = ChatManager._find_stream(
|
|
||||||
lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id),
|
|
||||||
platform=platform,
|
|
||||||
)
|
|
||||||
if stream is not None:
|
|
||||||
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
|
|
||||||
return stream
|
|
||||||
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatService] 查找私聊流失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
|
||||||
if not isinstance(chat_stream, BotChatSession):
|
|
||||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
|
||||||
if not chat_stream:
|
|
||||||
raise ValueError("chat_stream 不能为 None")
|
|
||||||
|
|
||||||
return "group" if chat_stream.is_group_session else "private"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
|
||||||
if not chat_stream:
|
|
||||||
raise ValueError("chat_stream 不能为 None")
|
|
||||||
if not isinstance(chat_stream, BotChatSession):
|
|
||||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
|
||||||
|
|
||||||
try:
|
|
||||||
info: Dict[str, Any] = {
|
|
||||||
"session_id": chat_stream.session_id,
|
|
||||||
"platform": chat_stream.platform,
|
|
||||||
"type": ChatManager.get_stream_type(chat_stream),
|
|
||||||
}
|
|
||||||
|
|
||||||
if chat_stream.is_group_session:
|
|
||||||
info["group_id"] = chat_stream.group_id
|
|
||||||
if (
|
|
||||||
chat_stream.context
|
|
||||||
and chat_stream.context.message
|
|
||||||
and chat_stream.context.message.message_info.group_info
|
|
||||||
):
|
|
||||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
|
||||||
else:
|
|
||||||
info["group_name"] = "未知群聊"
|
|
||||||
else:
|
|
||||||
info["user_id"] = chat_stream.user_id
|
|
||||||
if (
|
|
||||||
chat_stream.context
|
|
||||||
and chat_stream.context.message
|
|
||||||
and chat_stream.context.message.message_info.user_info
|
|
||||||
):
|
|
||||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
|
||||||
else:
|
|
||||||
info["user_name"] = "未知用户"
|
|
||||||
|
|
||||||
return info
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
|
|
||||||
return {}
|
|
||||||
@@ -9,6 +9,7 @@ from typing import Any, Optional
|
|||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_manager import BotChatSession
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import ActionRecord
|
from src.common.database.database_model import ActionRecord
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -23,12 +24,10 @@ def _to_dict(record: Any) -> dict[str, Any]:
|
|||||||
return record
|
return record
|
||||||
if hasattr(record, "model_dump"):
|
if hasattr(record, "model_dump"):
|
||||||
return record.model_dump()
|
return record.model_dump()
|
||||||
if hasattr(record, "__dict__"):
|
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
|
||||||
return dict(record.__dict__)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_field(model_class: type[SQLModel], field_name: str):
|
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
|
||||||
field = getattr(model_class, field_name, None)
|
field = getattr(model_class, field_name, None)
|
||||||
if field is None:
|
if field is None:
|
||||||
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
|
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
|
||||||
@@ -41,7 +40,7 @@ def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
|||||||
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
|
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
|
||||||
|
|
||||||
|
|
||||||
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
|
def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any:
|
||||||
if not order_by:
|
if not order_by:
|
||||||
return statement
|
return statement
|
||||||
|
|
||||||
@@ -60,7 +59,7 @@ async def db_save(
|
|||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
key_field: Optional[str] = None,
|
key_field: Optional[str] = None,
|
||||||
key_value: Optional[Any] = None,
|
key_value: Optional[Any] = None,
|
||||||
):
|
) -> Optional[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
record = None
|
record = None
|
||||||
@@ -91,12 +90,11 @@ async def db_get(
|
|||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
order_by: Optional[str | list[str]] = None,
|
order_by: Optional[str | list[str]] = None,
|
||||||
single_result: bool = False,
|
single_result: bool = False,
|
||||||
):
|
) -> Optional[dict[str, Any]] | list[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
with get_db_session(auto_commit=False) as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(model_class)
|
statement = select(model_class)
|
||||||
conditions = _build_filters(model_class, filters)
|
if conditions := _build_filters(model_class, filters):
|
||||||
if conditions:
|
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
statement = _apply_order_by(statement, model_class, order_by)
|
statement = _apply_order_by(statement, model_class, order_by)
|
||||||
if limit:
|
if limit:
|
||||||
@@ -116,8 +114,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
|||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(model_class)
|
statement = select(model_class)
|
||||||
conditions = _build_filters(model_class, filters)
|
if conditions := _build_filters(model_class, filters):
|
||||||
if conditions:
|
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
records = session.exec(statement).all()
|
records = session.exec(statement).all()
|
||||||
for record in records:
|
for record in records:
|
||||||
@@ -136,8 +133,7 @@ async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any
|
|||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = delete(model_class)
|
statement = delete(model_class)
|
||||||
conditions = _build_filters(model_class, filters)
|
if conditions := _build_filters(model_class, filters):
|
||||||
if conditions:
|
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
result = session.exec(statement)
|
result = session.exec(statement)
|
||||||
return result.rowcount or 0
|
return result.rowcount or 0
|
||||||
@@ -151,8 +147,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
|||||||
try:
|
try:
|
||||||
with get_db_session(auto_commit=False) as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(func.count()).select_from(model_class)
|
statement = select(func.count()).select_from(model_class)
|
||||||
conditions = _build_filters(model_class, filters)
|
if conditions := _build_filters(model_class, filters):
|
||||||
if conditions:
|
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
result = session.exec(statement).one()
|
result = session.exec(statement).one()
|
||||||
return int(result or 0)
|
return int(result or 0)
|
||||||
@@ -163,18 +158,15 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
|||||||
|
|
||||||
|
|
||||||
async def store_action_info(
|
async def store_action_info(
|
||||||
chat_stream=None,
|
chat_stream: BotChatSession,
|
||||||
builtin_prompt: Optional[str] = None,
|
builtin_prompt: Optional[str] = None,
|
||||||
display_prompt: str = "",
|
display_prompt: str = "",
|
||||||
thinking_id: str = "",
|
thinking_id: str = "",
|
||||||
action_data: Optional[dict] = None,
|
action_data: Optional[dict[str, Any]] = None,
|
||||||
action_name: str = "",
|
action_name: str = "",
|
||||||
action_reasoning: str = "",
|
action_reasoning: str = "",
|
||||||
):
|
) -> Optional[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
if chat_stream is None:
|
|
||||||
raise ValueError("store_action_info 需要 chat_stream")
|
|
||||||
|
|
||||||
record_data = {
|
record_data = {
|
||||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||||
"timestamp": datetime.now(),
|
"timestamp": datetime.now(),
|
||||||
|
|||||||
@@ -1,406 +0,0 @@
|
|||||||
"""
|
|
||||||
表情服务模块
|
|
||||||
|
|
||||||
提供表情包相关的核心功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.utils.utils_image import ImageUtils
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
logger = get_logger("emoji_service")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包获取函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
|
||||||
"""根据描述选择表情包"""
|
|
||||||
if not description:
|
|
||||||
raise ValueError("描述不能为空")
|
|
||||||
if not isinstance(description, str):
|
|
||||||
raise TypeError("描述必须是字符串类型")
|
|
||||||
try:
|
|
||||||
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
|
|
||||||
|
|
||||||
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
|
||||||
|
|
||||||
if not emoji_obj:
|
|
||||||
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
|
|
||||||
return None
|
|
||||||
|
|
||||||
emoji_path = str(emoji_obj.full_path)
|
|
||||||
emoji_description = emoji_obj.description
|
|
||||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
|
||||||
return emoji_base64, emoji_description, matched_emotion
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取表情包失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
|
||||||
"""随机获取指定数量的表情包"""
|
|
||||||
if not isinstance(count, int):
|
|
||||||
raise TypeError("count 必须是整数类型")
|
|
||||||
if count < 0:
|
|
||||||
raise ValueError("count 不能为负数")
|
|
||||||
if count == 0:
|
|
||||||
logger.warning("[EmojiService] count 为0,返回空列表")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
if not all_emojis:
|
|
||||||
logger.warning("[EmojiService] 没有可用的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
|
||||||
if not valid_emojis:
|
|
||||||
logger.warning("[EmojiService] 没有有效的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(valid_emojis) < count:
|
|
||||||
logger.debug(
|
|
||||||
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
|
||||||
)
|
|
||||||
count = len(valid_emojis)
|
|
||||||
|
|
||||||
selected_emojis = random.sample(valid_emojis, count)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for selected_emoji in selected_emojis:
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
|
||||||
|
|
||||||
emoji_manager.update_emoji_usage(selected_emoji)
|
|
||||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
|
||||||
|
|
||||||
if not results and count > 0:
|
|
||||||
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|
||||||
"""根据情感标签获取表情包"""
|
|
||||||
if not emotion:
|
|
||||||
raise ValueError("情感标签不能为空")
|
|
||||||
if not isinstance(emotion, str):
|
|
||||||
raise TypeError("情感标签必须是字符串类型")
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
|
|
||||||
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
matching_emojis = []
|
|
||||||
matching_emojis.extend(
|
|
||||||
emoji_obj
|
|
||||||
for emoji_obj in all_emojis
|
|
||||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
|
||||||
)
|
|
||||||
if not matching_emojis:
|
|
||||||
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
|
|
||||||
return None
|
|
||||||
|
|
||||||
selected_emoji = random.choice(matching_emojis)
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
emoji_manager.update_emoji_usage(selected_emoji)
|
|
||||||
|
|
||||||
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
|
|
||||||
return emoji_base64, selected_emoji.description, emotion
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包信息查询函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_count() -> int:
|
|
||||||
try:
|
|
||||||
return len(emoji_manager.emojis)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_info():
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"current_count": len(emoji_manager.emojis),
|
|
||||||
"max_count": global_config.emoji.max_reg_num,
|
|
||||||
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
|
|
||||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
|
||||||
|
|
||||||
|
|
||||||
def get_emotions() -> List[str]:
|
|
||||||
try:
|
|
||||||
emotions = set()
|
|
||||||
|
|
||||||
for emoji_obj in emoji_manager.emojis:
|
|
||||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
|
||||||
emotions.update(emoji_obj.emotion)
|
|
||||||
|
|
||||||
return sorted(list(emotions))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all() -> List[Tuple[str, str, str]]:
|
|
||||||
try:
|
|
||||||
all_emojis = emoji_manager.emojis
|
|
||||||
|
|
||||||
if not all_emojis:
|
|
||||||
logger.warning("[EmojiService] 没有可用的表情包")
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for emoji_obj in all_emojis:
|
|
||||||
if emoji_obj.is_deleted:
|
|
||||||
continue
|
|
||||||
|
|
||||||
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
|
|
||||||
|
|
||||||
if not emoji_base64:
|
|
||||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
|
||||||
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_descriptions() -> List[str]:
|
|
||||||
try:
|
|
||||||
descriptions = []
|
|
||||||
|
|
||||||
descriptions.extend(
|
|
||||||
emoji_obj.description
|
|
||||||
for emoji_obj in emoji_manager.emojis
|
|
||||||
if not emoji_obj.is_deleted and emoji_obj.description
|
|
||||||
)
|
|
||||||
return descriptions
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 表情包注册函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""注册新的表情包"""
|
|
||||||
if not image_base64:
|
|
||||||
raise ValueError("图片base64编码不能为空")
|
|
||||||
if not isinstance(image_base64, str):
|
|
||||||
raise TypeError("image_base64必须是字符串类型")
|
|
||||||
if filename is not None and not isinstance(filename, str):
|
|
||||||
raise TypeError("filename必须是字符串类型或None")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
|
|
||||||
|
|
||||||
count_before = len(emoji_manager.emojis)
|
|
||||||
max_count = global_config.emoji.max_reg_num
|
|
||||||
|
|
||||||
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
|
||||||
|
|
||||||
if not can_register:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
if not filename:
|
|
||||||
import time as _time
|
|
||||||
|
|
||||||
timestamp = int(_time.time())
|
|
||||||
microseconds = int(_time.time() * 1000000) % 1000000
|
|
||||||
|
|
||||||
random_bytes = random.getrandbits(72).to_bytes(9, "big")
|
|
||||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
|
||||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
|
||||||
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
|
||||||
|
|
||||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
|
||||||
filename = f"{filename}.png"
|
|
||||||
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
attempts = 0
|
|
||||||
max_attempts = 10
|
|
||||||
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
|
||||||
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
|
||||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
|
||||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
|
||||||
|
|
||||||
name_part, ext = os.path.splitext(filename)
|
|
||||||
base_name = name_part.rsplit("_", 1)[0]
|
|
||||||
filename = f"{base_name}_{short_id}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
attempts += 1
|
|
||||||
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
uuid_short = str(uuid.uuid4())[:8]
|
|
||||||
name_part, ext = os.path.splitext(filename)
|
|
||||||
base_name = name_part.rsplit("_", 1)[0]
|
|
||||||
filename = f"{base_name}_{uuid_short}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
|
|
||||||
counter = 1
|
|
||||||
original_filename = filename
|
|
||||||
while os.path.exists(temp_file_path):
|
|
||||||
name_part, ext = os.path.splitext(original_filename)
|
|
||||||
filename = f"{name_part}_{counter}{ext}"
|
|
||||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
if counter > 100:
|
|
||||||
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "无法生成唯一文件名,请稍后重试",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
|
|
||||||
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "无法保存图片文件",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
|
|
||||||
|
|
||||||
except Exception as save_error:
|
|
||||||
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"保存图片文件失败: {str(save_error)}",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
|
||||||
|
|
||||||
if not register_success and os.path.exists(temp_file_path):
|
|
||||||
try:
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
|
|
||||||
|
|
||||||
if register_success:
|
|
||||||
count_after = len(emoji_manager.emojis)
|
|
||||||
replaced = count_after <= count_before
|
|
||||||
|
|
||||||
new_emoji_info = None
|
|
||||||
if count_after > count_before or replaced:
|
|
||||||
try:
|
|
||||||
for emoji_obj in reversed(emoji_manager.emojis):
|
|
||||||
if not emoji_obj.is_deleted and (
|
|
||||||
emoji_obj.file_name == filename
|
|
||||||
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
|
||||||
):
|
|
||||||
new_emoji_info = emoji_obj
|
|
||||||
break
|
|
||||||
except Exception as find_error:
|
|
||||||
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
|
|
||||||
|
|
||||||
description = new_emoji_info.description if new_emoji_info else None
|
|
||||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
|
||||||
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
|
||||||
"description": description,
|
|
||||||
"emotions": emotions,
|
|
||||||
"replaced": replaced,
|
|
||||||
"hash": emoji_hash,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"注册过程中发生错误: {str(e)}",
|
|
||||||
"description": None,
|
|
||||||
"emotions": None,
|
|
||||||
"replaced": None,
|
|
||||||
"hash": None,
|
|
||||||
}
|
|
||||||
@@ -35,7 +35,7 @@ logger = get_logger("generator_service")
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_replyer(
|
def _get_replyer(
|
||||||
chat_stream: Optional[BotChatSession] = None,
|
chat_stream: Optional[BotChatSession] = None,
|
||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
@@ -58,6 +58,35 @@ def get_replyer(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_unknown_words(action_data: Optional[Dict[str, Any]]) -> Optional[List[str]]:
|
||||||
|
if not action_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
unknown_words = action_data.get("unknown_words")
|
||||||
|
if not isinstance(unknown_words, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned_words: List[str] = []
|
||||||
|
for item in unknown_words:
|
||||||
|
if isinstance(item, str) and (cleaned_item := item.strip()):
|
||||||
|
cleaned_words.append(cleaned_item)
|
||||||
|
|
||||||
|
return cleaned_words or None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_message_sequence(
|
||||||
|
content: Optional[str],
|
||||||
|
*,
|
||||||
|
enable_splitter: bool,
|
||||||
|
enable_chinese_typo: bool,
|
||||||
|
) -> tuple[Optional[MessageSequence], List[str]]:
|
||||||
|
if not content:
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
processed_output = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||||
|
return MessageSequence(components=[TextComponent(text) for text in processed_output]), processed_output
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 回复生成函数
|
# 回复生成函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -87,7 +116,7 @@ async def generate_reply(
|
|||||||
reply_time_point = time.time()
|
reply_time_point = time.time()
|
||||||
|
|
||||||
logger.debug("[GeneratorService] 开始生成回复")
|
logger.debug("[GeneratorService] 开始生成回复")
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorService] 无法获取回复器")
|
logger.error("[GeneratorService] 无法获取回复器")
|
||||||
return False, None
|
return False, None
|
||||||
@@ -98,16 +127,7 @@ async def generate_reply(
|
|||||||
if not reply_reason:
|
if not reply_reason:
|
||||||
reply_reason = action_data.get("reason", "")
|
reply_reason = action_data.get("reason", "")
|
||||||
if unknown_words is None:
|
if unknown_words is None:
|
||||||
uw = action_data.get("unknown_words")
|
unknown_words = _extract_unknown_words(action_data)
|
||||||
if isinstance(uw, list):
|
|
||||||
cleaned: List[str] = []
|
|
||||||
for item in uw:
|
|
||||||
if isinstance(item, str):
|
|
||||||
s = item.strip()
|
|
||||||
if s:
|
|
||||||
cleaned.append(s)
|
|
||||||
if cleaned:
|
|
||||||
unknown_words = cleaned
|
|
||||||
|
|
||||||
success, llm_response = await replyer.generate_reply_with_context(
|
success, llm_response = await replyer.generate_reply_with_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
@@ -126,13 +146,12 @@ async def generate_reply(
|
|||||||
if not success:
|
if not success:
|
||||||
logger.warning("[GeneratorService] 回复生成失败")
|
logger.warning("[GeneratorService] 回复生成失败")
|
||||||
return False, None
|
return False, None
|
||||||
reply_set: Optional[MessageSequence] = None
|
reply_set, processed_output = _build_message_sequence(
|
||||||
if content := llm_response.content:
|
llm_response.content,
|
||||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
enable_splitter=enable_splitter,
|
||||||
llm_response.processed_output = processed_response
|
enable_chinese_typo=enable_chinese_typo,
|
||||||
reply_set = MessageSequence(components=[])
|
)
|
||||||
for text in processed_response:
|
llm_response.processed_output = processed_output
|
||||||
reply_set.components.append(TextComponent(text))
|
|
||||||
llm_response.reply_set = reply_set
|
llm_response.reply_set = reply_set
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
|
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
|
||||||
@@ -181,7 +200,7 @@ async def rewrite_reply(
|
|||||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||||
"""重写回复"""
|
"""重写回复"""
|
||||||
try:
|
try:
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorService] 无法获取回复器")
|
logger.error("[GeneratorService] 无法获取回复器")
|
||||||
return False, None
|
return False, None
|
||||||
@@ -198,9 +217,13 @@ async def rewrite_reply(
|
|||||||
reason=reason,
|
reason=reason,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
)
|
)
|
||||||
reply_set: Optional[MessageSequence] = None
|
reply_set, processed_output = _build_message_sequence(
|
||||||
if success and llm_response and (content := llm_response.content):
|
llm_response.content if success and llm_response else None,
|
||||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
enable_splitter=enable_splitter,
|
||||||
|
enable_chinese_typo=enable_chinese_typo,
|
||||||
|
)
|
||||||
|
if llm_response is not None:
|
||||||
|
llm_response.processed_output = processed_output
|
||||||
llm_response.reply_set = reply_set
|
llm_response.reply_set = reply_set
|
||||||
if success:
|
if success:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -219,44 +242,3 @@ async def rewrite_reply(
|
|||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
|
|
||||||
"""将文本处理为更拟人化的文本"""
|
|
||||||
if not isinstance(content, str):
|
|
||||||
raise ValueError("content 必须是字符串类型")
|
|
||||||
try:
|
|
||||||
reply_set = MessageSequence(components=[])
|
|
||||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
|
||||||
|
|
||||||
for text in processed_response:
|
|
||||||
reply_set.components.append(TextComponent(text))
|
|
||||||
|
|
||||||
return reply_set
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_response_custom(
|
|
||||||
chat_stream: Optional[BotChatSession] = None,
|
|
||||||
chat_id: Optional[str] = None,
|
|
||||||
request_type: str = "generator_api",
|
|
||||||
prompt: str = "",
|
|
||||||
) -> Optional[str]:
|
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
|
||||||
if not replyer:
|
|
||||||
logger.error("[GeneratorService] 无法获取回复器")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("[GeneratorService] 开始生成自定义回复")
|
|
||||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
|
||||||
if response:
|
|
||||||
logger.debug("[GeneratorService] 自定义回复生成成功")
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
logger.warning("[GeneratorService] 自定义回复生成失败")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from src.common.database.database_model import ActionRecord, Images, ImageType
|
|||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_messages, find_messages
|
||||||
from src.common.utils.math_utils import translate_timestamp_to_human_readable
|
from src.common.utils.math_utils import translate_timestamp_to_human_readable
|
||||||
from src.common.utils.utils_action import ActionUtils
|
from src.common.utils.utils_action import ActionUtils
|
||||||
from src.chat.utils.utils import is_bot_self
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
@@ -113,103 +112,6 @@ def get_messages_by_time_in_chat(
|
|||||||
return _normalize_messages(messages)
|
return _normalize_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_in_chat_inclusive(
|
|
||||||
chat_id: str,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
filter_mai: bool = False,
|
|
||||||
filter_command: bool = False,
|
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit 不能为负数")
|
|
||||||
if not chat_id:
|
|
||||||
raise ValueError("chat_id 不能为空")
|
|
||||||
if not isinstance(chat_id, str):
|
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
|
||||||
messages = find_messages(
|
|
||||||
message_filter={
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {
|
|
||||||
"$gte": start_time,
|
|
||||||
"$lte": end_time,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
limit=limit,
|
|
||||||
limit_mode=limit_mode,
|
|
||||||
filter_bot=filter_mai,
|
|
||||||
filter_command=filter_command,
|
|
||||||
filter_intercept_message_level=filter_intercept_message_level,
|
|
||||||
)
|
|
||||||
return _normalize_messages(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_in_chat_for_users(
|
|
||||||
chat_id: str,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
person_ids: List[str],
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit 不能为负数")
|
|
||||||
if not chat_id:
|
|
||||||
raise ValueError("chat_id 不能为空")
|
|
||||||
if not isinstance(chat_id, str):
|
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
|
||||||
messages = find_messages(
|
|
||||||
message_filter={
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {
|
|
||||||
"$gte": start_time,
|
|
||||||
"$lte": end_time,
|
|
||||||
},
|
|
||||||
"user_id": {"$in": person_ids},
|
|
||||||
},
|
|
||||||
limit=limit,
|
|
||||||
limit_mode=limit_mode,
|
|
||||||
)
|
|
||||||
return _normalize_messages(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def get_random_chat_messages(
|
|
||||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit 不能为负数")
|
|
||||||
return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_for_users(
|
|
||||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit 不能为负数")
|
|
||||||
messages = find_messages(
|
|
||||||
message_filter={
|
|
||||||
"time": {
|
|
||||||
"$gte": start_time,
|
|
||||||
"$lte": end_time,
|
|
||||||
},
|
|
||||||
"user_id": {"$in": person_ids},
|
|
||||||
},
|
|
||||||
limit=limit,
|
|
||||||
limit_mode=limit_mode,
|
|
||||||
)
|
|
||||||
return _normalize_messages(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
||||||
if not isinstance(timestamp, (int, float)):
|
if not isinstance(timestamp, (int, float)):
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
raise ValueError("timestamp 必须是数字类型")
|
||||||
@@ -252,24 +154,6 @@ def get_messages_before_time_in_chat(
|
|||||||
return _normalize_messages(messages)
|
return _normalize_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time_for_users(
|
|
||||||
timestamp: float, person_ids: List[str], limit: int = 0
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
if not isinstance(timestamp, (int, float)):
|
|
||||||
raise ValueError("timestamp 必须是数字类型")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit 不能为负数")
|
|
||||||
messages = find_messages(
|
|
||||||
message_filter={
|
|
||||||
"time": {"$lt": timestamp},
|
|
||||||
"user_id": {"$in": person_ids},
|
|
||||||
},
|
|
||||||
limit=limit,
|
|
||||||
limit_mode="latest",
|
|
||||||
)
|
|
||||||
return _normalize_messages(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def get_recent_messages(
|
def get_recent_messages(
|
||||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||||
) -> List[SessionMessage]:
|
) -> List[SessionMessage]:
|
||||||
@@ -307,22 +191,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
|||||||
return count_messages(message_filter)
|
return count_messages(message_filter)
|
||||||
|
|
||||||
|
|
||||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
|
||||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
||||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
||||||
if not chat_id:
|
|
||||||
raise ValueError("chat_id 不能为空")
|
|
||||||
if not isinstance(chat_id, str):
|
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
|
||||||
return count_messages(
|
|
||||||
{
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {"$gt": start_time, "$lte": end_time},
|
|
||||||
"user_id": {"$in": person_ids},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 消息格式化函数
|
# 消息格式化函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -365,17 +233,6 @@ def build_readable_messages(
|
|||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_readable_messages_to_str(
|
|
||||||
messages: List[SessionMessage],
|
|
||||||
replace_bot_name: bool = True,
|
|
||||||
timestamp_mode: str = "relative",
|
|
||||||
read_mark: float = 0.0,
|
|
||||||
truncate: bool = False,
|
|
||||||
show_actions: bool = False,
|
|
||||||
) -> str:
|
|
||||||
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
|
||||||
|
|
||||||
|
|
||||||
def build_readable_messages_with_id(
|
def build_readable_messages_with_id(
|
||||||
messages: List[SessionMessage],
|
messages: List[SessionMessage],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
@@ -415,148 +272,6 @@ def build_readable_messages_with_id(
|
|||||||
return "\n".join(lines), message_id_list
|
return "\n".join(lines), message_id_list
|
||||||
|
|
||||||
|
|
||||||
async def build_readable_messages_with_details(
|
|
||||||
messages: List[SessionMessage],
|
|
||||||
replace_bot_name: bool = True,
|
|
||||||
timestamp_mode: str = "relative",
|
|
||||||
truncate: bool = False,
|
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
|
||||||
normalized_messages = _normalize_messages(messages)
|
|
||||||
message_list = [
|
|
||||||
(
|
|
||||||
message.timestamp.timestamp(),
|
|
||||||
message.message_info.user_info.user_id,
|
|
||||||
message.processed_plain_text or "",
|
|
||||||
)
|
|
||||||
for message in normalized_messages
|
|
||||||
]
|
|
||||||
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
|
|
||||||
|
|
||||||
|
|
||||||
async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
|
|
||||||
person_ids: List[str] = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, SessionMessage):
|
|
||||||
person_ids.append(message.message_info.user_info.user_id)
|
|
||||||
elif isinstance(message, dict) and (user_id := message.get("user_id")):
|
|
||||||
person_ids.append(str(user_id))
|
|
||||||
return person_ids
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 消息过滤函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
|
|
||||||
"""从消息列表中移除麦麦的消息"""
|
|
||||||
return [
|
|
||||||
msg
|
|
||||||
for msg in messages
|
|
||||||
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp(
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat(
|
|
||||||
chat_id: str,
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
filter_bot: bool = False,
|
|
||||||
filter_command: bool = False,
|
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_by_time_in_chat(
|
|
||||||
chat_id,
|
|
||||||
timestamp_start,
|
|
||||||
timestamp_end,
|
|
||||||
limit,
|
|
||||||
limit_mode,
|
|
||||||
filter_bot,
|
|
||||||
filter_command,
|
|
||||||
filter_intercept_message_level,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
|
||||||
chat_id: str,
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
filter_bot: bool = False,
|
|
||||||
filter_command: bool = False,
|
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_by_time_in_chat_inclusive(
|
|
||||||
chat_id,
|
|
||||||
timestamp_start,
|
|
||||||
timestamp_end,
|
|
||||||
limit,
|
|
||||||
limit_mode,
|
|
||||||
filter_bot,
|
|
||||||
filter_command,
|
|
||||||
filter_intercept_message_level,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat_users(
|
|
||||||
chat_id: str,
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
person_ids: List[str],
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_users(
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
person_ids: List[str],
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
|
|
||||||
return get_messages_before_time(timestamp, limit)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id: str,
|
|
||||||
timestamp: float,
|
|
||||||
limit: int = 0,
|
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
|
|
||||||
return get_messages_before_time_for_users(timestamp, person_ids, limit)
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_random(
|
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
limit: int = 0,
|
|
||||||
limit_mode: str = "latest",
|
|
||||||
) -> List[SessionMessage]:
|
|
||||||
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
|
|
||||||
|
|
||||||
|
|
||||||
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
|
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = (
|
statement = (
|
||||||
|
|||||||
Reference in New Issue
Block a user