feat: Enhance OpenAI compatibility and introduce unified LLM service data models

- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs.
- Introduced new data models for LLM service requests and responses to standardize interactions across layers.
- Added an adapter base class for unified request execution across different providers.
- Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
DrSmoothl
2026-03-26 16:15:42 +08:00
parent 6e7daae55d
commit 777d4cb0d2
48 changed files with 5443 additions and 2945 deletions

View File

@@ -19,7 +19,7 @@ dependencies = [
"jieba>=0.42.1",
"json-repair>=0.47.6",
"maim-message>=0.6.2",
"maibot-plugin-sdk>=2.0.0",
"maibot-plugin-sdk>=2.1.0",
"mcp",
"msgpack>=1.1.2",
"numpy>=2.2.6",

View File

@@ -2,8 +2,8 @@ import time
from typing import Tuple, Optional # 增加了 Optional
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
@@ -109,8 +109,8 @@ class ActionPlanner:
"""行动规划器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.planner,
self.llm = LLMServiceClient(
task_name="planner",
request_type="action_planning",
)
self.personality_info = self._get_personality_prompt()
@@ -398,7 +398,8 @@ class ActionPlanner:
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
try:
content, _ = await self.llm.generate_response_async(prompt)
generation_result = await self.llm.generate_response(prompt)
content = generation_result.response
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
# --- 初始行动规划解析 ---
@@ -427,7 +428,8 @@ class ActionPlanner:
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
)
try:
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
end_generation_result = await self.llm.generate_response(end_decision_prompt)
end_content = end_generation_result.response # 再次调用LLM
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
# 解析结束决策的JSON

View File

@@ -1,7 +1,7 @@
from typing import List, Tuple, TYPE_CHECKING
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
@@ -43,7 +43,9 @@ class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
self.llm = LLMServiceClient(
task_name="planner", request_type="conversation_goal"
)
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname
@@ -157,7 +159,8 @@ class GoalAnalyzer:
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
try:
content, _ = await self.llm.generate_response_async(prompt)
generation_result = await self.llm.generate_response(prompt)
content = generation_result.response
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
@@ -271,7 +274,8 @@ class GoalAnalyzer:
}}"""
try:
content, _ = await self.llm.generate_response_async(prompt)
generation_result = await self.llm.generate_response(prompt)
content = generation_result.response
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
# 尝试解析JSON

View File

@@ -3,8 +3,7 @@ from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.services.llm_service import LLMServiceClient
from src.chat.knowledge import qa_manager
logger = get_logger("knowledge_fetcher")
@@ -14,7 +13,7 @@ class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
self.llm = LLMServiceClient(task_name="utils")
self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str:

View File

@@ -2,8 +2,8 @@ import json
import random
from typing import Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from .chat_observer import ChatObserver
from maim_message import UserInfo
@@ -14,7 +14,7 @@ class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
self.llm = LLMServiceClient(task_name="utils", request_type="reply_check")
self.personality_info = self._get_personality_prompt()
self.name = global_config.bot.nickname
self.private_name = private_name
@@ -137,7 +137,8 @@ class ReplyChecker:
注意请严格按照JSON格式输出不要包含任何其他内容。"""
try:
content, _ = await self.llm.generate_response_async(prompt)
generation_result = await self.llm.generate_response(prompt)
content = generation_result.response
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
# 清理内容尝试提取JSON部分

View File

@@ -1,7 +1,7 @@
from typing import Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .reply_checker import ReplyChecker
@@ -87,8 +87,8 @@ class ReplyGenerator:
"""回复生成器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model_set=model_config.model_task_config.replyer,
self.llm = LLMServiceClient(
task_name="replyer",
request_type="reply_generation",
)
self.personality_info = self._get_personality_prompt()
@@ -223,7 +223,8 @@ class ReplyGenerator:
# --- 调用 LLM 生成 ---
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
try:
content, _ = await self.llm.generate_response_async(prompt)
generation_result = await self.llm.generate_response(prompt)
content = generation_result.response
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
return content

View File

@@ -17,9 +17,9 @@ from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.logger import get_logger
from src.common.utils.utils_action import ActionUtils
from src.config.config import global_config, model_config
from src.config.config import global_config
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.llm_models.utils_model import LLMRequest
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
@@ -43,8 +43,8 @@ class BrainPlanner:
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="planner"
self.planner_llm = LLMServiceClient(
task_name="planner", request_type="planner"
) # 用于动作规划
self.last_obs_time_mark = 0.0
@@ -412,7 +412,9 @@ class BrainPlanner:
try:
# 调用LLM
llm_start = time.perf_counter()
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
generation_result = await self.planner_llm.generate_response(prompt=prompt)
llm_content = generation_result.response
reasoning_content = generation_result.reasoning
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
llm_reasoning = reasoning_content

View File

@@ -17,8 +17,9 @@ from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -38,8 +39,10 @@ def _ensure_directories() -> None:
# TODO: 修改这个vlm为获取的vlm client暂时使用这个VLM方法
emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see")
emoji_manager_emotion_judge_llm = LLMServiceClient(
task_name="utils", request_type="emoji"
)
class EmojiManager:
@@ -461,9 +464,11 @@ class EmojiManager:
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
emoji_replace_prompt, temperature=0.8, max_tokens=600
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
emoji_replace_prompt,
options=LLMGenerationOptions(temperature=0.8, max_tokens=600),
)
decision = decision_result.response
logger.info(f"[决策] 结果: {decision}")
# 解析决策结果
@@ -524,24 +529,36 @@ class EmojiManager:
return False, target_emoji
prompt: str = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明简短描述一下表情包表达的情感和内容从互联网梗、meme的角度去分析精简回答"
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
description, _ = await emoji_manager_vlm.generate_response_for_image(
prompt, image_base64, "jpg", temperature=0.5
description_result = await emoji_manager_vlm.generate_response_for_image(
prompt,
image_base64,
"jpg",
options=LLMImageOptions(temperature=0.5),
)
description = description_result.response
else:
prompt: str = "这是一个表情包请详细描述一下表情包所表达的情感和内容简短描述细节从互联网梗、meme的角度去分析精简回答"
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
description, _ = await emoji_manager_vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.5
description_result = await emoji_manager_vlm.generate_response_for_image(
prompt,
image_base64,
image_format,
options=LLMImageOptions(temperature=0.5),
)
description = description_result.response
# 表情包审查
if global_config.emoji.content_filtration:
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
llm_response, _ = await emoji_manager_vlm.generate_response_for_image(
filtration_prompt, image_base64, image_format, temperature=0.3
filtration_result = await emoji_manager_vlm.generate_response_for_image(
filtration_prompt,
image_base64,
image_format,
options=LLMImageOptions(temperature=0.3),
)
llm_response = filtration_result.response
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
@@ -567,9 +584,11 @@ class EmojiManager:
emotion_prompt_template.add_context("description", target_emoji.description)
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
# 调用LLM生成情感标签
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
emotion_prompt, temperature=0.3, max_tokens=200
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
emotion_prompt,
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
)
emotion_result = emotion_generation_result.response
# 解析情感标签结果
emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()]

View File

@@ -11,8 +11,9 @@ from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.data_models.image_data_model import MaiImage
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMImageOptions
from src.services.llm_service import LLMServiceClient
install(extra_lines=3)
@@ -27,7 +28,7 @@ def _ensure_image_dir_exists():
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
vlm = LLMServiceClient(task_name="vlm", request_type="image")
class ImageManager:
@@ -260,7 +261,13 @@ class ImageManager:
prompt = global_config.personality.visual_style
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4)
generation_result = await vlm.generate_response_for_image(
prompt,
image_base64,
image_format,
options=LLMImageOptions(temperature=0.4),
)
description = generation_result.response
if not description:
logger.warning("VLM未能生成图片描述")
return description or ""

View File

@@ -139,14 +139,14 @@ class EmbeddingStore:
asyncio.set_event_loop(loop)
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
# 创建新的服务层实例
from src.services.llm_service import LLMServiceClient
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding = embedding_result.embedding
if embedding and len(embedding) > 0:
return embedding
@@ -195,13 +195,12 @@ class EmbeddingStore:
start_idx, chunk_strs = chunk_data
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
# 为每个线程创建独立的服务层实例
from src.services.llm_service import LLMServiceClient
try:
# 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 创建线程专用的服务层实例
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
@@ -209,7 +208,8 @@ class EmbeddingStore:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding = loop.run_until_complete(llm.get_embedding(s))
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding = embedding_result.embedding
finally:
loop.close()

View File

@@ -1,18 +1,27 @@
import asyncio
import json
import time
from typing import List, Union
from typing import Dict, List, Tuple, Union
from .global_logger import logger
from . import prompt_template
from . import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
def _extract_json_from_text(text: str):
from . import INVALID_ENTITY
from . import prompt_template
from .global_logger import logger
def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]:
# sourcery skip: assign-if-exp, extract-method
"""从文本中提取JSON数据的高容错方法"""
"""从文本中提取 JSON 数据
Args:
text: 原始模型输出文本。
Returns:
List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。
"""
if text is None:
logger.error("输入文本为None")
return []
@@ -46,20 +55,30 @@ def _extract_json_from_text(text: str):
return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""段落进行实体提取返回提取出的实体列表JSON格式"""
"""单段文本执行实体提取。
Args:
llm_req: LLM 服务门面实例。
paragraph: 待提取实体的原始段落文本。
Returns:
List[str]: 提取出的实体列表。
"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
response, _ = future.result()
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop)
generation_result = future.result()
response = generation_result.response
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
generation_result = asyncio.run(llm_req.generate_response(entity_extract_context))
response = generation_result.response
# 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}")
@@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
return entity_extract_result
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
def _rdf_triple_extract(
llm_req: LLMServiceClient,
paragraph: str,
entities: List[str],
) -> List[List[str]]:
"""对单段文本执行 RDF 三元组提取。
Args:
llm_req: LLM 服务门面实例。
paragraph: 待提取的原始段落文本。
entities: 已识别出的实体列表。
Returns:
List[List[str]]: 提取出的三元组列表。
"""
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
@@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
response, _ = future.result()
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop)
generation_result = future.result()
response = generation_result.response
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context))
response = generation_result.response
# 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}")
@@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
def info_extract_from_str(
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
llm_client_for_ner: LLMServiceClient,
llm_client_for_rdf: LLMServiceClient,
paragraph: str,
) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]:
"""从文本中提取实体与三元组信息。
Args:
llm_client_for_ner: 实体提取使用的 LLM 服务门面。
llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。
paragraph: 原始段落文本。
Returns:
Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回
``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。
"""
try_count = 0
while True:
try:
@@ -176,17 +223,30 @@ def info_extract_from_str(
class IEProcess:
"""
信息抽取处理器类,提供更方便的批次处理接口。
"""
"""信息抽取处理器。"""
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
def __init__(
self,
llm_ner: LLMServiceClient,
llm_rdf: LLMServiceClient | None = None,
) -> None:
"""初始化信息抽取处理器。
Args:
llm_ner: 实体提取使用的 LLM 服务门面。
llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。
"""
self.llm_ner = llm_ner
self.llm_rdf = llm_rdf or llm_ner
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
"""
异步处理多个段落。
async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]:
"""异步处理多个段落。
Args:
paragraphs: 待处理的段落列表。
Returns:
List[Dict[str, object]]: 每个成功段落对应的抽取结果。
"""
from .utils.hash import get_sha256

View File

@@ -91,13 +91,14 @@ class LPMMOperations:
# 2. 实体与三元组抽取 (内部调用大模型)
from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.services.llm_service import LLMServiceClient
llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
llm_ner = LLMServiceClient(
task_name="lpmm_entity_extract", request_type="lpmm.entity_extract"
)
llm_rdf = LLMServiceClient(
task_name="lpmm_rdf_build", request_type="lpmm.rdf_build"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")

View File

@@ -149,7 +149,7 @@ class ActionModifier:
random.shuffle(actions_to_check)
for action_name, action_info in actions_to_check:
activation_type = action_info.activation_type or action_info.focus_activation_type
activation_type = action_info.activation_type
if activation_type == ActionActivationType.ALWAYS:
continue # 总是激活,无需处理

View File

@@ -19,9 +19,9 @@ from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.config.config import global_config
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.llm_models.utils_model import LLMRequest
from src.services.llm_service import LLMServiceClient
from src.person_info.person_info import Person
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
@@ -46,8 +46,8 @@ class ActionPlanner:
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="planner"
self.planner_llm = LLMServiceClient(
task_name="planner", request_type="planner"
) # 用于动作规划
self.last_obs_time_mark = 0.0
@@ -725,7 +725,9 @@ class ActionPlanner:
try:
# 调用LLM
llm_start = time.perf_counter()
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
generation_result = await self.planner_llm.generate_response(prompt=prompt)
llm_content = generation_result.response
reasoning_content = generation_result.reasoning
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
llm_reasoning = reasoning_content

View File

@@ -10,8 +10,8 @@ from datetime import datetime
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.common.data_models.mai_message_data_model import MaiMessage
@@ -56,7 +56,9 @@ class DefaultReplyer:
chat_stream: 当前绑定的聊天会话。
request_type: LLM 请求类型标识。
"""
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.express_model = LLMServiceClient(
task_name="replyer", request_type=request_type
)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
@@ -1158,9 +1160,11 @@ class DefaultReplyer:
# else:
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
generation_result = await self.express_model.generate_response(prompt)
content = generation_result.response
reasoning_content = generation_result.reasoning
model_name = generation_result.model_name
tool_calls = generation_result.tool_calls
# 移除 content 前后的换行符和空格
content = content.strip()
@@ -1200,11 +1204,15 @@ class DefaultReplyer:
template_prompt.add_context("sender", sender)
template_prompt.add_context("target_message", target)
prompt = await prompt_manager.render_prompt(template_prompt)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[search_knowledge_tool.get_tool_definition()],
generation_result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name="tool_use",
request_type="replyer.lpmm_knowledge",
prompt=prompt,
tool_options=[search_knowledge_tool.get_tool_definition()],
)
)
tool_calls = generation_result.completion.tool_calls
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")

View File

@@ -9,8 +9,8 @@ from datetime import datetime
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.common.data_models.mai_message_data_model import MaiMessage
@@ -52,7 +52,9 @@ class PrivateReplyer:
chat_stream: 当前绑定的聊天会话。
request_type: LLM 请求类型标识。
"""
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.express_model = LLMServiceClient(
task_name="replyer", request_type=request_type
)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
# self.memory_activator = MemoryActivator()
@@ -997,9 +999,11 @@ class PrivateReplyer:
else:
logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
generation_result = await self.express_model.generate_response(prompt)
content = generation_result.response
reasoning_content = generation_result.reasoning
model_name = generation_result.model_name
tool_calls = generation_result.tool_calls
content = content.strip()

View File

@@ -4,16 +4,18 @@
自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, cast
import hashlib
import time
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.config.config import global_config
from src.core.announcement_manager import global_announcement_manager
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.component_query import component_query_service
from src.prompt.prompt_manager import prompt_manager
@@ -33,7 +35,9 @@ class ToolExecutor:
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
self.llm_model = LLMServiceClient(
task_name="tool_use", request_type="tool_executor"
)
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
@@ -69,9 +73,11 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
generation_result = await self.llm_model.generate_response(
prompt=prompt,
options=LLMGenerationOptions(tool_options=tools, raise_when_empty=False),
)
tool_calls = generation_result.tool_calls
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
@@ -85,11 +91,15 @@ class ToolExecutor:
return tool_results, used_tools, prompt
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
def _get_tool_definitions(self) -> List[ToolDefinitionInput]:
"""获取 LLM 可用的工具定义列表"""
all_tools = component_query_service.get_llm_available_tools()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
return [
cast(ToolDefinitionInput, info.get_llm_definition())
for name, info in all_tools.items()
if name not in user_disabled_tools
]
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用列表"""

View File

@@ -13,8 +13,8 @@ import jieba
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
from src.person_info.person_info import Person
from .typo_generator import ChineseTypoGenerator
@@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
# 每次都创建新的服务层实例以避免事件循环冲突
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)
embedding_result = await llm.embed_text(text)
embedding = embedding_result.embedding
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None

View File

@@ -0,0 +1,187 @@
"""LLM 服务层与编排层共享数据模型。
该模块集中定义 LLM 服务层与底层编排器共同使用的请求、选项与结果对象,
用于替代散落在各层之间的复杂元组返回值。
"""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeAlias
import asyncio
from src.common.data_models import BaseDataModel
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
if TYPE_CHECKING:
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
PromptMessage: TypeAlias = Dict[str, Any]
"""统一的原始提示消息结构。"""
PromptInput: TypeAlias = str | List[PromptMessage]
"""统一的提示输入类型。"""
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
"""统一的消息工厂类型。"""
@dataclass(slots=True)
class LLMServiceRequest(BaseDataModel):
"""LLM 服务层统一请求对象。"""
task_name: str
request_type: str
prompt: PromptInput | None = None
message_factory: MessageFactory | None = None
tool_options: List[ToolDefinitionInput] | None = None
temperature: float | None = None
max_tokens: int | None = None
response_format: RespFormat | None = None
interrupt_flag: asyncio.Event | None = None
def __post_init__(self) -> None:
"""校验请求对象的必要字段。
Raises:
ValueError: 当 `task_name` 为空,或 `prompt` 与 `message_factory`
的组合非法时抛出。
"""
self.task_name = self.task_name.strip()
if not self.task_name:
raise ValueError("`task_name` 不能为空")
has_prompt = self.prompt is not None
has_message_factory = self.message_factory is not None
if has_prompt == has_message_factory:
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
@dataclass(slots=True)
class LLMResponseResult(BaseDataModel):
"""单次 LLM 响应结果。"""
response: str = field(default_factory=str)
reasoning: str = field(default_factory=str)
model_name: str = field(default_factory=str)
tool_calls: List[ToolCall] | None = None
@dataclass(slots=True)
class LLMServiceResult(BaseDataModel):
"""LLM 服务层统一响应对象。"""
success: bool = False
completion: LLMResponseResult = field(default_factory=LLMResponseResult)
error: str | None = None
@classmethod
def from_response_result(cls, completion: LLMResponseResult) -> "LLMServiceResult":
"""从单次 LLM 响应结果构建服务响应。
Args:
completion: 单次 LLM 响应结果。
Returns:
LLMServiceResult: 标记为成功的服务响应对象。
"""
return cls(
success=True,
completion=completion,
error=None,
)
@classmethod
def from_error(cls, error_message: str, error_detail: str | None = None) -> "LLMServiceResult":
"""构建失败的服务响应对象。
Args:
error_message: 对上层展示的错误消息。
error_detail: 底层错误详情。
Returns:
LLMServiceResult: 标记为失败的服务响应对象。
"""
return cls(
success=False,
completion=LLMResponseResult(response=error_message),
error=error_detail or error_message,
)
def to_capability_payload(self) -> Dict[str, Any]:
"""转换为插件能力层可直接返回的结构。
Returns:
Dict[str, Any]: 标准化后的能力返回值。
"""
payload: Dict[str, Any] = {
"success": self.success,
"response": self.completion.response,
"reasoning": self.completion.reasoning,
"model_name": self.completion.model_name,
}
if self.completion.tool_calls is not None:
payload["tool_calls"] = [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": tool_call.args or {},
},
}
for tool_call in self.completion.tool_calls
]
if self.error:
payload["error"] = self.error
return payload
@dataclass(slots=True)
class LLMGenerationOptions(BaseDataModel):
"""LLM 文本生成选项。"""
temperature: float | None = None
max_tokens: int | None = None
tool_options: List[ToolDefinitionInput] | None = None
response_format: RespFormat | None = None
interrupt_flag: asyncio.Event | None = None
raise_when_empty: bool = True
@dataclass(slots=True)
class LLMImageOptions(BaseDataModel):
"""LLM 图像理解选项。"""
temperature: float | None = None
max_tokens: int | None = None
interrupt_flag: asyncio.Event | None = None
@dataclass(slots=True)
class LLMAudioTranscriptionResult(BaseDataModel):
"""LLM 音频转写结果。"""
text: str | None = None
@dataclass(slots=True)
class LLMEmbeddingResult(BaseDataModel):
"""LLM 向量生成结果。"""
embedding: List[float] = field(default_factory=list)
model_name: str = field(default_factory=str)
__all__ = [
"LLMAudioTranscriptionResult",
"LLMEmbeddingResult",
"LLMGenerationOptions",
"LLMImageOptions",
"LLMResponseResult",
"LLMServiceRequest",
"LLMServiceResult",
"MessageFactory",
"PromptInput",
"PromptMessage",
]

View File

@@ -4,16 +4,15 @@ from typing import Optional
import base64
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
install(extra_lines=3)
logger = get_logger("voice_utils")
# TODO: 在LLMRequest重构后修改这里
asr_model = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
asr_model = LLMServiceClient(task_name="voice", request_type="audio")
async def get_voice_text(voice_bytes: bytes) -> Optional[str]:
@@ -30,7 +29,8 @@ async def get_voice_text(voice_bytes: bytes) -> Optional[str]:
return None
try:
voice_base64 = base64.b64encode(voice_bytes).decode("utf-8")
text = await asr_model.generate_response_for_voice(voice_base64)
transcription_result = await asr_model.transcribe_audio(voice_base64)
text = transcription_result.text
if not text:
logger.warning("语音转文字结果为空")

View File

@@ -1,7 +1,35 @@
from enum import Enum
from typing import Any
from .config_base import ConfigBase, Field
from src.common.i18n import t
from .config_base import ConfigBase, Field
class OpenAICompatibleAuthType(str, Enum):
"""OpenAI 兼容接口的鉴权方式。"""
BEARER = "bearer"
HEADER = "header"
QUERY = "query"
NONE = "none"
class ReasoningParseMode(str, Enum):
"""推理内容解析策略。"""
AUTO = "auto"
NATIVE = "native"
THINK_TAG = "think_tag"
NONE = "none"
class ToolArgumentParseMode(str, Enum):
"""工具调用参数的解析策略。"""
AUTO = "auto"
STRICT = "strict"
REPAIR = "repair"
DOUBLE_DECODE = "double_decode"
class APIProvider(ConfigBase):
@@ -33,7 +61,7 @@ class APIProvider(ConfigBase):
"x-icon": "key",
},
)
"""API密钥"""
"""API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。"""
client_type: str = Field(
default="openai",
@@ -44,6 +72,105 @@ class APIProvider(ConfigBase):
)
"""客户端类型 (可选: openai/google, 默认为openai)"""
auth_type: str = Field(
default=OpenAICompatibleAuthType.BEARER.value,
json_schema_extra={
"x-widget": "select",
"x-icon": "shield",
},
)
"""OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。"""
auth_header_name: str = Field(
default="Authorization",
json_schema_extra={
"x-widget": "input",
"x-icon": "header",
},
)
"""当 `auth_type` 为 `header` 时使用的请求头名称。"""
auth_header_prefix: str = Field(
default="Bearer",
json_schema_extra={
"x-widget": "input",
"x-icon": "shield-check",
},
)
"""当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。"""
auth_query_name: str = Field(
default="api_key",
json_schema_extra={
"x-widget": "input",
"x-icon": "link",
},
)
"""当 `auth_type` 为 `query` 时使用的查询参数名称。"""
default_headers: dict[str, str] = Field(
default_factory=dict,
json_schema_extra={
"x-widget": "custom",
"x-icon": "header",
},
)
"""所有请求默认附带的 HTTP Header。"""
default_query: dict[str, str] = Field(
default_factory=dict,
json_schema_extra={
"x-widget": "custom",
"x-icon": "list-filter",
},
)
"""所有请求默认附带的查询参数。"""
organization: str | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "building-2",
},
)
"""OpenAI 官方接口可选的 `organization`。"""
project: str | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "folder-kanban",
},
)
"""OpenAI 官方接口可选的 `project`。"""
model_list_endpoint: str = Field(
default="/models",
json_schema_extra={
"x-widget": "input",
"x-icon": "list",
},
)
"""模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。"""
reasoning_parse_mode: str = Field(
default=ReasoningParseMode.AUTO.value,
json_schema_extra={
"x-widget": "select",
"x-icon": "brain",
},
)
"""推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。"""
tool_argument_parse_mode: str = Field(
default=ToolArgumentParseMode.AUTO.value,
json_schema_extra={
"x-widget": "select",
"x-icon": "braces",
},
)
"""工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。"""
max_retry: int = Field(
default=2,
ge=0,
@@ -76,15 +203,26 @@ class APIProvider(ConfigBase):
)
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
def model_post_init(self, context: Any = None):
"""确保api_key在repr中不被显示"""
if not self.api_key:
def model_post_init(self, context: Any = None) -> None:
"""执行 API 提供商配置的后置校验。
Args:
context: Pydantic 传入的上下文对象。
Raises:
ValueError: 当配置项缺失或组合不合法时抛出。
"""
if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key:
raise ValueError(t("config.api_key_empty"))
if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url
raise ValueError(t("config.api_base_url_empty"))
if not self.name:
raise ValueError(t("config.api_provider_name_empty"))
return super().model_post_init(context)
if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip():
raise ValueError("当 auth_type=header 时auth_header_name 不能为空")
if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip():
raise ValueError("当 auth_type=query 时auth_query_name 不能为空")
super().model_post_init(context)
class ModelInfo(ConfigBase):

View File

@@ -1,12 +1,13 @@
import copy
import warnings
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import copy
import warnings
from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
from src.llm_models.payload_content.tool_option import ToolCall
# from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType
# from src.common.data_models.message_data_model import ReplyContent as ReplyContent
# from src.common.data_models.message_data_model import ForwardNode as ForwardNode
@@ -15,49 +16,42 @@ from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
"""Host 内部使用的组件类型枚举"""
ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件
TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
TOOL = "tool" # 工具组件
def __str__(self) -> str:
"""返回枚举值字符串。
Returns:
str: 当前组件类型对应的字符串值。
"""
return self.value
# 动作激活类型枚举
class ActionActivationType(Enum):
"""动作激活类型枚举"""
"""动作激活类型枚举"""
NEVER = "never" # 从不激活(默认关闭)
ALWAYS = "always" # 默认参与到planner
RANDOM = "random" # 随机启用action到planner
KEYWORD = "keyword" # 关键词触发启用action到planner
def __str__(self):
return self.value
def __str__(self) -> str:
"""返回枚举值字符串。
# 聊天模式枚举
class ChatMode(Enum):
"""聊天模式枚举"""
FOCUS = "focus" # Focus聊天模式
NORMAL = "normal" # Normal聊天模式
PRIORITY = "priority" # 优先级聊天模式
ALL = "all" # 所有聊天模式
def __str__(self):
Returns:
str: 当前激活类型对应的字符串值。
"""
return self.value
# 事件类型枚举
class EventType(Enum):
"""
事件类型枚举类
"""
"""事件类型枚举。"""
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
@@ -72,185 +66,96 @@ class EventType(Enum):
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
"""返回枚举值字符串。
Returns:
str: 当前事件类型对应的字符串值。
"""
return self.value
@dataclass
class PythonDependency:
"""Python包依赖信息"""
package_name: str # 包名称
version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本
optional: bool = False # 是否为可选依赖
description: str = "" # 依赖描述
install_name: str = "" # 安装时的包名如果与import名不同
def __post_init__(self):
if not self.install_name:
self.install_name = self.package_name
def get_pip_requirement(self) -> str:
"""获取pip安装格式的依赖字符串"""
if self.version:
return f"{self.install_name}{self.version}"
return self.install_name
@dataclass
@dataclass(slots=True)
class ComponentInfo:
"""组件信息"""
"""Host 内部使用的组件信息快照。"""
name: str # 组件名称
component_type: ComponentType # 组件类型
description: str = "" # 组件描述
enabled: bool = True # 是否启用
plugin_name: str = "" # 所属插件名称
is_built_in: bool = False # 是否为内置组件
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
name: str
"""组件名称。"""
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
description: str = ""
"""组件描述。"""
enabled: bool = True
"""组件是否启用。"""
plugin_name: str = ""
"""所属插件 ID。"""
component_type: ComponentType = field(init=False)
"""组件类型。"""
@dataclass
@dataclass(slots=True)
class ActionInfo(ComponentInfo):
"""动作组件信息"""
"""供 Planner 与回复链使用的动作信息快照。"""
action_parameters: Dict[str, str] = field(
default_factory=dict
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
keyword_case_sensitive: bool = False
# 模式和并行设置
parallel_action: bool = False
component_type: ComponentType = field(init=False, default=ComponentType.ACTION)
"""组件类型。"""
def __post_init__(self):
super().__post_init__()
if self.activation_keywords is None:
self.activation_keywords = []
if self.action_parameters is None:
self.action_parameters = {}
if self.action_require is None:
self.action_require = []
if self.associated_types is None:
self.associated_types = []
self.component_type = ComponentType.ACTION
def __post_init__(self) -> None:
"""归一化动作快照中的集合字段。"""
self.action_parameters = dict(self.action_parameters or {})
self.action_require = list(self.action_require or [])
self.associated_types = list(self.associated_types or [])
self.activation_keywords = list(self.activation_keywords or [])
@dataclass
@dataclass(slots=True)
class CommandInfo(ComponentInfo):
"""命令组件信息"""
"""命令处理链使用的命令信息快照。"""
command_pattern: str = "" # 命令匹配模式(正则表达式)
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.COMMAND
component_type: ComponentType = field(init=False, default=ComponentType.COMMAND)
"""组件类型。"""
@dataclass
@dataclass(slots=True)
class ToolInfo(ComponentInfo):
"""工具组件信息"""
"""工具执行链使用的工具信息快照。"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
default_factory=list
) # 工具参数定义
tool_description: str = "" # 工具描述
parameters_schema: Dict[str, Any] | None = None
"""对象级工具参数 Schema。"""
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.TOOL
component_type: ComponentType = field(init=False, default=ComponentType.TOOL)
"""组件类型。"""
def get_llm_definition(self) -> dict:
"""生成 LLM function-calling 所需的工具定义"""
return {
def get_llm_definition(self) -> Dict[str, Any]:
"""生成 LLM 使用的规范化工具定义
Returns:
Dict[str, Any]: 统一工具定义字典。
"""
definition: Dict[str, Any] = {
"name": self.name,
"description": self.tool_description,
"parameters": self.tool_parameters,
"description": self.description,
}
if self.parameters_schema is not None:
definition["parameters_schema"] = copy.deepcopy(self.parameters_schema)
return definition
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
event_type: EventType | str = EventType.ON_MESSAGE # 监听事件类型
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
weight: int = 0 # 事件处理器权重,决定执行顺序
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.EVENT_HANDLER
@dataclass
class PluginInfo:
"""插件信息"""
display_name: str # 插件显示名称
name: str # 插件名称
description: str # 插件描述
version: str = "1.0.0" # 插件版本
author: str = "" # 插件作者
enabled: bool = True # 是否启用
is_built_in: bool = False # 是否为内置插件
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
config_file: str = "" # 配置文件路径
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
# 新增manifest相关信息
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
license: str = "" # 插件许可证
homepage_url: str = "" # 插件主页
repository_url: str = "" # 插件仓库地址
keywords: List[str] = field(default_factory=list) # 插件关键词
categories: List[str] = field(default_factory=list) # 插件分类
min_host_version: str = "" # 最低主机版本要求
max_host_version: str = "" # 最高主机版本要求
def __post_init__(self):
if self.components is None:
self.components = []
if self.dependencies is None:
self.dependencies = []
if self.python_dependencies is None:
self.python_dependencies = []
if self.metadata is None:
self.metadata = {}
if self.manifest_data is None:
self.manifest_data = {}
if self.keywords is None:
self.keywords = []
if self.categories is None:
self.categories = []
def get_missing_packages(self) -> List[PythonDependency]:
"""检查缺失的Python包"""
missing = []
for dep in self.python_dependencies:
try:
__import__(dep.package_name)
except ImportError:
if not dep.optional:
missing.append(dep)
return missing
def get_pip_requirements(self) -> List[str]:
"""获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
@dataclass(slots=True)
class ModifyFlag:
"""消息修改标记集合。"""
modify_message_segments: bool = False
modify_plain_text: bool = False
modify_llm_prompt: bool = False
@@ -258,9 +163,9 @@ class ModifyFlag:
modify_llm_response_reasoning: bool = False
@dataclass
@dataclass(slots=True)
class MaiMessages:
"""MaiM插件消息"""
"""核心事件系统使用的统一消息模型。"""
message_segments: List[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
@@ -306,11 +211,17 @@ class MaiMessages:
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
def __post_init__(self):
def __post_init__(self) -> None:
"""归一化消息段列表。"""
if self.message_segments is None:
self.message_segments = []
def deepcopy(self):
def deepcopy(self) -> "MaiMessages":
"""深拷贝当前消息对象。
Returns:
MaiMessages: 深拷贝后的消息对象。
"""
return copy.deepcopy(self)
def to_transport_dict(self) -> Dict[str, Any]:
@@ -347,6 +258,14 @@ class MaiMessages:
@staticmethod
def _serialize_transport_value(value: Any) -> Any:
"""递归序列化字段值为可传输结构。
Args:
value: 任意字段值。
Returns:
Any: 可用于 IPC 传输的纯 Python 值。
"""
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, Enum):
@@ -367,13 +286,22 @@ class MaiMessages:
@staticmethod
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
"""反序列化特定字段的传输值。
Args:
field_name: 字段名称。
value: 传输层返回的字段值。
Returns:
Any: 反序列化后的字段值。
"""
if field_name == "message_segments" and isinstance(value, list):
deserialized_segments: List[Seg] = []
for segment in value:
if isinstance(segment, Seg):
deserialized_segments.append(segment)
elif isinstance(segment, dict) and "type" in segment:
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data", "")))
return deserialized_segments
if field_name == "llm_response_tool_call" and isinstance(value, list):
@@ -393,15 +321,15 @@ class MaiMessages:
return value
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
"""
修改消息段列表
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False) -> None:
"""修改消息段列表。
Warning:
在生成了plain_text的情况下调用此方法可能会导致plain_text内容与消息段不一致
在生成了 ``plain_text`` 的情况下调用此方法,可能会导致文本与消息段不一致
Args:
new_segments (List[Seg]): 新的消息段列表
new_segments: 新的消息段列表
suppress_warning: 是否抑制潜在不一致警告。
"""
if self.plain_text and not suppress_warning:
warnings.warn(
@@ -412,15 +340,15 @@ class MaiMessages:
self.message_segments = new_segments
self._modify_flags.modify_message_segments = True
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
"""
修改LLM提示词
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False) -> None:
"""修改 LLM 提示词。
Warning:
在没有生成llm_prompt的情况下调用此方法可能会导致修改无效
在没有生成 ``llm_prompt`` 的情况下调用此方法,可能会导致修改无效
Args:
new_prompt (str): 新的提示词内容
new_prompt: 新的提示词内容
suppress_warning: 是否抑制潜在无效修改警告。
"""
if self.llm_prompt is None and not suppress_warning:
warnings.warn(
@@ -431,15 +359,15 @@ class MaiMessages:
self.llm_prompt = new_prompt
self._modify_flags.modify_llm_prompt = True
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
"""
修改生成的plain_text内容
def modify_plain_text(self, new_text: str, suppress_warning: bool = False) -> None:
"""修改生成的纯文本内容。
Warning:
在未生成plain_text的情况下调用此方法可能会导致plain_text为空或者修改无效
在未生成 ``plain_text`` 的情况下调用此方法,可能会导致修改无效
Args:
new_text (str): 新的纯文本内容
new_text: 新的纯文本内容
suppress_warning: 是否抑制潜在无效修改警告。
"""
if not self.plain_text and not suppress_warning:
warnings.warn(
@@ -450,15 +378,15 @@ class MaiMessages:
self.plain_text = new_text
self._modify_flags.modify_plain_text = True
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
"""
修改生成的llm_response_content内容
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False) -> None:
"""修改生成的 LLM 响应正文。
Warning:
在未生成llm_response_content的情况下调用此方法可能会导致llm_response_content为空或者修改无效
在未生成 ``llm_response_content`` 的情况下调用此方法,可能会导致修改无效
Args:
new_content (str): 新的LLM响应内容
new_content: 新的 LLM 响应内容
suppress_warning: 是否抑制潜在无效修改警告。
"""
if not self.llm_response_content and not suppress_warning:
warnings.warn(
@@ -469,15 +397,15 @@ class MaiMessages:
self.llm_response_content = new_content
self._modify_flags.modify_llm_response_content = True
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
"""
修改生成的llm_response_reasoning内容
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False) -> None:
"""修改生成的 LLM 推理内容。
Warning:
在未生成llm_response_reasoning的情况下调用此方法可能会导致llm_response_reasoning为空或者修改无效
在未生成 ``llm_response_reasoning`` 的情况下调用此方法,可能会导致修改无效
Args:
new_reasoning (str): 新的LLM响应推理内容
new_reasoning: 新的 LLM 推理内容
suppress_warning: 是否抑制潜在无效修改警告。
"""
if not self.llm_response_reasoning and not suppress_warning:
warnings.warn(
@@ -487,10 +415,3 @@ class MaiMessages:
)
self.llm_response_reasoning = new_reasoning
self._modify_flags.modify_llm_response_reasoning = True
@dataclass
class CustomEventHandlerResult:
message: str = ""
timestamp: float = 0.0
extra_info: Optional[Dict] = None

View File

@@ -20,8 +20,8 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.manager.async_task_manager import AsyncTask
logger = get_logger("expression_auto_check_task")
@@ -76,7 +76,7 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
return prompt
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check")
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
@@ -94,10 +94,11 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
prompt=prompt, temperature=0.6, max_tokens=1024
generation_result = await judge_llm.generate_response(
prompt=prompt,
options=LLMGenerationOptions(temperature=0.6, max_tokens=1024),
)
response = generation_result.response
logger.debug(f"LLM响应: {response}")
# 解析JSON响应

View File

@@ -7,8 +7,9 @@ import difflib
import json
import re
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
@@ -26,10 +27,11 @@ if TYPE_CHECKING:
logger = get_logger("expressor")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner")
summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary")
check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check")
express_learn_model = LLMServiceClient(
task_name="utils", request_type="expression.learner"
)
summary_model = LLMServiceClient(task_name="tool_use", request_type="expression.summary")
check_model = LLMServiceClient(task_name="tool_use", request_type="expression.check")
class ExpressionLearner:
@@ -74,7 +76,10 @@ class ExpressionLearner:
# 调用 LLM 学习表达方式
try:
response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3)
generation_result = await express_learn_model.generate_response(
prompt, options=LLMGenerationOptions(temperature=0.3)
)
response = generation_result.response
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错:{e}")
return
@@ -413,7 +418,10 @@ class ExpressionLearner:
"只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary_result = await summary_model.generate_response(
prompt, options=LLMGenerationOptions(temperature=0.2)
)
summary = summary_result.response
if summary := summary.strip():
return summary
except Exception as e:

View File

@@ -4,10 +4,11 @@ import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
@@ -17,8 +18,8 @@ logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
self.llm_model = LLMServiceClient(
task_name="tool_use", request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
@@ -383,8 +384,8 @@ class ExpressionSelector:
prompt = await prompt_manager.render_prompt(prompt_template)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
generation_result = await self.llm_model.generate_response(prompt=prompt)
content = generation_result.response
# print(prompt)
# print(content)

View File

@@ -1,19 +1,40 @@
from json_repair import repair_json
from typing import Tuple, Optional, List
from typing import Any, List, Optional, Tuple
import json
import re
from src.config.config import model_config
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
logger = get_logger("expression_utils")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check")
def _normalize_repair_json_result(repaired_result: Any) -> str:
"""将 repair_json 的返回值规范化为 JSON 字符串。
Args:
repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。
Returns:
str: 可供 `json.loads` 继续解析的 JSON 字符串。
Raises:
TypeError: 当返回值无法规范化为字符串时抛出。
"""
if isinstance(repaired_result, str):
return repaired_result
if isinstance(repaired_result, tuple) and repaired_result:
first_item = repaired_result[0]
if isinstance(first_item, str):
return first_item
return json.dumps(first_item, ensure_ascii=False)
raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}")
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
@@ -51,7 +72,11 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
logger.info(f"正在评估表达方式: situation={situation}, style={style}")
response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024)
generation_result = await judge_llm.generate_response(
prompt=prompt,
options=LLMGenerationOptions(temperature=0.6, max_tokens=1024),
)
response = generation_result.response
logger.debug(f"评估结果: {response}")
@@ -59,7 +84,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
evaluation = json.loads(response)
except json.JSONDecodeError:
try:
response_repaired = repair_json(response)
response_repaired = _normalize_repair_json_result(repair_json(response))
evaluation = json.loads(response_repaired)
except Exception as e:
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
@@ -74,7 +99,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
return False, f"评估结果格式错误: {e}", str(e)
def fix_chinese_quotes_in_json(text):
def fix_chinese_quotes_in_json(text: str) -> str:
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
@@ -201,12 +226,12 @@ def is_single_char_jargon(content: str) -> bool:
)
def _try_parse(text):
def _try_parse(text: str) -> Any:
try:
return json.loads(text)
except Exception:
try:
repaired = repair_json(text)
repaired = _normalize_repair_json_result(repair_json(text))
return json.loads(repaired)
except Exception:
return None

View File

@@ -4,8 +4,9 @@ from typing import List, Dict, Optional, Any
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.learners.jargon_miner_old import search_jargon
from src.learners.learner_utils_old import (
@@ -23,8 +24,8 @@ class JargonExplainer:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
self.llm = LLMServiceClient(
task_name="tool_use",
request_type="jargon.explain",
)
@@ -206,7 +207,10 @@ class JargonExplainer:
prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text)
summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize)
summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3)
summary_result = await self.llm.generate_response(
summarize_prompt, options=LLMGenerationOptions(temperature=0.3)
)
summary = summary_result.response
if not summary:
# 如果LLM概括失败直接返回原始解释
return f"上下文中的黑话解释:\n{explanations_text}"

View File

@@ -12,17 +12,17 @@ from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.prompt.prompt_manager import prompt_manager
from .expression_utils import is_single_char_jargon
logger = get_logger("jargon")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
llm_extract = LLMServiceClient(task_name="utils", request_type="jargon.extract")
llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference")
class JargonEntry(TypedDict):
@@ -100,7 +100,10 @@ class JargonMiner:
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
prompt1 = await prompt_manager.render_prompt(prompt1_template)
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
generation_result_1 = await llm_inference.generate_response(
prompt1, options=LLMGenerationOptions(temperature=0.3)
)
llm_response_1 = generation_result_1.response
if not llm_response_1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
@@ -129,7 +132,10 @@ class JargonMiner:
prompt2_template.add_context("content", content)
prompt2 = await prompt_manager.render_prompt(prompt2_template)
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
generation_result_2 = await llm_inference.generate_response(
prompt2, options=LLMGenerationOptions(temperature=0.3)
)
llm_response_2 = generation_result_2.response
if not llm_response_2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
@@ -153,7 +159,10 @@ class JargonMiner:
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
generation_result_3 = await llm_inference.generate_response(
prompt3, options=LLMGenerationOptions(temperature=0.3)
)
llm_response_3 = generation_result_3.response
if not llm_response_3:
logger.warning(f"jargon {content} 比较失败:无响应")
return

View File

@@ -0,0 +1,259 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast
import asyncio
from src.config.model_configs import ModelInfo
from .base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
UsageTuple,
)
RawStreamT = TypeVar("RawStreamT")
"""流式原始响应类型变量。"""
RawResponseT = TypeVar("RawResponseT")
"""非流式原始响应类型变量。"""
TaskResultT = TypeVar("TaskResultT")
"""异步任务返回值类型变量。"""
ProviderStreamResponseHandler = Callable[
[RawStreamT, asyncio.Event | None],
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
]
"""Provider 专用流式响应处理函数类型。"""
ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]]
"""Provider 专用非流式响应解析函数类型。"""
async def await_task_with_interrupt(
task: asyncio.Task[TaskResultT],
interrupt_flag: asyncio.Event | None,
*,
interval_seconds: float = 0.1,
) -> TaskResultT:
"""在支持外部中断的前提下等待异步任务完成。
Args:
task: 待等待的异步任务。
interrupt_flag: 外部中断标记。
interval_seconds: 轮询检查间隔,单位秒。
Returns:
TaskResultT: 任务执行结果。
Raises:
ReqAbortException: 等待期间收到外部中断信号时抛出。
"""
from src.llm_models.exceptions import ReqAbortException
while not task.done():
if interrupt_flag and interrupt_flag.is_set():
task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(interval_seconds)
return await task
class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
"""提供统一请求执行骨架的 Provider 适配基类。"""
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 解析完成的统一响应对象。
"""
stream_response_handler = self._resolve_stream_response_handler(request)
response_parser = self._resolve_response_parser(request)
response, usage_record = await self._execute_response_request(
request,
stream_response_handler,
response_parser,
)
return self._attach_usage_record(response, request.model_info, usage_record)
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 解析完成的统一嵌入响应。
"""
response, usage_record = await self._execute_embedding_request(request)
return self._attach_usage_record(response, request.model_info, usage_record)
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 解析完成的统一音频转录响应。
"""
response, usage_record = await self._execute_audio_transcription_request(request)
return self._attach_usage_record(response, request.model_info, usage_record)
def _resolve_stream_response_handler(
self,
request: ResponseRequest,
) -> ProviderStreamResponseHandler[RawStreamT]:
"""解析实际使用的流式响应处理器。
Args:
request: 统一响应请求对象。
Returns:
ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。
"""
if request.stream_response_handler is not None:
return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler)
return self._build_default_stream_response_handler(request)
def _resolve_response_parser(
self,
request: ResponseRequest,
) -> ProviderResponseParser[RawResponseT]:
"""解析实际使用的非流式响应解析器。
Args:
request: 统一响应请求对象。
Returns:
ProviderResponseParser[RawResponseT]: 非流式响应解析器。
"""
if request.async_response_parser is not None:
return cast(ProviderResponseParser[RawResponseT], request.async_response_parser)
return self._build_default_response_parser(request)
@staticmethod
def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord:
"""根据统一使用量三元组构建 `UsageRecord`。
Args:
model_info: 模型信息。
usage_record: 使用量三元组。
Returns:
UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。
"""
return UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
def _attach_usage_record(
self,
response: APIResponse,
model_info: ModelInfo,
usage_record: UsageTuple | None,
) -> APIResponse:
"""在响应对象上附加统一使用量信息。
Args:
response: 已解析的统一响应对象。
model_info: 模型信息。
usage_record: 可选的使用量三元组。
Returns:
APIResponse: 附加使用量后的响应对象。
"""
if usage_record is not None:
response.usage = self._build_usage_record(model_info, usage_record)
return response
@abstractmethod
def _build_default_stream_response_handler(
self,
request: ResponseRequest,
) -> ProviderStreamResponseHandler[RawStreamT]:
"""构建默认流式响应处理器。
Args:
request: 统一响应请求对象。
Returns:
ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。
"""
raise NotImplementedError
@abstractmethod
def _build_default_response_parser(
self,
request: ResponseRequest,
) -> ProviderResponseParser[RawResponseT]:
"""构建默认非流式响应解析器。
Args:
request: 统一响应请求对象。
Returns:
ProviderResponseParser[RawResponseT]: 默认非流式解析器。
"""
raise NotImplementedError
@abstractmethod
async def _execute_response_request(
self,
request: ResponseRequest,
stream_response_handler: ProviderStreamResponseHandler[RawStreamT],
response_parser: ProviderResponseParser[RawResponseT],
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的文本/多模态响应请求。
Args:
request: 统一响应请求对象。
stream_response_handler: 流式响应处理器。
response_parser: 非流式响应解析器。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError
@abstractmethod
async def _execute_embedding_request(
self,
request: EmbeddingRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的嵌入请求。
Args:
request: 统一嵌入请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError
@abstractmethod
async def _execute_audio_transcription_request(
self,
request: AudioTranscriptionRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的音频转录请求。
Args:
request: 统一音频转录请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError

View File

@@ -1,14 +1,15 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
import asyncio
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
logger = get_logger("model_client_registry")
@@ -47,10 +48,10 @@ class APIResponse:
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
tool_calls: List[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
embedding: List[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
@@ -60,6 +61,82 @@ class APIResponse:
"""响应原始数据"""
UsageTuple = Tuple[int, int, int]
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
StreamResponseHandler = Callable[
[Any, asyncio.Event | None],
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
]
"""统一的流式响应处理函数类型。"""
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
"""统一的非流式响应解析函数类型。"""
@dataclass(slots=True)
class ResponseRequest:
"""统一的文本/多模态响应请求。"""
model_info: ModelInfo
message_list: List[Message]
tool_options: List[ToolOption] | None = None
max_tokens: int | None = None
temperature: float | None = None
response_format: RespFormat | None = None
stream_response_handler: StreamResponseHandler | None = None
async_response_parser: ResponseParser | None = None
interrupt_flag: asyncio.Event | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
def copy_with(self, **changes: Any) -> "ResponseRequest":
"""基于当前请求创建一个带局部变更的新请求。
Args:
**changes: 需要覆盖的字段值。
Returns:
ResponseRequest: 复制后的请求对象。
"""
payload = {
"model_info": self.model_info,
"message_list": list(self.message_list),
"tool_options": None if self.tool_options is None else list(self.tool_options),
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"response_format": self.response_format,
"stream_response_handler": self.stream_response_handler,
"async_response_parser": self.async_response_parser,
"interrupt_flag": self.interrupt_flag,
"extra_params": dict(self.extra_params),
}
payload.update(changes)
return ResponseRequest(**payload)
@dataclass(slots=True)
class EmbeddingRequest:
"""统一的嵌入请求。"""
model_info: ModelInfo
embedding_input: str
extra_params: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class AudioTranscriptionRequest:
"""统一的音频转录请求。"""
model_info: ModelInfo
audio_base64: str
max_tokens: int | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
"""统一客户端请求类型。"""
class BaseClient(ABC):
"""
基础客户端
@@ -67,97 +144,82 @@ class BaseClient(ABC):
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
def __init__(self, api_provider: APIProvider) -> None:
"""初始化基础客户端。
Args:
api_provider: API 提供商配置。
"""
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
"""
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 支持的图片格式列表。
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str):
"""
注册API客户端类
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
Args:
client_class: API客户端类
client_type: 客户端类型标识。
Returns:
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
@@ -165,14 +227,15 @@ class ClientRegistry:
return decorator
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题)
api_provider: APIProvider 实例
force_new: 是否强制创建新实例
Returns:
BaseClient: 注册的API客户端实例
BaseClient: 注册的 API 客户端实例
"""
from . import ensure_client_type_loaded
@@ -194,6 +257,7 @@ class ClientRegistry:
return self.client_instance_cache[api_provider.name]
def clear_client_instance_cache(self) -> None:
"""清空客户端实例缓存。"""
self.client_instance_cache.clear()
logger.info("检测到配置重载已清空LLM客户端实例缓存")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
from dataclasses import dataclass, field
from typing import Any, Mapping
from src.config.model_configs import APIProvider, OpenAICompatibleAuthType
@dataclass(slots=True)
class OpenAICompatibleClientConfig:
"""OpenAI 兼容客户端的基础配置。"""
api_key: str
base_url: str
default_headers: dict[str, str] = field(default_factory=dict)
default_query: dict[str, object] = field(default_factory=dict)
@dataclass(slots=True)
class OpenAICompatibleRequestOverrides:
"""单次请求级别的附加配置。"""
extra_headers: dict[str, str] = field(default_factory=dict)
extra_query: dict[str, object] = field(default_factory=dict)
extra_body: dict[str, Any] = field(default_factory=dict)
def normalize_openai_base_url(base_url: str) -> str:
"""规范化 OpenAI 兼容接口的基础地址。
Args:
base_url: 原始基础地址。
Returns:
str: 去掉尾部斜杠后的地址。
"""
return base_url.rstrip("/")
def _build_auth_header_value(prefix: str, api_key: str) -> str:
"""构造鉴权请求头的值。
Args:
prefix: 请求头前缀。
api_key: 实际密钥。
Returns:
str: 拼接完成的请求头值。
"""
normalized_prefix = prefix.strip()
if not normalized_prefix:
return api_key
return f"{normalized_prefix} {api_key}"
def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig:
"""构建 OpenAI 兼容客户端配置。
Args:
api_provider: API 提供商配置。
Returns:
OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。
"""
default_headers = dict(api_provider.default_headers)
default_query: dict[str, object] = dict(api_provider.default_query)
client_api_key = api_provider.api_key
if api_provider.auth_type == OpenAICompatibleAuthType.BEARER:
if (
api_provider.auth_header_name != "Authorization"
or api_provider.auth_header_prefix.strip() != "Bearer"
):
client_api_key = ""
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
prefix=api_provider.auth_header_prefix,
api_key=api_provider.api_key,
)
elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER:
client_api_key = ""
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
prefix=api_provider.auth_header_prefix,
api_key=api_provider.api_key,
)
elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY:
client_api_key = ""
default_query[api_provider.auth_query_name] = api_provider.api_key
elif api_provider.auth_type == OpenAICompatibleAuthType.NONE:
client_api_key = ""
return OpenAICompatibleClientConfig(
api_key=client_api_key,
base_url=normalize_openai_base_url(api_provider.base_url),
default_headers=default_headers,
default_query=default_query,
)
def _extract_mapping(value: Any) -> dict[str, Any]:
"""将任意映射值规范化为普通字典。
Args:
value: 原始输入值。
Returns:
dict[str, Any]: 规范化后的字典。非映射值时返回空字典。
"""
if isinstance(value, Mapping):
return {str(key): item for key, item in value.items()}
return {}
def split_openai_request_overrides(
extra_params: Mapping[str, Any] | None,
*,
reserved_body_keys: set[str] | None = None,
) -> OpenAICompatibleRequestOverrides:
"""拆分单次请求中的头、查询参数和请求体扩展字段。
Args:
extra_params: 模型级别或请求级别的附加参数。
reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。
Returns:
OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。
"""
raw_params = dict(extra_params or {})
extra_headers = _extract_mapping(raw_params.pop("headers", None))
extra_query = _extract_mapping(raw_params.pop("query", None))
extra_body = _extract_mapping(raw_params.pop("body", None))
blocked_body_keys = reserved_body_keys or set()
for key, value in raw_params.items():
if key in blocked_body_keys:
continue
extra_body[key] = value
return OpenAICompatibleRequestOverrides(
extra_headers={key: str(value) for key, value in extra_headers.items()},
extra_query=extra_query,
extra_body=extra_body,
)

View File

@@ -1,133 +1,280 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from typing import List, Tuple
from .tool_option import ToolCall
# 设计这系列类的目的是为未来可能的扩展做准备
class RoleType(str, Enum):
"""消息角色类型。"""
class RoleType(Enum):
System = "system"
User = "user"
Assistant = "assistant"
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
"""默认支持的图片格式列表。"""
@dataclass(slots=True)
class TextMessagePart:
"""文本消息片段。"""
text: str
def __post_init__(self) -> None:
"""执行文本片段的基础校验。
Raises:
ValueError: 当文本为空时抛出。
"""
if self.text == "":
raise ValueError("文本消息片段不能为空字符串")
@dataclass(slots=True)
class ImageMessagePart:
"""Base64 图片消息片段。"""
image_format: str
image_base64: str
def __post_init__(self) -> None:
"""执行图片片段的基础校验。
Raises:
ValueError: 当图片格式或 Base64 数据无效时抛出。
"""
if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
raise ValueError("不受支持的图片格式")
if not self.image_base64:
raise ValueError("图片的 base64 编码不能为空")
@property
def normalized_image_format(self) -> str:
"""获取规范化后的图片格式。
Returns:
str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。
"""
image_format = self.image_format.lower()
if image_format in {"jpg", "jpeg"}:
return "jpeg"
return image_format
MessagePart = TextMessagePart | ImageMessagePart
@dataclass(slots=True)
class Message:
def __init__(
self,
role: RoleType,
content: str | list[tuple[str, str] | str],
tool_call_id: str | None = None,
tool_calls: Optional[List[ToolCall]] = None,
):
"""统一消息模型。"""
role: RoleType
parts: List[MessagePart] = field(default_factory=list)
tool_call_id: str | None = None
tool_calls: List[ToolCall] | None = None
def __post_init__(self) -> None:
"""执行消息对象的基础校验。
Raises:
ValueError: 当消息内容或工具调用信息不完整时抛出。
"""
初始化消息对象
不应直接修改Message类而应使用MessageBuilder类来构建对象
if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls):
raise ValueError("消息内容不能为空")
if self.role == RoleType.Tool and not self.tool_call_id:
raise ValueError("Tool 角色的工具调用 ID 不能为空")
@property
def content(self) -> str | List[Tuple[str, str] | str]:
"""获取兼容旧逻辑的内容视图。
Returns:
str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串,
否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。
"""
self.role: RoleType = role
self.content: str | list[tuple[str, str] | str] = content
self.tool_call_id: str | None = tool_call_id
self.tool_calls: Optional[List[ToolCall]] = tool_calls
if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart):
return self.parts[0].text
content_items: List[Tuple[str, str] | str] = []
for part in self.parts:
if isinstance(part, TextMessagePart):
content_items.append(part.text)
else:
content_items.append((part.image_format, part.image_base64))
return content_items
def get_text_content(self) -> str:
"""提取消息中的所有文本片段。
Returns:
str: 以原始顺序拼接后的文本内容。
"""
return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart))
def __str__(self) -> str:
"""生成便于调试的字符串表示。
Returns:
str: 当前消息对象的可读摘要。
"""
return (
f"Role: {self.role}, Content: {self.content}, "
f"Role: {self.role}, Parts: {self.parts}, "
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
)
class MessageBuilder:
def __init__(self):
"""消息构建器。"""
def __init__(self) -> None:
"""初始化构建器。"""
self.__role: RoleType = RoleType.User
self.__content: list[tuple[str, str] | str] = []
self.__parts: List[MessagePart] = []
self.__tool_call_id: str | None = None
self.__tool_calls: Optional[List[ToolCall]] = None
self.__tool_calls: List[ToolCall] | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
"""
设置角色默认为User
:param role: 角色
:return: MessageBuilder对象
"""设置消息角色。
Args:
role: 目标角色,默认为 `user`。
Returns:
MessageBuilder: 当前构建器实例。
"""
self.__role = role
return self
def add_text_part(self, text: str) -> "MessageBuilder":
"""追加文本片段。
Args:
text: 文本内容。
Returns:
MessageBuilder: 当前构建器实例。
"""
self.__parts.append(TextMessagePart(text=text))
return self
def add_text_content(self, text: str) -> "MessageBuilder":
"""追加文本片段。
Args:
text: 文本内容。
Returns:
MessageBuilder: 当前构建器实例。
"""
添加文本内容
:param text: 文本内容
:return: MessageBuilder对象
return self.add_text_part(text)
def add_image_base64_part(
self,
image_format: str,
image_base64: str,
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
) -> "MessageBuilder":
"""追加 Base64 图片片段。
Args:
image_format: 图片格式。
image_base64: 图片的 Base64 编码。
support_formats: 允许的图片格式列表。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当图片格式不被支持时抛出。
"""
self.__content.append(text)
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64))
return self
def add_image_content(
self,
image_format: str,
image_base64: str,
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
) -> "MessageBuilder":
"""
添加图片内容
:param image_format: 图片格式
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")
self.__content.append((image_format, image_base64))
return self
"""追加 Base64 图片片段。
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
Args:
image_format: 图片格式。
image_base64: 图片的 Base64 编码。
support_formats: 允许的图片格式列表。
Returns:
MessageBuilder: 当前构建器实例。
"""
添加工具调用指令调用时请确保已设置为Tool角色
:param tool_call_id: 工具调用指令的id
:return: MessageBuilder对象
return self.add_image_base64_part(
image_format=image_format,
image_base64=image_base64,
support_formats=support_formats,
)
def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder":
"""设置工具结果消息引用的工具调用 ID。
Args:
tool_call_id: 工具调用 ID。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。
"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID")
if not tool_call_id:
raise ValueError("工具调用ID不能为空")
raise ValueError("工具调用 ID 不能为空")
self.__tool_call_id = tool_call_id
return self
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
"""设置工具结果消息引用的工具调用 ID。
Args:
tool_call_id: 工具调用 ID。
Returns:
MessageBuilder: 当前构建器实例。
"""
设置助手消息的工具调用列表
:param tool_calls: 工具调用列表
:return: MessageBuilder对象
return self.set_tool_call_id(tool_call_id)
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
"""设置助手消息中的工具调用列表。
Args:
tool_calls: 工具调用列表。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。
"""
if self.__role != RoleType.Assistant:
raise ValueError("仅当角色为Assistant时才能设置工具调用列表")
raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表")
if not tool_calls:
raise ValueError("工具调用列表不能为空")
self.__tool_calls = tool_calls
self.__tool_calls = list(tool_calls)
return self
def build(self) -> Message:
"""
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")
"""构建消息对象。
Returns:
Message: 构建完成的消息对象。
"""
return Message(
role=self.__role,
content=(
self.__content[0]
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
else self.__content
),
parts=list(self.__parts),
tool_call_id=self.__tool_call_id,
tool_calls=self.__tool_calls,
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
)

View File

@@ -1,51 +1,40 @@
from copy import deepcopy
from enum import Enum
from typing import Optional, Any
from typing import Any, Dict, List, Mapping, Optional, Type, cast
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
from typing_extensions import Required, TypedDict
class RespFormatType(Enum):
TEXT = "text" # 文本
JSON_OBJ = "json_object" # JSON
JSON_SCHEMA = "json_schema" # JSON Schema
"""响应格式类型。"""
TEXT = "text"
JSON_OBJ = "json_object"
JSON_SCHEMA = "json_schema"
class JsonSchema(TypedDict, total=False):
"""内部使用的 JSON Schema 包装结构。"""
name: Required[str]
"""
The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""
description: Optional[str]
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: dict[str, object]
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
schema: Dict[str, Any]
strict: Optional[bool]
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
learn more, read the
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
"""
def _json_schema_type_check(instance) -> str | None:
def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None:
"""检查 JSON Schema 包装结构是否合法。
Args:
instance: 待检查的 JSON Schema 包装字典。
Returns:
str | None: 不合法时返回错误信息,合法时返回 `None`。
"""
if "name" not in instance:
return "schema必须包含'name'字段"
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
if not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str) or instance["description"].strip() == ""
@@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None:
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
return "schema必须包含'schema'字段"
elif not isinstance(instance["schema"], dict):
if not isinstance(instance["schema"], dict):
return "schema的'schema'字段必须是字典详见https://json-schema.org/"
if "strict" in instance and not isinstance(instance["strict"], bool):
return "schema的'strict'字段只能填入布尔值"
return None
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
"""
递归移除JSON Schema中的title字段
def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
"""递归移除 JSON Schema 中的 `title` 字段。
Args:
schema: 待处理的 Schema 树。
Returns:
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
for index, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
schema[index] = _remove_title(item)
return schema
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
"""
链接JSON Schema中的definitions字段
def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
"""展开 Schema 中的本地 `$defs`/`$ref` 引用。
Args:
schema: 待处理的根 Schema。
Returns:
Dict[str, Any]: 展开后的 Schema。
"""
def link_definitions_recursive(
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
) -> dict[str, Any]:
"""
递归链接JSON Schema中的definitions字段
:param path: 当前路径
:param sub_schema: 子Schema
:param defs: Schema定义集
:return:
path: str,
sub_schema: Dict[str, Any] | List[Any],
definitions: Dict[str, Any],
) -> Dict[str, Any] | List[Any]:
"""递归展开局部定义。
Args:
path: 当前递归路径。
sub_schema: 当前子 Schema。
definitions: 已收集的定义字典。
Returns:
Dict[str, Any] | List[Any]: 展开后的子 Schema。
"""
if isinstance(sub_schema, list):
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
# 如果当前Schema有$def字段则将其添加到defs中
key_prefix = f"{path}/$defs/"
for key, value in sub_schema["$defs"].items():
def_key = key_prefix + key
if def_key not in defs:
defs[def_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
# 如果当前Schema有$ref字段则将其替换为defs中的定义
def_key = sub_schema["$ref"]
if def_key in defs:
sub_schema = defs[def_key]
else:
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
# 遍历键值对
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
for index, item in enumerate(sub_schema):
if isinstance(item, (dict, list)):
sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions)
return sub_schema
if "$defs" in sub_schema:
key_prefix = f"{path}/$defs/"
defs_payload = cast(Dict[str, Any], sub_schema["$defs"])
for key, value in defs_payload.items():
definition_key = key_prefix + key
if definition_key not in definitions:
definitions[definition_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
definition_key = cast(str, sub_schema["$ref"])
if definition_key in definitions:
return definitions[definition_key]
raise ValueError(f"Schema中引用的定义'{definition_key}'不存在")
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions)
return sub_schema
return link_definitions_recursive("#", schema, {})
return cast(Dict[str, Any], link_definitions_recursive("#", schema, {}))
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归移除JSON Schema中的$defs字段
def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
"""递归移除 JSON Schema 中的 `$defs` 字段。
Args:
schema: 待处理的 Schema 树。
Returns:
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
for index, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
schema[index] = _remove_defs(item)
return schema
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_defs(value)
return schema
class RespFormat:
"""
响应格式
"""
"""统一响应格式定义。"""
@staticmethod
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema:
"""从 Pydantic 模型生成内部 JSON Schema 包装结构。
Args:
schema_model: Pydantic 模型类。
Returns:
JsonSchema: 内部统一 JSON Schema 包装结构。
"""
schema_tree = deepcopy(schema_model.model_json_schema())
json_schema: JsonSchema = {
"name": schema_model.__name__,
"schema": cast(
Dict[str, Any],
_remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))),
),
"strict": False,
}
if schema.__doc__:
json_schema["description"] = schema.__doc__
if schema_model.__doc__:
json_schema["description"] = schema_model.__doc__
return json_schema
def __init__(
self,
format_type: RespFormatType = RespFormatType.TEXT,
schema: type | JsonSchema | None = None,
):
"""
响应格式
:param format_type: 响应格式类型(默认为文本)
:param schema: 模板类或JsonSchema仅当format_type为JSON Schema时有效
schema: Type[BaseModel] | JsonSchema | None = None,
) -> None:
"""初始化响应格式对象。
Args:
format_type: 响应格式类型。
schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。
"""
self.format_type: RespFormatType = format_type
self.schema_source: Type[BaseModel] | JsonSchema | None = schema
self.schema: JsonSchema | None = None
if format_type == RespFormatType.JSON_SCHEMA:
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
if format_type != RespFormatType.JSON_SCHEMA:
return
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
self.schema = cast(JsonSchema, deepcopy(schema))
return
if isinstance(schema, type) and issubclass(schema, BaseModel):
try:
self.schema = self._generate_schema_from_model(schema)
except Exception as exc:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from exc
return
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
self.schema = schema
elif issubclass(schema, BaseModel):
try:
json_schema = self._generate_schema_from_model(schema)
def get_schema_object(self) -> Dict[str, Any] | None:
"""获取内部包装中的对象级 JSON Schema。
self.schema = json_schema
except Exception as e:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from e
else:
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
else:
self.schema = None
def to_dict(self):
Returns:
Dict[str, Any] | None: 对象级 JSON Schema不存在时返回 `None`。
"""
将响应格式转换为字典
:return: 字典
if self.schema is None:
return None
schema_payload = self.schema.get("schema")
if isinstance(schema_payload, dict):
return cast(Dict[str, Any], deepcopy(schema_payload))
return None
def to_dict(self) -> Dict[str, Any]:
"""将响应格式转换为字典。
Returns:
Dict[str, Any]: 序列化后的响应格式字典。
"""
if self.schema:
return {
"format_type": self.format_type.value,
"schema": self.schema,
}
else:
return {
"format_type": self.format_type.value,
}
return {
"format_type": self.format_type.value,
}

View File

@@ -1,83 +1,368 @@
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Tuple, TypeAlias, cast
class ToolParamType(Enum):
class ToolParamType(str, Enum):
"""工具参数类型。"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
FLOAT = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None]
"""旧版工具参数元组格式。"""
def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType:
"""将任意输入值规范化为内部工具参数类型。
Args:
raw_value: 原始参数类型值。
Returns:
ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。
"""
工具调用参数类型
if isinstance(raw_value, ToolParamType):
return raw_value
normalized_value = str(raw_value or "").strip().lower()
if normalized_value in {"integer", "int"}:
return ToolParamType.INTEGER
if normalized_value in {"number", "float"}:
return ToolParamType.NUMBER
if normalized_value in {"boolean", "bool"}:
return ToolParamType.BOOLEAN
if normalized_value == "array":
return ToolParamType.ARRAY
if normalized_value == "object":
return ToolParamType.OBJECT
return ToolParamType.STRING
def _is_object_schema(schema: Dict[str, Any]) -> bool:
"""判断输入字典是否已经是对象级 JSON Schema。
Args:
schema: 待判断的字典。
Returns:
bool: 为对象级 JSON Schema 时返回 `True`。
"""
STRING = "string" # 字符串
INTEGER = "integer" # 整型
FLOAT = "float" # 浮点型
BOOLEAN = "bool" # 布尔型
return schema.get("type") == "object" or "properties" in schema or "required" in schema
def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]:
"""将属性映射转换为对象级 JSON Schema。
Args:
property_map: 仅包含属性定义的映射。
Returns:
Dict[str, Any]: 对象级 JSON Schema。
"""
required_names: List[str] = []
normalized_properties: Dict[str, Any] = {}
for property_name, property_schema in property_map.items():
if not isinstance(property_schema, dict):
continue
property_schema_copy = deepcopy(property_schema)
is_required = bool(property_schema_copy.pop("required", False))
if is_required:
required_names.append(str(property_name))
normalized_properties[str(property_name)] = property_schema_copy
parameters_schema: Dict[str, Any] = {
"type": "object",
"properties": normalized_properties,
}
if required_names:
parameters_schema["required"] = required_names
return parameters_schema
@dataclass(slots=True)
class ToolParam:
"""
工具调用参数
"""
"""工具参数定义。"""
def __init__(
self,
name: str
param_type: ToolParamType
description: str
required: bool
enum_values: List[Any] | None = None
items_schema: Dict[str, Any] | None = None
properties: Dict[str, Dict[str, Any]] | None = None
required_properties: List[str] = field(default_factory=list)
additional_properties: bool | Dict[str, Any] | None = None
default: Any = None
def __post_init__(self) -> None:
"""执行参数定义的基础校验。
Raises:
ValueError: 当参数名称或复杂类型定义不合法时抛出。
"""
if not self.name:
raise ValueError("参数名称不能为空")
if self.param_type == ToolParamType.ARRAY and self.items_schema is None:
raise ValueError("数组参数必须提供 items_schema")
if self.param_type == ToolParamType.OBJECT and self.properties is None:
self.properties = {}
@classmethod
def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam":
"""从旧版五元组参数定义构建工具参数。
Args:
parameter: 旧版参数元组。
Returns:
ToolParam: 规范化后的工具参数对象。
"""
return cls(
name=parameter[0],
param_type=parameter[1],
description=parameter[2],
required=parameter[3],
enum_values=parameter[4],
)
@classmethod
def from_dict(
cls,
name: str,
param_type: ToolParamType,
description: str,
required: bool,
enum_values: list[str] | None = None,
):
parameter_schema: Dict[str, Any],
*,
required: bool = False,
) -> "ToolParam":
"""从属性级 JSON Schema 或结构化参数字典构建工具参数。
Args:
name: 参数名称。
parameter_schema: 参数对应的 Schema 或结构化定义。
required: 参数是否必填。
Returns:
ToolParam: 规范化后的工具参数对象。
"""
初始化工具调用参数
不应直接修改ToolParam类而应使用ToolOptionBuilder类来构建对象
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填
raw_required_properties = parameter_schema.get("required_properties")
if raw_required_properties is None and isinstance(parameter_schema.get("required"), list):
raw_required_properties = parameter_schema.get("required")
return cls(
name=name,
param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")),
description=str(parameter_schema.get("description", "") or ""),
required=required,
enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")),
items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")),
properties=deepcopy(parameter_schema.get("properties")),
required_properties=list(raw_required_properties or []),
additional_properties=deepcopy(
parameter_schema["additional_properties"]
if "additional_properties" in parameter_schema
else parameter_schema.get("additionalProperties")
),
default=deepcopy(parameter_schema.get("default")),
)
def to_json_schema(self) -> Dict[str, Any]:
"""将参数定义转换为 JSON Schema。
Returns:
Dict[str, Any]: 参数对应的 JSON Schema 片段。
"""
self.name: str = name
self.param_type: ToolParamType = param_type
self.description: str = description
self.required: bool = required
self.enum_values: list[str] | None = enum_values
schema: Dict[str, Any] = {
"type": self.param_type.value,
"description": self.description,
}
if self.enum_values:
schema["enum"] = list(self.enum_values)
if self.default is not None:
schema["default"] = deepcopy(self.default)
if self.param_type == ToolParamType.ARRAY and self.items_schema is not None:
schema["items"] = deepcopy(self.items_schema)
if self.param_type == ToolParamType.OBJECT:
schema["properties"] = deepcopy(self.properties or {})
if self.required_properties:
schema["required"] = list(self.required_properties)
if self.additional_properties is not None:
schema["additionalProperties"] = deepcopy(self.additional_properties)
return schema
@dataclass(slots=True)
class ToolOption:
"""
工具调用项
"""
"""工具定义。"""
def __init__(
self,
name: str,
description: str,
params: list[ToolParam] | None = None,
):
name: str
description: str
params: List[ToolParam] | None = None
parameters_schema_override: Dict[str, Any] | None = None
def __post_init__(self) -> None:
"""执行工具定义的基础校验。
Raises:
ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。
"""
初始化工具调用项
不应直接修改ToolOption类而应使用ToolOptionBuilder类来构建对象
:param name: 工具名称
:param description: 工具描述
:param params: 工具参数列表
if not self.name:
raise ValueError("工具名称不能为空")
if not self.description:
raise ValueError("工具描述不能为空")
if self.parameters_schema_override is not None:
schema_type = self.parameters_schema_override.get("type")
if schema_type != "object":
raise ValueError("工具参数 Schema 必须是 object 类型")
@classmethod
def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption":
"""从任意支持的工具定义字典构建内部工具对象。
支持以下输入形状:
- `{"name", "description", "parameters_schema"}`
- `{"name", "description", "parameters"}`
- OpenAI function tool`{"type": "function", "function": {...}}`
- 仅属性映射的对象参数定义:`{"query": {"type": "string"}}`
Args:
definition: 原始工具定义字典。
Returns:
ToolOption: 规范化后的工具定义对象。
Raises:
ValueError: 当工具定义缺少必要字段时抛出。
"""
self.name: str = name
self.description: str = description
self.params: list[ToolParam] | None = params
if definition.get("type") == "function" and isinstance(definition.get("function"), dict):
function_definition = cast(Dict[str, Any], definition["function"])
return cls.from_definition(
{
"name": function_definition.get("name", ""),
"description": function_definition.get("description", ""),
"parameters_schema": function_definition.get("parameters"),
}
)
name = str(definition.get("name", "") or "").strip()
description = str(definition.get("description", "") or "").strip()
if not name:
raise ValueError("工具定义缺少 name")
if not description:
description = f"工具 {name}"
parameters_schema = definition.get("parameters_schema")
if isinstance(parameters_schema, dict):
normalized_schema = deepcopy(parameters_schema)
if not _is_object_schema(normalized_schema):
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
return cls(
name=name,
description=description,
params=None,
parameters_schema_override=normalized_schema,
)
raw_parameters = definition.get("parameters")
if isinstance(raw_parameters, dict):
normalized_schema = deepcopy(raw_parameters)
if not _is_object_schema(normalized_schema):
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
return cls(
name=name,
description=description,
params=None,
parameters_schema_override=normalized_schema,
)
if isinstance(raw_parameters, list):
params: List[ToolParam] = []
for raw_parameter in raw_parameters:
if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5:
params.append(ToolParam.from_legacy_tuple(raw_parameter))
continue
if isinstance(raw_parameter, dict):
parameter_name = str(raw_parameter.get("name", "") or "").strip()
if not parameter_name:
continue
params.append(
ToolParam.from_dict(
parameter_name,
raw_parameter,
required=bool(raw_parameter.get("required", False)),
)
)
return cls(
name=name,
description=description,
params=params or None,
parameters_schema_override=None,
)
return cls(name=name, description=description, params=None, parameters_schema_override=None)
@property
def parameters_schema(self) -> Dict[str, Any] | None:
"""获取工具参数的对象级 JSON Schema。
Returns:
Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。
"""
if self.parameters_schema_override is not None:
return deepcopy(self.parameters_schema_override)
if not self.params:
return None
return {
"type": "object",
"properties": {param.name: param.to_json_schema() for param in self.params},
"required": [param.name for param in self.params if param.required],
}
def to_openai_function_schema(self) -> Dict[str, Any]:
"""转换为 OpenAI function calling 结构。
Returns:
Dict[str, Any]: OpenAI 兼容的工具定义。
"""
function_schema: Dict[str, Any] = {
"name": self.name,
"description": self.description,
}
if self.parameters_schema is not None:
function_schema["parameters"] = self.parameters_schema
return {
"type": "function",
"function": function_schema,
}
class ToolOptionBuilder:
"""
工具调用项构建器
"""
"""工具定义构建器。"""
def __init__(self):
def __init__(self) -> None:
"""初始化构建器。"""
self.__name: str = ""
self.__description: str = ""
self.__params: list[ToolParam] = []
self.__params: List[ToolParam] = []
self.__parameters_schema_override: Dict[str, Any] | None = None
def set_name(self, name: str) -> "ToolOptionBuilder":
"""
设置工具名称
:param name: 工具名称
:return: ToolBuilder实例
"""设置工具名称。
Args:
name: 工具名称。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当名称为空时抛出。
"""
if not name:
raise ValueError("工具名称不能为空")
@@ -85,35 +370,76 @@ class ToolOptionBuilder:
return self
def set_description(self, description: str) -> "ToolOptionBuilder":
"""
设置工具描述
:param description: 工具描述
:return: ToolBuilder实例
"""设置工具描述。
Args:
description: 工具描述。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当描述为空时抛出。
"""
if not description:
raise ValueError("工具描述不能为空")
self.__description = description
return self
def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder":
"""直接设置完整的参数对象 Schema。
Args:
schema: 完整的对象级 JSON Schema。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当 schema 不是 object 类型时抛出。
"""
if schema.get("type") != "object":
raise ValueError("工具参数 Schema 必须是 object 类型")
self.__parameters_schema_override = deepcopy(schema)
self.__params.clear()
return self
def add_param(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool = False,
enum_values: list[str] | None = None,
enum_values: List[Any] | None = None,
*,
items_schema: Dict[str, Any] | None = None,
properties: Dict[str, Dict[str, Any]] | None = None,
required_properties: List[str] | None = None,
additional_properties: bool | Dict[str, Any] | None = None,
default: Any = None,
) -> "ToolOptionBuilder":
"""
添加工具参数
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填默认为False
:return: ToolBuilder实例
"""
if not name or not description:
raise ValueError("参数名称/描述不能为空")
"""添加一个参数定义。
Args:
name: 参数名称。
param_type: 参数类型。
description: 参数描述。
required: 参数是否必填。
enum_values: 可选的枚举值列表。
items_schema: 数组参数的元素 Schema。
properties: 对象参数的属性定义。
required_properties: 对象参数内部的必填字段。
additional_properties: 对象参数是否允许额外字段。
default: 参数默认值。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当构建器已经设置完整 Schema 时抛出。
"""
if self.__parameters_schema_override is not None:
raise ValueError("已设置完整参数 Schema不能再逐项添加参数")
self.__params.append(
ToolParam(
name=name,
@@ -121,43 +447,83 @@ class ToolOptionBuilder:
description=description,
required=required,
enum_values=enum_values,
items_schema=deepcopy(items_schema),
properties=deepcopy(properties),
required_properties=list(required_properties or []),
additional_properties=deepcopy(additional_properties),
default=deepcopy(default),
)
)
return self
def build(self):
"""
构建工具调用项
:return: 工具调用项
"""
if self.__name == "" or self.__description == "":
raise ValueError("工具名称/描述不能为空")
def build(self) -> ToolOption:
"""构建工具定义。
Returns:
ToolOption: 构建完成的工具定义。
Raises:
ValueError: 当工具名称或描述缺失时抛出。
"""
if not self.__name or not self.__description:
raise ValueError("工具名称和描述不能为空")
return ToolOption(
name=self.__name,
description=self.__description,
params=None if len(self.__params) == 0 else self.__params,
params=None if not self.__params else list(self.__params),
parameters_schema_override=deepcopy(self.__parameters_schema_override),
)
class ToolCall:
"""
来自模型反馈的工具调用
"""
ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any]
"""统一的工具定义输入类型。"""
def __init__(
self,
call_id: str,
func_name: str,
args: dict | None = None,
):
def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption:
"""将任意支持的工具输入规范化为内部 `ToolOption`。
Args:
tool_definition: 原始工具定义输入。
Returns:
ToolOption: 规范化后的工具定义对象。
"""
if isinstance(tool_definition, ToolOption):
return tool_definition
return ToolOption.from_definition(tool_definition)
def normalize_tool_options(
tool_definitions: List[ToolDefinitionInput] | None,
) -> List[ToolOption] | None:
"""批量规范化工具定义列表。
Args:
tool_definitions: 原始工具定义列表。
Returns:
List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。
"""
if not tool_definitions:
return None
return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions]
@dataclass(slots=True)
class ToolCall:
"""来自模型输出的工具调用。"""
call_id: str
func_name: str
args: Dict[str, Any] | None = None
def __post_init__(self) -> None:
"""执行工具调用的基础校验。
Raises:
ValueError: 当工具调用标识或函数名缺失时抛出。
"""
初始化工具调用
:param call_id: 工具调用ID
:param func_name: 要调用的函数名称
:param args: 工具调用参数
"""
self.call_id: str = call_id
self.func_name: str = func_name
self.args: dict | None = args
if not self.call_id:
raise ValueError("工具调用 ID 不能为空")
if not self.func_name:
raise ValueError("工具函数名称不能为空")

File diff suppressed because it is too large Load Diff

View File

@@ -65,8 +65,8 @@ class BufferCLI:
self._mcp_manager: Optional[MCPManager] = None
self._init_llm()
def _init_llm(self):
"""Initialize the LLM service from the main project config."""
def _init_llm(self) -> None:
"""从主项目配置初始化 LLM 服务。"""
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
@@ -77,7 +77,7 @@ class BufferCLI:
enable_thinking=enable_thinking,
)
model_name = self.llm_service._model_name
model_name = self.llm_service.get_current_model_name()
console.print(f"[success][OK] LLM service initialized[/success] [muted](model: {model_name})[/muted]")
def _build_tool_context(self) -> ToolHandlerContext:

View File

@@ -1,6 +1,6 @@
"""
MaiSaka LLM 服务 - 使用主项目 LLM 系统
将主项目的 LLMRequest 适配为 MaiSaka 需要的接口
"""MaiSaka LLM 服务。
该模块基于主项目服务层封装 MaiSaka 所需的对话与工具调用接口
"""
from base64 import b64decode
@@ -8,7 +8,7 @@ from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from time import perf_counter
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
import asyncio
import random
@@ -23,9 +23,16 @@ from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import config_manager, global_config
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, ToolParamType
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content.tool_option import (
ToolCall,
ToolDefinitionInput,
ToolOption,
normalize_tool_options,
)
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from . import config
from .config import console
@@ -36,17 +43,15 @@ from .message_adapter import (
get_message_kind,
get_message_role,
get_message_text,
get_tool_call_id,
get_tool_calls,
remove_last_perception,
to_llm_message,
)
logger = get_logger("maisaka_llm")
@dataclass
@dataclass(slots=True)
class ChatResponse:
"""LLM 对话循环单步响应"""
"""LLM 对话循环单步响应"""
content: Optional[str]
tool_calls: List[ToolCall]
@@ -65,38 +70,34 @@ class MaiSakaLLMService:
temperature: float = 0.5,
max_tokens: int = 2048,
enable_thinking: Optional[bool] = None,
):
"""
初始化 LLM 服务
) -> None:
"""初始化 MaiSaka LLM 服务。
参数仅为兼容性保留,实际使用主项目配置
Args:
api_key: 兼容旧接口保留的参数,当前不使用。
base_url: 兼容旧接口保留的参数,当前不使用。
model: 兼容旧接口保留的参数,当前不使用。
chat_system_prompt: 可选的系统提示词覆盖值。
temperature: 默认温度参数。
max_tokens: 默认最大输出 token 数。
enable_thinking: 是否启用思考模式。
"""
del api_key, base_url, model
self._temperature = temperature
self._max_tokens = max_tokens
self._enable_thinking = enable_thinking
self._extra_tools: List[dict] = []
self._extra_tools: List[ToolOption] = []
self._prompts_loaded = False
self._prompt_load_lock = asyncio.Lock()
# 获取主项目模型配置
try:
model_config = config_manager.get_model_config()
self._model_configs = model_config.model_task_config
except Exception:
# 如果配置加载失败,使用默认配置
from src.config.model_configs import ModelTaskConfig
self._model_configs = ModelTaskConfig()
logger.warning("无法加载主项目模型配置,使用默认配置")
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer
self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use")
# 主对话也使用 tool_use 模型(因为需要工具调用支持)
self._llm_planner = LLMRequest(model_set=self._model_configs.planner, request_type="maisaka_planner")
# 初始化服务层 LLM 门面(按任务名实时解析配置,确保热重载生效)
self._llm_tool_use = LLMServiceClient(task_name="tool_use", request_type="maisaka_tool_use")
# 主对话也使用 planner 模型
self._llm_planner = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
self._llm_chat = self._llm_planner
self._llm_utils = self._llm_tool_use
# 回复生成使用 replyer 模型
self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer")
self._llm_replyer = LLMServiceClient(task_name="replyer", request_type="maisaka_replyer")
# 尝试修复数据库 schema忽略错误
self._try_fix_database_schema()
@@ -111,15 +112,30 @@ class MaiSakaLLMService:
else:
self._chat_system_prompt = chat_system_prompt
self._model_name = (
self._model_configs.planner.model_list[0] if self._model_configs.planner.model_list else "未配置"
)
# 子模块提示词同样采用懒加载
self._emotion_prompt: Optional[str] = None
self._cognition_prompt: Optional[str] = None
def get_current_model_name(self) -> str:
"""获取当前 Maisaka 对话主模型名称。
Returns:
str: 当前 planner 任务的首选模型名;未配置时返回 ``未配置``。
"""
try:
model_task_config = config_manager.get_model_config().model_task_config
if model_task_config.planner.model_list:
return model_task_config.planner.model_list[0]
except Exception as exc:
logger.warning(f"获取当前 Maisaka 模型名称失败: {exc}")
return "未配置"
def _try_fix_database_schema(self) -> None:
"""尝试修复数据库 schema,添加缺失的列"""
"""尝试修复数据库 schema
Returns:
None: 该方法仅执行数据库修复副作用。
"""
try:
from src.common.database.database_client import get_db_session
from sqlalchemy import text
@@ -139,7 +155,11 @@ class MaiSakaLLMService:
pass
def _build_personality_prompt(self) -> str:
"""构建人设信息,参考 replyer 的做法"""
"""构建当前人设提示词。
Returns:
str: 最终用于系统提示词的人设描述。
"""
try:
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
@@ -169,60 +189,21 @@ class MaiSakaLLMService:
# 返回默认人设
return "你的名字是麦麦你是一个活泼可爱的AI助手。"
def set_extra_tools(self, tools: List[dict]) -> None:
"""设置额外工具定义(如 MCP 工具)"""
self._extra_tools = [self._normalize_extra_tool(tool) for tool in tools]
logger.info(f"Normalized {len(self._extra_tools)} extra tool(s) for Maisaka")
def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None:
"""设置额外工具定义
@staticmethod
def _json_type_to_tool_param_type(json_type: str) -> ToolParamType:
normalized = (json_type or "").lower()
if normalized == "integer":
return ToolParamType.INTEGER
if normalized == "number":
return ToolParamType.FLOAT
if normalized == "boolean":
return ToolParamType.BOOLEAN
return ToolParamType.STRING
@classmethod
def _normalize_extra_tool(cls, tool: dict) -> dict:
"""Normalize external/OpenAI-style tool definitions into the internal tool schema."""
if "name" in tool and "description" in tool:
return tool
if tool.get("type") != "function":
return tool
function_info = tool.get("function", {})
parameters_schema = function_info.get("parameters", {}) or {}
required_names = set(parameters_schema.get("required", []) or [])
properties = parameters_schema.get("properties", {}) or {}
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for param_name, param_schema in properties.items():
if not isinstance(param_schema, dict):
continue
enum_values = param_schema.get("enum")
normalized_enum = [str(value) for value in enum_values] if isinstance(enum_values, list) else None
parameters.append(
(
str(param_name),
cls._json_type_to_tool_param_type(str(param_schema.get("type", "string"))),
str(param_schema.get("description", "")),
param_name in required_names,
normalized_enum,
)
)
return {
"name": str(function_info.get("name", "")),
"description": str(function_info.get("description", "")),
"parameters": parameters,
}
Args:
tools: 外部传入的工具定义列表,例如 MCP 暴露的 OpenAI-compatible 工具。
"""
self._extra_tools = normalize_tool_options(tools) or []
logger.info(f"已为 Maisaka 加载 {len(self._extra_tools)} 个额外工具")
async def _ensure_prompts_loaded(self) -> None:
"""异步懒加载提示词,避免在运行中的事件循环里同步渲染 prompt。"""
"""异步懒加载提示词
Returns:
None: 该方法仅刷新内部提示词缓存。
"""
if self._prompts_loaded:
return
@@ -260,7 +241,14 @@ class MaiSakaLLMService:
@staticmethod
def _get_role_badge_style(role: str) -> str:
"""为不同 role 返回不同的标签样式。"""
"""为不同角色返回终端标签样式。
Args:
role: 消息角色名称。
Returns:
str: Rich 可识别的样式字符串。
"""
if role == "system":
return "bold white on blue"
if role == "user":
@@ -273,7 +261,14 @@ class MaiSakaLLMService:
@staticmethod
def _build_terminal_image_preview(image_base64: str) -> Optional[str]:
"""Build a low-resolution ASCII preview for terminals without inline-image support."""
"""构建终端 ASCII 图片预览。
Args:
image_base64: 图片的 Base64 数据。
Returns:
Optional[str]: 可渲染的 ASCII 预览文本;失败时返回 `None`。
"""
ascii_chars = " .:-=+*#%@"
try:
@@ -291,7 +286,7 @@ class MaiSakaLLMService:
except Exception:
return None
rows: list[str] = []
rows: List[str] = []
for row_index in range(preview_height):
row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width]
row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels)
@@ -301,12 +296,19 @@ class MaiSakaLLMService:
@staticmethod
def _render_message_content(content: Any) -> object:
"""消息内容转成适合 Rich 输出的 renderable。"""
"""消息内容转换为 Rich 可渲染对象。
Args:
content: 原始消息内容。
Returns:
object: Rich 可渲染对象。
"""
if isinstance(content, str):
return Text(content)
if isinstance(content, list):
parts: list[object] = []
parts: List[object] = []
for item in content:
if isinstance(item, str):
parts.append(Text(item))
@@ -316,7 +318,7 @@ class MaiSakaLLMService:
if isinstance(image_format, str) and isinstance(image_base64, str):
approx_size = max(0, len(image_base64) * 3 // 4)
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
preview_parts: list[object] = [
preview_parts: List[object] = [
Text(f"image/{image_format} {size_text}\nbase64 omitted", style="magenta")
]
if config.TERMINAL_IMAGE_PREVIEW:
@@ -343,8 +345,15 @@ class MaiSakaLLMService:
return Pretty(content, expand_all=True)
@staticmethod
def _format_tool_call_for_display(tool_call: Any) -> dict[str, Any]:
""" tool call 转成适合 CLI 展示结构。"""
def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]:
"""工具调用转换为 CLI 展示结构。
Args:
tool_call: 原始工具调用对象或字典。
Returns:
Dict[str, Any]: 统一后的展示字典。
"""
if isinstance(tool_call, dict):
function_info = tool_call.get("function", {})
return {
@@ -360,7 +369,16 @@ class MaiSakaLLMService:
}
def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel:
"""Render assistant tool calls as standalone cards."""
"""渲染单个工具调用面板。
Args:
tool_call: 原始工具调用对象或字典。
index: 当前工具调用在父消息中的序号。
parent_index: 父消息在消息列表中的序号。
Returns:
Panel: 可直接打印的工具调用面板。
"""
title = Text.assemble(
Text(" TOOL CALL ", style="bold white on magenta"),
Text(f" #{parent_index}.{index}", style="muted"),
@@ -373,16 +391,22 @@ class MaiSakaLLMService:
)
def _render_message_panel(self, message: Any, index: int) -> Panel:
"""渲染主循环 prompt 中的一条消息。"""
"""渲染主循环 Prompt 中的一条消息。
Args:
message: 原始消息对象或字典。
index: 当前消息序号。
Returns:
Panel: 可直接打印的消息面板。
"""
if isinstance(message, dict):
raw_role = message.get("role", "unknown")
content = message.get("content")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")
else:
raw_role = getattr(message, "role", "unknown")
content = getattr(message, "content", None)
tool_calls = getattr(message, "tool_calls", None)
tool_call_id = getattr(message, "tool_call_id", None)
role = raw_role.value if hasattr(raw_role, "value") else str(raw_role)
@@ -391,7 +415,7 @@ class MaiSakaLLMService:
Text(f" #{index}", style="muted"),
)
parts: list[object] = []
parts: List[object] = []
if content not in (None, "", []):
parts.append(Text(" message ", style="bold cyan"))
parts.append(self._render_message_content(content))
@@ -415,30 +439,27 @@ class MaiSakaLLMService:
padding=(0, 1),
)
@staticmethod
def _tool_option_to_dict(tool: "ToolOption") -> dict:
"""将 ToolOption 对象转换为主项目期望的 dict 格式
async def chat_loop_step(self, chat_history: List[MaiMessage]) -> ChatResponse:
"""执行主对话循环的一步。
主项目的 _build_tool_options() 期望的格式:
{
"name": str,
"description": str,
"parameters": List[Tuple[name, ToolParamType, description, required, enum_values]]
}
Args:
chat_history: 当前对话历史。
Returns:
ChatResponse: 本轮对话生成结果。
"""
params = []
if tool.params:
for param in tool.params:
params.append((param.name, param.param_type, param.description, param.required, param.enum_values))
return {"name": tool.name, "description": tool.description, "parameters": params}
async def chat_loop_step(self, chat_history: list[MaiMessage]) -> ChatResponse:
"""执行对话循环的一步 - 使用 tool_use 模型"""
await self._ensure_prompts_loaded()
def message_factory(client) -> list[Message]:
"""将 MaiSaka 的 chat_history 转换为主项目的 Message 格式"""
messages: list[Message] = []
def message_factory(_client: BaseClient) -> List[Message]:
"""将 MaiSaka 对话历史转换为内部消息列表。
Args:
_client: 当前底层客户端实例。
Returns:
List[Message]: 规范化后的消息列表。
"""
messages: List[Message] = []
# 首先添加系统提示词
system_msg = MessageBuilder().set_role(RoleType.System)
@@ -454,15 +475,13 @@ class MaiSakaLLMService:
return messages
# 调用 LLM使用带消息的接口
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (
self._extra_tools if self._extra_tools else []
)
# 合并内置工具和额外工具,统一交给底层规范化流程处理。
all_tools = [*get_builtin_tools(), *self._extra_tools]
# 打印消息列表
built_messages = message_factory(None)
ordered_panels: list[Panel] = []
ordered_panels: List[Panel] = []
for index, msg in enumerate(built_messages, start=1):
ordered_panels.append(self._render_message_panel(msg, index))
tool_calls = getattr(msg, "tool_calls", None)
@@ -483,13 +502,18 @@ class MaiSakaLLMService:
request_started_at = perf_counter()
logger.info("chat_loop_step calling planner model generate_response_with_message_async")
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
logger.info("chat_loop_step calling planner model generate_response_with_messages")
generation_result = await self._llm_chat.generate_response_with_messages(
message_factory=message_factory,
tools=all_tools if all_tools else None,
temperature=self._temperature,
max_tokens=self._max_tokens,
options=LLMGenerationOptions(
tool_options=all_tools if all_tools else None,
temperature=self._temperature,
max_tokens=self._max_tokens,
),
)
response = generation_result.response
model = generation_result.model_name
tool_calls = generation_result.tool_calls
elapsed = perf_counter() - request_started_at
logger.info(
f"chat_loop_step planner model returned in {elapsed:.2f}s "
@@ -509,8 +533,15 @@ class MaiSakaLLMService:
raw_message=raw_message,
)
def _filter_for_api(self, chat_history: list[MaiMessage]) -> str:
"""过滤对话历史为 API 格式"""
def _filter_for_api(self, chat_history: List[MaiMessage]) -> str:
"""对话历史过滤为简单文本格式
Args:
chat_history: 当前对话历史。
Returns:
str: 过滤后的文本上下文。
"""
parts = []
for msg in chat_history:
role = get_message_role(msg)
@@ -535,8 +566,15 @@ class MaiSakaLLMService:
return "\n\n".join(parts)
def build_chat_context(self, user_text: str) -> list[MaiMessage]:
"""构建对话上下文"""
def build_chat_context(self, user_text: str) -> List[MaiMessage]:
"""构建新的对话上下文
Args:
user_text: 用户输入文本。
Returns:
List[MaiMessage]: 初始对话上下文消息列表。
"""
return [
build_message(
role=RoleType.User.value,
@@ -547,8 +585,15 @@ class MaiSakaLLMService:
# ──────── 分析模块(使用 utils 模型) ────────
async def analyze_emotion(self, chat_history: list[MaiMessage]) -> str:
"""情绪分析 - 使用 utils 模型"""
async def analyze_emotion(self, chat_history: List[MaiMessage]) -> str:
"""执行情绪分析
Args:
chat_history: 当前对话历史。
Returns:
str: 情绪分析文本。
"""
await self._ensure_prompts_loaded()
filtered = [m for m in chat_history if get_message_kind(m) != "perception"]
recent = filtered[-10:] if len(filtered) > 10 else filtered
@@ -574,19 +619,26 @@ class MaiSakaLLMService:
print("=" * 60 + "\n")
try:
response, _ = await self._llm_utils.generate_response_async(
generation_result = await self._llm_utils.generate_response(
prompt=prompt,
temperature=0.3,
max_tokens=512,
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
)
response = generation_result.response
return response
except Exception as e:
logger.error(f"情绪分析 LLM 调用出错: {e}")
return ""
async def analyze_cognition(self, chat_history: list[MaiMessage]) -> str:
"""认知分析 - 使用 utils 模型"""
async def analyze_cognition(self, chat_history: List[MaiMessage]) -> str:
"""执行认知分析
Args:
chat_history: 当前对话历史。
Returns:
str: 认知分析文本。
"""
await self._ensure_prompts_loaded()
filtered = [m for m in chat_history if get_message_kind(m) != "perception"]
recent = filtered[-10:] if len(filtered) > 10 else filtered
@@ -612,19 +664,27 @@ class MaiSakaLLMService:
print("=" * 60 + "\n")
try:
response, _ = await self._llm_utils.generate_response_async(
generation_result = await self._llm_utils.generate_response(
prompt=prompt,
temperature=0.3,
max_tokens=512,
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
)
response = generation_result.response
return response
except Exception as e:
logger.error(f"认知分析 LLM 调用出错: {e}")
return ""
async def _removed_analyze_timing(self, chat_history: list[MaiMessage], timing_info: str) -> str:
"""时间分析 - 使用 utils 模型"""
async def _removed_analyze_timing(self, chat_history: List[MaiMessage], timing_info: str) -> str:
"""执行时间节奏分析。
Args:
chat_history: 当前对话历史。
timing_info: 外部传入的时间信息摘要。
Returns:
str: 时间分析文本。
"""
await self._ensure_prompts_loaded()
filtered = [
m
@@ -653,11 +713,11 @@ class MaiSakaLLMService:
print("=" * 60 + "\n")
try:
response, _ = await self._llm_utils.generate_response_async(
generation_result = await self._llm_utils.generate_response(
prompt=prompt,
temperature=0.3,
max_tokens=512,
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
)
response = generation_result.response
return response
except Exception as e:
@@ -666,10 +726,15 @@ class MaiSakaLLMService:
# ──────── 回复生成(使用 replyer 模型) ────────
async def generate_reply(self, reason: str, chat_history: list[MaiMessage]) -> str:
"""
生成回复 - 使用 replyer 模型
可供 Replyer 类直接调用
async def generate_reply(self, reason: str, chat_history: List[MaiMessage]) -> str:
"""生成最终回复文本。
Args:
reason: 当前轮次的内部想法或回复理由。
chat_history: 当前对话历史。
Returns:
str: 最终回复文本。
"""
await self._ensure_prompts_loaded()
from datetime import datetime
@@ -704,17 +769,12 @@ class MaiSakaLLMService:
print("=" * 60 + "\n")
try:
response, _ = await self._llm_replyer.generate_response_async(
generation_result = await self._llm_replyer.generate_response(
prompt=messages,
temperature=0.8,
max_tokens=512,
options=LLMGenerationOptions(temperature=0.8, max_tokens=512),
)
response = generation_result.response
return response.strip() if response else "..."
except Exception as e:
logger.error(f"回复生成 LLM 调用出错: {e}")
return "..."

View File

@@ -16,8 +16,9 @@ from json_repair import repair_json
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
from src.config.config import model_config, global_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.services import message_service as message_api
from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person
@@ -88,8 +89,8 @@ class ChatHistorySummarizer:
# 注意:批次加载需要异步查询消息,所以在 start() 中调用
# LLM请求器用于压缩聊天内容
self.summarizer_llm = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
self.summarizer_llm = LLMServiceClient(
task_name="utils", request_type="chat_history_summarizer"
)
# 后台循环相关
@@ -656,10 +657,11 @@ class ChatHistorySummarizer:
prompt = await prompt_manager.render_prompt(prompt_template)
try:
response, _ = await self.summarizer_llm.generate_response_async(
generation_result = await self.summarizer_llm.generate_response(
prompt=prompt,
temperature=0.3,
options=LLMGenerationOptions(temperature=0.3),
)
response = generation_result.response
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
@@ -812,7 +814,8 @@ class ChatHistorySummarizer:
prompt = await prompt_manager.render_prompt(prompt_template)
try:
response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt)
generation_result = await self.summarizer_llm.generate_response(prompt=prompt)
response = generation_result.response
# 解析JSON响应
json_str = response.strip()

View File

@@ -5,7 +5,7 @@ import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.services import llm_service as llm_api
from sqlmodel import select, col
@@ -269,18 +269,18 @@ async def _react_agent_solve_question(
return messages
message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues]
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory_fn, # type: ignore[arg-type]
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
generation_result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name="tool_use",
request_type="memory.react",
message_factory=message_factory_fn, # type: ignore[arg-type]
tool_options=tool_definitions,
)
)
success = generation_result.success
response = generation_result.completion.response
reasoning_content = generation_result.completion.reasoning
tool_calls = generation_result.completion.tool_calls
# logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
@@ -679,18 +679,16 @@ async def _react_agent_solve_question(
evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template)
(
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
request_type="memory.react.final",
evaluation_result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name="tool_use",
request_type="memory.react.final",
prompt=evaluation_prompt,
tool_options=[],
)
)
eval_success = evaluation_result.success
eval_response = evaluation_result.completion.response
if not eval_success:
logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}")

View File

@@ -1,11 +1,12 @@
"""
工具注册系统
提供统一的工具注册和管理接口
"""工具注册系统。
提供统一的工具注册和管理接口
"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from typing import Any, Awaitable, Callable, Dict, List, Optional
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
logger = get_logger("memory_retrieval_tools")
@@ -14,16 +15,19 @@ class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""初始化工具。
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
execute_func: 执行函数,必须是异步函数
name: 工具名称
description: 工具描述
parameters: 参数定义列表
execute_func: 执行函数,必须是异步函数
"""
self.name = name
self.description = description
@@ -44,20 +48,17 @@ class MemoryRetrievalTool:
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
async def execute(self, **kwargs: Any) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
"""获取规范化的工具定义。
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
Dict[str, Any]: 统一工具定义字典
"""
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for param in self.parameters:
param_name = param.get("name", "")
@@ -77,20 +78,27 @@ class MemoryRetrievalTool:
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
return tool_def
normalized_option = normalize_tool_option(
{
"name": self.name,
"description": self.description,
"parameters": legacy_parameters,
}
)
return {
"name": normalized_option.name,
"description": normalized_option.description,
"parameters_schema": normalized_option.parameters_schema,
}
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
def __init__(self) -> None:
"""初始化工具注册器。"""
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
@@ -137,15 +145,18 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""注册记忆检索工具的便捷函数
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表
execute_func: 执行函数
name: 工具名称
description: 工具描述
parameters: 参数定义列表
execute_func: 执行函数
"""
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
_tool_registry.register_tool(tool)

View File

@@ -17,14 +17,14 @@ from src.common.data_models.person_info_data_model import dump_group_cardname_re
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
logger = get_logger("person_info")
relation_selection_model = LLMRequest(
model_set=model_config.model_task_config.tool_use, request_type="relation_selection"
relation_selection_model = LLMServiceClient(
task_name="tool_use", request_type="relation_selection"
)
@@ -578,7 +578,8 @@ class Person:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
generation_result = await relation_selection_model.generate_response(prompt)
response = generation_result.response
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
@@ -600,7 +601,8 @@ class Person:
例如:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
generation_result = await relation_selection_model.generate_response(prompt)
response = generation_result.response
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
@@ -634,7 +636,9 @@ class Person:
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
self.qv_name_llm = LLMServiceClient(
task_name="utils", request_type="relation.qv_name"
)
try:
with get_db_session() as _:
pass
@@ -737,7 +741,8 @@ class PersonInfoManager:
"nickname": "昵称",
"reason": "理由"
}"""
response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
generation_result = await self.qv_name_llm.generate_response(qv_name_prompt)
response = generation_result.response
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复{response}")
result = self._extract_json_from_text(response)

View File

@@ -1,33 +1,80 @@
from typing import Any, Dict
from typing import Any, Dict, List
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.payload_content.tool_option import ToolCall
logger = get_logger("plugin_runtime.integration")
def _get_nested_config_value(source: Any, key: str, default: Any = None) -> Any:
"""从嵌套对象或字典中读取配置值。
Args:
source: 配置对象或字典。
key: 以点号分隔的路径。
default: 未命中时返回的默认值。
Returns:
Any: 命中的值;读取失败时返回默认值。
"""
current = source
try:
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
elif hasattr(current, part):
continue
if hasattr(current, part):
current = getattr(current, part)
else:
raise KeyError(part)
continue
raise KeyError(part)
return current
except Exception:
return default
def _normalize_prompt_arg(prompt: Any) -> str | List[Dict[str, Any]]:
"""校验并规范化插件传入的提示参数。
Args:
prompt: 原始提示参数。
Returns:
str | List[Dict[str, Any]]: 规范化后的提示输入。
Raises:
ValueError: 提示参数缺失或结构不受支持时抛出。
"""
if isinstance(prompt, str):
if not prompt.strip():
raise ValueError("缺少必要参数 prompt")
return prompt
if isinstance(prompt, list) and prompt:
for index, prompt_message in enumerate(prompt, start=1):
if not isinstance(prompt_message, dict):
raise ValueError(f"prompt 第 {index} 项必须为字典")
return prompt
raise ValueError("缺少必要参数 prompt")
class RuntimeCoreCapabilityMixin:
"""插件运行时的核心能力混入。"""
async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""向指定流发送文本消息。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 能力执行结果。
"""
del plugin_id, capability
from src.services import send_service as send_api
text: str = args.get("text", "")
stream_id: str = args.get("stream_id", "")
text = str(args.get("text", ""))
stream_id = str(args.get("stream_id", ""))
if not text or not stream_id:
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
@@ -35,20 +82,31 @@ class RuntimeCoreCapabilityMixin:
result = await send_api.text_to_stream(
text=text,
stream_id=stream_id,
typing=args.get("typing", False),
set_reply=args.get("set_reply", False),
storage_message=args.get("storage_message", True),
typing=bool(args.get("typing", False)),
set_reply=bool(args.get("set_reply", False)),
storage_message=bool(args.get("storage_message", True)),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.send.text] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""向指定流发送表情图片。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 能力执行结果。
"""
del plugin_id, capability
from src.services import send_service as send_api
emoji_base64: str = args.get("emoji_base64", "")
stream_id: str = args.get("stream_id", "")
emoji_base64 = str(args.get("emoji_base64", ""))
stream_id = str(args.get("stream_id", ""))
if not emoji_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
@@ -56,18 +114,29 @@ class RuntimeCoreCapabilityMixin:
result = await send_api.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
storage_message=bool(args.get("storage_message", True)),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.send.emoji] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""向指定流发送图片。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 能力执行结果。
"""
del plugin_id, capability
from src.services import send_service as send_api
image_base64: str = args.get("image_base64", "")
stream_id: str = args.get("stream_id", "")
image_base64 = str(args.get("image_base64", ""))
stream_id = str(args.get("stream_id", ""))
if not image_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
@@ -75,18 +144,29 @@ class RuntimeCoreCapabilityMixin:
result = await send_api.image_to_stream(
image_base64=image_base64,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
storage_message=bool(args.get("storage_message", True)),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.send.image] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""向指定流发送命令消息。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 能力执行结果。
"""
del plugin_id, capability
from src.services import send_service as send_api
command = args.get("command", "")
stream_id: str = args.get("stream_id", "")
command = str(args.get("command", ""))
stream_id = str(args.get("stream_id", ""))
if not command or not stream_id:
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
@@ -95,22 +175,33 @@ class RuntimeCoreCapabilityMixin:
message_type="command",
content=command,
stream_id=stream_id,
storage_message=args.get("storage_message", True),
display_message=args.get("display_message", ""),
storage_message=bool(args.get("storage_message", True)),
display_message=str(args.get("display_message", "")),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.send.command] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""向指定流发送自定义消息。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 能力执行结果。
"""
del plugin_id, capability
from src.services import send_service as send_api
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
message_type = str(args.get("message_type", "") or args.get("custom_type", ""))
content = args.get("content")
if content is None:
content = args.get("data", "")
stream_id: str = args.get("stream_id", "")
stream_id = str(args.get("stream_id", ""))
if not message_type or not stream_id:
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
@@ -119,114 +210,116 @@ class RuntimeCoreCapabilityMixin:
message_type=message_type,
content=content,
stream_id=stream_id,
display_message=args.get("display_message", ""),
typing=args.get("typing", False),
storage_message=args.get("storage_message", True),
display_message=str(args.get("display_message", "")),
typing=bool(args.get("typing", False)),
storage_message=bool(args.get("storage_message", True)),
)
return {"success": result}
except Exception as e:
logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.send.custom] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""执行无工具的 LLM 生成能力。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 标准化后的 LLM 响应结构。
"""
del capability
from src.services import llm_service as llm_api
prompt: str = args.get("prompt", "")
if not prompt:
return {"success": False, "error": "缺少必要参数 prompt"}
model_name: str = args.get("model", "") or args.get("model_name", "")
temperature = args.get("temperature")
max_tokens = args.get("max_tokens")
try:
models = llm_api.get_available_models()
if model_name and model_name in models:
model_config = models[model_name]
else:
if not models:
return {"success": False, "error": "没有可用的模型配置"}
model_config = next(iter(models.values()))
success, response, reasoning, used_model = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type=f"plugin.{plugin_id}",
temperature=temperature,
max_tokens=max_tokens,
prompt = _normalize_prompt_arg(args.get("prompt"))
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type=f"plugin.{plugin_id}",
prompt=prompt,
temperature=args.get("temperature"),
max_tokens=args.get("max_tokens"),
)
)
return {
"success": success,
"response": response,
"reasoning": reasoning,
"model_name": used_model,
}
except Exception as e:
logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
return result.to_capability_payload()
except Exception as exc:
logger.error(f"[cap.llm.generate] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""执行带工具的 LLM 生成能力。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 标准化后的 LLM 响应结构。
"""
del capability
from src.services import llm_service as llm_api
prompt: str = args.get("prompt", "")
if not prompt:
return {"success": False, "error": "缺少必要参数 prompt"}
model_name: str = args.get("model", "") or args.get("model_name", "")
tool_options = args.get("tools") or args.get("tool_options")
temperature = args.get("temperature")
max_tokens = args.get("max_tokens")
if tool_options is not None and not isinstance(tool_options, list):
return {"success": False, "error": "tools 必须为列表"}
try:
models = llm_api.get_available_models()
if model_name and model_name in models:
model_config = models[model_name]
else:
if not models:
return {"success": False, "error": "没有可用的模型配置"}
model_config = next(iter(models.values()))
success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools(
prompt=prompt,
model_config=model_config,
tool_options=tool_options,
request_type=f"plugin.{plugin_id}",
temperature=temperature,
max_tokens=max_tokens,
prompt = _normalize_prompt_arg(args.get("prompt"))
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type=f"plugin.{plugin_id}",
prompt=prompt,
tool_options=tool_options,
temperature=args.get("temperature"),
max_tokens=args.get("max_tokens"),
)
)
serialized_tool_calls = None
if tool_calls:
serialized_tool_calls = [
{
"id": tool_call.call_id,
"function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
}
for tool_call in tool_calls
if isinstance(tool_call, ToolCall)
]
return {
"success": success,
"response": response,
"reasoning": reasoning,
"model_name": used_model,
"tool_calls": serialized_tool_calls,
}
except Exception as e:
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
return result.to_capability_payload()
except Exception as exc:
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取当前宿主可用的模型任务列表。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 可用模型列表。
"""
del plugin_id, capability, args
from src.services import llm_service as llm_api
try:
models = llm_api.get_available_models()
return {"success": True, "models": list(models.keys())}
except Exception as e:
logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
except Exception as exc:
logger.error(f"[cap.llm.get_available_models] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
key: str = args.get("key", "")
"""读取宿主全局配置中的单个字段。
Args:
plugin_id: 插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 配置读取结果。
"""
del plugin_id, capability
key = str(args.get("key", ""))
default = args.get("default")
if not key:
return {"success": False, "value": None, "error": "缺少必要参数 key"}
@@ -234,37 +327,57 @@ class RuntimeCoreCapabilityMixin:
try:
value = _get_nested_config_value(global_config, key, default)
return {"success": True, "value": value}
except Exception as e:
return {"success": False, "value": None, "error": str(e)}
except Exception as exc:
return {"success": False, "value": None, "error": str(exc)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""读取指定插件的配置。
Args:
plugin_id: 当前插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 配置读取结果。
"""
del capability
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
plugin_name = str(args.get("plugin_name", plugin_id))
key = str(args.get("key", ""))
default = args.get("default")
try:
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
if key:
value = _get_nested_config_value(config, key, default)
return {"success": True, "value": value}
return {"success": True, "value": config}
except Exception as e:
return {"success": False, "value": default, "error": str(e)}
except Exception as exc:
return {"success": False, "value": default, "error": str(exc)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""读取指定插件的全部配置。
Args:
plugin_id: 当前插件标识。
capability: 能力名称。
args: 能力调用参数。
Returns:
Any: 配置读取结果。
"""
del capability
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
plugin_name = str(args.get("plugin_name", plugin_id))
try:
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}
except Exception as e:
return {"success": False, "value": {}, "error": str(e)}
except Exception as exc:
return {"success": False, "value": {}, "error": str(exc)}

View File

@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
from src.common.logger import get_logger
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
from src.llm_models.payload_content.tool_option import ToolParamType
from src.llm_models.payload_content.tool_option import normalize_tool_option
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
@@ -28,13 +28,6 @@ _HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
ComponentType.COMMAND: "COMMAND",
ComponentType.TOOL: "TOOL",
}
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
class ComponentQueryService:
@@ -171,11 +164,9 @@ class ComponentQueryService:
return ActionInfo(
name=entry.name,
component_type=ComponentType.ACTION,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
action_parameters=action_parameters,
action_require=action_require,
associated_types=associated_types,
@@ -202,72 +193,48 @@ class ComponentQueryService:
metadata = dict(entry.metadata)
return CommandInfo(
name=entry.name,
component_type=ComponentType.COMMAND,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
command_pattern=str(metadata.get("command_pattern", "") or ""),
)
@staticmethod
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
"""规范化工具参数类型
Args:
raw_value: 原始工具参数类型值。
Returns:
ToolParamType: 规范化后的工具参数类型。
"""
normalized_value = str(raw_value or "").strip().lower()
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
@staticmethod
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
def _build_tool_definition(entry: "ToolEntry") -> dict[str, Any]:
"""将运行时 Tool 条目转换为原始工具定义字典
Args:
entry: 插件运行时中的 Tool 条目。
Returns:
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表
dict[str, Any]: 可交给 `normalize_tool_option()` 的原始工具定义
"""
raw_definition: dict[str, Any] = {
"name": entry.name,
"description": entry.description,
}
if isinstance(entry.parameters_raw, dict) and entry.parameters_raw:
raw_definition["parameters_schema"] = entry.parameters_raw
return raw_definition
if isinstance(entry.parameters, list) and entry.parameters:
raw_definition["parameters"] = entry.parameters
return raw_definition
if isinstance(entry.parameters_raw, list) and entry.parameters_raw:
raw_definition["parameters"] = entry.parameters_raw
return raw_definition
return raw_definition
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
if not structured_parameters and isinstance(entry.parameters_raw, dict):
structured_parameters = [
{"name": key, **value}
for key, value in entry.parameters_raw.items()
if isinstance(value, dict)
]
@staticmethod
def _build_tool_parameters_schema(entry: "ToolEntry") -> dict[str, Any] | None:
"""将运行时 Tool 条目转换为对象级参数 Schema。
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for parameter in structured_parameters:
if not isinstance(parameter, dict):
continue
Args:
entry: 插件运行时中的 Tool 条目。
parameter_name = str(parameter.get("name", "") or "").strip()
if not parameter_name:
continue
enum_values = parameter.get("enum")
normalized_enum_values = (
[str(item) for item in enum_values if item is not None]
if isinstance(enum_values, list)
else None
)
normalized_parameters.append(
(
parameter_name,
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
str(parameter.get("description", "") or ""),
bool(parameter.get("required", True)),
normalized_enum_values,
)
)
return normalized_parameters
Returns:
dict[str, Any] | None: 规范化后的对象级参数 Schema。
"""
normalized_option = normalize_tool_option(ComponentQueryService._build_tool_definition(entry))
return normalized_option.parameters_schema
@staticmethod
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
@@ -282,13 +249,10 @@ class ComponentQueryService:
return ToolInfo(
name=entry.name,
component_type=ComponentType.TOOL,
description=entry.description,
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=dict(entry.metadata),
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
tool_description=entry.description,
parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry),
)
@staticmethod

View File

@@ -91,7 +91,7 @@ class ToolEntry(ComponentEntry):
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.description: str = metadata.get("description", "")
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {})
super().__init__(name, component_type, plugin_id, metadata)

View File

@@ -1,191 +1,492 @@
"""LLM 服务模块
"""LLM 服务层。
提供与 LLM 模型交互的核心功能。
该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到
`src.llm_models` 中的底层请求调度器。
"""
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple
import json
from src.common.data_models.llm_service_data_models import (
LLMAudioTranscriptionResult,
LLMEmbeddingResult,
LLMGenerationOptions,
LLMImageOptions,
LLMResponseResult,
LLMServiceRequest,
LLMServiceResult,
MessageFactory,
PromptInput,
PromptMessage,
)
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.llm_models.utils_model import LLMOrchestrator
logger = get_logger("llm_service")
class LLMServiceClient:
"""面向上层模块的 LLM 服务对象式门面。
async def _generate_response(
model_config: TaskConfig,
request_type: str,
prompt: Optional[str] = None,
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, str, List[ToolCall] | None]:
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
当前推荐优先使用以下正式接口:
- `generate_response`
- `generate_response_with_messages`
- `generate_response_for_image`
- `transcribe_audio`
- `embed_text`
"""
if message_factory is not None:
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
message_factory=message_factory,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
def __init__(self, task_name: str, request_type: str = "") -> None:
"""初始化 LLM 服务门面。
Args:
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
request_type: 当前请求的业务类型标识。
"""
self.task_name = resolve_task_name(task_name)
self.request_type = request_type
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
@staticmethod
def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions:
"""规范化文本生成选项。
Args:
options: 原始生成选项。
Returns:
LLMGenerationOptions: 可直接用于执行请求的完整选项对象。
"""
if options is None:
return LLMGenerationOptions()
return options
@staticmethod
def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions:
"""规范化图像理解选项。
Args:
options: 原始图像理解选项。
Returns:
LLMImageOptions: 可直接用于执行请求的完整选项对象。
"""
if options is None:
return LLMImageOptions()
return options
async def generate_response(
self,
prompt: str,
options: LLMGenerationOptions | None = None,
) -> LLMResponseResult:
"""生成单轮文本响应。
Args:
prompt: 文本提示词。
options: 文本生成选项。
Returns:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
return await self._orchestrator.generate_response_async(
prompt=prompt,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
tools=active_options.tool_options,
response_format=active_options.response_format,
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
return response, reasoning_content, model_name, tool_call
if prompt is None:
raise ValueError("prompt 与 message_factory 不能同时为空")
async def generate_response_with_messages(
self,
message_factory: MessageFactory,
options: LLMGenerationOptions | None = None,
) -> LLMResponseResult:
"""基于消息工厂生成响应。
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return response, reasoning_content, model_name, tool_call
Args:
message_factory: 消息工厂,会根据客户端能力构建消息列表。
options: 文本生成选项。
Returns:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
return await self._orchestrator.generate_response_with_message_async(
message_factory=message_factory,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
tools=active_options.tool_options,
response_format=active_options.response_format,
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
async def generate_response_for_image(
self,
prompt: str,
image_base64: str,
image_format: str,
options: LLMImageOptions | None = None,
) -> LLMResponseResult:
"""为图像内容生成响应。
Args:
prompt: 文本提示词。
image_base64: 图像的 Base64 编码字符串。
image_format: 图像格式,例如 ``png``、``jpeg``。
options: 图像理解选项。
Returns:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_image_options(options)
return await self._orchestrator.generate_response_for_image(
prompt=prompt,
image_base64=image_base64,
image_format=image_format,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
interrupt_flag=active_options.interrupt_flag,
)
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
"""执行音频转写请求。
Args:
voice_base64: 音频的 Base64 编码字符串。
Returns:
LLMAudioTranscriptionResult: 音频转写结果对象。
"""
return await self._orchestrator.generate_response_for_voice(voice_base64)
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
"""生成文本嵌入向量。
Args:
embedding_input: 待编码的文本。
Returns:
LLMEmbeddingResult: 向量生成结果对象。
"""
return await self._orchestrator.get_embedding(embedding_input)
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用模型配置
"""获取所有可用模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
Dict[str, TaskConfig]: 模型任务名为键的配置映射。
"""
try:
models = config_manager.get_model_config().model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
continue
return rets
except Exception as e:
logger.error(f"[LLMService] 获取可用模型失败: {e}")
available_models: Dict[str, TaskConfig] = {}
for attr_name in dir(models):
if attr_name.startswith("__"):
continue
try:
attr_value = getattr(models, attr_name)
except Exception as exc:
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
continue
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
available_models[attr_name] = attr_value
return available_models
except Exception as exc:
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
return {}
async def generate_with_model(
prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
def resolve_task_name(task_name: str = "") -> str:
"""根据名称解析任务配置名。
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
task_name: 目标任务配置名;为空时返回首个可用任务名。
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
str: 解析得到的任务配置名。
Raises:
RuntimeError: 当前没有任何可用模型配置。
ValueError: 指定名称不存在时抛出。
"""
try:
logger.debug(f"[LLMService] 完整提示词: {prompt}")
response, reasoning_content, model_name, _ = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", ""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_task_name = task_name.strip()
if not normalized_task_name:
return next(iter(models.keys()))
if normalized_task_name not in models:
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
return normalized_task_name
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容
def _normalize_role(role_name: str) -> RoleType:
"""将原始角色字符串转换为内部角色枚举。
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
role_name: 原始角色名称。
Returns:
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
RoleType: 规范化后的角色枚举。
Raises:
ValueError: 角色类型不受支持时抛出。
"""
normalized_role_name = role_name.strip().lower()
try:
model_name_list = model_config.model_list
logger.info(f"使用模型{model_name_list}生成内容")
logger.debug(f"完整提示词: {prompt}")
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None
return RoleType(normalized_role_name)
except ValueError as exc:
raise ValueError(f"不支持的消息角色: {role_name}") from exc
async def generate_with_model_with_tools_by_message_factory(
message_factory: Callable[[BaseClient], List[Message]],
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
"""解析 Data URL 形式的图片内容。
Args:
message_factory: 消息工厂函数
model_config: 模型配置
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
image_url: 图片 URL。
Returns:
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
Tuple[str, str]: `(图片格式, Base64 数据)`。
Raises:
ValueError: 输入不是受支持的 Data URL 时抛出。
"""
try:
model_name_list = model_config.model_list
logger.info(f"使用模型 {model_name_list} 生成内容")
if not image_url.startswith("data:image/") or ";base64," not in image_url:
raise ValueError("仅支持 Data URL 形式的图片输入")
prefix, image_base64 = image_url.split(";base64,", maxsplit=1)
image_format = prefix.removeprefix("data:image/")
if not image_format or not image_base64:
raise ValueError("图片 Data URL 不完整")
return image_format, image_base64
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
message_factory=message_factory,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
"""将原始消息内容追加到内部消息构建器。
Args:
message_builder: 目标消息构建器。
content: 原始消息内容。
Raises:
ValueError: 消息内容结构不受支持时抛出。
"""
if isinstance(content, str):
message_builder.add_text_content(content)
return
content_items: List[Any]
if isinstance(content, list):
content_items = content
elif isinstance(content, dict):
content_items = [content]
else:
raise ValueError("消息内容必须为字符串、字典或列表")
for content_item in content_items:
if isinstance(content_item, str):
message_builder.add_text_content(content_item)
continue
if not isinstance(content_item, dict):
raise ValueError("消息内容列表中仅支持字符串或字典片段")
part_type = str(content_item.get("type", "text")).strip().lower()
if part_type == "text":
text_content = content_item.get("text")
if not isinstance(text_content, str):
raise ValueError("文本片段缺少 `text` 字段")
message_builder.add_text_content(text_content)
continue
if part_type in {"image", "image_url", "input_image"}:
image_url = content_item.get("image_url")
if isinstance(image_url, dict):
image_url = image_url.get("url")
if isinstance(image_url, str):
image_format, image_base64 = _parse_data_url_image(image_url)
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
image_format = content_item.get("image_format")
image_base64 = content_item.get("image_base64")
if isinstance(image_format, str) and isinstance(image_base64, str):
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
raise ValueError("图片片段缺少可识别的图片数据")
raise ValueError(f"不支持的消息片段类型: {part_type}")
def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None:
"""将原始工具参数规范化为字典。
Args:
arguments: 原始工具参数。
Returns:
Dict[str, Any] | None: 规范化后的参数字典。
"""
if arguments is None:
return None
if isinstance(arguments, dict):
return arguments
if isinstance(arguments, str):
stripped_arguments = arguments.strip()
if not stripped_arguments:
return {}
try:
parsed_arguments = json.loads(stripped_arguments)
except json.JSONDecodeError:
return {"raw_arguments": arguments}
if isinstance(parsed_arguments, dict):
return parsed_arguments
return {"value": parsed_arguments}
return {"value": arguments}
def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None:
"""从原始消息中提取工具调用列表。
Args:
raw_tool_calls: 原始工具调用结构。
Returns:
List[ToolCall] | None: 规范化后的工具调用列表。
Raises:
ValueError: 工具调用结构缺失必要字段时抛出。
"""
if raw_tool_calls is None:
return None
if not isinstance(raw_tool_calls, list):
raise ValueError("`tool_calls` 必须为列表")
tool_calls: List[ToolCall] = []
for raw_tool_call in raw_tool_calls:
if not isinstance(raw_tool_call, dict):
raise ValueError("工具调用项必须为字典")
function_info = raw_tool_call.get("function")
if isinstance(function_info, dict):
func_name = function_info.get("name")
arguments = function_info.get("arguments")
else:
func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name")
arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args")
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
if not isinstance(call_id, str) or not isinstance(func_name, str):
raise ValueError("工具调用缺少 `id` 或函数名称")
tool_calls.append(
ToolCall(
call_id=call_id,
func_name=func_name,
args=_normalize_tool_arguments(arguments),
)
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None
return tool_calls or None
def _build_message_from_dict(raw_message: PromptMessage) -> Message:
"""将原始消息字典转换为内部消息对象。
Args:
raw_message: 原始消息字典。
Returns:
Message: 规范化后的消息对象。
Raises:
ValueError: 原始消息结构不合法时抛出。
"""
raw_role = raw_message.get("role")
if not isinstance(raw_role, str):
raise ValueError("消息缺少字符串类型的 `role` 字段")
role = _normalize_role(raw_role)
message_builder = MessageBuilder().set_role(role)
tool_calls = _build_tool_calls(raw_message.get("tool_calls"))
if tool_calls is not None:
message_builder.set_tool_calls(tool_calls)
tool_call_id = raw_message.get("tool_call_id")
if isinstance(tool_call_id, str) and role == RoleType.Tool:
message_builder.set_tool_call_id(tool_call_id)
if "content" in raw_message and raw_message["content"] not in (None, "", []):
_append_content_parts(message_builder, raw_message["content"])
return message_builder.build()
def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory:
"""将统一提示输入转换为消息工厂。
Args:
prompt: 原始提示输入。
Returns:
MessageFactory: 惰性构建消息列表的工厂函数。
"""
if isinstance(prompt, str):
def build_messages(_: BaseClient) -> List[Message]:
"""构建单条用户消息。"""
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
return [message_builder.build()]
return build_messages
def build_messages(_: BaseClient) -> List[Message]:
"""构建多消息对话输入。"""
return [_build_message_from_dict(raw_message) for raw_message in prompt]
return build_messages
async def generate(request: LLMServiceRequest) -> LLMServiceResult:
"""执行统一的 LLM 服务请求。
Args:
request: 服务层统一请求对象。
Returns:
LLMServiceResult: 统一响应对象;失败时 `success=False`。
"""
llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type)
if request.message_factory is not None:
active_message_factory = request.message_factory
else:
prompt = request.prompt
if prompt is None:
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
active_message_factory = _build_prompt_message_factory(prompt)
try:
generation_result = await llm_client.generate_response_with_messages(
message_factory=active_message_factory,
options=LLMGenerationOptions(
temperature=request.temperature,
max_tokens=request.max_tokens,
tool_options=request.tool_options,
response_format=request.response_format,
interrupt_flag=request.interrupt_flag,
),
)
return LLMServiceResult.from_response_result(generation_result)
except Exception as exc:
error_message = f"生成内容时出错: {exc}"
logger.error(f"[LLMService] {error_message}")
return LLMServiceResult.from_error(error_message, str(exc))

View File

@@ -13,6 +13,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.config.model_configs import APIProvider
from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url
from src.webui.dependencies import require_auth
from src.webui.utils.network_security import validate_public_url
@@ -35,8 +37,8 @@ MODEL_FETCHER_CONFIG = {
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
return url.rstrip("/") if url else ""
"""规范化 URL去掉尾部斜杠"""
return normalize_openai_base_url(url) if url else ""
def _parse_openai_response(data: Dict) -> List[Dict]:
@@ -89,19 +91,30 @@ async def _fetch_models_from_provider(
endpoint: str,
parser: str,
client_type: str = "openai",
auth_type: str = "bearer",
auth_header_name: str = "Authorization",
auth_header_prefix: str = "Bearer",
auth_query_name: str = "api_key",
default_headers: Optional[Dict[str, str]] = None,
default_query: Optional[Dict[str, str]] = None,
) -> List[Dict]:
"""
从提供商 API 获取模型列表
"""从提供商 API 获取模型列表。
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型
client_type: 客户端类型
auth_type: OpenAI 兼容接口的鉴权方式。
auth_header_name: Header 鉴权时使用的请求头名称。
auth_header_prefix: Header 鉴权时使用的请求头前缀。
auth_query_name: Query 鉴权时使用的查询参数名称。
default_headers: 默认附带的请求头。
default_query: 默认附带的查询参数。
Returns:
模型列表
List[Dict]: 解析后的模型列表
"""
try:
base_url = validate_public_url(_normalize_url(base_url))
@@ -118,8 +131,21 @@ async def _fetch_models_from_provider(
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
provider = APIProvider(
name="webui-openai-compatible-fetcher",
base_url=base_url,
api_key=api_key,
client_type="openai",
auth_type=auth_type,
auth_header_name=auth_header_name,
auth_header_prefix=auth_header_prefix,
auth_query_name=auth_query_name,
default_headers=default_headers or {},
default_query=default_query or {},
)
client_config = build_openai_compatible_client_config(provider)
headers.update(client_config.default_headers)
params.update(client_config.default_query)
try:
async with httpx.AsyncClient(timeout=30.0) as client:
@@ -186,10 +212,9 @@ async def get_provider_models(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
):
"""
获取指定提供商的可用模型列表
"""获取指定提供商的可用模型列表。
通过提供商名称查找配置,然后请求对应的模型列表端点
通过提供商名称查找配置,然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
@@ -205,13 +230,21 @@ async def get_provider_models(
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
endpoint=resolved_endpoint,
parser=parser,
client_type=client_type,
auth_type=provider_config.get("auth_type", "bearer"),
auth_header_name=provider_config.get("auth_header_name", "Authorization"),
auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"),
auth_query_name=provider_config.get("auth_query_name", "api_key"),
default_headers=provider_config.get("default_headers", {}),
default_query=provider_config.get("default_query", {}),
)
return {
@@ -229,16 +262,22 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"),
auth_header_name: str = Query("Authorization", description="Header 鉴权名称"),
auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"),
auth_query_name: str = Query("api_key", description="Query 鉴权参数名"),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
"""
"""通过 URL 直接获取模型列表。"""
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
auth_type=auth_type,
auth_header_name=auth_header_name,
auth_header_prefix=auth_header_prefix,
auth_query_name=auth_query_name,
)
return {