Merge branch 'Mai-with-u:r-dev' into r-dev
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
@@ -17,6 +18,11 @@ from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manag
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
EmojiSelector = Callable[
|
||||
[str, str, Sequence[str] | None, int],
|
||||
Awaitable[tuple[MaiEmoji | None, str]],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MaisakaEmojiSendResult:
|
||||
@@ -198,13 +204,14 @@ async def send_emoji_for_maisaka(
|
||||
requested_emotion: str = "",
|
||||
reasoning: str = "",
|
||||
context_texts: Sequence[str] | None = None,
|
||||
emoji_selector: EmojiSelector | None = None,
|
||||
) -> MaisakaEmojiSendResult:
|
||||
"""为 Maisaka 选择并发送一个表情。"""
|
||||
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
normalized_reasoning = reasoning.strip()
|
||||
normalized_context_texts = _normalize_context_texts(context_texts)
|
||||
sample_size = 30
|
||||
sample_size = 20
|
||||
|
||||
before_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.before_select",
|
||||
@@ -232,12 +239,20 @@ async def send_emoji_for_maisaka(
|
||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
if emoji_selector is None:
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
else:
|
||||
selected_emoji, matched_emotion = await emoji_selector(
|
||||
normalized_requested_emotion,
|
||||
normalized_reasoning,
|
||||
normalized_context_texts,
|
||||
sample_size,
|
||||
)
|
||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.after_select",
|
||||
stream_id=stream_id,
|
||||
|
||||
@@ -205,7 +205,7 @@ class HeartFChatting:
|
||||
# TODO: Planner逻辑
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||
self._end_cycle(current_cycle_detail)
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -133,19 +134,20 @@ class EmbeddingStore:
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
"""以同步方式获取单条字符串的嵌入向量。
|
||||
|
||||
Args:
|
||||
s: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
List[float]: 嵌入向量;失败时返回空列表。
|
||||
"""
|
||||
try:
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_result = embedding_client.embed_text_sync(s)
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
@@ -157,17 +159,15 @@ class EmbeddingStore:
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
self,
|
||||
strs: List[str],
|
||||
chunk_size: int = 10,
|
||||
max_workers: int = 10,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
"""使用多线程批量获取嵌入向量。
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
@@ -190,53 +190,42 @@ class EmbeddingStore:
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
|
||||
"""处理单个数据块。
|
||||
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
Args:
|
||||
chunk_data: 数据块起始索引与字符串列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
|
||||
"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results: List[Tuple[int, str, List[float]]] = []
|
||||
|
||||
try:
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_results = embedding_client.embed_texts_sync(
|
||||
chunk_strs,
|
||||
max_concurrent=1,
|
||||
)
|
||||
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
|
||||
embedding = embedding_result.embedding
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding))
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
logger.error(f"创建 EmbeddingService 实例失败: {e}")
|
||||
for i, s in enumerate(chunk_strs):
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
@@ -1135,9 +1135,6 @@ class DefaultReplyer:
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
del message
|
||||
del sender
|
||||
del target
|
||||
return ""
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -151,7 +151,7 @@ class MaisakaReplyGenerator:
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
visible_speaker = speaker_name or global_config.maisaka.user_name.strip() or "User"
|
||||
visible_speaker = speaker_name or global_config.maisaka.cli_user_name.strip() or "User"
|
||||
parts.append(f"{timestamp} {visible_speaker}: {content}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ class MaisakaReplyGenerator:
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
"""将 replyer 上下文拆成多条 LLM 消息。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.user_name.strip() or "User"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
|
||||
for message in chat_history:
|
||||
|
||||
@@ -1058,12 +1058,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
|
||||
|
||||
if chat_id in _stat_chat_manager.sessions:
|
||||
session = _stat_chat_manager.sessions[chat_id]
|
||||
name = _stat_chat_manager.get_session_name(chat_id)
|
||||
if name and name.strip():
|
||||
return name.strip()
|
||||
if user_name and user_name.strip():
|
||||
return user_name.strip()
|
||||
|
||||
# 如果从chat_stream获取失败,尝试解析chat_id格式
|
||||
if chat_id.startswith("g"):
|
||||
|
||||
@@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
@@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
return is_mentioned, is_at, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]:
|
||||
"""获取文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
text: 待编码的文本内容。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。
|
||||
"""
|
||||
embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding_result = await embedding_client.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
|
||||
@@ -67,7 +67,7 @@ class BufferCLI:
|
||||
timestamp=timestamp,
|
||||
platform=BufferCLI._CLI_PLATFORM,
|
||||
)
|
||||
user_name = global_config.maisaka.user_name.strip() or "用户"
|
||||
user_name = global_config.maisaka.cli_user_name.strip() or "用户"
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id=BufferCLI._CLI_USER_ID,
|
||||
|
||||
19
src/common/data_models/embedding_service_data_models.py
Normal file
19
src/common/data_models/embedding_service_data_models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Embedding 服务层共享数据模型。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from src.common.data_models import BaseDataModel
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EmbeddingResult(BaseDataModel):
|
||||
"""Embedding 服务层统一响应对象。"""
|
||||
|
||||
embedding: List[float] = field(default_factory=list)
|
||||
model_name: str = field(default_factory=str)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingResult",
|
||||
]
|
||||
@@ -92,7 +92,6 @@ class UniversalMessageSender:
|
||||
"""
|
||||
# TODO: 重构至新的发送模型
|
||||
message_preview = (message.processed_plain_text or "")[:200]
|
||||
platform = message.platform
|
||||
|
||||
try:
|
||||
# 尝试通过主 API 发送
|
||||
|
||||
@@ -282,7 +282,24 @@ class ChatConfig(ConfigBase):
|
||||
"x-icon": "list",
|
||||
},
|
||||
)
|
||||
"""_wrap_为指定聊天添加额外的 prompt 配置列表"""
|
||||
|
||||
direct_image_input: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
replyer_generator_type: Literal["legacy", "multi"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multi(多消息版)"""
|
||||
|
||||
enable_talk_value_rules: bool = Field(
|
||||
default=True,
|
||||
@@ -1018,6 +1035,14 @@ class DebugConfig(ConfigBase):
|
||||
"x-icon": "brain",
|
||||
},
|
||||
)
|
||||
|
||||
show_maisaka_thinking: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "brain",
|
||||
},
|
||||
)
|
||||
"""是否显示回复器推理"""
|
||||
|
||||
show_jargon_prompt: bool = Field(
|
||||
@@ -1481,16 +1506,7 @@ class MaiSakaConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
"""启用知识库模块"""
|
||||
show_thinking: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "brain",
|
||||
},
|
||||
)
|
||||
"""是否显示MaiSaka思考过程"""
|
||||
|
||||
user_name: str = Field(
|
||||
cli_user_name: str = Field(
|
||||
default="用户",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
@@ -1499,33 +1515,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
direct_image_input: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
merge_user_messages: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "merge",
|
||||
},
|
||||
)
|
||||
"""是否将新接收的用户发言合并为单个用户消息"""
|
||||
|
||||
replyer_generator_type: Literal["legacy", "multi"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multi(多消息版)"""
|
||||
|
||||
max_internal_rounds: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
@@ -1565,14 +1554,14 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
terminal_image_display_mode: Literal["legacy", "path_link"] = Field(
|
||||
default="legacy",
|
||||
show_image_path: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""图片展示模式:legacy(仅显示元信息)/ path_link(可点击本地路径)"""
|
||||
"""是否显示图片本地路径"""
|
||||
|
||||
|
||||
class MCPAuthorizationConfig(ConfigBase):
|
||||
|
||||
@@ -1,16 +1,40 @@
|
||||
"""send_emoji 内置工具。"""
|
||||
|
||||
from datetime import datetime
|
||||
from random import sample
|
||||
from secrets import token_hex
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.maisaka.context_messages import LLMContextMessage
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.maisaka.context_messages import LLMContextMessage, ReferenceMessage, ReferenceMessageType, SessionBackedMessage
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
logger = get_logger("maisaka_builtin_send_emoji")
|
||||
|
||||
_EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12
|
||||
_EMOJI_SUB_AGENT_MAX_TOKENS = 240
|
||||
_EMOJI_SUB_AGENT_SAMPLE_SIZE = 20
|
||||
_EMOJI_SUCCESS_MESSAGE = "???????"
|
||||
|
||||
|
||||
class EmojiSelectionResult(BaseModel):
|
||||
"""表情包子代理的结构化选择结果。"""
|
||||
|
||||
emoji_id: str = PydanticField(default="", description="选中的候选表情包 ID。")
|
||||
matched_emotion: str = PydanticField(default="", description="本次命中的情绪标签,可为空。")
|
||||
reason: str = PydanticField(default="", description="简短选择理由。")
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 send_emoji 工具声明。"""
|
||||
@@ -33,6 +57,105 @@ def get_tool_spec() -> ToolSpec:
|
||||
)
|
||||
|
||||
|
||||
async def _build_emoji_candidate_message(emoji: MaiEmoji, candidate_id: str) -> SessionBackedMessage:
|
||||
"""构建供子代理挑选的图片候选消息。"""
|
||||
|
||||
image_bytes = await asyncio.to_thread(emoji.full_path.read_bytes)
|
||||
raw_message = MessageSequence(
|
||||
[
|
||||
TextComponent(f"ID: {candidate_id}"),
|
||||
ImageComponent(binary_hash=str(emoji.file_hash or ""), binary_data=image_bytes),
|
||||
]
|
||||
)
|
||||
return SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=f"ID: {candidate_id}",
|
||||
timestamp=datetime.now(),
|
||||
source_kind="emoji_candidate",
|
||||
)
|
||||
|
||||
|
||||
async def _select_emoji_with_sub_agent(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
requested_emotion: str,
|
||||
reasoning: str,
|
||||
context_texts: list[str],
|
||||
sample_size: int,
|
||||
) -> tuple[MaiEmoji | None, str]:
|
||||
"""通过临时子代理从候选表情包中选出一个结果。"""
|
||||
|
||||
available_emojis = list(emoji_manager.emojis)
|
||||
if not available_emojis:
|
||||
return None, ""
|
||||
|
||||
effective_sample_size = min(max(sample_size, 1), _EMOJI_SUB_AGENT_SAMPLE_SIZE, len(available_emojis))
|
||||
sampled_emojis = sample(available_emojis, effective_sample_size)
|
||||
|
||||
candidate_map: dict[str, MaiEmoji] = {}
|
||||
candidate_messages: list[LLMContextMessage] = []
|
||||
for emoji in sampled_emojis:
|
||||
candidate_id = token_hex(4)
|
||||
while candidate_id in candidate_map:
|
||||
candidate_id = token_hex(4)
|
||||
candidate_map[candidate_id] = emoji
|
||||
candidate_messages.append(await _build_emoji_candidate_message(emoji, candidate_id))
|
||||
|
||||
context_text = "\n".join(context_texts[-5:]) if context_texts else "(暂无额外上下文)"
|
||||
system_prompt = (
|
||||
"你是 Maisaka 的临时表情包选择子代理。\n"
|
||||
"你会收到一段群聊上下文,以及若干条候选表情包消息。每条候选消息里都有一个临时 ID。\n"
|
||||
"你的任务是根据上下文、当前语气和发送意图,从候选里选出最合适的一个表情包。\n"
|
||||
"必须只从候选消息中选择,不能编造新的 ID。\n"
|
||||
"如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n"
|
||||
"你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n"
|
||||
'返回格式固定为:{"emoji_id":"候选ID","matched_emotion":"情绪标签","reason":"简短理由"}'
|
||||
)
|
||||
prompt_message = ReferenceMessage(
|
||||
content=(
|
||||
f"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"reasoning: {reasoning or '辅助表达当前语气和情绪'}\n"
|
||||
f"recent_context:\n{context_text}\n"
|
||||
'请只输出 JSON。'
|
||||
),
|
||||
timestamp=datetime.now(),
|
||||
reference_type=ReferenceMessageType.TOOL_HINT,
|
||||
remaining_uses_value=1,
|
||||
display_prefix="[表情包选择任务]",
|
||||
)
|
||||
|
||||
response = await tool_ctx.runtime.run_sub_agent(
|
||||
context_message_limit=_EMOJI_SUB_AGENT_CONTEXT_LIMIT,
|
||||
system_prompt=system_prompt,
|
||||
extra_messages=[prompt_message, *candidate_messages],
|
||||
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
|
||||
response_format=RespFormat(
|
||||
format_type=RespFormatType.JSON_SCHEMA,
|
||||
schema=EmojiSelectionResult,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
selection = EmojiSelectionResult.model_validate_json(response.content or "")
|
||||
except Exception as exc:
|
||||
logger.warning(f"{tool_ctx.runtime.log_prefix} 表情包子代理结果解析失败,将回退到候选首项: {exc}")
|
||||
fallback_emoji = sampled_emojis[0] if sampled_emojis else None
|
||||
return fallback_emoji, requested_emotion
|
||||
|
||||
selected_emoji = candidate_map.get(selection.emoji_id.strip())
|
||||
if selected_emoji is None:
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包子代理返回了无效 ID: {selection.emoji_id!r},将回退到候选首项"
|
||||
)
|
||||
fallback_emoji = sampled_emojis[0] if sampled_emojis else None
|
||||
return fallback_emoji, requested_emotion
|
||||
|
||||
matched_emotion = selection.matched_emotion.strip()
|
||||
if not matched_emotion:
|
||||
matched_emotion = requested_emotion.strip()
|
||||
return selected_emoji, matched_emotion
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
@@ -64,6 +187,13 @@ async def handle_tool(
|
||||
requested_emotion=emotion,
|
||||
reasoning=tool_ctx.engine.last_reasoning_content,
|
||||
context_texts=context_texts,
|
||||
emoji_selector=lambda requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
|
||||
tool_ctx,
|
||||
requested_emotion,
|
||||
reasoning,
|
||||
list(context_texts or []),
|
||||
sample_size,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(f"{tool_ctx.runtime.log_prefix} 发送表情包时发生异常: {exc}")
|
||||
@@ -74,28 +204,29 @@ async def handle_tool(
|
||||
structured_content=structured_result,
|
||||
)
|
||||
|
||||
structured_result["description"] = send_result.description
|
||||
structured_result["emotion"] = list(send_result.emotions)
|
||||
structured_result["matched_emotion"] = send_result.matched_emotion
|
||||
structured_result["message"] = send_result.message
|
||||
|
||||
if send_result.success:
|
||||
structured_result["message"] = _EMOJI_SUCCESS_MESSAGE
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
|
||||
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
|
||||
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
|
||||
f"{tool_ctx.runtime.log_prefix} ??????? "
|
||||
f"??={send_result.description!r} ????={send_result.emotions} "
|
||||
f"????={emotion!r} ????={send_result.matched_emotion!r}"
|
||||
)
|
||||
tool_ctx.append_sent_emoji_to_chat_history(
|
||||
emoji_base64=send_result.emoji_base64,
|
||||
success_message=send_result.message,
|
||||
success_message=_EMOJI_SUCCESS_MESSAGE,
|
||||
)
|
||||
structured_result["success"] = True
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
send_result.message,
|
||||
_EMOJI_SUCCESS_MESSAGE,
|
||||
structured_content=structured_result,
|
||||
)
|
||||
|
||||
structured_result["description"] = send_result.description
|
||||
structured_result["emotion"] = list(send_result.emotions)
|
||||
structured_result["matched_emotion"] = send_result.matched_emotion
|
||||
structured_result["message"] = send_result.message
|
||||
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送失败 "
|
||||
f"请求情绪={emotion!r} 错误信息={send_result.message}"
|
||||
|
||||
@@ -210,7 +210,7 @@ class MaisakaChatLoopService:
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
self._prompts_loaded = False
|
||||
self._prompts_loaded = chat_system_prompt is not None
|
||||
self._prompt_load_lock = asyncio.Lock()
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
if chat_system_prompt is None:
|
||||
@@ -392,7 +392,12 @@ class MaisakaChatLoopService:
|
||||
"""设置当前 planner 请求使用的中断标记。"""
|
||||
self._interrupt_flag = interrupt_flag
|
||||
|
||||
def _build_request_messages(self, selected_history: List[LLMContextMessage]) -> List[Message]:
|
||||
def _build_request_messages(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
"""构造发给大模型的消息列表。
|
||||
|
||||
Args:
|
||||
@@ -404,7 +409,7 @@ class MaisakaChatLoopService:
|
||||
|
||||
messages: List[Message] = []
|
||||
system_msg = MessageBuilder().set_role(RoleType.System)
|
||||
system_msg.add_text_content(self._chat_system_prompt)
|
||||
system_msg.add_text_content(system_prompt if system_prompt is not None else self._chat_system_prompt)
|
||||
messages.append(system_msg.build())
|
||||
|
||||
for msg in selected_history:
|
||||
@@ -691,7 +696,13 @@ class MaisakaChatLoopService:
|
||||
|
||||
return extract_category_ids_from_result(generation_result.response or "")
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
|
||||
async def chat_loop_step(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
response_format: RespFormat | None = None,
|
||||
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""执行一轮 Maisaka 规划器请求。
|
||||
|
||||
Args:
|
||||
@@ -701,8 +712,9 @@ class MaisakaChatLoopService:
|
||||
ChatResponse: 本轮规划器返回结果。
|
||||
"""
|
||||
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
|
||||
if not self._prompts_loaded:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
@@ -719,7 +731,9 @@ class MaisakaChatLoopService:
|
||||
return built_messages
|
||||
|
||||
all_tools: List[ToolDefinitionInput]
|
||||
if self._tool_registry is not None:
|
||||
if tool_definitions is not None:
|
||||
all_tools = list(tool_definitions)
|
||||
elif self._tool_registry is not None:
|
||||
tool_specs = await self._tool_registry.list_tools()
|
||||
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
|
||||
@@ -748,10 +762,10 @@ class MaisakaChatLoopService:
|
||||
|
||||
ordered_panels = PromptCLIVisualizer.build_prompt_panels(
|
||||
built_messages,
|
||||
image_display_mode=global_config.maisaka.terminal_image_display_mode,
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
)
|
||||
|
||||
if global_config.maisaka.show_thinking and ordered_panels:
|
||||
if global_config.debug.show_maisaka_thinking and ordered_panels:
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*ordered_panels),
|
||||
@@ -776,6 +790,7 @@ class MaisakaChatLoopService:
|
||||
tool_options=all_tools if all_tools else None,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
response_format=response_format,
|
||||
interrupt_flag=self._interrupt_flag,
|
||||
),
|
||||
)
|
||||
@@ -837,6 +852,40 @@ class MaisakaChatLoopService:
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_llm_context_messages(
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
max_context_size: Optional[int] = None,
|
||||
) -> tuple[List[LLMContextMessage], str]:
|
||||
"""??????? LLM ???????"""
|
||||
|
||||
effective_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
continue
|
||||
|
||||
selected_indices.append(index)
|
||||
if message.count_in_context:
|
||||
counted_message_count += 1
|
||||
if counted_message_count >= effective_context_size:
|
||||
break
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"???????? {effective_context_size} ? user/assistant??? 0 ??"
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history = MaisakaChatLoopService._drop_leading_orphan_tool_results(selected_history)
|
||||
return (
|
||||
selected_history,
|
||||
f"???????? {effective_context_size} ? user/assistant??????????? {len(selected_history)} ?",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]:
|
||||
"""选择真正发送给 LLM 的上下文消息。
|
||||
@@ -905,4 +954,4 @@ class MaisakaChatLoopService:
|
||||
|
||||
if first_valid_index == 0:
|
||||
return selected_history
|
||||
return selected_history[first_valid_index:]
|
||||
return selected_history[first_valid_index:]
|
||||
|
||||
@@ -266,7 +266,7 @@ class MaisakaReasoningEngine:
|
||||
source_sequence = message.raw_message
|
||||
|
||||
planner_components = clone_message_sequence(source_sequence).components
|
||||
if global_config.maisaka.direct_image_input:
|
||||
if global_config.chat.direct_image_input:
|
||||
await self._hydrate_visual_components(planner_components)
|
||||
if planner_components and isinstance(planner_components[0], TextComponent):
|
||||
planner_components[0].text = planner_prefix + planner_components[0].text
|
||||
@@ -610,16 +610,8 @@ class MaisakaReasoningEngine:
|
||||
return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}"
|
||||
|
||||
if invocation.tool_name == "send_emoji":
|
||||
description = str(structured_content.get("description") or "").strip()
|
||||
emotion_list = structured_content.get("emotion")
|
||||
if isinstance(emotion_list, list):
|
||||
emotion_text = "、".join(str(item).strip() for item in emotion_list if str(item).strip())
|
||||
else:
|
||||
emotion_text = ""
|
||||
if result.success and description:
|
||||
if emotion_text:
|
||||
return f"你发送了表情包:{description}(情绪:{emotion_text})"
|
||||
return f"你发送了表情包:{description}"
|
||||
if result.success:
|
||||
return "你发送了表情包。"
|
||||
return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}"
|
||||
|
||||
if invocation.tool_name == "wait":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Maisaka 非 CLI 运行时。"""
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
@@ -20,12 +20,14 @@ from src.core.tooling import ToolRegistry
|
||||
from src.know_u.knowledge import KnowledgeLearner
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from src.mcp_module import MCPManager
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
from src.mcp_module.provider import MCPToolProvider
|
||||
from src.plugin_runtime.tool_provider import PluginToolProvider
|
||||
|
||||
from .chat_loop_service import MaisakaChatLoopService
|
||||
from .chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||
from .context_messages import LLMContextMessage
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .tool_provider import MaisakaBuiltinToolProvider
|
||||
@@ -197,6 +199,40 @@ class MaisakaHeartFlowChatting:
|
||||
self._tool_registry.register_provider(PluginToolProvider())
|
||||
self._chat_loop_service.set_tool_registry(self._tool_registry)
|
||||
|
||||
async def run_sub_agent(
|
||||
self,
|
||||
*,
|
||||
context_message_limit: int,
|
||||
system_prompt: str,
|
||||
extra_messages: Optional[Sequence[LLMContextMessage]] = None,
|
||||
max_tokens: int = 512,
|
||||
response_format: RespFormat | None = None,
|
||||
temperature: float = 0.2,
|
||||
tool_definitions: Optional[Sequence[ToolDefinitionInput]] = None,
|
||||
) -> ChatResponse:
|
||||
"""运行一个复制上下文的临时子代理,并在完成后立即销毁。"""
|
||||
|
||||
selected_history, _ = MaisakaChatLoopService.select_llm_context_messages(
|
||||
self._chat_history,
|
||||
max_context_size=context_message_limit,
|
||||
)
|
||||
sub_agent_history = list(selected_history)
|
||||
if extra_messages:
|
||||
sub_agent_history.extend(list(extra_messages))
|
||||
|
||||
sub_agent = MaisakaChatLoopService(
|
||||
chat_system_prompt=system_prompt,
|
||||
session_id=self.session_id,
|
||||
is_group_chat=self.chat_stream.is_group_session,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return await sub_agent.chat_loop_step(
|
||||
sub_agent_history,
|
||||
response_format=response_format,
|
||||
tool_definitions=[] if tool_definitions is None else tool_definitions,
|
||||
)
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
try:
|
||||
while self._running:
|
||||
@@ -421,7 +457,7 @@ class MaisakaHeartFlowChatting:
|
||||
if self.chat_stream.user_id:
|
||||
return UserInfo(
|
||||
user_id=self.chat_stream.user_id,
|
||||
user_nickname=global_config.maisaka.user_name.strip() or "用户",
|
||||
user_nickname=global_config.maisaka.cli_user_name.strip() or "用户",
|
||||
user_cardname=None,
|
||||
)
|
||||
return UserInfo(user_id="maisaka_user", user_nickname="用户", user_cardname=None)
|
||||
@@ -455,7 +491,7 @@ class MaisakaHeartFlowChatting:
|
||||
tool_results: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""在终端展示当前聊天流的上下文占用、规划结果与工具摘要。"""
|
||||
if not global_config.maisaka.show_thinking:
|
||||
if not global_config.debug.show_maisaka_thinking:
|
||||
return
|
||||
|
||||
session_name = chat_manager.get_session_name(self.session_id) or self.session_id
|
||||
|
||||
@@ -458,6 +458,62 @@ class PluginRunner:
|
||||
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
|
||||
return normalized_config, False
|
||||
|
||||
@staticmethod
|
||||
def _merge_plugin_config_document(target: Any, source: Any) -> None:
|
||||
"""递归更新现有 TOML 文档,尽量保留原注释与格式。
|
||||
|
||||
这里采用“更新已有键、补充缺失键”的策略,而不是直接整体重写,
|
||||
这样插件启动时因补齐默认配置触发落盘时,可以尽量保留用户手写的注释。
|
||||
|
||||
Args:
|
||||
target: 现有的 TOML 文档或表对象。
|
||||
source: 最新的配置字典。
|
||||
"""
|
||||
|
||||
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
|
||||
return
|
||||
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
PluginRunner._merge_plugin_config_document(target_value, value)
|
||||
else:
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
else:
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
|
||||
@staticmethod
|
||||
def _has_extra_config_keys(existing_config: Any, latest_config: Any) -> bool:
|
||||
"""判断现有配置中是否包含新配置不存在的键。
|
||||
|
||||
如果插件归一化后的结果删除了某些旧键,就需要回退到完整重写,
|
||||
否则仅做增量合并会把旧键残留在文件里。
|
||||
|
||||
Args:
|
||||
existing_config: 现有配置字典。
|
||||
latest_config: 最新配置字典。
|
||||
|
||||
Returns:
|
||||
bool: 是否存在需要通过整文件重写才能删除的旧键。
|
||||
"""
|
||||
|
||||
if not isinstance(existing_config, dict) or not isinstance(latest_config, dict):
|
||||
return False
|
||||
|
||||
for key, existing_value in existing_config.items():
|
||||
if key not in latest_config:
|
||||
return True
|
||||
if PluginRunner._has_extra_config_keys(existing_value, latest_config[key]):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
|
||||
"""根据配置内容判断插件是否应被视为启用。
|
||||
@@ -496,6 +552,19 @@ class PluginRunner:
|
||||
|
||||
config_path = Path(plugin_dir) / "config.toml"
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if config_path.exists():
|
||||
try:
|
||||
with config_path.open("r", encoding="utf-8") as handle:
|
||||
existing_document = tomlkit.load(handle)
|
||||
existing_config = existing_document.unwrap()
|
||||
if not PluginRunner._has_extra_config_keys(existing_config, config_data):
|
||||
PluginRunner._merge_plugin_config_document(existing_document, config_data)
|
||||
with config_path.open("w", encoding="utf-8") as handle:
|
||||
handle.write(tomlkit.dumps(existing_document))
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(f"保留插件配置注释失败,将回退为整文件重写: {config_path}: {exc}")
|
||||
|
||||
with config_path.open("w", encoding="utf-8") as handle:
|
||||
handle.write(tomlkit.dumps(config_data))
|
||||
|
||||
|
||||
160
src/services/embedding_service.py
Normal file
160
src/services/embedding_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Embedding 服务层。
|
||||
|
||||
该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到
|
||||
`src.llm_models` 中的底层嵌入调度器。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Coroutine, List, TypeVar
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.service_task_resolver import resolve_task_name
|
||||
|
||||
logger = get_logger("embedding_service")
|
||||
|
||||
_CoroutineReturnT = TypeVar("_CoroutineReturnT")
|
||||
|
||||
|
||||
class EmbeddingServiceClient:
|
||||
"""面向上层模块的 Embedding 服务对象式门面。"""
|
||||
|
||||
def __init__(self, task_name: str = "embedding", request_type: str = "") -> None:
|
||||
"""初始化 Embedding 服务门面。
|
||||
|
||||
Args:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""生成单条文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 统一嵌入结果对象。
|
||||
"""
|
||||
raw_result = await self._orchestrator.get_embedding(embedding_input)
|
||||
return EmbeddingResult(
|
||||
embedding=list(raw_result.embedding),
|
||||
model_name=raw_result.model_name,
|
||||
)
|
||||
|
||||
async def embed_texts(
|
||||
self,
|
||||
embedding_inputs: List[str],
|
||||
max_concurrent: int | None = None,
|
||||
) -> List[EmbeddingResult]:
|
||||
"""批量生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_inputs: 待编码的文本列表。
|
||||
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||
|
||||
Returns:
|
||||
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||
"""
|
||||
if not embedding_inputs:
|
||||
return []
|
||||
|
||||
safe_max_concurrent = max(1, int(max_concurrent or 1))
|
||||
if safe_max_concurrent == 1:
|
||||
results: List[EmbeddingResult] = []
|
||||
for embedding_input in embedding_inputs:
|
||||
results.append(await self.embed_text(embedding_input))
|
||||
return results
|
||||
|
||||
semaphore = asyncio.Semaphore(safe_max_concurrent)
|
||||
|
||||
async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]:
|
||||
"""执行单条嵌入并保留原始顺序索引。
|
||||
|
||||
Args:
|
||||
index: 原始输入索引。
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。
|
||||
"""
|
||||
async with semaphore:
|
||||
result = await self.embed_text(embedding_input)
|
||||
return index, result
|
||||
|
||||
ordered_results = await asyncio.gather(
|
||||
*[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)]
|
||||
)
|
||||
ordered_results.sort(key=lambda item: item[0])
|
||||
return [result for _, result in ordered_results]
|
||||
|
||||
def embed_text_sync(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""以同步方式生成单条文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 统一嵌入结果对象。
|
||||
"""
|
||||
return self._run_coroutine_sync(self.embed_text(embedding_input))
|
||||
|
||||
def embed_texts_sync(
|
||||
self,
|
||||
embedding_inputs: List[str],
|
||||
max_concurrent: int | None = None,
|
||||
) -> List[EmbeddingResult]:
|
||||
"""以同步方式批量生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_inputs: 待编码的文本列表。
|
||||
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||
|
||||
Returns:
|
||||
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||
"""
|
||||
return self._run_coroutine_sync(
|
||||
self.embed_texts(
|
||||
embedding_inputs,
|
||||
max_concurrent=max_concurrent,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT:
|
||||
"""在独立事件循环中执行协程。
|
||||
|
||||
Args:
|
||||
coroutine: 需要同步执行的协程对象。
|
||||
|
||||
Returns:
|
||||
_CoroutineReturnT: 协程返回值。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前线程已有运行中的事件循环时抛出。
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coroutine)
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
except Exception as exc:
|
||||
logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}")
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
@@ -8,9 +8,9 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
from src.common.data_models.llm_service_data_models import (
|
||||
LLMAudioTranscriptionResult,
|
||||
LLMEmbeddingResult,
|
||||
LLMGenerationOptions,
|
||||
LLMImageOptions,
|
||||
LLMResponseResult,
|
||||
@@ -21,15 +21,20 @@ from src.common.data_models.llm_service_data_models import (
|
||||
PromptMessage,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
from src.services.service_task_resolver import (
|
||||
get_available_models as _get_available_models,
|
||||
resolve_task_name as _resolve_task_name,
|
||||
resolve_task_name_from_model_config as _resolve_task_name_from_model_config,
|
||||
)
|
||||
|
||||
logger = get_logger("llm_service")
|
||||
|
||||
|
||||
class LLMServiceClient:
|
||||
"""面向上层模块的 LLM 服务对象式门面。
|
||||
|
||||
@@ -38,7 +43,7 @@ class LLMServiceClient:
|
||||
- `generate_response_with_messages`
|
||||
- `generate_response_for_image`
|
||||
- `transcribe_audio`
|
||||
- `embed_text`
|
||||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
@@ -48,7 +53,7 @@ class LLMServiceClient:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.task_name = _resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@@ -169,41 +174,29 @@ class LLMServiceClient:
|
||||
"""
|
||||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
|
||||
"""生成文本嵌入向量。
|
||||
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""兼容旧调用的文本嵌入入口。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本。
|
||||
|
||||
Returns:
|
||||
LLMEmbeddingResult: 向量生成结果对象。
|
||||
EmbeddingResult: 向量生成结果对象。
|
||||
"""
|
||||
return await self._orchestrator.get_embedding(embedding_input)
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name=self.task_name,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
return await embedding_client.embed_text(embedding_input)
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
def get_available_models() -> Dict[str, Any]:
|
||||
"""获取所有可用模型配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
|
||||
Dict[str, Any]: 以模型任务名为键的配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
|
||||
return {}
|
||||
return _get_available_models()
|
||||
|
||||
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
@@ -214,75 +207,24 @@ def resolve_task_name(task_name: str = "") -> str:
|
||||
|
||||
Returns:
|
||||
str: 解析得到的任务配置名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置。
|
||||
ValueError: 指定名称不存在时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
return _resolve_task_name(task_name)
|
||||
|
||||
|
||||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
|
||||
|
||||
该方法用于兼容仍以 `model_config` 传参的调用方:
|
||||
1. 优先使用显式给出的 `preferred_task_name`;
|
||||
2. 其次匹配对象同一性;
|
||||
3. 再尝试按 `model_list` 精确匹配;
|
||||
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
|
||||
|
||||
Args:
|
||||
model_config: 旧调用方持有的任务配置对象。
|
||||
preferred_task_name: 候选任务名(可选)。
|
||||
|
||||
Returns:
|
||||
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有可用模型配置。
|
||||
ValueError: 无法解析任何可用任务名时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_preferred = str(preferred_task_name or "").strip()
|
||||
if normalized_preferred and normalized_preferred in models:
|
||||
return normalized_preferred
|
||||
|
||||
for task_name, task_cfg in models.items():
|
||||
if task_cfg is model_config:
|
||||
return task_name
|
||||
|
||||
requested_model_list_raw = getattr(model_config, "model_list", [])
|
||||
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
||||
if requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if candidate_list == requested_model_list:
|
||||
return task_name
|
||||
|
||||
for requested_model in requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if requested_model in candidate_list:
|
||||
logger.info(
|
||||
"[LLMService] 旧版 model_config 未命中任务配置,"
|
||||
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
||||
)
|
||||
return task_name
|
||||
|
||||
if normalized_preferred:
|
||||
logger.warning(f"[LLMService] 无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
||||
return resolve_task_name("")
|
||||
return _resolve_task_name_from_model_config(
|
||||
model_config=model_config,
|
||||
preferred_task_name=preferred_task_name,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_role(role_name: str) -> RoleType:
|
||||
|
||||
108
src/services/service_task_resolver.py
Normal file
108
src/services/service_task_resolver.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""服务层模型任务解析工具。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
logger = get_logger("service_task_resolver")
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取当前所有可用的模型任务配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, TaskConfig]: 以任务名为键的可用任务配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"获取模型任务配置属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"获取可用模型配置失败: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
"""根据任务名解析实际可用的模型任务名称。
|
||||
|
||||
Args:
|
||||
task_name: 目标任务名;为空时返回首个可用任务。
|
||||
|
||||
Returns:
|
||||
str: 解析后的模型任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||
ValueError: 指定任务名不存在时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
|
||||
|
||||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||
"""根据旧版模型配置对象解析任务名。
|
||||
|
||||
Args:
|
||||
model_config: 旧调用方持有的任务配置对象。
|
||||
preferred_task_name: 候选任务名。
|
||||
|
||||
Returns:
|
||||
str: 解析后的模型任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||
ValueError: 无法解析任何可用任务名时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_preferred = str(preferred_task_name or "").strip()
|
||||
if normalized_preferred and normalized_preferred in models:
|
||||
return normalized_preferred
|
||||
|
||||
for task_name, task_cfg in models.items():
|
||||
if task_cfg is model_config:
|
||||
return task_name
|
||||
|
||||
requested_model_list_raw = getattr(model_config, "model_list", [])
|
||||
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
||||
if requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if candidate_list == requested_model_list:
|
||||
return task_name
|
||||
|
||||
for requested_model in requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if requested_model in candidate_list:
|
||||
logger.info(
|
||||
"旧版 model_config 未命中任务配置,"
|
||||
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
||||
)
|
||||
return task_name
|
||||
|
||||
if normalized_preferred:
|
||||
logger.warning(f"无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
||||
return resolve_task_name("")
|
||||
Reference in New Issue
Block a user