diff --git a/pyproject.toml b/pyproject.toml index f02add00..0ba96250 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "jieba>=0.42.1", "json-repair>=0.47.6", "maim-message>=0.6.2", - "maibot-plugin-sdk>=2.0.0", + "maibot-plugin-sdk>=2.1.0", "mcp", "msgpack>=1.1.2", "numpy>=2.2.6", diff --git a/src/chat/brain_chat/PFC/action_planner.py b/src/chat/brain_chat/PFC/action_planner.py index 94f68585..83bc28d9 100644 --- a/src/chat/brain_chat/PFC/action_planner.py +++ b/src/chat/brain_chat/PFC/action_planner.py @@ -2,8 +2,8 @@ import time from typing import Tuple, Optional # 增加了 Optional from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config import random from .chat_observer import ChatObserver from .pfc_utils import get_items_from_json @@ -109,8 +109,8 @@ class ActionPlanner: """行动规划器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.planner, + self.llm = LLMServiceClient( + task_name="planner", request_type="action_planning", ) self.personality_info = self._get_personality_prompt() @@ -398,7 +398,8 @@ class ActionPlanner: logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------") try: - content, _ = await self.llm.generate_response_async(prompt) + generation_result = await self.llm.generate_response(prompt) + content = generation_result.response logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}") # --- 初始行动规划解析 --- @@ -427,7 +428,8 @@ class ActionPlanner: f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------" ) try: - end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM + end_generation_result = await self.llm.generate_response(end_decision_prompt) + end_content = end_generation_result.response # 再次调用LLM logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}") # 解析结束决策的JSON diff --git a/src/chat/brain_chat/PFC/pfc.py b/src/chat/brain_chat/PFC/pfc.py index 5d051716..7d5fef84 100644 --- a/src/chat/brain_chat/PFC/pfc.py +++ b/src/chat/brain_chat/PFC/pfc.py @@ -1,7 +1,7 @@ from typing import List, Tuple, TYPE_CHECKING from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config import random from .chat_observer import ChatObserver from .pfc_utils import get_items_from_json @@ -43,7 +43,9 @@ class GoalAnalyzer: """对话目标分析器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal") + self.llm = LLMServiceClient( + task_name="planner", request_type="conversation_goal" + ) self.personality_info = self._get_personality_prompt() self.name = global_config.bot.nickname @@ -157,7 +159,8 @@ class GoalAnalyzer: logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}") try: - content, _ = await self.llm.generate_response_async(prompt) + generation_result = await self.llm.generate_response(prompt) + content = generation_result.response logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}") except Exception as e: logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}") @@ -271,7 +274,8 @@ class GoalAnalyzer: }}""" try: - content, _ = await self.llm.generate_response_async(prompt) + generation_result = await self.llm.generate_response(prompt) + content = generation_result.response logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}") # 尝试解析JSON diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index 67509bd5..f6adc718 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -3,8 +3,7 @@ from src.common.logger import get_logger # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config +from src.services.llm_service import LLMServiceClient from src.chat.knowledge import qa_manager logger = get_logger("knowledge_fetcher") @@ -14,7 +13,7 @@ class KnowledgeFetcher: """知识调取器""" def __init__(self, private_name: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.utils) + self.llm = LLMServiceClient(task_name="utils") self.private_name = private_name def _lpmm_get_knowledge(self, query: str) -> str: diff --git a/src/chat/brain_chat/PFC/reply_checker.py b/src/chat/brain_chat/PFC/reply_checker.py index c6304b30..37359e2f 100644 --- a/src/chat/brain_chat/PFC/reply_checker.py +++ b/src/chat/brain_chat/PFC/reply_checker.py @@ -2,8 +2,8 @@ import json import random from typing import Tuple, List, Dict, Any from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from .chat_observer import ChatObserver from maim_message import UserInfo @@ -14,7 +14,7 @@ class ReplyChecker: """回复检查器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check") + self.llm = LLMServiceClient(task_name="utils", request_type="reply_check") self.personality_info = self._get_personality_prompt() self.name = global_config.bot.nickname self.private_name = private_name @@ -137,7 +137,8 @@ class ReplyChecker: 注意:请严格按照JSON格式输出,不要包含任何其他内容。""" try: - content, _ = await self.llm.generate_response_async(prompt) + generation_result = await self.llm.generate_response(prompt) + content = generation_result.response logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}") # 清理内容,尝试提取JSON部分 diff --git a/src/chat/brain_chat/PFC/reply_generator.py b/src/chat/brain_chat/PFC/reply_generator.py index 95853e26..6cece33d 100644 --- a/src/chat/brain_chat/PFC/reply_generator.py +++ b/src/chat/brain_chat/PFC/reply_generator.py @@ -1,7 +1,7 @@ from typing import Tuple, List, Dict, Any from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config import random from .chat_observer import ChatObserver from .reply_checker import ReplyChecker @@ -87,8 +87,8 @@ class ReplyGenerator: """回复生成器""" def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.replyer, + self.llm = LLMServiceClient( + task_name="replyer", request_type="reply_generation", ) self.personality_info = self._get_personality_prompt() @@ -223,7 +223,8 @@ class ReplyGenerator: # --- 调用 LLM 生成 --- logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------") try: - content, _ = await self.llm.generate_response_async(prompt) + generation_result = await self.llm.generate_response(prompt) + content = generation_result.response logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}") # 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理 return content diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 709be8ee..f2a69cbe 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -17,9 +17,9 @@ from src.chat.utils.utils import get_chat_type_and_target_info from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.logger import get_logger from src.common.utils.utils_action import ActionUtils -from src.config.config import global_config, model_config +from src.config.config import global_config from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.llm_models.utils_model import LLMRequest +from src.services.llm_service import LLMServiceClient from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( @@ -43,8 +43,8 @@ class BrainPlanner: self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 - self.planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="planner" + self.planner_llm = LLMServiceClient( + task_name="planner", request_type="planner" ) # 用于动作规划 self.last_obs_time_mark = 0.0 @@ -412,7 +412,9 @@ class BrainPlanner: try: # 调用LLM llm_start = time.perf_counter() - llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) + generation_result = await self.planner_llm.generate_response(prompt=prompt) + llm_content = generation_result.response + reasoning_content = generation_result.reasoning llm_duration_ms = (time.perf_counter() - llm_start) * 1000 llm_reasoning = reasoning_content diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 1e8b5479..780049d2 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -17,8 +17,9 @@ from src.common.database.database_model import Images, ImageType from src.common.database.database import get_db_session, get_db_session_manual from src.common.utils.utils_image import ImageUtils from src.prompt.prompt_manager import prompt_manager -from src.config.config import config_manager, global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import config_manager, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions +from src.services.llm_service import LLMServiceClient logger = get_logger("emoji") @@ -38,8 +39,10 @@ def _ensure_directories() -> None: # TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法 -emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see") -emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") +emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see") +emoji_manager_emotion_judge_llm = LLMServiceClient( + task_name="utils", request_type="emoji" +) class EmojiManager: @@ -461,9 +464,11 @@ class EmojiManager: emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list)) emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template) - decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async( - emoji_replace_prompt, temperature=0.8, max_tokens=600 + decision_result = await emoji_manager_emotion_judge_llm.generate_response( + emoji_replace_prompt, + options=LLMGenerationOptions(temperature=0.8, max_tokens=600), ) + decision = decision_result.response logger.info(f"[决策] 结果: {decision}") # 解析决策结果 @@ -524,24 +529,36 @@ class EmojiManager: return False, target_emoji prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答" image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) - description, _ = await emoji_manager_vlm.generate_response_for_image( - prompt, image_base64, "jpg", temperature=0.5 + description_result = await emoji_manager_vlm.generate_response_for_image( + prompt, + image_base64, + "jpg", + options=LLMImageOptions(temperature=0.5), ) + description = description_result.response else: prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答" image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) - description, _ = await emoji_manager_vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.5 + description_result = await emoji_manager_vlm.generate_response_for_image( + prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.5), ) + description = description_result.response # 表情包审查 if global_config.emoji.content_filtration: filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration") filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt) filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template) - llm_response, _ = await emoji_manager_vlm.generate_response_for_image( - filtration_prompt, image_base64, image_format, temperature=0.3 + filtration_result = await emoji_manager_vlm.generate_response_for_image( + filtration_prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.3), ) + llm_response = filtration_result.response if "否" in llm_response: logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}") return False, target_emoji @@ -567,9 +584,11 @@ class EmojiManager: emotion_prompt_template.add_context("description", target_emoji.description) emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template) # 调用LLM生成情感标签 - emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async( - emotion_prompt, temperature=0.3, max_tokens=200 + emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response( + emotion_prompt, + options=LLMGenerationOptions(temperature=0.3, max_tokens=200), ) + emotion_result = emotion_generation_result.response # 解析情感标签结果 emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index ec545894..5edbf134 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -11,8 +11,9 @@ from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Images, ImageType from src.common.data_models.image_data_model import MaiImage -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMImageOptions +from src.services.llm_service import LLMServiceClient install(extra_lines=3) @@ -27,7 +28,7 @@ def _ensure_image_dir_exists(): IMAGE_DIR.mkdir(parents=True, exist_ok=True) -vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") +vlm = LLMServiceClient(task_name="vlm", request_type="image") class ImageManager: @@ -260,7 +261,13 @@ class ImageManager: prompt = global_config.personality.visual_style image_base64 = base64.b64encode(image_bytes).decode("utf-8") - description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4) + generation_result = await vlm.generate_response_for_image( + prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.4), + ) + description = generation_result.response if not description: logger.warning("VLM未能生成图片描述") return description or "" diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 6d041af6..026c72ee 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -139,14 +139,14 @@ class EmbeddingStore: asyncio.set_event_loop(loop) try: - # 创建新的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config + # 创建新的服务层实例 + from src.services.llm_service import LLMServiceClient - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + llm = LLMServiceClient(task_name="embedding", request_type="embedding") # 使用新的事件循环运行异步方法 - embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + embedding_result = loop.run_until_complete(llm.embed_text(s)) + embedding = embedding_result.embedding if embedding and len(embedding) > 0: return embedding @@ -195,13 +195,12 @@ class EmbeddingStore: start_idx, chunk_strs = chunk_data chunk_results = [] - # 为每个线程创建独立的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config + # 为每个线程创建独立的服务层实例 + from src.services.llm_service import LLMServiceClient try: - # 创建线程专用的LLM实例 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + # 创建线程专用的服务层实例 + llm = LLMServiceClient(task_name="embedding", request_type="embedding") for i, s in enumerate(chunk_strs): try: @@ -209,7 +208,8 @@ class EmbeddingStore: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - embedding = loop.run_until_complete(llm.get_embedding(s)) + embedding_result = loop.run_until_complete(llm.embed_text(s)) + embedding = embedding_result.embedding finally: loop.close() diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index d7413bdc..91ba83dc 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,18 +1,27 @@ import asyncio import json import time -from typing import List, Union +from typing import Dict, List, Tuple, Union -from .global_logger import logger -from . import prompt_template -from . import INVALID_ENTITY -from src.llm_models.utils_model import LLMRequest from json_repair import repair_json +from src.services.llm_service import LLMServiceClient -def _extract_json_from_text(text: str): +from . import INVALID_ENTITY +from . import prompt_template +from .global_logger import logger + + +def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]: # sourcery skip: assign-if-exp, extract-method - """从文本中提取JSON数据的高容错方法""" + """从文本中提取 JSON 数据。 + + Args: + text: 原始模型输出文本。 + + Returns: + List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。 + """ if text is None: logger.error("输入文本为None") return [] @@ -46,20 +55,30 @@ def _extract_json_from_text(text: str): return [] -def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: +def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]: # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression - """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" + """对单段文本执行实体提取。 + + Args: + llm_req: LLM 服务门面实例。 + paragraph: 待提取实体的原始段落文本。 + + Returns: + List[str]: 提取出的实体列表。 + """ entity_extract_context = prompt_template.build_entity_extract_context(paragraph) # 使用 asyncio.run 来运行异步方法 try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop) - response, _ = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop) + generation_result = future.result() + response = generation_result.response except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context)) + generation_result = asyncio.run(llm_req.generate_response(entity_extract_context)) + response = generation_result.response # 添加调试日志 logger.debug(f"LLM返回的原始响应: {response}") @@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: - """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" +def _rdf_triple_extract( + llm_req: LLMServiceClient, + paragraph: str, + entities: List[str], +) -> List[List[str]]: + """对单段文本执行 RDF 三元组提取。 + + Args: + llm_req: LLM 服务门面实例。 + paragraph: 待提取的原始段落文本。 + entities: 已识别出的实体列表。 + + Returns: + List[List[str]]: 提取出的三元组列表。 + """ rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) @@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop) - response, _ = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop) + generation_result = future.result() + response = generation_result.response except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context)) + generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context)) + response = generation_result.response # 添加调试日志 logger.debug(f"RDF LLM返回的原始响应: {response}") @@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> def info_extract_from_str( - llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str -) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: + llm_client_for_ner: LLMServiceClient, + llm_client_for_rdf: LLMServiceClient, + paragraph: str, +) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: + """从文本中提取实体与三元组信息。 + + Args: + llm_client_for_ner: 实体提取使用的 LLM 服务门面。 + llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。 + paragraph: 原始段落文本。 + + Returns: + Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回 + ``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。 + """ try_count = 0 while True: try: @@ -176,17 +223,30 @@ def info_extract_from_str( class IEProcess: - """ - 信息抽取处理器类,提供更方便的批次处理接口。 - """ + """信息抽取处理器。""" - def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None): + def __init__( + self, + llm_ner: LLMServiceClient, + llm_rdf: LLMServiceClient | None = None, + ) -> None: + """初始化信息抽取处理器。 + + Args: + llm_ner: 实体提取使用的 LLM 服务门面。 + llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。 + """ self.llm_ner = llm_ner self.llm_rdf = llm_rdf or llm_ner - async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]: - """ - 异步处理多个段落。 + async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]: + """异步处理多个段落。 + + Args: + paragraphs: 待处理的段落列表。 + + Returns: + List[Dict[str, object]]: 每个成功段落对应的抽取结果。 """ from .utils.hash import get_sha256 diff --git a/src/chat/knowledge/lpmm_ops.py b/src/chat/knowledge/lpmm_ops.py index acaac4ca..2fb72709 100644 --- a/src/chat/knowledge/lpmm_ops.py +++ b/src/chat/knowledge/lpmm_ops.py @@ -91,13 +91,14 @@ class LPMMOperations: # 2. 实体与三元组抽取 (内部调用大模型) from src.chat.knowledge.ie_process import IEProcess - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config + from src.services.llm_service import LLMServiceClient - llm_ner = LLMRequest( - model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" + llm_ner = LLMServiceClient( + task_name="lpmm_entity_extract", request_type="lpmm.entity_extract" + ) + llm_rdf = LLMServiceClient( + task_name="lpmm_rdf_build", request_type="lpmm.rdf_build" ) - llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") ie_process = IEProcess(llm_ner, llm_rdf) logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 0d81c18f..f94997e1 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -149,7 +149,7 @@ class ActionModifier: random.shuffle(actions_to_check) for action_name, action_info in actions_to_check: - activation_type = action_info.activation_type or action_info.focus_activation_type + activation_type = action_info.activation_type if activation_type == ActionActivationType.ALWAYS: continue # 总是激活,无需处理 diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index b21efa6b..a0e6f898 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -19,9 +19,9 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.logger import get_logger -from src.config.config import global_config, model_config +from src.config.config import global_config from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.llm_models.utils_model import LLMRequest +from src.services.llm_service import LLMServiceClient from src.person_info.person_info import Person from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager @@ -46,8 +46,8 @@ class ActionPlanner: self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 - self.planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="planner" + self.planner_llm = LLMServiceClient( + task_name="planner", request_type="planner" ) # 用于动作规划 self.last_obs_time_mark = 0.0 @@ -725,7 +725,9 @@ class ActionPlanner: try: # 调用LLM llm_start = time.perf_counter() - llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) + generation_result = await self.planner_llm.generate_response(prompt=prompt) + llm_content = generation_result.response + reasoning_content = generation_result.reasoning llm_duration_ms = (time.perf_counter() - llm_start) * 1000 llm_reasoning = reasoning_content diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 4ffa14a7..db007d3d 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -10,8 +10,8 @@ from datetime import datetime from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo from src.common.data_models.mai_message_data_model import MaiMessage @@ -56,7 +56,9 @@ class DefaultReplyer: chat_stream: 当前绑定的聊天会话。 request_type: LLM 请求类型标识。 """ - self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) + self.express_model = LLMServiceClient( + task_name="replyer", request_type=request_type + ) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) @@ -1158,9 +1160,11 @@ class DefaultReplyer: # else: # logger.debug(f"\nreplyer_Prompt:{prompt}\n") - content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( - prompt - ) + generation_result = await self.express_model.generate_response(prompt) + content = generation_result.response + reasoning_content = generation_result.reasoning + model_name = generation_result.model_name + tool_calls = generation_result.tool_calls # 移除 content 前后的换行符和空格 content = content.strip() @@ -1200,11 +1204,15 @@ class DefaultReplyer: template_prompt.add_context("sender", sender) template_prompt.add_context("target_message", target) prompt = await prompt_manager.render_prompt(template_prompt) - _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( - prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[search_knowledge_tool.get_tool_definition()], + generation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="tool_use", + request_type="replyer.lpmm_knowledge", + prompt=prompt, + tool_options=[search_knowledge_tool.get_tool_definition()], + ) ) + tool_calls = generation_result.completion.tool_calls # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index c125a42f..ccbb0086 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -9,8 +9,8 @@ from datetime import datetime from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo from src.common.data_models.mai_message_data_model import MaiMessage @@ -52,7 +52,9 @@ class PrivateReplyer: chat_stream: 当前绑定的聊天会话。 request_type: LLM 请求类型标识。 """ - self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) + self.express_model = LLMServiceClient( + task_name="replyer", request_type=request_type + ) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) # self.memory_activator = MemoryActivator() @@ -997,9 +999,11 @@ class PrivateReplyer: else: logger.debug(f"\n{prompt}\n") - content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( - prompt - ) + generation_result = await self.express_model.generate_response(prompt) + content = generation_result.response + reasoning_content = generation_result.reasoning + model_name = generation_result.model_name + tool_calls = generation_result.tool_calls content = content.strip() diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py index aa99fce8..bfe7ce96 100644 --- a/src/chat/tool_executor.py +++ b/src/chat/tool_executor.py @@ -4,16 +4,18 @@ 自动判断并执行相应的工具,返回结构化的工具执行结果。 """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast import hashlib import time from src.common.logger import get_logger -from src.config.config import global_config, model_config +from src.config.config import global_config from src.core.announcement_manager import global_announcement_manager from src.llm_models.payload_content import ToolCall -from src.llm_models.utils_model import LLMRequest +from src.llm_models.payload_content.tool_option import ToolDefinitionInput +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.plugin_runtime.component_query import component_query_service from src.prompt.prompt_manager import prompt_manager @@ -33,7 +35,9 @@ class ToolExecutor: self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id) self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" - self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") + self.llm_model = LLMServiceClient( + task_name="tool_use", request_type="tool_executor" + ) self.enable_cache = enable_cache self.cache_ttl = cache_ttl @@ -69,9 +73,11 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}开始LLM工具调用分析") - response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( - prompt=prompt, tools=tools, raise_when_empty=False + generation_result = await self.llm_model.generate_response( + prompt=prompt, + options=LLMGenerationOptions(tool_options=tools, raise_when_empty=False), ) + tool_calls = generation_result.tool_calls tool_results, used_tools = await self.execute_tool_calls(tool_calls) @@ -85,11 +91,15 @@ class ToolExecutor: return tool_results, used_tools, prompt return tool_results, [], "" - def _get_tool_definitions(self) -> List[Dict[str, Any]]: + def _get_tool_definitions(self) -> List[ToolDefinitionInput]: """获取 LLM 可用的工具定义列表""" all_tools = component_query_service.get_llm_available_tools() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools] + return [ + cast(ToolDefinitionInput, info.get_llm_definition()) + for name, info in all_tools.items() + if name not in user_disabled_tools + ] async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: """执行工具调用列表""" diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 07cec0b4..aa14e790 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,8 +13,8 @@ import jieba from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from src.person_info.person_info import Person from .typo_generator import ChineseTypoGenerator @@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" - # 每次都创建新的LLMRequest实例以避免事件循环冲突 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) + # 每次都创建新的服务层实例以避免事件循环冲突 + llm = LLMServiceClient(task_name="embedding", request_type=request_type) try: - embedding, _ = await llm.get_embedding(text) + embedding_result = await llm.embed_text(text) + embedding = embedding_result.embedding except Exception as e: logger.error(f"获取embedding失败: {str(e)}") embedding = None diff --git a/src/common/data_models/llm_service_data_models.py b/src/common/data_models/llm_service_data_models.py new file mode 100644 index 00000000..15b530ca --- /dev/null +++ b/src/common/data_models/llm_service_data_models.py @@ -0,0 +1,187 @@ +"""LLM 服务层与编排层共享数据模型。 + +该模块集中定义 LLM 服务层与底层编排器共同使用的请求、选项与结果对象, +用于替代散落在各层之间的复杂元组返回值。 +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeAlias + +import asyncio + +from src.common.data_models import BaseDataModel +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput + +if TYPE_CHECKING: + from src.llm_models.model_client.base_client import BaseClient + from src.llm_models.payload_content.message import Message + + +PromptMessage: TypeAlias = Dict[str, Any] +"""统一的原始提示消息结构。""" + +PromptInput: TypeAlias = str | List[PromptMessage] +"""统一的提示输入类型。""" + +MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]] +"""统一的消息工厂类型。""" + + +@dataclass(slots=True) +class LLMServiceRequest(BaseDataModel): + """LLM 服务层统一请求对象。""" + + task_name: str + request_type: str + prompt: PromptInput | None = None + message_factory: MessageFactory | None = None + tool_options: List[ToolDefinitionInput] | None = None + temperature: float | None = None + max_tokens: int | None = None + response_format: RespFormat | None = None + interrupt_flag: asyncio.Event | None = None + + def __post_init__(self) -> None: + """校验请求对象的必要字段。 + + Raises: + ValueError: 当 `task_name` 为空,或 `prompt` 与 `message_factory` + 的组合非法时抛出。 + """ + self.task_name = self.task_name.strip() + if not self.task_name: + raise ValueError("`task_name` 不能为空") + has_prompt = self.prompt is not None + has_message_factory = self.message_factory is not None + if has_prompt == has_message_factory: + raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个") + + +@dataclass(slots=True) +class LLMResponseResult(BaseDataModel): + """单次 LLM 响应结果。""" + + response: str = field(default_factory=str) + reasoning: str = field(default_factory=str) + model_name: str = field(default_factory=str) + tool_calls: List[ToolCall] | None = None + + +@dataclass(slots=True) +class LLMServiceResult(BaseDataModel): + """LLM 服务层统一响应对象。""" + + success: bool = False + completion: LLMResponseResult = field(default_factory=LLMResponseResult) + error: str | None = None + + @classmethod + def from_response_result(cls, completion: LLMResponseResult) -> "LLMServiceResult": + """从单次 LLM 响应结果构建服务响应。 + + Args: + completion: 单次 LLM 响应结果。 + + Returns: + LLMServiceResult: 标记为成功的服务响应对象。 + """ + return cls( + success=True, + completion=completion, + error=None, + ) + + @classmethod + def from_error(cls, error_message: str, error_detail: str | None = None) -> "LLMServiceResult": + """构建失败的服务响应对象。 + + Args: + error_message: 对上层展示的错误消息。 + error_detail: 底层错误详情。 + + Returns: + LLMServiceResult: 标记为失败的服务响应对象。 + """ + return cls( + success=False, + completion=LLMResponseResult(response=error_message), + error=error_detail or error_message, + ) + + def to_capability_payload(self) -> Dict[str, Any]: + """转换为插件能力层可直接返回的结构。 + + Returns: + Dict[str, Any]: 标准化后的能力返回值。 + """ + payload: Dict[str, Any] = { + "success": self.success, + "response": self.completion.response, + "reasoning": self.completion.reasoning, + "model_name": self.completion.model_name, + } + if self.completion.tool_calls is not None: + payload["tool_calls"] = [ + { + "id": tool_call.call_id, + "function": { + "name": tool_call.func_name, + "arguments": tool_call.args or {}, + }, + } + for tool_call in self.completion.tool_calls + ] + if self.error: + payload["error"] = self.error + return payload + + +@dataclass(slots=True) +class LLMGenerationOptions(BaseDataModel): + """LLM 文本生成选项。""" + + temperature: float | None = None + max_tokens: int | None = None + tool_options: List[ToolDefinitionInput] | None = None + response_format: RespFormat | None = None + interrupt_flag: asyncio.Event | None = None + raise_when_empty: bool = True + + +@dataclass(slots=True) +class LLMImageOptions(BaseDataModel): + """LLM 图像理解选项。""" + + temperature: float | None = None + max_tokens: int | None = None + interrupt_flag: asyncio.Event | None = None + + +@dataclass(slots=True) +class LLMAudioTranscriptionResult(BaseDataModel): + """LLM 音频转写结果。""" + + text: str | None = None + + +@dataclass(slots=True) +class LLMEmbeddingResult(BaseDataModel): + """LLM 向量生成结果。""" + + embedding: List[float] = field(default_factory=list) + model_name: str = field(default_factory=str) + + +__all__ = [ + "LLMAudioTranscriptionResult", + "LLMEmbeddingResult", + "LLMGenerationOptions", + "LLMImageOptions", + "LLMResponseResult", + "LLMServiceRequest", + "LLMServiceResult", + "MessageFactory", + "PromptInput", + "PromptMessage", +] diff --git a/src/common/utils/utils_voice.py b/src/common/utils/utils_voice.py index 651febf0..cef30119 100644 --- a/src/common/utils/utils_voice.py +++ b/src/common/utils/utils_voice.py @@ -4,16 +4,15 @@ from typing import Optional import base64 from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient install(extra_lines=3) logger = get_logger("voice_utils") -# TODO: 在LLMRequest重构后修改这里 -asr_model = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio") +asr_model = LLMServiceClient(task_name="voice", request_type="audio") async def get_voice_text(voice_bytes: bytes) -> Optional[str]: @@ -30,7 +29,8 @@ async def get_voice_text(voice_bytes: bytes) -> Optional[str]: return None try: voice_base64 = base64.b64encode(voice_bytes).decode("utf-8") - text = await asr_model.generate_response_for_voice(voice_base64) + transcription_result = await asr_model.transcribe_audio(voice_base64) + text = transcription_result.text if not text: logger.warning("语音转文字结果为空") diff --git a/src/config/model_configs.py b/src/config/model_configs.py index 6f10ff83..3f0feb54 100644 --- a/src/config/model_configs.py +++ b/src/config/model_configs.py @@ -1,7 +1,35 @@ +from enum import Enum from typing import Any -from .config_base import ConfigBase, Field from src.common.i18n import t +from .config_base import ConfigBase, Field + + +class OpenAICompatibleAuthType(str, Enum): + """OpenAI 兼容接口的鉴权方式。""" + + BEARER = "bearer" + HEADER = "header" + QUERY = "query" + NONE = "none" + + +class ReasoningParseMode(str, Enum): + """推理内容解析策略。""" + + AUTO = "auto" + NATIVE = "native" + THINK_TAG = "think_tag" + NONE = "none" + + +class ToolArgumentParseMode(str, Enum): + """工具调用参数的解析策略。""" + + AUTO = "auto" + STRICT = "strict" + REPAIR = "repair" + DOUBLE_DECODE = "double_decode" class APIProvider(ConfigBase): @@ -33,7 +61,7 @@ class APIProvider(ConfigBase): "x-icon": "key", }, ) - """API密钥""" + """API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。""" client_type: str = Field( default="openai", @@ -44,6 +72,105 @@ class APIProvider(ConfigBase): ) """客户端类型 (可选: openai/google, 默认为openai)""" + auth_type: str = Field( + default=OpenAICompatibleAuthType.BEARER.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "shield", + }, + ) + """OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。""" + + auth_header_name: str = Field( + default="Authorization", + json_schema_extra={ + "x-widget": "input", + "x-icon": "header", + }, + ) + """当 `auth_type` 为 `header` 时使用的请求头名称。""" + + auth_header_prefix: str = Field( + default="Bearer", + json_schema_extra={ + "x-widget": "input", + "x-icon": "shield-check", + }, + ) + """当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。""" + + auth_query_name: str = Field( + default="api_key", + json_schema_extra={ + "x-widget": "input", + "x-icon": "link", + }, + ) + """当 `auth_type` 为 `query` 时使用的查询参数名称。""" + + default_headers: dict[str, str] = Field( + default_factory=dict, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "header", + }, + ) + """所有请求默认附带的 HTTP Header。""" + + default_query: dict[str, str] = Field( + default_factory=dict, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "list-filter", + }, + ) + """所有请求默认附带的查询参数。""" + + organization: str | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "building-2", + }, + ) + """OpenAI 官方接口可选的 `organization`。""" + + project: str | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "folder-kanban", + }, + ) + """OpenAI 官方接口可选的 `project`。""" + + model_list_endpoint: str = Field( + default="/models", + json_schema_extra={ + "x-widget": "input", + "x-icon": "list", + }, + ) + """模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。""" + + reasoning_parse_mode: str = Field( + default=ReasoningParseMode.AUTO.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "brain", + }, + ) + """推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。""" + + tool_argument_parse_mode: str = Field( + default=ToolArgumentParseMode.AUTO.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "braces", + }, + ) + """工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。""" + max_retry: int = Field( default=2, ge=0, @@ -76,15 +203,26 @@ class APIProvider(ConfigBase): ) """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" - def model_post_init(self, context: Any = None): - """确保api_key在repr中不被显示""" - if not self.api_key: + def model_post_init(self, context: Any = None) -> None: + """执行 API 提供商配置的后置校验。 + + Args: + context: Pydantic 传入的上下文对象。 + + Raises: + ValueError: 当配置项缺失或组合不合法时抛出。 + """ + if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key: raise ValueError(t("config.api_key_empty")) if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url raise ValueError(t("config.api_base_url_empty")) if not self.name: raise ValueError(t("config.api_provider_name_empty")) - return super().model_post_init(context) + if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip(): + raise ValueError("当 auth_type=header 时,auth_header_name 不能为空") + if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip(): + raise ValueError("当 auth_type=query 时,auth_query_name 不能为空") + super().model_post_init(context) class ModelInfo(ConfigBase): diff --git a/src/core/types.py b/src/core/types.py index 535352f3..aff857a3 100644 --- a/src/core/types.py +++ b/src/core/types.py @@ -1,12 +1,13 @@ -import copy -import warnings from dataclasses import dataclass, field, fields from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional + +import copy +import warnings + from maim_message import Seg -from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType -from src.llm_models.payload_content.tool_option import ToolCall as ToolCall +from src.llm_models.payload_content.tool_option import ToolCall # from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType # from src.common.data_models.message_data_model import ReplyContent as ReplyContent # from src.common.data_models.message_data_model import ForwardNode as ForwardNode @@ -15,49 +16,42 @@ from src.llm_models.payload_content.tool_option import ToolCall as ToolCall # 组件类型枚举 class ComponentType(Enum): - """组件类型枚举""" + """Host 内部使用的组件类型枚举。""" ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 - TOOL = "tool" # 服务组件(预留) - SCHEDULER = "scheduler" # 定时任务组件(预留) - EVENT_HANDLER = "event_handler" # 事件处理组件(预留) + TOOL = "tool" # 工具组件 def __str__(self) -> str: + """返回枚举值字符串。 + + Returns: + str: 当前组件类型对应的字符串值。 + """ return self.value # 动作激活类型枚举 class ActionActivationType(Enum): - """动作激活类型枚举""" + """动作激活类型枚举。""" NEVER = "never" # 从不激活(默认关闭) ALWAYS = "always" # 默认参与到planner RANDOM = "random" # 随机启用action到planner KEYWORD = "keyword" # 关键词触发启用action到planner - def __str__(self): - return self.value + def __str__(self) -> str: + """返回枚举值字符串。 - -# 聊天模式枚举 -class ChatMode(Enum): - """聊天模式枚举""" - - FOCUS = "focus" # Focus聊天模式 - NORMAL = "normal" # Normal聊天模式 - PRIORITY = "priority" # 优先级聊天模式 - ALL = "all" # 所有聊天模式 - - def __str__(self): + Returns: + str: 当前激活类型对应的字符串值。 + """ return self.value # 事件类型枚举 class EventType(Enum): - """ - 事件类型枚举类 - """ + """事件类型枚举。""" ON_START = "on_start" # 启动事件,用于调用按时任务 ON_STOP = "on_stop" # 停止事件,用于调用按时任务 @@ -72,185 +66,96 @@ class EventType(Enum): UNKNOWN = "unknown" # 未知事件类型 def __str__(self) -> str: + """返回枚举值字符串。 + + Returns: + str: 当前事件类型对应的字符串值。 + """ return self.value -@dataclass -class PythonDependency: - """Python包依赖信息""" - - package_name: str # 包名称 - version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本 - optional: bool = False # 是否为可选依赖 - description: str = "" # 依赖描述 - install_name: str = "" # 安装时的包名(如果与import名不同) - - def __post_init__(self): - if not self.install_name: - self.install_name = self.package_name - - def get_pip_requirement(self) -> str: - """获取pip安装格式的依赖字符串""" - if self.version: - return f"{self.install_name}{self.version}" - return self.install_name - - -@dataclass +@dataclass(slots=True) class ComponentInfo: - """组件信息""" + """Host 内部使用的组件信息快照。""" - name: str # 组件名称 - component_type: ComponentType # 组件类型 - description: str = "" # 组件描述 - enabled: bool = True # 是否启用 - plugin_name: str = "" # 所属插件名称 - is_built_in: bool = False # 是否为内置组件 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + name: str + """组件名称。""" - def __post_init__(self): - if self.metadata is None: - self.metadata = {} + description: str = "" + """组件描述。""" + + enabled: bool = True + """组件是否启用。""" + + plugin_name: str = "" + """所属插件 ID。""" + + component_type: ComponentType = field(init=False) + """组件类型。""" -@dataclass +@dataclass(slots=True) class ActionInfo(ComponentInfo): - """动作组件信息""" + """供 Planner 与回复链使用的动作信息快照。""" action_parameters: Dict[str, str] = field( default_factory=dict ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} action_require: List[str] = field(default_factory=list) # 动作需求说明 associated_types: List[str] = field(default_factory=list) # 关联的消息类型 - # 激活类型相关 - focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用 - normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用 activation_type: ActionActivationType = ActionActivationType.ALWAYS random_activation_probability: float = 0.0 activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表 keyword_case_sensitive: bool = False - # 模式和并行设置 parallel_action: bool = False + component_type: ComponentType = field(init=False, default=ComponentType.ACTION) + """组件类型。""" - def __post_init__(self): - super().__post_init__() - if self.activation_keywords is None: - self.activation_keywords = [] - if self.action_parameters is None: - self.action_parameters = {} - if self.action_require is None: - self.action_require = [] - if self.associated_types is None: - self.associated_types = [] - self.component_type = ComponentType.ACTION + def __post_init__(self) -> None: + """归一化动作快照中的集合字段。""" + self.action_parameters = dict(self.action_parameters or {}) + self.action_require = list(self.action_require or []) + self.associated_types = list(self.associated_types or []) + self.activation_keywords = list(self.activation_keywords or []) -@dataclass +@dataclass(slots=True) class CommandInfo(ComponentInfo): - """命令组件信息""" + """供命令处理链使用的命令信息快照。""" - command_pattern: str = "" # 命令匹配模式(正则表达式) - - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.COMMAND + component_type: ComponentType = field(init=False, default=ComponentType.COMMAND) + """组件类型。""" -@dataclass +@dataclass(slots=True) class ToolInfo(ComponentInfo): - """工具组件信息""" + """供工具执行链使用的工具信息快照。""" - tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field( - default_factory=list - ) # 工具参数定义 - tool_description: str = "" # 工具描述 + parameters_schema: Dict[str, Any] | None = None + """对象级工具参数 Schema。""" - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.TOOL + component_type: ComponentType = field(init=False, default=ComponentType.TOOL) + """组件类型。""" - def get_llm_definition(self) -> dict: - """生成 LLM function-calling 所需的工具定义""" - return { + def get_llm_definition(self) -> Dict[str, Any]: + """生成供 LLM 使用的规范化工具定义。 + + Returns: + Dict[str, Any]: 统一工具定义字典。 + """ + definition: Dict[str, Any] = { "name": self.name, - "description": self.tool_description, - "parameters": self.tool_parameters, + "description": self.description, } + if self.parameters_schema is not None: + definition["parameters_schema"] = copy.deepcopy(self.parameters_schema) + return definition -@dataclass -class EventHandlerInfo(ComponentInfo): - """事件处理器组件信息""" - - event_type: EventType | str = EventType.ON_MESSAGE # 监听事件类型 - intercept_message: bool = False # 是否拦截消息处理(默认不拦截) - weight: int = 0 # 事件处理器权重,决定执行顺序 - - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.EVENT_HANDLER - - -@dataclass -class PluginInfo: - """插件信息""" - - display_name: str # 插件显示名称 - name: str # 插件名称 - description: str # 插件描述 - version: str = "1.0.0" # 插件版本 - author: str = "" # 插件作者 - enabled: bool = True # 是否启用 - is_built_in: bool = False # 是否为内置插件 - components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表 - dependencies: List[str] = field(default_factory=list) # 依赖的其他插件 - python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖 - config_file: str = "" # 配置文件路径 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 - # 新增:manifest相关信息 - manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据 - license: str = "" # 插件许可证 - homepage_url: str = "" # 插件主页 - repository_url: str = "" # 插件仓库地址 - keywords: List[str] = field(default_factory=list) # 插件关键词 - categories: List[str] = field(default_factory=list) # 插件分类 - min_host_version: str = "" # 最低主机版本要求 - max_host_version: str = "" # 最高主机版本要求 - - def __post_init__(self): - if self.components is None: - self.components = [] - if self.dependencies is None: - self.dependencies = [] - if self.python_dependencies is None: - self.python_dependencies = [] - if self.metadata is None: - self.metadata = {} - if self.manifest_data is None: - self.manifest_data = {} - if self.keywords is None: - self.keywords = [] - if self.categories is None: - self.categories = [] - - def get_missing_packages(self) -> List[PythonDependency]: - """检查缺失的Python包""" - missing = [] - for dep in self.python_dependencies: - try: - __import__(dep.package_name) - except ImportError: - if not dep.optional: - missing.append(dep) - return missing - - def get_pip_requirements(self) -> List[str]: - """获取所有pip安装格式的依赖""" - return [dep.get_pip_requirement() for dep in self.python_dependencies] - - -@dataclass +@dataclass(slots=True) class ModifyFlag: + """消息修改标记集合。""" + modify_message_segments: bool = False modify_plain_text: bool = False modify_llm_prompt: bool = False @@ -258,9 +163,9 @@ class ModifyFlag: modify_llm_response_reasoning: bool = False -@dataclass +@dataclass(slots=True) class MaiMessages: - """MaiM插件消息""" + """核心事件系统使用的统一消息模型。""" message_segments: List[Seg] = field(default_factory=list) """消息段列表,支持多段消息""" @@ -306,11 +211,17 @@ class MaiMessages: _modify_flags: ModifyFlag = field(default_factory=ModifyFlag) - def __post_init__(self): + def __post_init__(self) -> None: + """归一化消息段列表。""" if self.message_segments is None: self.message_segments = [] - def deepcopy(self): + def deepcopy(self) -> "MaiMessages": + """深拷贝当前消息对象。 + + Returns: + MaiMessages: 深拷贝后的消息对象。 + """ return copy.deepcopy(self) def to_transport_dict(self) -> Dict[str, Any]: @@ -347,6 +258,14 @@ class MaiMessages: @staticmethod def _serialize_transport_value(value: Any) -> Any: + """递归序列化字段值为可传输结构。 + + Args: + value: 任意字段值。 + + Returns: + Any: 可用于 IPC 传输的纯 Python 值。 + """ if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, Enum): @@ -367,13 +286,22 @@ class MaiMessages: @staticmethod def _deserialize_transport_field(field_name: str, value: Any) -> Any: + """反序列化特定字段的传输值。 + + Args: + field_name: 字段名称。 + value: 传输层返回的字段值。 + + Returns: + Any: 反序列化后的字段值。 + """ if field_name == "message_segments" and isinstance(value, list): deserialized_segments: List[Seg] = [] for segment in value: if isinstance(segment, Seg): deserialized_segments.append(segment) elif isinstance(segment, dict) and "type" in segment: - deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data"))) + deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data", ""))) return deserialized_segments if field_name == "llm_response_tool_call" and isinstance(value, list): @@ -393,15 +321,15 @@ class MaiMessages: return value - def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False): - """ - 修改消息段列表 + def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False) -> None: + """修改消息段列表。 Warning: - 在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致 + 在生成了 ``plain_text`` 的情况下调用此方法,可能会导致文本与消息段不一致。 Args: - new_segments (List[Seg]): 新的消息段列表 + new_segments: 新的消息段列表。 + suppress_warning: 是否抑制潜在不一致警告。 """ if self.plain_text and not suppress_warning: warnings.warn( @@ -412,15 +340,15 @@ class MaiMessages: self.message_segments = new_segments self._modify_flags.modify_message_segments = True - def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False): - """ - 修改LLM提示词 + def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False) -> None: + """修改 LLM 提示词。 Warning: - 在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效 + 在没有生成 ``llm_prompt`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_prompt (str): 新的提示词内容 + new_prompt: 新的提示词内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if self.llm_prompt is None and not suppress_warning: warnings.warn( @@ -431,15 +359,15 @@ class MaiMessages: self.llm_prompt = new_prompt self._modify_flags.modify_llm_prompt = True - def modify_plain_text(self, new_text: str, suppress_warning: bool = False): - """ - 修改生成的plain_text内容 + def modify_plain_text(self, new_text: str, suppress_warning: bool = False) -> None: + """修改生成的纯文本内容。 Warning: - 在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效 + 在未生成 ``plain_text`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_text (str): 新的纯文本内容 + new_text: 新的纯文本内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.plain_text and not suppress_warning: warnings.warn( @@ -450,15 +378,15 @@ class MaiMessages: self.plain_text = new_text self._modify_flags.modify_plain_text = True - def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False): - """ - 修改生成的llm_response_content内容 + def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False) -> None: + """修改生成的 LLM 响应正文。 Warning: - 在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效 + 在未生成 ``llm_response_content`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_content (str): 新的LLM响应内容 + new_content: 新的 LLM 响应内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.llm_response_content and not suppress_warning: warnings.warn( @@ -469,15 +397,15 @@ class MaiMessages: self.llm_response_content = new_content self._modify_flags.modify_llm_response_content = True - def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False): - """ - 修改生成的llm_response_reasoning内容 + def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False) -> None: + """修改生成的 LLM 推理内容。 Warning: - 在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效 + 在未生成 ``llm_response_reasoning`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_reasoning (str): 新的LLM响应推理内容 + new_reasoning: 新的 LLM 推理内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.llm_response_reasoning and not suppress_warning: warnings.warn( @@ -487,10 +415,3 @@ class MaiMessages: ) self.llm_response_reasoning = new_reasoning self._modify_flags.modify_llm_response_reasoning = True - - -@dataclass -class CustomEventHandlerResult: - message: str = "" - timestamp: float = 0.0 - extra_info: Optional[Dict] = None diff --git a/src/learners/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py index e5af1057..311d69e8 100644 --- a/src/learners/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -20,8 +20,8 @@ from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.logger import get_logger from src.config.config import global_config -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.manager.async_task_manager import AsyncTask logger = get_logger("expression_auto_check_task") @@ -76,7 +76,7 @@ def create_evaluation_prompt(situation: str, style: str) -> str: return prompt -judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") +judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check") async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]: @@ -94,10 +94,11 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str prompt = create_evaluation_prompt(situation, style) logger.debug(f"正在评估表达方式: situation={situation}, style={style}") - response, (reasoning, model_name, _) = await judge_llm.generate_response_async( - prompt=prompt, temperature=0.6, max_tokens=1024 + generation_result = await judge_llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions(temperature=0.6, max_tokens=1024), ) - + response = generation_result.response logger.debug(f"LLM响应: {response}") # 解析JSON响应 diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index b82ae1fa..34d2cb8b 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -7,8 +7,9 @@ import difflib import json import re -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.common.logger import get_logger from src.common.database.database_model import Expression @@ -26,10 +27,11 @@ if TYPE_CHECKING: logger = get_logger("expressor") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner") -summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary") -check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check") +express_learn_model = LLMServiceClient( + task_name="utils", request_type="expression.learner" +) +summary_model = LLMServiceClient(task_name="tool_use", request_type="expression.summary") +check_model = LLMServiceClient(task_name="tool_use", request_type="expression.check") class ExpressionLearner: @@ -74,7 +76,10 @@ class ExpressionLearner: # 调用 LLM 学习表达方式 try: - response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3) + generation_result = await express_learn_model.generate_response( + prompt, options=LLMGenerationOptions(temperature=0.3) + ) + response = generation_result.response except Exception as e: logger.error(f"学习表达方式失败,模型生成出错:{e}") return @@ -413,7 +418,10 @@ class ExpressionLearner: "只输出概括内容。" ) try: - summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2) + summary_result = await summary_model.generate_response( + prompt, options=LLMGenerationOptions(temperature=0.2) + ) + summary = summary_result.response if summary := summary.strip(): return summary except Exception as e: diff --git a/src/learners/expression_selector.py b/src/learners/expression_selector.py index c96e84cf..7fc714ea 100644 --- a/src/learners/expression_selector.py +++ b/src/learners/expression_selector.py @@ -4,10 +4,11 @@ import time from typing import List, Dict, Optional, Any, Tuple from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.common.logger import get_logger from src.common.database.database_model import Expression +from src.common.utils.utils_session import SessionUtils from src.prompt.prompt_manager import prompt_manager from src.learners.learner_utils_old import weighted_sample from src.chat.utils.common_utils import TempMethodsExpression @@ -17,8 +18,8 @@ logger = get_logger("expression_selector") class ExpressionSelector: def __init__(self): - self.llm_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.selector" + self.llm_model = LLMServiceClient( + task_name="tool_use", request_type="expression.selector" ) def can_use_expression_for_chat(self, chat_id: str) -> bool: @@ -383,8 +384,8 @@ class ExpressionSelector: prompt = await prompt_manager.render_prompt(prompt_template) # 4. 调用LLM - content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) - + generation_result = await self.llm_model.generate_response(prompt=prompt) + content = generation_result.response # print(prompt) # print(content) diff --git a/src/learners/expression_utils.py b/src/learners/expression_utils.py index 88237e57..573ce364 100644 --- a/src/learners/expression_utils.py +++ b/src/learners/expression_utils.py @@ -1,19 +1,40 @@ from json_repair import repair_json -from typing import Tuple, Optional, List +from typing import Any, List, Optional, Tuple import json import re -from src.config.config import model_config from src.config.config import global_config -from src.llm_models.utils_model import LLMRequest +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.prompt.prompt_manager import prompt_manager from src.common.logger import get_logger logger = get_logger("expression_utils") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") +judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check") + + +def _normalize_repair_json_result(repaired_result: Any) -> str: + """将 repair_json 的返回值规范化为 JSON 字符串。 + + Args: + repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。 + + Returns: + str: 可供 `json.loads` 继续解析的 JSON 字符串。 + + Raises: + TypeError: 当返回值无法规范化为字符串时抛出。 + """ + if isinstance(repaired_result, str): + return repaired_result + if isinstance(repaired_result, tuple) and repaired_result: + first_item = repaired_result[0] + if isinstance(first_item, str): + return first_item + return json.dumps(first_item, ensure_ascii=False) + raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}") async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]: @@ -51,7 +72,11 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool logger.info(f"正在评估表达方式: situation={situation}, style={style}") - response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024) + generation_result = await judge_llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions(temperature=0.6, max_tokens=1024), + ) + response = generation_result.response logger.debug(f"评估结果: {response}") @@ -59,7 +84,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool evaluation = json.loads(response) except json.JSONDecodeError: try: - response_repaired = repair_json(response) + response_repaired = _normalize_repair_json_result(repair_json(response)) evaluation = json.loads(response_repaired) except Exception as e: raise ValueError(f"无法解析LLM响应为JSON: {response}") from e @@ -74,7 +99,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool return False, f"评估结果格式错误: {e}", str(e) -def fix_chinese_quotes_in_json(text): +def fix_chinese_quotes_in_json(text: str) -> str: """使用状态机修复 JSON 字符串值中的中文引号""" result = [] i = 0 @@ -201,12 +226,12 @@ def is_single_char_jargon(content: str) -> bool: ) -def _try_parse(text): +def _try_parse(text: str) -> Any: try: return json.loads(text) except Exception: try: - repaired = repair_json(text) + repaired = _normalize_repair_json_result(repair_json(text)) return json.loads(repaired) except Exception: return None diff --git a/src/learners/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py index 0cfafa82..fded9019 100644 --- a/src/learners/jargon_explainer_old.py +++ b/src/learners/jargon_explainer_old.py @@ -4,8 +4,9 @@ from typing import List, Dict, Optional, Any from src.common.logger import get_logger from src.common.database.database_model import Jargon -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.learners.jargon_miner_old import search_jargon from src.learners.learner_utils_old import ( @@ -23,8 +24,8 @@ class JargonExplainer: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id - self.llm = LLMRequest( - model_set=model_config.model_task_config.tool_use, + self.llm = LLMServiceClient( + task_name="tool_use", request_type="jargon.explain", ) @@ -206,7 +207,10 @@ class JargonExplainer: prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text) summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize) - summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3) + summary_result = await self.llm.generate_response( + summarize_prompt, options=LLMGenerationOptions(temperature=0.3) + ) + summary = summary_result.response if not summary: # 如果LLM概括失败,直接返回原始解释 return f"上下文中的黑话解释:\n{explanations_text}" diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 32926894..b0b4ae33 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -12,17 +12,17 @@ from src.common.data_models.jargon_data_model import MaiJargon from src.common.database.database import get_db_session from src.common.database.database_model import Jargon from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.prompt.prompt_manager import prompt_manager from .expression_utils import is_single_char_jargon logger = get_logger("jargon") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract") -llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference") +llm_extract = LLMServiceClient(task_name="utils", request_type="jargon.extract") +llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference") class JargonEntry(TypedDict): @@ -100,7 +100,10 @@ class JargonMiner: prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) prompt1 = await prompt_manager.render_prompt(prompt1_template) - llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3) + generation_result_1 = await llm_inference.generate_response( + prompt1, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_1 = generation_result_1.response if not llm_response_1: logger.warning(f"jargon {content} 推断1失败:无响应") return @@ -129,7 +132,10 @@ class JargonMiner: prompt2_template.add_context("content", content) prompt2 = await prompt_manager.render_prompt(prompt2_template) - llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3) + generation_result_2 = await llm_inference.generate_response( + prompt2, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_2 = generation_result_2.response if not llm_response_2: logger.warning(f"jargon {content} 推断2失败:无响应") return @@ -153,7 +159,10 @@ class JargonMiner: if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 比较提示词: {prompt3}") - llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3) + generation_result_3 = await llm_inference.generate_response( + prompt3, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_3 = generation_result_3.response if not llm_response_3: logger.warning(f"jargon {content} 比较失败:无响应") return diff --git a/src/llm_models/model_client/adapter_base.py b/src/llm_models/model_client/adapter_base.py new file mode 100644 index 00000000..d631870c --- /dev/null +++ b/src/llm_models/model_client/adapter_base.py @@ -0,0 +1,259 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast + +import asyncio + +from src.config.model_configs import ModelInfo + +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + EmbeddingRequest, + ResponseRequest, + UsageRecord, + UsageTuple, +) + +RawStreamT = TypeVar("RawStreamT") +"""流式原始响应类型变量。""" + +RawResponseT = TypeVar("RawResponseT") +"""非流式原始响应类型变量。""" + +TaskResultT = TypeVar("TaskResultT") +"""异步任务返回值类型变量。""" + +ProviderStreamResponseHandler = Callable[ + [RawStreamT, asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]], +] +"""Provider 专用流式响应处理函数类型。""" + +ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]] +"""Provider 专用非流式响应解析函数类型。""" + + +async def await_task_with_interrupt( + task: asyncio.Task[TaskResultT], + interrupt_flag: asyncio.Event | None, + *, + interval_seconds: float = 0.1, +) -> TaskResultT: + """在支持外部中断的前提下等待异步任务完成。 + + Args: + task: 待等待的异步任务。 + interrupt_flag: 外部中断标记。 + interval_seconds: 轮询检查间隔,单位秒。 + + Returns: + TaskResultT: 任务执行结果。 + + Raises: + ReqAbortException: 等待期间收到外部中断信号时抛出。 + """ + from src.llm_models.exceptions import ReqAbortException + + while not task.done(): + if interrupt_flag and interrupt_flag.is_set(): + task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(interval_seconds) + return await task + + +class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]): + """提供统一请求执行骨架的 Provider 适配基类。""" + + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取对话响应。 + + Args: + request: 统一响应请求对象。 + + Returns: + APIResponse: 解析完成的统一响应对象。 + """ + stream_response_handler = self._resolve_stream_response_handler(request) + response_parser = self._resolve_response_parser(request) + response, usage_record = await self._execute_response_request( + request, + stream_response_handler, + response_parser, + ) + return self._attach_usage_record(response, request.model_info, usage_record) + + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取文本嵌入。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + APIResponse: 解析完成的统一嵌入响应。 + """ + response, usage_record = await self._execute_embedding_request(request) + return self._attach_usage_record(response, request.model_info, usage_record) + + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取音频转录。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + APIResponse: 解析完成的统一音频转录响应。 + """ + response, usage_record = await self._execute_audio_transcription_request(request) + return self._attach_usage_record(response, request.model_info, usage_record) + + def _resolve_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[RawStreamT]: + """解析实际使用的流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。 + """ + if request.stream_response_handler is not None: + return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler) + return self._build_default_stream_response_handler(request) + + def _resolve_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[RawResponseT]: + """解析实际使用的非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[RawResponseT]: 非流式响应解析器。 + """ + if request.async_response_parser is not None: + return cast(ProviderResponseParser[RawResponseT], request.async_response_parser) + return self._build_default_response_parser(request) + + @staticmethod + def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord: + """根据统一使用量三元组构建 `UsageRecord`。 + + Args: + model_info: 模型信息。 + usage_record: 使用量三元组。 + + Returns: + UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。 + """ + return UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + def _attach_usage_record( + self, + response: APIResponse, + model_info: ModelInfo, + usage_record: UsageTuple | None, + ) -> APIResponse: + """在响应对象上附加统一使用量信息。 + + Args: + response: 已解析的统一响应对象。 + model_info: 模型信息。 + usage_record: 可选的使用量三元组。 + + Returns: + APIResponse: 附加使用量后的响应对象。 + """ + if usage_record is not None: + response.usage = self._build_usage_record(model_info, usage_record) + return response + + @abstractmethod + def _build_default_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[RawStreamT]: + """构建默认流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。 + """ + raise NotImplementedError + + @abstractmethod + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[RawResponseT]: + """构建默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[RawResponseT]: 默认非流式解析器。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[RawStreamT], + response_parser: ProviderResponseParser[RawResponseT], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_embedding_request( + self, + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_audio_transcription_request( + self, + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的音频转录请求。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 226c725f..fc03ac02 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,14 +1,15 @@ -import asyncio -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Callable, Any, Optional +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type + +import asyncio from src.common.logger import get_logger from src.config.config import config_manager -from src.config.model_configs import ModelInfo, APIProvider -from ..payload_content.message import Message -from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolCall +from src.config.model_configs import APIProvider, ModelInfo +from src.llm_models.payload_content.message import Message +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption logger = get_logger("model_client_registry") @@ -47,10 +48,10 @@ class APIResponse: reasoning_content: str | None = None """推理内容""" - tool_calls: list[ToolCall] | None = None + tool_calls: List[ToolCall] | None = None """工具调用 [(工具名称, 工具参数), ...]""" - embedding: list[float] | None = None + embedding: List[float] | None = None """嵌入向量""" usage: UsageRecord | None = None @@ -60,6 +61,82 @@ class APIResponse: """响应原始数据""" +UsageTuple = Tuple[int, int, int] +"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。""" + +StreamResponseHandler = Callable[ + [Any, asyncio.Event | None], + Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]], +] +"""统一的流式响应处理函数类型。""" + +ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]] +"""统一的非流式响应解析函数类型。""" + + +@dataclass(slots=True) +class ResponseRequest: + """统一的文本/多模态响应请求。""" + + model_info: ModelInfo + message_list: List[Message] + tool_options: List[ToolOption] | None = None + max_tokens: int | None = None + temperature: float | None = None + response_format: RespFormat | None = None + stream_response_handler: StreamResponseHandler | None = None + async_response_parser: ResponseParser | None = None + interrupt_flag: asyncio.Event | None = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + def copy_with(self, **changes: Any) -> "ResponseRequest": + """基于当前请求创建一个带局部变更的新请求。 + + Args: + **changes: 需要覆盖的字段值。 + + Returns: + ResponseRequest: 复制后的请求对象。 + """ + payload = { + "model_info": self.model_info, + "message_list": list(self.message_list), + "tool_options": None if self.tool_options is None else list(self.tool_options), + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "response_format": self.response_format, + "stream_response_handler": self.stream_response_handler, + "async_response_parser": self.async_response_parser, + "interrupt_flag": self.interrupt_flag, + "extra_params": dict(self.extra_params), + } + payload.update(changes) + return ResponseRequest(**payload) + + +@dataclass(slots=True) +class EmbeddingRequest: + """统一的嵌入请求。""" + + model_info: ModelInfo + embedding_input: str + extra_params: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class AudioTranscriptionRequest: + """统一的音频转录请求。""" + + model_info: ModelInfo + audio_base64: str + max_tokens: int | None = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + +ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest +"""统一客户端请求类型。""" + + class BaseClient(ABC): """ 基础客户端 @@ -67,97 +144,82 @@ class BaseClient(ABC): api_provider: APIProvider - def __init__(self, api_provider: APIProvider): + def __init__(self, api_provider: APIProvider) -> None: + """初始化基础客户端。 + + Args: + api_provider: API 提供商配置。 + """ self.api_provider = api_provider @abstractmethod - async def get_response( - self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] - ] = None, - async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 - :param model_info: 模型信息 - :param message_list: 对话体 - :param tool_options: 工具选项(可选,默认为None) - :param max_tokens: 最大token数(可选,默认为1024) - :param temperature: 温度(可选,默认为0.7) - :param response_format: 响应格式(可选,默认为 NotGiven ) - :param stream_response_handler: 流式响应处理函数(可选) - :param async_response_parser: 响应解析函数(可选) - :param interrupt_flag: 中断信号量(可选,默认为None) - :return: (响应文本, 推理文本, 工具调用, 其他数据) + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取对话响应。 + + Args: + request: 统一响应请求对象。 + + Returns: + APIResponse: 统一响应对象。 """ raise NotImplementedError("'get_response' method should be overridden in subclasses") @abstractmethod - async def get_embedding( - self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取文本嵌入。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + APIResponse: 嵌入响应。 """ raise NotImplementedError("'get_embedding' method should be overridden in subclasses") @abstractmethod - async def get_audio_transcriptions( - self, - model_info: ModelInfo, - audio_base64: str, - max_tokens: Optional[int] = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: base64编码的音频数据 - :extra_params: 附加的请求参数 - :return: 音频转录响应 + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取音频转录。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + APIResponse: 音频转录响应。 """ raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses") @abstractmethod - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + def get_support_image_formats(self) -> List[str]: + """获取支持的图片格式。 + + Returns: + List[str]: 支持的图片格式列表。 """ raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses") class ClientRegistry: + """客户端注册表。""" + def __init__(self) -> None: - self.client_registry: dict[str, type[BaseClient]] = {} + """初始化注册表并绑定配置重载回调。""" + self.client_registry: Dict[str, Type[BaseClient]] = {} """APIProvider.type -> BaseClient的映射表""" - self.client_instance_cache: dict[str, BaseClient] = {} + self.client_instance_cache: Dict[str, BaseClient] = {} """APIProvider.name -> BaseClient的映射表""" config_manager.register_reload_callback(self.clear_client_instance_cache) - def register_client_class(self, client_type: str): - """ - 注册API客户端类 + def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]: + """注册 API 客户端类。 + Args: - client_class: API客户端类 + client_type: 客户端类型标识。 + + Returns: + Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。 """ - def decorator(cls: type[BaseClient]) -> type[BaseClient]: + def decorator(cls: Type[BaseClient]) -> Type[BaseClient]: if not issubclass(cls, BaseClient): raise TypeError(f"{cls.__name__} is not a subclass of BaseClient") self.client_registry[client_type] = cls @@ -165,14 +227,15 @@ class ClientRegistry: return decorator - def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: - """ - 获取注册的API客户端实例 + def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient: + """获取注册的 API 客户端实例。 + Args: - api_provider: APIProvider实例 - force_new: 是否强制创建新实例(用于解决事件循环问题) + api_provider: APIProvider 实例。 + force_new: 是否强制创建新实例。 + Returns: - BaseClient: 注册的API客户端实例 + BaseClient: 注册的 API 客户端实例。 """ from . import ensure_client_type_loaded @@ -194,6 +257,7 @@ class ClientRegistry: return self.client_instance_cache[api_provider.name] def clear_client_instance_cache(self) -> None: + """清空客户端实例缓存。""" self.client_instance_cache.clear() logger.info("检测到配置重载,已清空LLM客户端实例缓存") diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index f63707d9..17cedb45 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,771 +1,973 @@ +from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast + import asyncio -import io import base64 -from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict +import io +import json from google import genai -from google.genai.types import ( - Content, - Part, - FunctionDeclaration, - GenerateContentResponse, - ContentListUnion, - ContentUnion, - ThinkingConfig, - Tool, - GoogleSearch, - GenerateContentConfig, - EmbedContentResponse, - EmbedContentConfig, - SafetySetting, - HttpOptions, - HarmCategory, - HarmBlockThreshold, -) from google.genai.errors import ( ClientError, + FunctionInvocationError, ServerError, UnknownFunctionCallArgumentError, UnsupportedFunctionError, - FunctionInvocationError, +) +from google.genai.types import ( + Candidate, + Content, + ContentListUnion, + ContentUnion, + EmbedContentConfig, + EmbedContentResponse, + FunctionDeclaration, + GenerateContentConfig, + GenerateContentResponse, + GoogleSearch, + HarmBlockThreshold, + HarmCategory, + HttpOptions, + Part, + SafetySetting, + ThinkingConfig, + Tool, ) -from src.config.model_configs import ModelInfo, APIProvider from src.common.logger import get_logger - -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry -from ..exceptions import ( - RespParseException, - NetworkConnectionError, - RespNotOkException, - ReqAbortException, +from src.config.model_configs import APIProvider +from src.llm_models.exceptions import ( EmptyResponseException, + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart +from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption + +from .adapter_base import ( + AdapterClient, + ProviderResponseParser, + ProviderStreamResponseHandler, + await_task_with_interrupt, +) +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + EmbeddingRequest, + ResponseRequest, + UsageTuple, + client_registry, ) -from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") -# gemini_thinking参数(默认范围) -# 不同模型的思考预算范围配置 -THINKING_BUDGET_LIMITS = { +GeminiStreamResponseHandler = Callable[ + [AsyncIterator[GenerateContentResponse], asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, Optional[UsageTuple]]], +] +"""Gemini 流式响应处理函数类型。""" + +GeminiResponseParser = Callable[[GenerateContentResponse], Tuple[APIResponse, Optional[UsageTuple]]] +"""Gemini 非流式响应解析函数类型。""" + +THINKING_BUDGET_LIMITS: Dict[str, Dict[str, int | bool]] = { "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, } -# 思维预算特殊值 -THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 -THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) +"""不同 Gemini 模型允许的思考预算范围。""" -gemini_safe_settings = [ +THINKING_BUDGET_AUTO = -1 +"""自动思考预算模式,由模型自行决定。""" + +THINKING_BUDGET_DISABLED = 0 +"""禁用思考预算模式。仅部分模型支持。""" + +GEMINI_SAFE_SETTINGS: List[SafetySetting] = [ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), ] +"""默认安全策略,避免 Gemini 在部分内容上返回空响应。""" + +GENERATE_CONFIG_RESERVED_EXTRA_PARAMS = { + "thinking_budget", + "include_thoughts", + "enable_google_search", + "transcription_prompt", + "audio_mime_type", +} +"""由当前客户端自行处理、不再直接透传给 `GenerateContentConfig` 的额外参数。""" + +EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = { + "task_type", + "title", + "output_dimensionality", + "mime_type", + "auto_truncate", +} +"""可透传给 `EmbedContentConfig` 的额外参数字段。""" -def _convert_messages( - messages: list[Message], -) -> tuple[ContentListUnion, list[str] | None]: +def _normalize_image_mime_type(image_format: str) -> str: + """将图片格式名称转换为标准 MIME 类型。 + + Args: + image_format: 图片格式名,例如 `png`、`jpg`。 + + Returns: + str: 规范化后的图片 MIME 类型。 """ - 转换消息格式 - 将消息转换为Gemini API所需的格式 - :param messages: 消息列表 - :return: 转换后的消息列表(和可能存在的system消息) + normalized_image_format = image_format.lower() + if normalized_image_format in {"jpg", "jpeg"}: + return "image/jpeg" + return f"image/{normalized_image_format}" + + +def _build_non_tool_parts(message: Message) -> List[Part]: + """将消息中的文本与图片片段转换为 Gemini `Part` 列表。 + + Args: + message: 内部统一消息对象。 + + Returns: + List[Part]: Gemini 所需的内容片段列表。 """ + converted_parts: List[Part] = [] + for message_part in message.parts: + if isinstance(message_part, TextMessagePart): + converted_parts.append(Part.from_text(text=message_part.text)) + continue + if isinstance(message_part, ImageMessagePart): + converted_parts.append( + Part.from_bytes( + data=base64.b64decode(message_part.image_base64), + mime_type=_normalize_image_mime_type(message_part.normalized_image_format), + ) + ) + return converted_parts - def _convert_message_item(message: Message) -> Content: - """ - 转换单个消息格式,除了system和tool类型的消息 - :param message: 消息对象 - :return: 转换后的消息字典 - """ - # 将openai格式的角色重命名为gemini格式的角色 - if message.role == RoleType.Assistant: - role = "model" - elif message.role == RoleType.User: - role = "user" - else: - raise ValueError(f"Unsupported role: {message.role}") +def _normalize_function_response_payload(message: Message) -> Dict[str, Any]: + """将内部工具结果消息转换为 Gemini 函数响应负载。 - # 添加Content - if isinstance(message.content, str): - content = [Part.from_text(text=message.content)] - elif isinstance(message.content, list): - content: List[Part] = [] - for item in message.content: - if isinstance(item, tuple): - image_format = item[0].lower() - # 规范 JPEG MIME 类型后缀,统一使用 image/jpeg - if image_format in ("jpg", "jpeg"): - image_format = "jpeg" - content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")) - elif isinstance(item, str): - content.append(Part.from_text(text=item)) - else: - raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + Args: + message: 工具结果消息。 - return Content(role=role, parts=content) + Returns: + Dict[str, Any]: 可用于 `Part.from_function_response()` 的响应对象。 + """ + content = message.content + if isinstance(content, str): + stripped_content = content.strip() + if not stripped_content: + return {} + try: + parsed_content = json.loads(stripped_content) + except json.JSONDecodeError: + return {"result": content} + if isinstance(parsed_content, dict): + return parsed_content + return {"result": parsed_content} + + return {"result": content} + + +def _get_candidates(response: GenerateContentResponse) -> List[Candidate]: + """安全获取 Gemini 响应中的候选列表。 + + Args: + response: Gemini 响应对象。 + + Returns: + List[Candidate]: 非空时返回原候选列表,否则返回空列表。 + """ + return response.candidates or [] + + +def _extract_response_json_schema(response_format: RespFormat) -> Dict[str, object] | None: + """从内部响应格式中提取可供 Gemini 使用的 JSON Schema。 + + Args: + response_format: 输出格式定义。 + + Returns: + Dict[str, object] | None: 可直接传给 `response_json_schema` 的 JSON Schema。 + """ + schema_payload = response_format.get_schema_object() + if schema_payload is None: + return None + return cast(Dict[str, object], schema_payload) + + +def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str | None]: + """将内部统一消息列表转换为 Gemini 内容结构。 + + Args: + messages: 内部统一消息列表。 + + Returns: + Tuple[ContentListUnion, str | None]: `contents` 与可选的 `system_instruction`。 + + Raises: + ValueError: 当消息结构无法映射到 Gemini 内容模型时抛出。 + """ + contents: List[ContentUnion] = [] + system_instruction_chunks: List[str] = [] + tool_name_by_call_id: Dict[str, str] = {} - temp_list: list[ContentUnion] = [] - system_instructions: list[str] = [] for message in messages: if message.role == RoleType.System: - if isinstance(message.content, str): - system_instructions.append(message.content) - else: - raise ValueError("你tm怎么往system里面塞图片base64?") - elif message.role == RoleType.Tool: + system_text = message.get_text_content().strip() + if not system_text: + raise ValueError("Gemini 的 system message 必须为非空文本") + system_instruction_chunks.append(system_text) + continue + + if message.role == RoleType.User: + contents.append(Content(role="user", parts=_build_non_tool_parts(message))) + continue + + if message.role == RoleType.Assistant: + assistant_parts = _build_non_tool_parts(message) + if message.tool_calls: + for tool_call in message.tool_calls: + assistant_parts.append( + Part.from_function_call( + name=tool_call.func_name, + args=tool_call.args or {}, + ) + ) + tool_name_by_call_id[tool_call.call_id] = tool_call.func_name + contents.append(Content(role="model", parts=assistant_parts)) + continue + + if message.role == RoleType.Tool: if not message.tool_call_id: - raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") - else: - temp_list.append(_convert_message_item(message)) - if system_instructions: - # 如果有system消息,就把它加上去 - ret: tuple = (temp_list, system_instructions) - else: - # 如果没有system消息,就直接返回 - ret: tuple = (temp_list, None) + raise ValueError("Gemini 工具结果消息缺少 tool_call_id") + tool_name = tool_name_by_call_id.get(message.tool_call_id) + if not tool_name: + raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称") + function_response_part = Part.from_function_response( + name=tool_name, + response=_normalize_function_response_payload(message), + ) + contents.append(Content(role="tool", parts=[function_response_part])) + continue - return ret + raise ValueError(f"不支持的消息角色: {message.role}") + + system_instruction = "\n\n".join(chunk for chunk in system_instruction_chunks if chunk.strip()) or None + return contents, system_instruction -def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]: +def _build_tools(tool_options: List[ToolOption]) -> List[Tool]: + """将内部工具定义转换为 Gemini `Tool` 列表。 + + Args: + tool_options: 内部统一工具定义列表。 + + Returns: + List[Tool]: Gemini 所需工具列表。 """ - 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式 - :param tool_options: 工具选项列表 - :return: 转换后的工具对象列表 - """ - - def _convert_tool_param(tool_option_param: ToolParam) -> dict: - """ - 转换单个工具参数格式 - :param tool_option_param: 工具参数对象 - :return: 转换后的工具参数字典 - """ - # JSON Schema 类型名称修正: - # - 布尔类型使用 "boolean" 而不是 "bool" - # - 浮点数使用 "number" 而不是 "float" - param_type_value = tool_option_param.param_type.value - if param_type_value == "bool": - param_type_value = "boolean" - elif param_type_value == "float": - param_type_value = "number" - - return_dict: dict[str, Any] = { - "type": param_type_value, - "description": tool_option_param.description, - } - if tool_option_param.enum_values: - return_dict["enum"] = tool_option_param.enum_values - return return_dict - - def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration: - """ - 转换单个工具项格式 - :param tool_option: 工具选项对象 - :return: 转换后的Gemini工具选项对象 - """ - ret: dict[str, Any] = { + function_declarations: List[FunctionDeclaration] = [] + for tool_option in tool_options: + payload: Dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } - if tool_option.params: - ret["parameters"] = { - "type": "object", - "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, - "required": [param.name for param in tool_option.params if param.required], - } - ret1 = FunctionDeclaration(**ret) - return ret1 - - return [_convert_tool_option_item(tool_option) for tool_option in tool_options] + if tool_option.parameters_schema is not None: + payload["parameters_json_schema"] = tool_option.parameters_schema + function_declarations.append(FunctionDeclaration(**payload)) + return [Tool(function_declarations=function_declarations)] if function_declarations else [] -def _process_delta( - delta: GenerateContentResponse, - fc_delta_buffer: io.StringIO, - tool_calls_buffer: list[tuple[str, str, dict[str, Any]]], - resp: APIResponse | None = None, -): - if not hasattr(delta, "candidates") or not delta.candidates: - raise RespParseException(delta, "响应解析失败,缺失candidates字段") +def _extract_usage_record(response: GenerateContentResponse) -> Optional[UsageTuple]: + """从 Gemini 响应中提取使用量信息。 - # 处理 thought(Gemini 的特殊字段) - for c in getattr(delta, "candidates", []): - if c.content and getattr(c.content, "parts", None): - for p in c.content.parts: - if getattr(p, "thought", False) and getattr(p, "text", None): - # 保存到 reasoning_content - if resp is not None: - resp.reasoning_content = (resp.reasoning_content or "") + p.text - elif getattr(p, "text", None): - # 正常输出写入 buffer - fc_delta_buffer.write(p.text) + Args: + response: Gemini 响应对象。 - if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 - for call in delta.function_calls: - try: - if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了 - raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型") - if not call.id or not call.name: - raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段") - tool_calls_buffer.append( - ( - call.id, - call.name, - call.args or {}, # 如果args是None,则转换为一个空字典 - ) - ) - except Exception as e: - raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e + Returns: + Optional[UsageTuple]: 统一的使用量三元组;缺失时返回 `None`。 + """ + usage_metadata = getattr(response, "usage_metadata", None) + if usage_metadata is None: + return None + prompt_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0 + completion_tokens = ( + (getattr(usage_metadata, "candidates_token_count", 0) or 0) + + (getattr(usage_metadata, "thoughts_token_count", 0) or 0) + ) + total_tokens = getattr(usage_metadata, "total_token_count", 0) or 0 + return prompt_tokens, completion_tokens, total_tokens -def _build_stream_api_resp( - _fc_delta_buffer: io.StringIO, - _tool_calls_buffer: list[tuple[str, str, dict]], - last_resp: GenerateContentResponse | None = None, # 传入 last_resp - resp: APIResponse | None = None, -) -> APIResponse: - # sourcery skip: simplify-len-comparison, use-assigned-variable - if resp is None: - resp = APIResponse() +def _extract_finish_reason(response: GenerateContentResponse | None) -> str | None: + """提取 Gemini 响应的结束原因。 - if _fc_delta_buffer.tell() > 0: - # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 - resp.content = _fc_delta_buffer.getvalue() - _fc_delta_buffer.close() - if len(_tool_calls_buffer) > 0: - # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 - resp.tool_calls = [] - for call_id, function_name, arguments_buffer in _tool_calls_buffer: - if arguments_buffer is not None: - arguments = arguments_buffer - if not isinstance(arguments, dict): - raise RespParseException( - None, - f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}", - ) + Args: + response: Gemini 响应对象。 + + Returns: + str | None: 结束原因字符串;获取失败时返回 `None`。 + """ + if response is None: + return None + candidates = _get_candidates(response) + if not candidates: + return None + for candidate in candidates: + finish_reason = getattr(candidate, "finish_reason", None) or getattr(candidate, "finishReason", None) + if finish_reason: + return str(finish_reason) + return None + + +def _warn_if_max_tokens_truncated( + response: GenerateContentResponse | None, + content: str | None, + tool_calls: List[ToolCall] | None, +) -> None: + """在 Gemini 因 token 限制截断时输出警告。 + + Args: + response: Gemini 响应对象。 + content: 已解析的可见文本内容。 + tool_calls: 已解析的工具调用列表。 + """ + finish_reason = _extract_finish_reason(response) + if finish_reason is None or "MAX_TOKENS" not in finish_reason: + return + has_visible_output = bool((content and content.strip()) or tool_calls) + if has_visible_output: + logger.warning( + "Gemini 响应因达到 max_tokens 限制被部分截断,可能影响回复完整性,建议调整模型 max_tokens 配置。" + ) + return + logger.warning("Gemini 响应因达到 max_tokens 限制被截断,且未返回可见输出,请检查模型 max_tokens 配置。") + + +def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]: + """从 Gemini 响应中提取工具调用列表。 + + Args: + response: Gemini 响应对象。 + + Returns: + List[ToolCall]: 规范化后的工具调用列表。 + + Raises: + RespParseException: 当函数调用结构不合法时抛出。 + """ + raw_function_calls = getattr(response, "function_calls", None) + candidates = _get_candidates(response) + if not raw_function_calls and candidates: + raw_function_calls = [] + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + function_call = getattr(part, "function_call", None) + if function_call is not None: + raw_function_calls.append(function_call) + + if not raw_function_calls: + return [] + + tool_calls: List[ToolCall] = [] + for index, function_call in enumerate(raw_function_calls, start=1): + call_name = getattr(function_call, "name", None) + call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}" + call_args = getattr(function_call, "args", None) or {} + if not isinstance(call_name, str) or not call_name: + raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段") + if not isinstance(call_args, dict): + raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典") + tool_calls.append(ToolCall(call_id=call_id, func_name=call_name, args=call_args)) + return tool_calls + + +def _process_stream_chunk( + chunk: GenerateContentResponse, + content_buffer: io.StringIO, + tool_calls_buffer: List[ToolCall], + response: APIResponse, +) -> None: + """处理单个 Gemini 流式响应块。 + + Args: + chunk: 当前流式响应块。 + content_buffer: 正文缓冲区。 + tool_calls_buffer: 工具调用缓冲区。 + response: 当前累积的统一响应对象。 + """ + candidates = _get_candidates(chunk) + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + part_text = getattr(part, "text", None) + if not part_text: + continue + if getattr(part, "thought", False): + response.reasoning_content = (response.reasoning_content or "") + part_text else: - arguments = None + content_buffer.write(part_text) - resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + tool_calls_buffer.extend(_collect_function_calls(chunk)) - # 检查是否因为 max_tokens 截断 - reason = None - if last_resp and getattr(last_resp, "candidates", None): - for c in last_resp.candidates: - fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None) - if fr: - reason = str(fr) - break - if str(reason).endswith("MAX_TOKENS"): - has_visible_output = bool(resp.content and resp.content.strip()) - if has_visible_output: - logger.warning( - "⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n" - " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" - ) - else: - logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") +def _build_stream_api_response( + content_buffer: io.StringIO, + tool_calls_buffer: List[ToolCall], + last_response: GenerateContentResponse | None, + response: APIResponse, +) -> APIResponse: + """根据流式缓冲区内容构建统一响应对象。 - if not resp.content and not resp.tool_calls: - if not getattr(resp, "reasoning_content", None): - raise EmptyResponseException() + Args: + content_buffer: 正文缓冲区。 + tool_calls_buffer: 工具调用缓冲区。 + last_response: 最后一个 Gemini 响应块。 + response: 已累积的响应对象。 - return resp + Returns: + APIResponse: 构建完成的统一响应对象。 + + Raises: + EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 + """ + if content_buffer.tell() > 0: + response.content = content_buffer.getvalue() + content_buffer.close() + + if tool_calls_buffer: + response.tool_calls = list(tool_calls_buffer) + response.raw_data = last_response + + _warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls) + if not response.content and not response.tool_calls and not response.reasoning_content: + raise EmptyResponseException() + return response async def _default_stream_response_handler( - resp_stream: AsyncIterator[GenerateContentResponse], + response_stream: AsyncIterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> Tuple[APIResponse, Optional[UsageTuple]]: + """处理 Gemini 流式响应。 + + Args: + response_stream: Gemini 异步流式响应迭代器。 + interrupt_flag: 外部中断标记。 + + Returns: + Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 """ - 流式响应处理函数 - 处理Gemini API的流式响应 - :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 - :return: APIResponse对象 - """ - _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 - _usage_record = None # 使用情况记录 - last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk - resp = APIResponse() - - def _insure_buffer_closed(): - if _fc_delta_buffer and not _fc_delta_buffer.closed: - _fc_delta_buffer.close() - - async for chunk in resp_stream: - last_resp = chunk # 保存最后一个响应 - # 检查是否有中断量 - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量被设置,则抛出ReqAbortException - raise ReqAbortException("请求被外部信号中断") - - _process_delta( - chunk, - _fc_delta_buffer, - _tool_calls_buffer, - resp=resp, - ) - - if chunk.usage_metadata: - # 如果有使用情况,则将其存储在APIResponse对象中 - _usage_record = ( - chunk.usage_metadata.prompt_token_count or 0, - (chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0), - chunk.usage_metadata.total_token_count or 0, - ) + content_buffer = io.StringIO() + tool_calls_buffer: List[ToolCall] = [] + api_response = APIResponse() + usage_record: Optional[UsageTuple] = None + last_response: GenerateContentResponse | None = None try: - return _build_stream_api_resp( - _fc_delta_buffer, - _tool_calls_buffer, - last_resp=last_resp, - resp=resp, - ), _usage_record + async for chunk in response_stream: + last_response = chunk + if interrupt_flag and interrupt_flag.is_set(): + raise ReqAbortException("请求被外部信号中断") + _process_stream_chunk(chunk, content_buffer, tool_calls_buffer, api_response) + usage_record = _extract_usage_record(chunk) or usage_record + return _build_stream_api_response(content_buffer, tool_calls_buffer, last_response, api_response), usage_record except Exception: - # 确保缓冲区被关闭 - _insure_buffer_closed() + if not content_buffer.closed: + content_buffer.close() raise def _default_normal_response_parser( - resp: GenerateContentResponse, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + response: GenerateContentResponse, +) -> Tuple[APIResponse, Optional[UsageTuple]]: + """解析 Gemini 非流式响应。 + + Args: + response: Gemini 响应对象。 + + Returns: + Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 + + Raises: + EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 """ - 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 - :param resp: 响应对象 - :return: APIResponse对象 - """ - api_response = APIResponse() + api_response = APIResponse(raw_data=response) + visible_parts: List[str] = [] - # 解析思考内容 - try: - if candidates := resp.candidates: - if candidates[0].content and candidates[0].content.parts: - for part in candidates[0].content.parts: - if not part.text: - continue - if part.thought: - api_response.reasoning_content = ( - api_response.reasoning_content + part.text if api_response.reasoning_content else part.text - ) - except Exception as e: - logger.warning(f"解析思考内容时发生错误: {e},跳过解析") + for candidate in _get_candidates(response): + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + part_text = getattr(part, "text", None) + if not part_text: + continue + if getattr(part, "thought", False): + api_response.reasoning_content = (api_response.reasoning_content or "") + part_text + else: + visible_parts.append(part_text) - # 解析响应内容 - api_response.content = resp.text + api_response.content = "".join(visible_parts).strip() or getattr(response, "text", None) - # 解析工具调用 - if function_calls := resp.function_calls: - api_response.tool_calls = [] - for call in function_calls: - try: - if not isinstance(call.args, dict): - raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") - if not call.name: - raise RespParseException(resp, "响应解析失败,工具调用缺失name字段") - api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {})) - except Exception as e: - raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e + tool_calls = _collect_function_calls(response) + if tool_calls: + api_response.tool_calls = tool_calls - # 解析使用情况 - if usage_metadata := resp.usage_metadata: - _usage_record = ( - usage_metadata.prompt_token_count or 0, - (usage_metadata.candidates_token_count or 0) + (usage_metadata.thoughts_token_count or 0), - usage_metadata.total_token_count or 0, - ) - else: - _usage_record = None - - api_response.raw_data = resp - - # 检查是否因为 max_tokens 截断 - try: - if resp.candidates: - c0 = resp.candidates[0] - reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None) - if reason and "MAX_TOKENS" in str(reason): - # 检查第二个及之后的 parts 是否有内容 - has_real_output = False - if getattr(c0, "content", None) and getattr(c0.content, "parts", None): - for p in c0.content.parts[1:]: # 跳过第一个 thought - if getattr(p, "text", None) and p.text.strip(): - has_real_output = True - break - - if not has_real_output and getattr(resp, "text", None): - has_real_output = True - - if has_real_output: - logger.warning( - "⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n" - " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" - ) - else: - logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") - - return api_response, _usage_record - except Exception as e: - logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") - - # 最终的、唯一的空响应检查 - if not api_response.content and not api_response.tool_calls: + usage_record = _extract_usage_record(response) + _warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls) + if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content: raise EmptyResponseException("响应中既无文本内容也无工具调用") + return api_response, usage_record - return api_response, _usage_record + +def _build_http_options(api_provider: APIProvider) -> HttpOptions: + """根据 Provider 配置构建 Gemini SDK 的 `HttpOptions`。 + + Args: + api_provider: API 提供商配置。 + + Returns: + HttpOptions: Gemini SDK HTTP 选项对象。 + """ + http_options_payload: Dict[str, Any] = {} + if api_provider.timeout is not None: + http_options_payload["timeout"] = int(api_provider.timeout * 1000) + + base_url = api_provider.base_url.strip() + if base_url: + normalized_base_url = base_url.rstrip("/") + version_candidate = normalized_base_url.rsplit("/", 1) + if len(version_candidate) == 2 and version_candidate[1].startswith("v"): + http_options_payload["base_url"] = f"{version_candidate[0]}/" + http_options_payload["api_version"] = version_candidate[1] + else: + http_options_payload["base_url"] = f"{normalized_base_url}/" + + return HttpOptions(**http_options_payload) + + +def _filter_generate_content_extra_params(extra_params: Dict[str, Any]) -> Dict[str, Any]: + """筛选可透传给 `GenerateContentConfig` 的额外参数。 + + Args: + extra_params: 模型级额外参数。 + + Returns: + Dict[str, Any]: 可直接透传到 `GenerateContentConfig` 的字段字典。 + """ + filtered_params: Dict[str, Any] = {} + for key, value in extra_params.items(): + if key in GENERATE_CONFIG_RESERVED_EXTRA_PARAMS: + continue + if key in GenerateContentConfig.model_fields: + filtered_params[key] = value + return filtered_params + + +def _build_embed_content_config(extra_params: Dict[str, Any]) -> EmbedContentConfig: + """构建 Gemini 嵌入配置。 + + Args: + extra_params: 模型级额外参数。 + + Returns: + EmbedContentConfig: Gemini 嵌入配置对象。 + """ + config_payload: Dict[str, Any] = {"task_type": extra_params.get("task_type", "SEMANTIC_SIMILARITY")} + for key in EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS: + if key == "task_type": + continue + if key in extra_params: + config_payload[key] = extra_params[key] + return EmbedContentConfig(**config_payload) @client_registry.register_client_class("gemini") -class GeminiClient(BaseClient): +class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], GenerateContentResponse]): + """Gemini 官方 SDK 客户端适配器。""" + client: genai.Client - def __init__(self, api_provider: APIProvider): + def __init__(self, api_provider: APIProvider) -> None: + """初始化 Gemini 客户端。 + + Args: + api_provider: API 提供商配置。 + """ super().__init__(api_provider) - - # 增加传入参数处理 - http_options_kwargs: Dict[str, Any] = {} - - # 秒转换为毫秒传入 - if api_provider.timeout is not None: - http_options_kwargs["timeout"] = int(api_provider.timeout * 1000) - - # 传入并处理地址和版本(必须为Gemini格式) - if api_provider.base_url: - parts = api_provider.base_url.rstrip("/").rsplit("/", 1) - if len(parts) == 2 and parts[1].startswith("v"): - http_options_kwargs["base_url"] = f"{parts[0]}/" - http_options_kwargs["api_version"] = parts[1] - else: - http_options_kwargs["base_url"] = api_provider.base_url - http_options_kwargs["api_version"] = None self.client = genai.Client( - http_options=HttpOptions(**http_options_kwargs), api_key=api_provider.api_key, - ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + http_options=_build_http_options(api_provider), + ) @staticmethod - def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int: - """ - 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) - """ - limits = None + def clamp_thinking_budget(extra_params: Dict[str, Any] | None, model_id: str) -> int: + """将思考预算裁剪到模型允许的范围内。 - # 参数传入处理 - tb = THINKING_BUDGET_AUTO + Args: + extra_params: 请求额外参数。 + model_id: 当前模型标识。 + + Returns: + int: 裁剪后的思考预算值。 + """ + thinking_budget = THINKING_BUDGET_AUTO if extra_params and "thinking_budget" in extra_params: try: - tb = int(extra_params["thinking_budget"]) - except (ValueError, TypeError): - logger.warning( - f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" - ) + thinking_budget = int(extra_params["thinking_budget"]) + except (TypeError, ValueError): + logger.warning("无效的 thinking_budget=%s,已回退为自动模式", extra_params["thinking_budget"]) - # 优先尝试精确匹配 + limits: Dict[str, int | bool] | None = None if model_id in THINKING_BUDGET_LIMITS: limits = THINKING_BUDGET_LIMITS[model_id] else: - # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 - sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) - for key in sorted_keys: - # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) - if model_id == key or model_id.startswith(f"{key}-"): - limits = THINKING_BUDGET_LIMITS[key] + for candidate_prefix in sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True): + if model_id == candidate_prefix or model_id.startswith(f"{candidate_prefix}-"): + limits = THINKING_BUDGET_LIMITS[candidate_prefix] break - # 预算值处理 - if tb == THINKING_BUDGET_AUTO: + if thinking_budget == THINKING_BUDGET_AUTO: return THINKING_BUDGET_AUTO - if tb == THINKING_BUDGET_DISABLED: - if limits and limits.get("can_disable", False): + + if thinking_budget == THINKING_BUDGET_DISABLED: + if limits and bool(limits.get("can_disable", False)): return THINKING_BUDGET_DISABLED if limits: - logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}") - return limits["min"] + minimum_value = int(limits["min"]) + logger.warning("模型 %s 不支持禁用思考预算,已回退为最小值 %s", model_id, minimum_value) + return minimum_value return THINKING_BUDGET_AUTO - # 已知模型范围裁剪 + 提示 - if limits: - if tb < limits["min"]: - logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}") - return limits["min"] - if tb > limits["max"]: - logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}") - return limits["max"] - return tb + if limits is None: + logger.warning("模型 %s 未配置思考预算范围,已回退为自动模式", model_id) + return THINKING_BUDGET_AUTO - # 未知模型 → 默认自动模式 - logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容") - return THINKING_BUDGET_AUTO + minimum_value = int(limits["min"]) + maximum_value = int(limits["max"]) + if thinking_budget < minimum_value: + logger.warning("模型 %s 的 thinking_budget=%s 过小,已调整为 %s", model_id, thinking_budget, minimum_value) + return minimum_value + if thinking_budget > maximum_value: + logger.warning("模型 %s 的 thinking_budget=%s 过大,已调整为 %s", model_id, thinking_budget, maximum_value) + return maximum_value + return thinking_budget + + @staticmethod + def _resolve_model_identifier(model_identifier: str, extra_params: Dict[str, Any]) -> Tuple[str, bool]: + """解析请求实际使用的 Gemini 模型标识。 - async def get_response( - self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = 1024, - temperature: Optional[float] = 0.4, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncIterator[GenerateContentResponse], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 Args: - model_info: 模型信息 - message_list: 对话体 - tool_options: 工具选项(可选,默认为None) - max_tokens: 最大token数(可选,默认为1024) - temperature: 温度(可选,默认为0.7) - response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入) - stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) - async_response_parser: 响应解析函数(可选,默认为default_response_parser) - interrupt_flag: 中断信号量(可选,默认为None) + model_identifier: 原始模型标识。 + extra_params: 模型级额外参数。 + Returns: - APIResponse对象,包含响应内容、推理内容、工具调用等信息 + Tuple[str, bool]: `(实际模型标识, 是否启用 Google Search)`。 """ - if stream_response_handler is None: - stream_response_handler = _default_stream_response_handler - - if async_response_parser is None: - async_response_parser = _default_normal_response_parser - - # 将messages构造为Gemini API所需的格式 - messages = _convert_messages(message_list) - # 将tool_options转换为Gemini API所需的格式 - tools = _convert_tool_options(tool_options) if tool_options else None - # 解析并裁剪 thinking_budget - tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) - # 检测是否为带 -search 的模型 - enable_google_search = False - model_identifier = model_info.model_identifier - if model_identifier.endswith("-search"): + enable_google_search = bool(extra_params.get("enable_google_search", False)) + resolved_model_identifier = model_identifier + if resolved_model_identifier.endswith("-search"): + resolved_model_identifier = resolved_model_identifier.removesuffix("-search") enable_google_search = True - # 去掉后缀并更新模型ID - model_identifier = model_identifier.removesuffix("-search") - model_info.model_identifier = model_identifier - logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}") + return resolved_model_identifier, enable_google_search - # 将response_format转换为Gemini API所需的格式 - generation_config_dict = { - "max_output_tokens": max_tokens, - "temperature": temperature, - "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=tb, - ), - "safety_settings": gemini_safe_settings, # 防止空回复问题 - } - if tools: - generation_config_dict["tools"] = Tool(function_declarations=tools) - if messages[1]: - # 如果有system消息,则将其添加到配置中 - generation_config_dict["system_instructions"] = messages[1] - if response_format and response_format.format_type == RespFormatType.TEXT: - generation_config_dict["response_mime_type"] = "text/plain" - elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): - generation_config_dict["response_mime_type"] = "application/json" - generation_config_dict["response_schema"] = response_format.to_dict() - # 自动启用 GoogleSearch grounding_tool + def _build_generation_config( + self, + *, + model_identifier: str, + system_instruction: str | None, + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + max_tokens: int | None, + temperature: float | None, + extra_params: Dict[str, Any], + enable_google_search: bool, + ) -> GenerateContentConfig: + """构建 Gemini 生成配置。 + + Args: + model_identifier: 当前请求实际使用的模型标识。 + system_instruction: 系统指令文本。 + tool_options: 内部工具定义列表。 + response_format: 输出格式定义。 + max_tokens: 最大输出 token 数。 + temperature: 温度参数。 + extra_params: 模型级额外参数。 + enable_google_search: 是否自动追加 Google Search 工具。 + + Returns: + GenerateContentConfig: Gemini 生成配置对象。 + """ + config_payload = _filter_generate_content_extra_params(extra_params) + + if max_tokens is not None and "max_output_tokens" not in config_payload: + config_payload["max_output_tokens"] = max_tokens + if temperature is not None and "temperature" not in config_payload: + config_payload["temperature"] = temperature + if system_instruction and "system_instruction" not in config_payload: + config_payload["system_instruction"] = system_instruction + if "response_modalities" not in config_payload: + config_payload["response_modalities"] = ["TEXT"] + if "safety_settings" not in config_payload: + config_payload["safety_settings"] = GEMINI_SAFE_SETTINGS + if "thinking_config" not in config_payload: + config_payload["thinking_config"] = ThinkingConfig( + include_thoughts=bool(extra_params.get("include_thoughts", True)), + thinking_budget=self.clamp_thinking_budget(extra_params, model_identifier), + ) + + tools = _build_tools(tool_options) if tool_options else [] if enable_google_search: - grounding_tool = Tool(google_search=GoogleSearch()) - if "tools" in generation_config_dict: - existing = generation_config_dict["tools"] - if isinstance(existing, list): - existing.append(grounding_tool) + tools.append(Tool(google_search=GoogleSearch())) + if tools: + if "tools" in config_payload: + existing_tools = config_payload["tools"] + if isinstance(existing_tools, list): + config_payload["tools"] = [*existing_tools, *tools] else: - generation_config_dict["tools"] = [existing, grounding_tool] + config_payload["tools"] = [existing_tools, *tools] else: - generation_config_dict["tools"] = [grounding_tool] + config_payload["tools"] = tools - generation_config = GenerateContentConfig(**generation_config_dict) + if response_format is not None: + if response_format.format_type == RespFormatType.TEXT: + config_payload.setdefault("response_mime_type", "text/plain") + elif response_format.format_type == RespFormatType.JSON_OBJ: + config_payload.setdefault("response_mime_type", "application/json") + elif response_format.format_type == RespFormatType.JSON_SCHEMA: + config_payload.setdefault("response_mime_type", "application/json") + response_json_schema = _extract_response_json_schema(response_format) + if ( + response_json_schema is not None + and "response_json_schema" not in config_payload + and "response_schema" not in config_payload + ): + config_payload["response_json_schema"] = response_json_schema + + return GenerateContentConfig(**config_payload) + + def _build_default_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: + """构建 Gemini 默认流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: 默认流式处理器。 + """ + del request + return _default_stream_response_handler + + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[GenerateContentResponse]: + """构建 Gemini 默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[GenerateContentResponse]: 默认非流式解析器。 + """ + del request + return _default_normal_response_parser + + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]], + response_parser: ProviderResponseParser[GenerateContentResponse], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + contents, system_instruction = _convert_messages(request.message_list) + model_identifier, enable_google_search = self._resolve_model_identifier( + model_info.model_identifier, + request.extra_params, + ) + generation_config = self._build_generation_config( + model_identifier=model_identifier, + system_instruction=system_instruction, + tool_options=request.tool_options, + response_format=request.response_format, + max_tokens=request.max_tokens, + temperature=request.temperature, + extra_params=request.extra_params, + enable_google_search=enable_google_search, + ) try: if model_info.force_stream_mode: - req_task = asyncio.create_task( + stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task( self.client.aio.models.generate_content_stream( - model=model_info.model_identifier, - contents=messages[0], + model=model_identifier, + contents=contents, config=generation_config, ) ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) - else: - req_task = asyncio.create_task( - self.client.aio.models.generate_content( - model=model_info.model_identifier, - contents=messages[0], - config=generation_config, - ) + raw_response_stream = cast( + AsyncIterator[GenerateContentResponse], + await await_task_with_interrupt(stream_task, request.interrupt_flag), ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + return await stream_response_handler(raw_response_stream, request.interrupt_flag) - resp, usage_record = async_response_parser(req_task.result()) - except (ClientError, ServerError) as e: - # 重封装 ClientError 和 ServerError 为 RespNotOkException - raise RespNotOkException(e.code, e.message) from None - except ( - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, - ) as e: - # 工具调用相关错误 - raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None - except EmptyResponseException as e: - # 保持原始异常,便于区分“空响应”和网络异常 - raise e - except Exception as e: - # 其他未预料的错误,才归为网络连接类 - raise NetworkConnectionError() from e - - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], + completion_task: asyncio.Task[GenerateContentResponse] = asyncio.create_task( + self.client.aio.models.generate_content( + model=model_identifier, + contents=contents, + config=generation_config, + ) ) + raw_response = cast( + GenerateContentResponse, + await await_task_with_interrupt(completion_task, request.interrupt_flag), + ) + return response_parser(raw_response) + except ReqAbortException: + raise + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc: + raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc + except EmptyResponseException: + raise + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - return resp - - async def get_embedding( + async def _execute_embedding_request( self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 文本嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + embedding_input = request.embedding_input + extra_params = request.extra_params + embed_config = _build_embed_content_config(extra_params) + try: raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( model=model_info.model_identifier, contents=embedding_input, - config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + config=embed_config, ) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code) from None - except Exception as e: - raise NetworkConnectionError() from e + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - response = APIResponse() - - # 解析嵌入响应和使用情况 - if hasattr(raw_response, "embeddings") and raw_response.embeddings: + response = APIResponse(raw_data=raw_response) + if raw_response.embeddings: response.embedding = raw_response.embeddings[0].values else: - raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") + raise RespParseException(raw_response, "响应解析失败,缺失 embeddings 字段") - response.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=len(embedding_input), - completion_tokens=0, - total_tokens=len(embedding_input), + billable_character_count = 0 + if raw_response.metadata is not None: + billable_character_count = getattr(raw_response.metadata, "billable_character_count", 0) or 0 + usage_record: UsageTuple = ( + billable_character_count or len(embedding_input), + 0, + billable_character_count or len(embedding_input), ) + return response, usage_record - return response - - async def get_audio_transcriptions( + async def _execute_audio_transcription_request( self, - model_info: ModelInfo, - audio_base64: str, - max_tokens: Optional[int] = 2048, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: 音频文件的Base64编码字符串 - :param max_tokens: 最大输出token数(默认2048) - :param extra_params: 额外参数(可选) - :return: 转录响应 - """ - # 解析并裁剪 thinking_budget - tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 音频转录请求。 - # 构造 prompt + 音频输入 - prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." - contents = [ + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + audio_base64 = request.audio_base64 + max_tokens = request.max_tokens + extra_params = request.extra_params + model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params) + + transcription_prompt = str( + extra_params.get( + "transcription_prompt", + "Generate a transcript of the speech. The language of the transcript should match the speech.", + ) + ) + audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav")) + contents: List[ContentUnion] = [ Content( role="user", parts=[ - Part.from_text(text=prompt), - Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + Part.from_text(text=transcription_prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type), ], ) ] - - generation_config_dict = { - "max_output_tokens": max_tokens, - "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=tb, - ), - "safety_settings": gemini_safe_settings, - } - generate_content_config = GenerateContentConfig(**generation_config_dict) + generation_config = self._build_generation_config( + model_identifier=model_identifier, + system_instruction=None, + tool_options=None, + response_format=None, + max_tokens=max_tokens, + temperature=None, + extra_params=extra_params, + enable_google_search=False, + ) try: raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( - model=model_info.model_identifier, + model=model_identifier, contents=contents, - config=generate_content_config, + config=generation_config, ) - resp, usage_record = _default_normal_response_parser(raw_response) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code) from None - except Exception as e: - raise NetworkConnectionError() from e + response, usage_record = _default_normal_response_parser(raw_response) + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], - ) + return response, usage_record - return resp + def get_support_image_formats(self) -> List[str]: + """获取 Gemini 当前支持的图片格式列表。 - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + Returns: + List[str]: 当前客户端支持的图片格式列表。 """ return ["png", "jpg", "jpeg", "webp", "heic", "heif"] diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 99efe8d9..47f75263 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -1,742 +1,1000 @@ +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast + import asyncio +import base64 import io import json import re -import base64 -from collections.abc import Iterable -from typing import Callable, Any, Coroutine, Optional -from json_repair import repair_json -from openai import ( - AsyncOpenAI, - APIConnectionError, - APIStatusError, - NOT_GIVEN, - AsyncStream, -) +from json_repair import repair_json +from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream +from openai._types import FileTypes, Omit, omit from openai.types.chat import ( ChatCompletion, + ChatCompletionAssistantMessageParam, ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageFunctionToolCallParam, ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, ChatCompletionToolParam, + ChatCompletionUserMessageParam, ) +from openai.types.shared_params.function_definition import FunctionDefinition from openai.types.chat.chat_completion_chunk import ChoiceDelta -from src.config.model_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry -from ..exceptions import ( - RespParseException, - NetworkConnectionError, - RespNotOkException, - ReqAbortException, +from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode +from src.llm_models.exceptions import ( EmptyResponseException, + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from src.llm_models.openai_compat import ( + build_openai_compatible_client_config, + split_openai_request_overrides, +) +from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart +from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption + +from .adapter_base import ( + AdapterClient, + ProviderResponseParser, + ProviderStreamResponseHandler, + await_task_with_interrupt, +) +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + EmbeddingRequest, + ResponseRequest, + UsageTuple, + client_registry, ) -from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("llm_models") +THINK_CONTENT_PATTERN = re.compile( + r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", + re.DOTALL, +) +"""用于解析 `` 推理块的正则表达式。""" + +CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = { + "max_tokens", + "messages", + "model", + "response_format", + "stream", + "temperature", + "tools", +} +"""由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。""" + +OpenAIStreamResponseHandler = Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]], +] +"""OpenAI 流式响应处理函数类型。""" + +OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]] +"""OpenAI 非流式响应解析函数类型。""" + + +def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: + """将配置中的推理解析模式收敛为枚举值。 + + Args: + parse_mode: 原始解析模式配置。 + + Returns: + ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 + """ + if isinstance(parse_mode, ReasoningParseMode): + return parse_mode + try: + return ReasoningParseMode(parse_mode) + except ValueError: + logger.warning("未识别的推理解析模式 %s,已回退为 auto", parse_mode) + return ReasoningParseMode.AUTO + + +def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode: + """将配置中的工具参数解析模式收敛为枚举值。 + + Args: + parse_mode: 原始解析模式配置。 + + Returns: + ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 + """ + if isinstance(parse_mode, ToolArgumentParseMode): + return parse_mode + try: + return ToolArgumentParseMode(parse_mode) + except ValueError: + logger.warning("未识别的工具参数解析模式 %s,已回退为 auto", parse_mode) + return ToolArgumentParseMode.AUTO + + +def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam: + """构建文本内容片段。 + + Args: + text: 文本内容。 + + Returns: + ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。 + """ + return { + "type": "text", + "text": text, + } + + +def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam: + """构建图片内容片段。 + + Args: + part: 内部图片片段。 + + Returns: + ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。 + """ + return { + "type": "image_url", + "image_url": { + "url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}", + }, + } + def _convert_response_format(response_format: RespFormat | None) -> Any: - """ - 转换响应格式 - 将内部RespFormat转换为OpenAI API所需格式 - """ - if response_format is None: - return NOT_GIVEN + """将内部响应格式转换为 OpenAI 兼容结构。 - if response_format.format_type == RespFormatType.TEXT: - return NOT_GIVEN + Args: + response_format: 内部响应格式定义。 + Returns: + Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。 + """ + if response_format is None or response_format.format_type == RespFormatType.TEXT: + return omit if response_format.format_type == RespFormatType.JSON_OBJ: return {"type": "json_object"} - if response_format.format_type == RespFormatType.JSON_SCHEMA: return { "type": "json_schema", "json_schema": response_format.schema, } - - return NOT_GIVEN + return omit -def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: +def _convert_text_only_message_content( + message: Message, +) -> str | List[ChatCompletionContentPartTextParam]: + """将仅允许文本的消息转换为 OpenAI 兼容内容。 + + Args: + message: 内部统一消息对象。 + + Returns: + str | List[ChatCompletionContentPartTextParam]: 文本内容结构。 + + Raises: + ValueError: 当消息中包含非文本片段时抛出。 """ - 转换消息格式 - 将消息转换为OpenAI API所需的格式 - :param messages: 消息列表 - :return: 转换后的消息列表 + if not message.parts: + return "" + if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): + return message.parts[0].text + + content: List[ChatCompletionContentPartTextParam] = [] + for part in message.parts: + if not isinstance(part, TextMessagePart): + raise ValueError(f"{message.role.value} 消息仅支持文本片段") + content.append(_build_text_content_part(part.text)) + return content + + +def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]: + """将用户消息转换为 OpenAI 兼容内容。 + + Args: + message: 内部统一消息对象。 + + Returns: + str | List[ChatCompletionContentPartParam]: 用户消息内容结构。 """ + if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): + return message.parts[0].text - def _convert_message_item(message: Message) -> ChatCompletionMessageParam: - """ - 转换单个消息格式 - :param message: 消息对象 - :return: 转换后的消息字典 - """ + content: List[ChatCompletionContentPartParam] = [] + for part in message.parts: + if isinstance(part, TextMessagePart): + content.append(_build_text_content_part(part.text)) + continue + content.append(_build_image_content_part(part)) + return content - # 添加Content - content: str | list[dict[str, Any]] - if isinstance(message.content, str): - content = message.content - elif isinstance(message.content, list): - content = [] - for item in message.content: - if isinstance(item, tuple): - image_format = item[0].lower() - # 规范 JPEG MIME 类型后缀,统一使用 image/jpeg - if image_format in ("jpg", "jpeg"): - mime_suffix = "jpeg" - else: - mime_suffix = image_format - content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"}, - } - ) - elif isinstance(item, str): - content.append({"type": "text", "text": item}) - else: - raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - ret = { - "role": message.role.value, - "content": content, - } +def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]: + """将内部工具调用转换为 OpenAI assistant tool_calls 结构。 - if message.role == RoleType.Assistant and getattr(message, "tool_calls", None): - tool_calls_payload: list[dict[str, Any]] = [] - for call in message.tool_calls or []: - tool_calls_payload.append( - { - "id": call.call_id, - "type": "function", - "function": { - "name": call.func_name, - "arguments": json.dumps(call.args or {}, ensure_ascii=False), - }, - } - ) - ret["tool_calls"] = tool_calls_payload - if ret["content"] == []: - ret["content"] = "" + Args: + tool_calls: 内部工具调用列表。 + + Returns: + List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。 + """ + converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = [] + for tool_call in tool_calls: + converted_tool_calls.append( + { + "id": tool_call.call_id, + "type": "function", + "function": { + "name": tool_call.func_name, + "arguments": json.dumps(tool_call.args or {}, ensure_ascii=False), + }, + } + ) + return converted_tool_calls + + +def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]: + """将内部消息列表转换为 OpenAI 兼容消息列表。 + + Args: + messages: 内部统一消息列表。 + + Returns: + List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。 + """ + converted_messages: List[ChatCompletionMessageParam] = [] + for message in messages: + if message.role == RoleType.System: + system_payload: ChatCompletionSystemMessageParam = { + "role": "system", + "content": _convert_text_only_message_content(message), + } + converted_messages.append(system_payload) + continue + + if message.role == RoleType.User: + user_payload: ChatCompletionUserMessageParam = { + "role": "user", + "content": _convert_user_message_content(message), + } + converted_messages.append(user_payload) + continue + + if message.role == RoleType.Assistant: + assistant_payload: ChatCompletionAssistantMessageParam = { + "role": "assistant", + "content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message), + } + if message.tool_calls: + assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls) + converted_messages.append(assistant_payload) + continue - # 添加工具调用ID if message.role == RoleType.Tool: - if not message.tool_call_id: - raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") - ret["tool_call_id"] = message.tool_call_id + if message.tool_call_id is None: + raise ValueError("Tool 消息缺少 tool_call_id") + tool_payload: ChatCompletionToolMessageParam = { + "role": "tool", + "content": _convert_text_only_message_content(message), + "tool_call_id": message.tool_call_id, + } + converted_messages.append(tool_payload) + continue - return ret # type: ignore + raise ValueError(f"不支持的消息角色:{message.role}") - return [_convert_message_item(message) for message in messages] + return converted_messages -def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]: +def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]: + """将工具定义转换为 OpenAI 兼容的工具列表。 + + Args: + tool_options: 内部统一工具定义列表。 + + Returns: + List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。 """ - 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式 - :param tool_options: 工具选项列表 - :return: 转换后的工具选项列表 - """ - - def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]: - """ - 转换单个工具参数格式 - :param tool_option_param: 工具参数对象 - :return: 转换后的工具参数字典 - """ - # JSON Schema 类型名称修正: - # - 布尔类型使用 "boolean" 而不是 "bool" - # - 浮点数使用 "number" 而不是 "float" - param_type_value = tool_option_param.param_type.value - if param_type_value == "bool": - param_type_value = "boolean" - elif param_type_value == "float": - param_type_value = "number" - - return_dict: dict[str, Any] = { - "type": param_type_value, - "description": tool_option_param.description, - } - if tool_option_param.enum_values: - return_dict["enum"] = tool_option_param.enum_values - return return_dict - - def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: - """ - 转换单个工具项格式 - :param tool_option: 工具选项对象 - :return: 转换后的工具选项字典 - """ - ret: dict[str, Any] = { + converted_tools: List[ChatCompletionToolParam] = [] + for tool_option in tool_options: + function_schema: FunctionDefinition = { "name": tool_option.name, "description": tool_option.description, } - if tool_option.params: - ret["parameters"] = { - "type": "object", - "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, - "required": [param.name for param in tool_option.params if param.required], + parameters_schema = tool_option.parameters_schema + if parameters_schema is not None: + function_schema["parameters"] = cast(Dict[str, object], parameters_schema) + converted_tools.append( + { + "type": "function", + "function": function_schema, } - return ret - - return [ - { - "type": "function", - "function": _convert_tool_option_item(tool_option), - } - for tool_option in tool_options - ] + ) + return converted_tools -def _process_delta( - delta: ChoiceDelta, - has_rc_attr_flag: bool, - in_rc_flag: bool, - rc_delta_buffer: io.StringIO, - fc_delta_buffer: io.StringIO, - tool_calls_buffer: list[tuple[str, str, io.StringIO]], -) -> bool: - # 接收content - if has_rc_attr_flag: - # 有独立的推理内容块,则无需考虑content内容的判读 - if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore - # 如果有推理内容,则将其写入推理内容缓冲区 - assert isinstance(delta.reasoning_content, str) # type: ignore - rc_delta_buffer.write(delta.reasoning_content) # type: ignore - elif delta.content: - # 如果有正式内容,则将其写入正式内容缓冲区 - fc_delta_buffer.write(delta.content) - elif hasattr(delta, "content") and delta.content is not None: - # 没有独立的推理内容块,但有正式内容 - if in_rc_flag: - # 当前在推理内容块中 - if delta.content == "": - # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块 - in_rc_flag = False - else: - # 其他情况视为推理内容,加入推理内容缓冲区 - rc_delta_buffer.write(delta.content) - elif delta.content == "" and not fc_delta_buffer.getvalue(): - # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token - # 则将其视为推理内容的开始标记,进入推理内容块 - in_rc_flag = True +def _extract_usage_record(usage: Any) -> UsageTuple | None: + """从响应对象中提取 usage 三元组。 + + Args: + usage: OpenAI SDK 返回的 usage 对象。 + + Returns: + UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。 + """ + if usage is None: + return None + return ( + getattr(usage, "prompt_tokens", 0) or 0, + getattr(usage, "completion_tokens", 0) or 0, + getattr(usage, "total_tokens", 0) or 0, + ) + + +def _parse_tool_arguments( + raw_arguments: str, + parse_mode: ToolArgumentParseMode, + response: Any, +) -> Dict[str, Any]: + """解析工具调用参数字符串。 + + Args: + raw_arguments: 工具调用参数原始字符串。 + parse_mode: 参数解析模式。 + response: 原始响应对象,用于异常上下文。 + + Returns: + Dict[str, Any]: 解析后的参数字典。 + + Raises: + RespParseException: 当参数无法解析为字典时抛出。 + """ + try: + if parse_mode == ToolArgumentParseMode.STRICT: + arguments: Any = json.loads(raw_arguments) + elif parse_mode == ToolArgumentParseMode.REPAIR: + arguments = repair_json(raw_arguments, return_objects=True, logging=False) else: - # 其他情况视为正式内容,加入正式内容缓冲区 - fc_delta_buffer.write(delta.content) - # 接收tool_calls - if hasattr(delta, "tool_calls") and delta.tool_calls: - tool_call_delta = delta.tool_calls[0] + arguments = repair_json(raw_arguments, return_objects=True, logging=False) + if isinstance(arguments, str) and parse_mode in { + ToolArgumentParseMode.AUTO, + ToolArgumentParseMode.DOUBLE_DECODE, + }: + arguments = repair_json(arguments, return_objects=True, logging=False) + except json.JSONDecodeError as exc: + raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc - if tool_call_delta.index >= len(tool_calls_buffer): - # 调用索引号大于等于缓冲区长度,说明是新的工具调用 - if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name: - tool_calls_buffer.append( - ( - tool_call_delta.id, - tool_call_delta.function.name, - io.StringIO(), - ) + if not isinstance(arguments, dict): + raise RespParseException( + response, + f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}。", + ) + return arguments + + +def _extract_reasoning_and_content( + content: str, + parse_mode: ReasoningParseMode, +) -> Tuple[str | None, str | None]: + """从文本内容中提取推理内容与正式输出。 + + Args: + content: 模型返回的文本内容。 + parse_mode: 推理解析模式。 + + Returns: + Tuple[str | None, str | None]: `(reasoning_content, content)`。 + """ + if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}: + return None, content + + match = THINK_CONTENT_PATTERN.match(content) + if not match: + return None, content + if match.group("think") is not None: + reasoning_content = match.group("think").strip() or None + final_content = match.group("content").strip() or None + return reasoning_content, final_content + if match.group("think_unclosed") is not None: + return match.group("think_unclosed").strip() or None, None + return None, match.group("content_only").strip() or None + + +def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None: + """记录因长度截断导致的告警日志。 + + Args: + finish_reason: OpenAI 兼容接口返回的完成原因。 + model_name: 上游返回的模型标识。 + """ + if finish_reason == "length": + logger.info("模型%s因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整", model_name or "") + + +def _coerce_openai_argument(value: Any) -> Any | Omit: + """将可选参数转换为 OpenAI SDK 期望的值。 + + Args: + value: 原始参数值。 + + Returns: + Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。 + """ + if value is None: + return omit + return value + + +def _build_api_status_message(error: APIStatusError) -> str: + """构建更适合记录和展示的状态错误信息。 + + Args: + error: OpenAI SDK 抛出的状态错误。 + + Returns: + str: 拼装后的错误信息。 + """ + message_parts: List[str] = [] + if getattr(error, "message", None): + message_parts.append(str(error.message)) + response_text = getattr(getattr(error, "response", None), "text", None) + if response_text: + message_parts.append(str(response_text)[:300]) + if message_parts: + return " | ".join(message_parts) + return f"上游接口返回状态码 {error.status_code}" + + +@dataclass(slots=True) +class _StreamedToolCallState: + """流式工具调用累积状态。""" + + index: int + call_id: str = "" + function_name: str = "" + arguments_buffer: io.StringIO = field(default_factory=io.StringIO) + + def append_arguments(self, arguments_chunk: str) -> None: + """追加一段工具调用参数字符串。 + + Args: + arguments_chunk: 参数增量片段。 + """ + self.arguments_buffer.write(arguments_chunk) + + def close(self) -> None: + """关闭内部缓存。""" + if not self.arguments_buffer.closed: + self.arguments_buffer.close() + + +class _OpenAIStreamAccumulator: + """OpenAI 兼容流式响应累积器。""" + + def __init__( + self, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, + ) -> None: + """初始化累积器。 + + Args: + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + """ + self.reasoning_parse_mode = reasoning_parse_mode + self.tool_argument_parse_mode = tool_argument_parse_mode + self.reasoning_buffer = io.StringIO() + self.content_buffer = io.StringIO() + self.tool_call_states: Dict[int, _StreamedToolCallState] = {} + self.finish_reason: str | None = None + self.model_name: str | None = None + self._using_native_reasoning = False + + def capture_event_metadata(self, event: ChatCompletionChunk) -> None: + """捕获事件中的完成原因和模型名。 + + Args: + event: 当前流式事件。 + """ + if getattr(event, "model", None) and not self.model_name: + self.model_name = event.model + if getattr(event, "choices", None): + finish_reason = getattr(event.choices[0], "finish_reason", None) + if finish_reason: + self.finish_reason = finish_reason + + def process_delta(self, delta: ChoiceDelta) -> None: + """处理一个增量块。 + + Args: + delta: 当前增量对象。 + """ + self._process_reasoning_delta(delta) + self._process_tool_call_delta(delta) + + def _process_reasoning_delta(self, delta: ChoiceDelta) -> None: + """处理推理内容与正式内容。 + + Args: + delta: 当前增量对象。 + """ + native_reasoning = getattr(delta, "reasoning_content", None) + if isinstance(native_reasoning, str) and native_reasoning: + self._using_native_reasoning = True + if self.reasoning_parse_mode != ReasoningParseMode.NONE: + self.reasoning_buffer.write(native_reasoning) + return + + content_chunk = getattr(delta, "content", None) + if not isinstance(content_chunk, str) or content_chunk == "": + return + + if self.reasoning_parse_mode == ReasoningParseMode.NONE: + self.content_buffer.write(content_chunk) + return + + if self.reasoning_parse_mode == ReasoningParseMode.NATIVE: + self.content_buffer.write(content_chunk) + return + + self.content_buffer.write(content_chunk) + + def _process_tool_call_delta(self, delta: ChoiceDelta) -> None: + """处理工具调用增量。 + + Args: + delta: 当前增量对象。 + """ + tool_call_deltas = getattr(delta, "tool_calls", None) or [] + for tool_call_delta in tool_call_deltas: + state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index)) + if tool_call_delta.id: + state.call_id = tool_call_delta.id + function = tool_call_delta.function + if function is not None and function.name: + state.function_name = function.name + if function is not None and function.arguments: + state.append_arguments(function.arguments) + + def build_response(self) -> APIResponse: + """构建最终 APIResponse 对象。 + + Returns: + APIResponse: 累积完成的响应对象。 + + Raises: + EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。 + RespParseException: 当工具调用结构不完整时抛出。 + """ + response = APIResponse() + + content = self.content_buffer.getvalue().strip() + reasoning_content = self.reasoning_buffer.getvalue().strip() + if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content: + parsed_reasoning_content, parsed_content = _extract_reasoning_and_content( + content=content, + parse_mode=self.reasoning_parse_mode, + ) + if parsed_reasoning_content: + reasoning_content = parsed_reasoning_content + content = parsed_content or "" + if reasoning_content: + response.reasoning_content = reasoning_content + if content: + response.content = content + + if self.tool_call_states: + response.tool_calls = [] + for index in sorted(self.tool_call_states): + state = self.tool_call_states[index] + if not state.function_name: + raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。") + raw_arguments = state.arguments_buffer.getvalue().strip() + arguments = ( + _parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None) + if raw_arguments + else None ) - else: - logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。") + call_id = state.call_id or f"tool_call_{index}" + response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) - if tool_call_delta.function and tool_call_delta.function.arguments: - # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 - tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments) + response.raw_data = {"model": self.model_name} if self.model_name else None - return in_rc_flag + if not response.content and not response.tool_calls: + raise EmptyResponseException() + return response -def _build_stream_api_resp( - _fc_delta_buffer: io.StringIO, - _rc_delta_buffer: io.StringIO, - _tool_calls_buffer: list[tuple[str, str, io.StringIO]], - finish_reason: str | None = None, -) -> APIResponse: - resp = APIResponse() - - if _rc_delta_buffer.tell() > 0: - # 如果推理内容缓冲区不为空,则将其写入APIResponse对象 - resp.reasoning_content = _rc_delta_buffer.getvalue() - _rc_delta_buffer.close() - if _fc_delta_buffer.tell() > 0: - # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 - resp.content = _fc_delta_buffer.getvalue() - _fc_delta_buffer.close() - if _tool_calls_buffer: - # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 - resp.tool_calls = [] - for call_id, function_name, arguments_buffer in _tool_calls_buffer: - if arguments_buffer.tell() > 0: - # 如果参数串缓冲区不为空,则解析为JSON对象 - raw_arg_data = arguments_buffer.getvalue() - arguments_buffer.close() - try: - arguments = json.loads(repair_json(raw_arg_data)) - if not isinstance(arguments, dict): - raise RespParseException( - None, - f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}", - ) - except json.JSONDecodeError as e: - raise RespParseException( - None, - f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}", - ) from e - else: - arguments_buffer.close() - arguments = None - - resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) - - # 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出) - # 保留 finish_reason 仅用于上层判断 - - if not resp.content and not resp.tool_calls: - raise EmptyResponseException() - - return resp + def close(self) -> None: + """关闭内部缓冲区。""" + if not self.reasoning_buffer.closed: + self.reasoning_buffer.close() + if not self.content_buffer.closed: + self.content_buffer.close() + for state in self.tool_call_states.values(): + state.close() async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + *, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, +) -> Tuple[APIResponse, UsageTuple | None]: + """处理 OpenAI 兼容流式响应。 + + Args: + resp_stream: OpenAI SDK 返回的流式响应对象。 + interrupt_flag: 外部中断标记。 + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 """ - 流式响应处理函数 - 处理OpenAI API的流式响应 - :param resp_stream: 流式响应对象 - :return: APIResponse对象 - """ - - _has_rc_attr_flag = False # 标记是否有独立的推理内容块 - _in_rc_flag = False # 标记是否在推理内容块中 - _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 - _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 - _usage_record = None # 使用情况记录 - finish_reason: str | None = None # 记录最后的 finish_reason - _model_name: str | None = None # 记录模型名 - - def _insure_buffer_closed(): - # 确保缓冲区被关闭 - if _rc_delta_buffer and not _rc_delta_buffer.closed: - _rc_delta_buffer.close() - if _fc_delta_buffer and not _fc_delta_buffer.closed: - _fc_delta_buffer.close() - for _, _, buffer in _tool_calls_buffer: - if buffer and not buffer.closed: - buffer.close() - - async for event in resp_stream: - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量被设置,则抛出ReqAbortException - _insure_buffer_closed() - raise ReqAbortException("请求被外部信号中断") - # 空 choices / usage-only 帧的防御 - if not hasattr(event, "choices") or not event.choices: - if hasattr(event, "usage") and event.usage: - _usage_record = ( - event.usage.prompt_tokens or 0, - event.usage.completion_tokens or 0, - event.usage.total_tokens or 0, - ) - continue # 跳过本帧,避免访问 choices[0] - delta = event.choices[0].delta # 获取当前块的delta内容 - - if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason: - finish_reason = event.choices[0].finish_reason - - if hasattr(event, "model") and event.model and not _model_name: - _model_name = event.model # 记录模型名 - - if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore - # 标记:有独立的推理内容块 - _has_rc_attr_flag = True - - _in_rc_flag = _process_delta( - delta, - _has_rc_attr_flag, - _in_rc_flag, - _rc_delta_buffer, - _fc_delta_buffer, - _tool_calls_buffer, - ) - - if event.usage: - # 如果有使用情况,则将其存储在APIResponse对象中 - _usage_record = ( - event.usage.prompt_tokens or 0, - event.usage.completion_tokens or 0, - event.usage.total_tokens or 0, - ) + accumulator = _OpenAIStreamAccumulator( + reasoning_parse_mode=reasoning_parse_mode, + tool_argument_parse_mode=tool_argument_parse_mode, + ) + usage_record: UsageTuple | None = None try: - resp = _build_stream_api_resp( - _fc_delta_buffer, - _rc_delta_buffer, - _tool_calls_buffer, - finish_reason=finish_reason, - ) - # 统一在这里输出 max_tokens 截断的警告,并从 resp 中读取 - if finish_reason == "length": - # 把模型名塞到 resp.raw_data,后续严格“从 resp 提取” - try: - if _model_name: - resp.raw_data = {"model": _model_name} - except Exception: - pass - model_dbg = None - try: - if isinstance(resp.raw_data, dict): - model_dbg = resp.raw_data.get("model") - except Exception: - model_dbg = None + async for event in resp_stream: + if interrupt_flag and interrupt_flag.is_set(): + raise ReqAbortException("请求被外部信号中断") - # 统一日志格式 - logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (model_dbg or "")) + accumulator.capture_event_metadata(event) + event_usage = _extract_usage_record(getattr(event, "usage", None)) + if event_usage is not None: + usage_record = event_usage - return resp, _usage_record - except Exception: - # 确保缓冲区被关闭 - _insure_buffer_closed() - raise + if not getattr(event, "choices", None): + continue + accumulator.process_delta(event.choices[0].delta) -pattern = re.compile( - r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", - re.DOTALL, -) -"""用于解析推理内容的正则表达式""" + response = accumulator.build_response() + model_name = None + if isinstance(response.raw_data, dict): + model_name = response.raw_data.get("model") + _log_length_truncation(accumulator.finish_reason, model_name) + return response, usage_record + finally: + accumulator.close() def _default_normal_response_parser( resp: ChatCompletion, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: - """ - 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 - :param resp: 响应对象 - :return: APIResponse对象 - """ - api_response = APIResponse() + *, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, +) -> Tuple[APIResponse, UsageTuple | None]: + """解析 OpenAI 兼容的非流式响应。 - # 兼容部分 OpenAI 兼容服务在空回复时返回 choices=None 的情况 + Args: + resp: OpenAI SDK 返回的聊天补全响应。 + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 + + Raises: + EmptyResponseException: 当 choices 为空或响应内容为空时抛出。 + """ choices = getattr(resp, "choices", None) if not choices: - try: - model_dbg = getattr(resp, "model", None) - id_dbg = getattr(resp, "id", None) - usage_dbg = None - if hasattr(resp, "usage") and resp.usage: - usage_dbg = { - "prompt": getattr(resp.usage, "prompt_tokens", None), - "completion": getattr(resp.usage, "completion_tokens", None), - "total": getattr(resp.usage, "total_tokens", None), - } - try: - raw_snippet = str(resp)[:300] - except Exception: - raw_snippet = "" - logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}") - except Exception: - # 日志采集失败不应影响控制流 - pass - # 统一抛出可重试的 EmptyResponseException,触发上层重试逻辑 raise EmptyResponseException("响应解析失败,choices 为空或缺失") + + api_response = APIResponse() message_part = choices[0].message + native_reasoning = getattr(message_part, "reasoning_content", None) + message_content = message_part.content if isinstance(message_part.content, str) else None - if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore - # 有有效的推理字段 - api_response.content = message_part.content - api_response.reasoning_content = message_part.reasoning_content # type: ignore - elif message_part.content: - # 提取推理和内容 - match = pattern.match(message_part.content) - if not match: - raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容") - if match.group("think") is not None: - result = match.group("think").strip(), match.group("content").strip() - elif match.group("think_unclosed") is not None: - result = match.group("think_unclosed").strip(), None - else: - result = None, match.group("content_only").strip() - api_response.reasoning_content, api_response.content = result - - # 提取工具调用 - if message_part.tool_calls: - api_response.tool_calls = [] - for call in message_part.tool_calls: - try: - arguments = json.loads(repair_json(call.function.arguments)) - # 【新增修复逻辑】如果解析出来还是字符串,说明发生了双重编码,尝试二次解析 - if isinstance(arguments, str): - try: - # 尝试对字符串内容再次进行修复和解析 - arguments = json.loads(repair_json(arguments)) - except Exception: - # 如果二次解析失败,保留原值,让下方的 isinstance(dict) 抛出更具体的错误 - pass - if not isinstance(arguments, dict): - # 此时为了调试方便,建议打印出 arguments 的类型 - raise RespParseException( - resp, - f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}", - ) - api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) - except json.JSONDecodeError as e: - raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e - - # 提取Usage信息 - if resp.usage: - _usage_record = ( - resp.usage.prompt_tokens or 0, - resp.usage.completion_tokens or 0, - resp.usage.total_tokens or 0, + if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE: + api_response.reasoning_content = native_reasoning + api_response.content = message_content + elif isinstance(message_content, str) and message_content: + reasoning_content, final_content = _extract_reasoning_and_content( + content=message_content, + parse_mode=reasoning_parse_mode, ) - else: - _usage_record = None + api_response.reasoning_content = reasoning_content + api_response.content = final_content - # 将原始响应存储在原始数据中 + tool_calls = getattr(message_part, "tool_calls", None) or [] + if tool_calls: + api_response.tool_calls = [] + for tool_call in tool_calls: + if tool_call.type != "function": + raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}。") + raw_arguments = tool_call.function.arguments or "" + arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp) + api_response.tool_calls.append( + ToolCall( + call_id=tool_call.id, + func_name=tool_call.function.name, + args=arguments, + ) + ) + + usage_record = _extract_usage_record(getattr(resp, "usage", None)) api_response.raw_data = resp - # 检查 max_tokens 截断 - try: - choice0 = resp.choices[0] - reason = getattr(choice0, "finish_reason", None) - if reason and reason == "length": - # print(resp) - _model_name = resp.model - # 统一日志格式 - logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (_model_name or "")) - return api_response, _usage_record - except Exception as e: - logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") + finish_reason = getattr(resp.choices[0], "finish_reason", None) + _log_length_truncation(finish_reason, getattr(resp, "model", None)) if not api_response.content and not api_response.tool_calls: raise EmptyResponseException() - return api_response, _usage_record + return api_response, usage_record @client_registry.register_client_class("openai") -class OpenaiClient(BaseClient): - def __init__(self, api_provider: APIProvider): +class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]): + """OpenAI 兼容客户端。""" + + client: AsyncOpenAI + reasoning_parse_mode: ReasoningParseMode + tool_argument_parse_mode: ToolArgumentParseMode + + def __init__(self, api_provider: APIProvider) -> None: + """初始化 OpenAI 兼容客户端。 + + Args: + api_provider: API 提供商配置。 + """ super().__init__(api_provider) - self.client: AsyncOpenAI = AsyncOpenAI( - base_url=api_provider.base_url, - api_key=api_provider.api_key, - max_retries=0, + client_config = build_openai_compatible_client_config(api_provider) + self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode) + self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode) + self.client = AsyncOpenAI( + api_key=client_config.api_key, + organization=api_provider.organization, + project=api_provider.project, + base_url=client_config.base_url, timeout=api_provider.timeout, + max_retries=api_provider.max_retry, + default_headers=client_config.default_headers or None, + default_query=client_config.default_query or None, ) - async def get_response( + def _build_default_stream_response_handler( self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = 1024, - temperature: Optional[float] = 0.7, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: + """构建 OpenAI 默认流式响应处理器。 + Args: - model_info: 模型信息 - message_list: 对话体 - tool_options: 工具选项(可选,默认为None) - max_tokens: 最大token数(可选,默认为1024) - temperature: 温度(可选,默认为0.7) - response_format: 响应格式(可选,默认为 NotGiven ) - stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) - async_response_parser: 响应解析函数(可选,默认为default_response_parser) - interrupt_flag: 中断信号量(可选,默认为None) + request: 统一响应请求对象。 + Returns: - (响应文本, 推理文本, 工具调用, 其他数据) + ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。 """ - if stream_response_handler is None: - stream_response_handler = _default_stream_response_handler + del request - if async_response_parser is None: - async_response_parser = _default_normal_response_parser + async def default_stream_handler( + resp_stream: AsyncStream[ChatCompletionChunk], + flag: asyncio.Event | None, + ) -> Tuple[APIResponse, UsageTuple | None]: + """包装默认流式解析器。""" + return await _default_stream_response_handler( + resp_stream, + flag, + reasoning_parse_mode=self.reasoning_parse_mode, + tool_argument_parse_mode=self.tool_argument_parse_mode, + ) - # 将messages构造为OpenAI API所需的格式 - messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) - # 将tool_options转换为OpenAI API所需的格式 - tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore - openai_response_format = _convert_response_format(response_format) + return default_stream_handler + + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[ChatCompletion]: + """构建 OpenAI 默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[ChatCompletion]: 默认非流式解析器。 + """ + del request + + def default_response_parser( + response: ChatCompletion, + ) -> Tuple[APIResponse, UsageTuple | None]: + """包装默认非流式解析器。""" + return _default_normal_response_parser( + response, + reasoning_parse_mode=self.reasoning_parse_mode, + tool_argument_parse_mode=self.tool_argument_parse_mode, + ) + + return default_response_parser + + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]], + response_parser: ProviderResponseParser[ChatCompletion], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list) + tools: Iterable[ChatCompletionToolParam] | Omit = ( + _convert_tool_options(request.tool_options) if request.tool_options else omit + ) + openai_response_format = _convert_response_format(request.response_format) + request_overrides = split_openai_request_overrides( + request.extra_params, + reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS, + ) + + temperature_argument = ( + omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature) + ) + max_tokens_argument = ( + omit + if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body + else _coerce_openai_argument(request.max_tokens) + ) try: if model_info.force_stream_mode: - req_task = asyncio.create_task( + stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task( self.client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, - temperature=temperature, - max_tokens=max_tokens, + temperature=temperature_argument, + max_tokens=max_tokens_argument, stream=True, response_format=openai_response_format, - extra_body=extra_params, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - - resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) - else: - # 发送请求并获取响应 - # start_time = time.time() - req_task = asyncio.create_task( - self.client.chat.completions.create( - model=model_info.model_identifier, - messages=messages, - tools=tools, - temperature=temperature, - max_tokens=max_tokens, - stream=False, - response_format=openai_response_format, - extra_body=extra_params, - ) + raw_response = cast( + AsyncStream[ChatCompletionChunk], + await await_task_with_interrupt(stream_task, request.interrupt_flag), ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态 + return await stream_response_handler(raw_response, request.interrupt_flag) - # logger. - # logger.debug(f"OpenAI API响应(非流式): {req_task.result()}") - - # logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}") - - resp, usage_record = async_response_parser(req_task.result()) - except APIConnectionError as e: - # 重封装APIConnectionError为NetworkConnectionError - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code, e.message) from e - - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], + completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature_argument, + max_tokens=max_tokens_argument, + stream=False, + response_format=openai_response_format, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, + ) ) + raw_response = cast( + ChatCompletion, + await await_task_with_interrupt(completion_task, request.interrupt_flag), + ) + return response_parser(raw_response) + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc - # logger.debug(f"OpenAI API响应: {resp}") - - return resp - - async def get_embedding( + async def _execute_embedding_request( self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的文本嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + embedding_input = request.embedding_input + extra_params = request.extra_params + request_overrides = split_openai_request_overrides(extra_params) + try: raw_response = await self.client.embeddings.create( model=model_info.model_identifier, input=embedding_input, - extra_body=extra_params, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) - except APIConnectionError as e: - # 添加详细的错误信息以便调试 - logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") - logger.error(f"错误类型: {type(e)}") - if hasattr(e, "__cause__") and e.__cause__: - logger.error(f"底层错误: {str(e.__cause__)}") - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code) from e + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc response = APIResponse() - - # 解析嵌入响应 - if len(raw_response.data) > 0: + if raw_response.data: response.embedding = raw_response.data[0].embedding else: - raise RespParseException( - raw_response, - "响应解析失败,缺失嵌入数据。", - ) + raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。") - # 解析使用情况 - if hasattr(raw_response, "usage"): - response.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=raw_response.usage.prompt_tokens or 0, - completion_tokens=getattr(raw_response.usage, "completion_tokens", 0), - total_tokens=raw_response.usage.total_tokens or 0, - ) + usage_record = _extract_usage_record(getattr(raw_response, "usage", None)) + return response, usage_record - return response - - async def get_audio_transcriptions( + async def _execute_audio_transcription_request( self, - model_info: ModelInfo, - audio_base64: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: base64编码的音频数据 - :extra_params: 附加的请求参数 - :return: 音频转录响应 + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的音频转录请求。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + audio_base64 = request.audio_base64 + extra_params = request.extra_params + request_overrides = split_openai_request_overrides(extra_params) + audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64))) + try: raw_response = await self.client.audio.transcriptions.create( model=model_info.model_identifier, - file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))), - extra_body=extra_params, + file=audio_file, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) - except APIConnectionError as e: - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code) from e - response = APIResponse() - # 解析转录响应 - if hasattr(raw_response, "text"): - response.content = raw_response.text - else: - raise RespParseException( - raw_response, - "响应解析失败,缺失转录文本。", - ) - return response + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + response = APIResponse() + transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None) + if isinstance(transcription_text, str): + response.content = transcription_text + return response, None + raise RespParseException(raw_response, "响应解析失败,缺失转录文本。") + + def get_support_image_formats(self) -> List[str]: + """获取支持的图片格式列表。 + + Returns: + List[str]: 当前客户端支持的图片格式列表。 """ return ["jpg", "jpeg", "png", "webp", "gif"] diff --git a/src/llm_models/openai_compat.py b/src/llm_models/openai_compat.py new file mode 100644 index 00000000..19190e0a --- /dev/null +++ b/src/llm_models/openai_compat.py @@ -0,0 +1,140 @@ +from dataclasses import dataclass, field +from typing import Any, Mapping + +from src.config.model_configs import APIProvider, OpenAICompatibleAuthType + + +@dataclass(slots=True) +class OpenAICompatibleClientConfig: + """OpenAI 兼容客户端的基础配置。""" + + api_key: str + base_url: str + default_headers: dict[str, str] = field(default_factory=dict) + default_query: dict[str, object] = field(default_factory=dict) + + +@dataclass(slots=True) +class OpenAICompatibleRequestOverrides: + """单次请求级别的附加配置。""" + + extra_headers: dict[str, str] = field(default_factory=dict) + extra_query: dict[str, object] = field(default_factory=dict) + extra_body: dict[str, Any] = field(default_factory=dict) + + +def normalize_openai_base_url(base_url: str) -> str: + """规范化 OpenAI 兼容接口的基础地址。 + + Args: + base_url: 原始基础地址。 + + Returns: + str: 去掉尾部斜杠后的地址。 + """ + return base_url.rstrip("/") + + +def _build_auth_header_value(prefix: str, api_key: str) -> str: + """构造鉴权请求头的值。 + + Args: + prefix: 请求头前缀。 + api_key: 实际密钥。 + + Returns: + str: 拼接完成的请求头值。 + """ + normalized_prefix = prefix.strip() + if not normalized_prefix: + return api_key + return f"{normalized_prefix} {api_key}" + + +def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig: + """构建 OpenAI 兼容客户端配置。 + + Args: + api_provider: API 提供商配置。 + + Returns: + OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。 + """ + default_headers = dict(api_provider.default_headers) + default_query: dict[str, object] = dict(api_provider.default_query) + client_api_key = api_provider.api_key + + if api_provider.auth_type == OpenAICompatibleAuthType.BEARER: + if ( + api_provider.auth_header_name != "Authorization" + or api_provider.auth_header_prefix.strip() != "Bearer" + ): + client_api_key = "" + default_headers[api_provider.auth_header_name] = _build_auth_header_value( + prefix=api_provider.auth_header_prefix, + api_key=api_provider.api_key, + ) + elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER: + client_api_key = "" + default_headers[api_provider.auth_header_name] = _build_auth_header_value( + prefix=api_provider.auth_header_prefix, + api_key=api_provider.api_key, + ) + elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY: + client_api_key = "" + default_query[api_provider.auth_query_name] = api_provider.api_key + elif api_provider.auth_type == OpenAICompatibleAuthType.NONE: + client_api_key = "" + + return OpenAICompatibleClientConfig( + api_key=client_api_key, + base_url=normalize_openai_base_url(api_provider.base_url), + default_headers=default_headers, + default_query=default_query, + ) + + +def _extract_mapping(value: Any) -> dict[str, Any]: + """将任意映射值规范化为普通字典。 + + Args: + value: 原始输入值。 + + Returns: + dict[str, Any]: 规范化后的字典。非映射值时返回空字典。 + """ + if isinstance(value, Mapping): + return {str(key): item for key, item in value.items()} + return {} + + +def split_openai_request_overrides( + extra_params: Mapping[str, Any] | None, + *, + reserved_body_keys: set[str] | None = None, +) -> OpenAICompatibleRequestOverrides: + """拆分单次请求中的头、查询参数和请求体扩展字段。 + + Args: + extra_params: 模型级别或请求级别的附加参数。 + reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。 + + Returns: + OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。 + """ + raw_params = dict(extra_params or {}) + extra_headers = _extract_mapping(raw_params.pop("headers", None)) + extra_query = _extract_mapping(raw_params.pop("query", None)) + extra_body = _extract_mapping(raw_params.pop("body", None)) + blocked_body_keys = reserved_body_keys or set() + + for key, value in raw_params.items(): + if key in blocked_body_keys: + continue + extra_body[key] = value + + return OpenAICompatibleRequestOverrides( + extra_headers={key: str(value) for key, value in extra_headers.items()}, + extra_query=extra_query, + extra_body=extra_body, + ) diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 960de08b..8ed392ef 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -1,133 +1,280 @@ +from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Tuple from .tool_option import ToolCall -# 设计这系列类的目的是为未来可能的扩展做准备 +class RoleType(str, Enum): + """消息角色类型。""" - -class RoleType(Enum): System = "system" User = "user" Assistant = "assistant" Tool = "tool" -SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式 +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] +"""默认支持的图片格式列表。""" +@dataclass(slots=True) +class TextMessagePart: + """文本消息片段。""" + + text: str + + def __post_init__(self) -> None: + """执行文本片段的基础校验。 + + Raises: + ValueError: 当文本为空时抛出。 + """ + if self.text == "": + raise ValueError("文本消息片段不能为空字符串") + + +@dataclass(slots=True) +class ImageMessagePart: + """Base64 图片消息片段。""" + + image_format: str + image_base64: str + + def __post_init__(self) -> None: + """执行图片片段的基础校验。 + + Raises: + ValueError: 当图片格式或 Base64 数据无效时抛出。 + """ + if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS: + raise ValueError("不受支持的图片格式") + if not self.image_base64: + raise ValueError("图片的 base64 编码不能为空") + + @property + def normalized_image_format(self) -> str: + """获取规范化后的图片格式。 + + Returns: + str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。 + """ + image_format = self.image_format.lower() + if image_format in {"jpg", "jpeg"}: + return "jpeg" + return image_format + + +MessagePart = TextMessagePart | ImageMessagePart + + +@dataclass(slots=True) class Message: - def __init__( - self, - role: RoleType, - content: str | list[tuple[str, str] | str], - tool_call_id: str | None = None, - tool_calls: Optional[List[ToolCall]] = None, - ): + """统一消息模型。""" + + role: RoleType + parts: List[MessagePart] = field(default_factory=list) + tool_call_id: str | None = None + tool_calls: List[ToolCall] | None = None + + def __post_init__(self) -> None: + """执行消息对象的基础校验。 + + Raises: + ValueError: 当消息内容或工具调用信息不完整时抛出。 """ - 初始化消息对象 - (不应直接修改Message类,而应使用MessageBuilder类来构建对象) + if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls): + raise ValueError("消息内容不能为空") + if self.role == RoleType.Tool and not self.tool_call_id: + raise ValueError("Tool 角色的工具调用 ID 不能为空") + + @property + def content(self) -> str | List[Tuple[str, str] | str]: + """获取兼容旧逻辑的内容视图。 + + Returns: + str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串, + 否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。 """ - self.role: RoleType = role - self.content: str | list[tuple[str, str] | str] = content - self.tool_call_id: str | None = tool_call_id - self.tool_calls: Optional[List[ToolCall]] = tool_calls + if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart): + return self.parts[0].text + content_items: List[Tuple[str, str] | str] = [] + for part in self.parts: + if isinstance(part, TextMessagePart): + content_items.append(part.text) + else: + content_items.append((part.image_format, part.image_base64)) + return content_items + + def get_text_content(self) -> str: + """提取消息中的所有文本片段。 + + Returns: + str: 以原始顺序拼接后的文本内容。 + """ + return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart)) def __str__(self) -> str: + """生成便于调试的字符串表示。 + + Returns: + str: 当前消息对象的可读摘要。 + """ return ( - f"Role: {self.role}, Content: {self.content}, " + f"Role: {self.role}, Parts: {self.parts}, " f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}" ) class MessageBuilder: - def __init__(self): + """消息构建器。""" + + def __init__(self) -> None: + """初始化构建器。""" self.__role: RoleType = RoleType.User - self.__content: list[tuple[str, str] | str] = [] + self.__parts: List[MessagePart] = [] self.__tool_call_id: str | None = None - self.__tool_calls: Optional[List[ToolCall]] = None + self.__tool_calls: List[ToolCall] | None = None def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": - """ - 设置角色(默认为User) - :param role: 角色 - :return: MessageBuilder对象 + """设置消息角色。 + + Args: + role: 目标角色,默认为 `user`。 + + Returns: + MessageBuilder: 当前构建器实例。 """ self.__role = role return self + def add_text_part(self, text: str) -> "MessageBuilder": + """追加文本片段。 + + Args: + text: 文本内容。 + + Returns: + MessageBuilder: 当前构建器实例。 + """ + self.__parts.append(TextMessagePart(text=text)) + return self + def add_text_content(self, text: str) -> "MessageBuilder": + """追加文本片段。 + + Args: + text: 文本内容。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 添加文本内容 - :param text: 文本内容 - :return: MessageBuilder对象 + return self.add_text_part(text) + + def add_image_base64_part( + self, + image_format: str, + image_base64: str, + support_formats: List[str] = SUPPORTED_IMAGE_FORMATS, + ) -> "MessageBuilder": + """追加 Base64 图片片段。 + + Args: + image_format: 图片格式。 + image_base64: 图片的 Base64 编码。 + support_formats: 允许的图片格式列表。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当图片格式不被支持时抛出。 """ - self.__content.append(text) + if image_format.lower() not in support_formats: + raise ValueError("不受支持的图片格式") + self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64)) return self def add_image_content( self, image_format: str, image_base64: str, - support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式 + support_formats: List[str] = SUPPORTED_IMAGE_FORMATS, ) -> "MessageBuilder": - """ - 添加图片内容 - :param image_format: 图片格式 - :param image_base64: 图片的base64编码 - :return: MessageBuilder对象 - """ - if image_format.lower() not in support_formats: - raise ValueError("不受支持的图片格式") - if not image_base64: - raise ValueError("图片的base64编码不能为空") - self.__content.append((image_format, image_base64)) - return self + """追加 Base64 图片片段。 - def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + Args: + image_format: 图片格式。 + image_base64: 图片的 Base64 编码。 + support_formats: 允许的图片格式列表。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 添加工具调用指令(调用时请确保已设置为Tool角色) - :param tool_call_id: 工具调用指令的id - :return: MessageBuilder对象 + return self.add_image_base64_part( + image_format=image_format, + image_base64=image_base64, + support_formats=support_formats, + ) + + def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder": + """设置工具结果消息引用的工具调用 ID。 + + Args: + tool_call_id: 工具调用 ID。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。 """ if self.__role != RoleType.Tool: - raise ValueError("仅当角色为Tool时才能添加工具调用ID") + raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID") if not tool_call_id: - raise ValueError("工具调用ID不能为空") + raise ValueError("工具调用 ID 不能为空") self.__tool_call_id = tool_call_id return self - def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder": + def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + """设置工具结果消息引用的工具调用 ID。 + + Args: + tool_call_id: 工具调用 ID。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 设置助手消息的工具调用列表 - :param tool_calls: 工具调用列表 - :return: MessageBuilder对象 + return self.set_tool_call_id(tool_call_id) + + def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder": + """设置助手消息中的工具调用列表。 + + Args: + tool_calls: 工具调用列表。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。 """ if self.__role != RoleType.Assistant: - raise ValueError("仅当角色为Assistant时才能设置工具调用列表") + raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表") if not tool_calls: raise ValueError("工具调用列表不能为空") - self.__tool_calls = tool_calls + self.__tool_calls = list(tool_calls) return self def build(self) -> Message: - """ - 构建消息对象 - :return: Message对象 - """ - if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls): - raise ValueError("内容不能为空") - if self.__role == RoleType.Tool and self.__tool_call_id is None: - raise ValueError("Tool角色的工具调用ID不能为空") + """构建消息对象。 + Returns: + Message: 构建完成的消息对象。 + """ return Message( role=self.__role, - content=( - self.__content[0] - if (len(self.__content) == 1 and isinstance(self.__content[0], str)) - else self.__content - ), + parts=list(self.__parts), tool_call_id=self.__tool_call_id, - tool_calls=self.__tool_calls, + tool_calls=list(self.__tool_calls) if self.__tool_calls else None, ) diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py index e1baa374..4319b03d 100644 --- a/src/llm_models/payload_content/resp_format.py +++ b/src/llm_models/payload_content/resp_format.py @@ -1,51 +1,40 @@ +from copy import deepcopy from enum import Enum -from typing import Optional, Any +from typing import Any, Dict, List, Mapping, Optional, Type, cast from pydantic import BaseModel -from typing_extensions import TypedDict, Required +from typing_extensions import Required, TypedDict class RespFormatType(Enum): - TEXT = "text" # 文本 - JSON_OBJ = "json_object" # JSON - JSON_SCHEMA = "json_schema" # JSON Schema + """响应格式类型。""" + + TEXT = "text" + JSON_OBJ = "json_object" + JSON_SCHEMA = "json_schema" class JsonSchema(TypedDict, total=False): + """内部使用的 JSON Schema 包装结构。""" + name: Required[str] - """ - The name of the response format. - - Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length - of 64. - """ - description: Optional[str] - """ - A description of what the response format is for, used by the model to determine - how to respond in the format. - """ - - schema: dict[str, object] - """ - The schema for the response format, described as a JSON Schema object. Learn how - to build JSON schemas [here](https://json-schema.org/). - """ - + schema: Dict[str, Any] strict: Optional[bool] - """ - Whether to enable strict schema adherence when generating the output. If set to - true, the model will always follow the exact schema defined in the `schema` - field. Only a subset of JSON Schema is supported when `strict` is `true`. To - learn more, read the - [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). - """ -def _json_schema_type_check(instance) -> str | None: +def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None: + """检查 JSON Schema 包装结构是否合法。 + + Args: + instance: 待检查的 JSON Schema 包装字典。 + + Returns: + str | None: 不合法时返回错误信息,合法时返回 `None`。 + """ if "name" not in instance: return "schema必须包含'name'字段" - elif not isinstance(instance["name"], str) or instance["name"].strip() == "": + if not isinstance(instance["name"], str) or instance["name"].strip() == "": return "schema的'name'字段必须是非空字符串" if "description" in instance and ( not isinstance(instance["description"], str) or instance["description"].strip() == "" @@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None: return "schema的'description'字段只能填入非空字符串" if "schema" not in instance: return "schema必须包含'schema'字段" - elif not isinstance(instance["schema"], dict): + if not isinstance(instance["schema"], dict): return "schema的'schema'字段必须是字典,详见https://json-schema.org/" if "strict" in instance and not isinstance(instance["strict"], bool): return "schema的'strict'字段只能填入布尔值" - return None -def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: - """ - 递归移除JSON Schema中的title字段 +def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]: + """递归移除 JSON Schema 中的 `title` 字段。 + + Args: + schema: 待处理的 Schema 树。 + + Returns: + Dict[str, Any] | List[Any]: 处理后的 Schema 树。 """ if isinstance(schema, list): - # 如果当前Schema是列表,则对所有dict/list子元素递归调用 - for idx, item in enumerate(schema): + for index, item in enumerate(schema): if isinstance(item, (dict, list)): - schema[idx] = _remove_title(item) - elif isinstance(schema, dict): - # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "title" in schema: - del schema["title"] - for key, value in schema.items(): - if isinstance(value, (dict, list)): - schema[key] = _remove_title(value) + schema[index] = _remove_title(item) + return schema + if "title" in schema: + del schema["title"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) return schema -def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: - """ - 链接JSON Schema中的definitions字段 +def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]: + """展开 Schema 中的本地 `$defs`/`$ref` 引用。 + + Args: + schema: 待处理的根 Schema。 + + Returns: + Dict[str, Any]: 展开后的 Schema。 """ def link_definitions_recursive( - path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any] - ) -> dict[str, Any]: - """ - 递归链接JSON Schema中的definitions字段 - :param path: 当前路径 - :param sub_schema: 子Schema - :param defs: Schema定义集 - :return: + path: str, + sub_schema: Dict[str, Any] | List[Any], + definitions: Dict[str, Any], + ) -> Dict[str, Any] | List[Any]: + """递归展开局部定义。 + + Args: + path: 当前递归路径。 + sub_schema: 当前子 Schema。 + definitions: 已收集的定义字典。 + + Returns: + Dict[str, Any] | List[Any]: 展开后的子 Schema。 """ if isinstance(sub_schema, list): - # 如果当前Schema是列表,则遍历每个元素 - for i in range(len(sub_schema)): - if isinstance(sub_schema[i], dict): - sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs) - else: - # 否则为字典 - if "$defs" in sub_schema: - # 如果当前Schema有$def字段,则将其添加到defs中 - key_prefix = f"{path}/$defs/" - for key, value in sub_schema["$defs"].items(): - def_key = key_prefix + key - if def_key not in defs: - defs[def_key] = value - del sub_schema["$defs"] - if "$ref" in sub_schema: - # 如果当前Schema有$ref字段,则将其替换为defs中的定义 - def_key = sub_schema["$ref"] - if def_key in defs: - sub_schema = defs[def_key] - else: - raise ValueError(f"Schema中引用的定义'{def_key}'不存在") - # 遍历键值对 - for key, value in sub_schema.items(): - if isinstance(value, (dict, list)): - # 如果当前值是字典或列表,则递归调用 - sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs) + for index, item in enumerate(sub_schema): + if isinstance(item, (dict, list)): + sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions) + return sub_schema + if "$defs" in sub_schema: + key_prefix = f"{path}/$defs/" + defs_payload = cast(Dict[str, Any], sub_schema["$defs"]) + for key, value in defs_payload.items(): + definition_key = key_prefix + key + if definition_key not in definitions: + definitions[definition_key] = value + del sub_schema["$defs"] + + if "$ref" in sub_schema: + definition_key = cast(str, sub_schema["$ref"]) + if definition_key in definitions: + return definitions[definition_key] + raise ValueError(f"Schema中引用的定义'{definition_key}'不存在") + + for key, value in sub_schema.items(): + if isinstance(value, (dict, list)): + sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions) return sub_schema - return link_definitions_recursive("#", schema, {}) + return cast(Dict[str, Any], link_definitions_recursive("#", schema, {})) -def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: - """ - 递归移除JSON Schema中的$defs字段 +def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]: + """递归移除 JSON Schema 中的 `$defs` 字段。 + + Args: + schema: 待处理的 Schema 树。 + + Returns: + Dict[str, Any] | List[Any]: 处理后的 Schema 树。 """ if isinstance(schema, list): - # 如果当前Schema是列表,则对所有dict/list子元素递归调用 - for idx, item in enumerate(schema): + for index, item in enumerate(schema): if isinstance(item, (dict, list)): - schema[idx] = _remove_title(item) - elif isinstance(schema, dict): - # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "$defs" in schema: - del schema["$defs"] - for key, value in schema.items(): - if isinstance(value, (dict, list)): - schema[key] = _remove_title(value) + schema[index] = _remove_defs(item) + return schema + if "$defs" in schema: + del schema["$defs"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_defs(value) return schema class RespFormat: - """ - 响应格式 - """ + """统一响应格式定义。""" @staticmethod - def _generate_schema_from_model(schema): - json_schema = { - "name": schema.__name__, - "schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))), + def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema: + """从 Pydantic 模型生成内部 JSON Schema 包装结构。 + + Args: + schema_model: Pydantic 模型类。 + + Returns: + JsonSchema: 内部统一 JSON Schema 包装结构。 + """ + schema_tree = deepcopy(schema_model.model_json_schema()) + json_schema: JsonSchema = { + "name": schema_model.__name__, + "schema": cast( + Dict[str, Any], + _remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))), + ), "strict": False, } - if schema.__doc__: - json_schema["description"] = schema.__doc__ + if schema_model.__doc__: + json_schema["description"] = schema_model.__doc__ return json_schema def __init__( self, format_type: RespFormatType = RespFormatType.TEXT, - schema: type | JsonSchema | None = None, - ): - """ - 响应格式 - :param format_type: 响应格式类型(默认为文本) - :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效) + schema: Type[BaseModel] | JsonSchema | None = None, + ) -> None: + """初始化响应格式对象。 + + Args: + format_type: 响应格式类型。 + schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。 """ self.format_type: RespFormatType = format_type + self.schema_source: Type[BaseModel] | JsonSchema | None = schema + self.schema: JsonSchema | None = None - if format_type == RespFormatType.JSON_SCHEMA: - if schema is None: - raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") - if isinstance(schema, dict): - if check_msg := _json_schema_type_check(schema): - raise ValueError(f"schema格式不正确,{check_msg}") + if format_type != RespFormatType.JSON_SCHEMA: + return + if schema is None: + raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") + if isinstance(schema, dict): + if check_msg := _json_schema_type_check(schema): + raise ValueError(f"schema格式不正确,{check_msg}") + self.schema = cast(JsonSchema, deepcopy(schema)) + return + if isinstance(schema, type) and issubclass(schema, BaseModel): + try: + self.schema = self._generate_schema_from_model(schema) + except Exception as exc: + raise ValueError( + f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" + f"{schema.__name__}:\n" + ) from exc + return + raise ValueError("schema必须是BaseModel的子类或JsonSchema") - self.schema = schema - elif issubclass(schema, BaseModel): - try: - json_schema = self._generate_schema_from_model(schema) + def get_schema_object(self) -> Dict[str, Any] | None: + """获取内部包装中的对象级 JSON Schema。 - self.schema = json_schema - except Exception as e: - raise ValueError( - f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" - f"{schema.__name__}:\n" - ) from e - else: - raise ValueError("schema必须是BaseModel的子类或JsonSchema") - else: - self.schema = None - - def to_dict(self): + Returns: + Dict[str, Any] | None: 对象级 JSON Schema;不存在时返回 `None`。 """ - 将响应格式转换为字典 - :return: 字典 + if self.schema is None: + return None + schema_payload = self.schema.get("schema") + if isinstance(schema_payload, dict): + return cast(Dict[str, Any], deepcopy(schema_payload)) + return None + + def to_dict(self) -> Dict[str, Any]: + """将响应格式转换为字典。 + + Returns: + Dict[str, Any]: 序列化后的响应格式字典。 """ if self.schema: return { "format_type": self.format_type.value, "schema": self.schema, } - else: - return { - "format_type": self.format_type.value, - } + return { + "format_type": self.format_type.value, + } diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py index 9fedbc86..ac5224cc 100644 --- a/src/llm_models/payload_content/tool_option.py +++ b/src/llm_models/payload_content/tool_option.py @@ -1,83 +1,368 @@ +from copy import deepcopy +from dataclasses import dataclass, field from enum import Enum +from typing import Any, Dict, List, Tuple, TypeAlias, cast -class ToolParamType(Enum): +class ToolParamType(str, Enum): + """工具参数类型。""" + + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + FLOAT = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None] +"""旧版工具参数元组格式。""" + + +def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType: + """将任意输入值规范化为内部工具参数类型。 + + Args: + raw_value: 原始参数类型值。 + + Returns: + ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。 """ - 工具调用参数类型 + if isinstance(raw_value, ToolParamType): + return raw_value + + normalized_value = str(raw_value or "").strip().lower() + if normalized_value in {"integer", "int"}: + return ToolParamType.INTEGER + if normalized_value in {"number", "float"}: + return ToolParamType.NUMBER + if normalized_value in {"boolean", "bool"}: + return ToolParamType.BOOLEAN + if normalized_value == "array": + return ToolParamType.ARRAY + if normalized_value == "object": + return ToolParamType.OBJECT + return ToolParamType.STRING + + +def _is_object_schema(schema: Dict[str, Any]) -> bool: + """判断输入字典是否已经是对象级 JSON Schema。 + + Args: + schema: 待判断的字典。 + + Returns: + bool: 为对象级 JSON Schema 时返回 `True`。 """ - - STRING = "string" # 字符串 - INTEGER = "integer" # 整型 - FLOAT = "float" # 浮点型 - BOOLEAN = "bool" # 布尔型 + return schema.get("type") == "object" or "properties" in schema or "required" in schema +def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]: + """将属性映射转换为对象级 JSON Schema。 + + Args: + property_map: 仅包含属性定义的映射。 + + Returns: + Dict[str, Any]: 对象级 JSON Schema。 + """ + required_names: List[str] = [] + normalized_properties: Dict[str, Any] = {} + for property_name, property_schema in property_map.items(): + if not isinstance(property_schema, dict): + continue + + property_schema_copy = deepcopy(property_schema) + is_required = bool(property_schema_copy.pop("required", False)) + if is_required: + required_names.append(str(property_name)) + normalized_properties[str(property_name)] = property_schema_copy + + parameters_schema: Dict[str, Any] = { + "type": "object", + "properties": normalized_properties, + } + if required_names: + parameters_schema["required"] = required_names + return parameters_schema + + +@dataclass(slots=True) class ToolParam: - """ - 工具调用参数 - """ + """工具参数定义。""" - def __init__( - self, + name: str + param_type: ToolParamType + description: str + required: bool + enum_values: List[Any] | None = None + items_schema: Dict[str, Any] | None = None + properties: Dict[str, Dict[str, Any]] | None = None + required_properties: List[str] = field(default_factory=list) + additional_properties: bool | Dict[str, Any] | None = None + default: Any = None + + def __post_init__(self) -> None: + """执行参数定义的基础校验。 + + Raises: + ValueError: 当参数名称或复杂类型定义不合法时抛出。 + """ + if not self.name: + raise ValueError("参数名称不能为空") + if self.param_type == ToolParamType.ARRAY and self.items_schema is None: + raise ValueError("数组参数必须提供 items_schema") + if self.param_type == ToolParamType.OBJECT and self.properties is None: + self.properties = {} + + @classmethod + def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam": + """从旧版五元组参数定义构建工具参数。 + + Args: + parameter: 旧版参数元组。 + + Returns: + ToolParam: 规范化后的工具参数对象。 + """ + return cls( + name=parameter[0], + param_type=parameter[1], + description=parameter[2], + required=parameter[3], + enum_values=parameter[4], + ) + + @classmethod + def from_dict( + cls, name: str, - param_type: ToolParamType, - description: str, - required: bool, - enum_values: list[str] | None = None, - ): + parameter_schema: Dict[str, Any], + *, + required: bool = False, + ) -> "ToolParam": + """从属性级 JSON Schema 或结构化参数字典构建工具参数。 + + Args: + name: 参数名称。 + parameter_schema: 参数对应的 Schema 或结构化定义。 + required: 参数是否必填。 + + Returns: + ToolParam: 规范化后的工具参数对象。 """ - 初始化工具调用参数 - (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象) - :param name: 参数名称 - :param param_type: 参数类型 - :param description: 参数描述 - :param required: 是否必填 + raw_required_properties = parameter_schema.get("required_properties") + if raw_required_properties is None and isinstance(parameter_schema.get("required"), list): + raw_required_properties = parameter_schema.get("required") + return cls( + name=name, + param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")), + description=str(parameter_schema.get("description", "") or ""), + required=required, + enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")), + items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")), + properties=deepcopy(parameter_schema.get("properties")), + required_properties=list(raw_required_properties or []), + additional_properties=deepcopy( + parameter_schema["additional_properties"] + if "additional_properties" in parameter_schema + else parameter_schema.get("additionalProperties") + ), + default=deepcopy(parameter_schema.get("default")), + ) + + def to_json_schema(self) -> Dict[str, Any]: + """将参数定义转换为 JSON Schema。 + + Returns: + Dict[str, Any]: 参数对应的 JSON Schema 片段。 """ - self.name: str = name - self.param_type: ToolParamType = param_type - self.description: str = description - self.required: bool = required - self.enum_values: list[str] | None = enum_values + schema: Dict[str, Any] = { + "type": self.param_type.value, + "description": self.description, + } + if self.enum_values: + schema["enum"] = list(self.enum_values) + if self.default is not None: + schema["default"] = deepcopy(self.default) + if self.param_type == ToolParamType.ARRAY and self.items_schema is not None: + schema["items"] = deepcopy(self.items_schema) + if self.param_type == ToolParamType.OBJECT: + schema["properties"] = deepcopy(self.properties or {}) + if self.required_properties: + schema["required"] = list(self.required_properties) + if self.additional_properties is not None: + schema["additionalProperties"] = deepcopy(self.additional_properties) + return schema +@dataclass(slots=True) class ToolOption: - """ - 工具调用项 - """ + """工具定义。""" - def __init__( - self, - name: str, - description: str, - params: list[ToolParam] | None = None, - ): + name: str + description: str + params: List[ToolParam] | None = None + parameters_schema_override: Dict[str, Any] | None = None + + def __post_init__(self) -> None: + """执行工具定义的基础校验。 + + Raises: + ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。 """ - 初始化工具调用项 - (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象) - :param name: 工具名称 - :param description: 工具描述 - :param params: 工具参数列表 + if not self.name: + raise ValueError("工具名称不能为空") + if not self.description: + raise ValueError("工具描述不能为空") + if self.parameters_schema_override is not None: + schema_type = self.parameters_schema_override.get("type") + if schema_type != "object": + raise ValueError("工具参数 Schema 必须是 object 类型") + + @classmethod + def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption": + """从任意支持的工具定义字典构建内部工具对象。 + + 支持以下输入形状: + - `{"name", "description", "parameters_schema"}` + - `{"name", "description", "parameters"}` + - OpenAI function tool:`{"type": "function", "function": {...}}` + - 仅属性映射的对象参数定义:`{"query": {"type": "string"}}` + + Args: + definition: 原始工具定义字典。 + + Returns: + ToolOption: 规范化后的工具定义对象。 + + Raises: + ValueError: 当工具定义缺少必要字段时抛出。 """ - self.name: str = name - self.description: str = description - self.params: list[ToolParam] | None = params + if definition.get("type") == "function" and isinstance(definition.get("function"), dict): + function_definition = cast(Dict[str, Any], definition["function"]) + return cls.from_definition( + { + "name": function_definition.get("name", ""), + "description": function_definition.get("description", ""), + "parameters_schema": function_definition.get("parameters"), + } + ) + + name = str(definition.get("name", "") or "").strip() + description = str(definition.get("description", "") or "").strip() + if not name: + raise ValueError("工具定义缺少 name") + if not description: + description = f"工具 {name}" + + parameters_schema = definition.get("parameters_schema") + if isinstance(parameters_schema, dict): + normalized_schema = deepcopy(parameters_schema) + if not _is_object_schema(normalized_schema): + normalized_schema = _build_parameters_schema_from_property_map(normalized_schema) + return cls( + name=name, + description=description, + params=None, + parameters_schema_override=normalized_schema, + ) + + raw_parameters = definition.get("parameters") + if isinstance(raw_parameters, dict): + normalized_schema = deepcopy(raw_parameters) + if not _is_object_schema(normalized_schema): + normalized_schema = _build_parameters_schema_from_property_map(normalized_schema) + return cls( + name=name, + description=description, + params=None, + parameters_schema_override=normalized_schema, + ) + + if isinstance(raw_parameters, list): + params: List[ToolParam] = [] + for raw_parameter in raw_parameters: + if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5: + params.append(ToolParam.from_legacy_tuple(raw_parameter)) + continue + if isinstance(raw_parameter, dict): + parameter_name = str(raw_parameter.get("name", "") or "").strip() + if not parameter_name: + continue + params.append( + ToolParam.from_dict( + parameter_name, + raw_parameter, + required=bool(raw_parameter.get("required", False)), + ) + ) + return cls( + name=name, + description=description, + params=params or None, + parameters_schema_override=None, + ) + + return cls(name=name, description=description, params=None, parameters_schema_override=None) + + @property + def parameters_schema(self) -> Dict[str, Any] | None: + """获取工具参数的对象级 JSON Schema。 + + Returns: + Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。 + """ + if self.parameters_schema_override is not None: + return deepcopy(self.parameters_schema_override) + if not self.params: + return None + return { + "type": "object", + "properties": {param.name: param.to_json_schema() for param in self.params}, + "required": [param.name for param in self.params if param.required], + } + + def to_openai_function_schema(self) -> Dict[str, Any]: + """转换为 OpenAI function calling 结构。 + + Returns: + Dict[str, Any]: OpenAI 兼容的工具定义。 + """ + function_schema: Dict[str, Any] = { + "name": self.name, + "description": self.description, + } + if self.parameters_schema is not None: + function_schema["parameters"] = self.parameters_schema + return { + "type": "function", + "function": function_schema, + } class ToolOptionBuilder: - """ - 工具调用项构建器 - """ + """工具定义构建器。""" - def __init__(self): + def __init__(self) -> None: + """初始化构建器。""" self.__name: str = "" self.__description: str = "" - self.__params: list[ToolParam] = [] + self.__params: List[ToolParam] = [] + self.__parameters_schema_override: Dict[str, Any] | None = None def set_name(self, name: str) -> "ToolOptionBuilder": - """ - 设置工具名称 - :param name: 工具名称 - :return: ToolBuilder实例 + """设置工具名称。 + + Args: + name: 工具名称。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当名称为空时抛出。 """ if not name: raise ValueError("工具名称不能为空") @@ -85,35 +370,76 @@ class ToolOptionBuilder: return self def set_description(self, description: str) -> "ToolOptionBuilder": - """ - 设置工具描述 - :param description: 工具描述 - :return: ToolBuilder实例 + """设置工具描述。 + + Args: + description: 工具描述。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当描述为空时抛出。 """ if not description: raise ValueError("工具描述不能为空") self.__description = description return self + def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder": + """直接设置完整的参数对象 Schema。 + + Args: + schema: 完整的对象级 JSON Schema。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当 schema 不是 object 类型时抛出。 + """ + if schema.get("type") != "object": + raise ValueError("工具参数 Schema 必须是 object 类型") + self.__parameters_schema_override = deepcopy(schema) + self.__params.clear() + return self + def add_param( self, name: str, param_type: ToolParamType, description: str, required: bool = False, - enum_values: list[str] | None = None, + enum_values: List[Any] | None = None, + *, + items_schema: Dict[str, Any] | None = None, + properties: Dict[str, Dict[str, Any]] | None = None, + required_properties: List[str] | None = None, + additional_properties: bool | Dict[str, Any] | None = None, + default: Any = None, ) -> "ToolOptionBuilder": - """ - 添加工具参数 - :param name: 参数名称 - :param param_type: 参数类型 - :param description: 参数描述 - :param required: 是否必填(默认为False) - :return: ToolBuilder实例 - """ - if not name or not description: - raise ValueError("参数名称/描述不能为空") + """添加一个参数定义。 + Args: + name: 参数名称。 + param_type: 参数类型。 + description: 参数描述。 + required: 参数是否必填。 + enum_values: 可选的枚举值列表。 + items_schema: 数组参数的元素 Schema。 + properties: 对象参数的属性定义。 + required_properties: 对象参数内部的必填字段。 + additional_properties: 对象参数是否允许额外字段。 + default: 参数默认值。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当构建器已经设置完整 Schema 时抛出。 + """ + if self.__parameters_schema_override is not None: + raise ValueError("已设置完整参数 Schema,不能再逐项添加参数") self.__params.append( ToolParam( name=name, @@ -121,43 +447,83 @@ class ToolOptionBuilder: description=description, required=required, enum_values=enum_values, + items_schema=deepcopy(items_schema), + properties=deepcopy(properties), + required_properties=list(required_properties or []), + additional_properties=deepcopy(additional_properties), + default=deepcopy(default), ) ) - return self - def build(self): - """ - 构建工具调用项 - :return: 工具调用项 - """ - if self.__name == "" or self.__description == "": - raise ValueError("工具名称/描述不能为空") + def build(self) -> ToolOption: + """构建工具定义。 + Returns: + ToolOption: 构建完成的工具定义。 + + Raises: + ValueError: 当工具名称或描述缺失时抛出。 + """ + if not self.__name or not self.__description: + raise ValueError("工具名称和描述不能为空") return ToolOption( name=self.__name, description=self.__description, - params=None if len(self.__params) == 0 else self.__params, + params=None if not self.__params else list(self.__params), + parameters_schema_override=deepcopy(self.__parameters_schema_override), ) -class ToolCall: - """ - 来自模型反馈的工具调用 - """ +ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any] +"""统一的工具定义输入类型。""" - def __init__( - self, - call_id: str, - func_name: str, - args: dict | None = None, - ): + +def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption: + """将任意支持的工具输入规范化为内部 `ToolOption`。 + + Args: + tool_definition: 原始工具定义输入。 + + Returns: + ToolOption: 规范化后的工具定义对象。 + """ + if isinstance(tool_definition, ToolOption): + return tool_definition + return ToolOption.from_definition(tool_definition) + + +def normalize_tool_options( + tool_definitions: List[ToolDefinitionInput] | None, +) -> List[ToolOption] | None: + """批量规范化工具定义列表。 + + Args: + tool_definitions: 原始工具定义列表。 + + Returns: + List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。 + """ + if not tool_definitions: + return None + return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions] + + +@dataclass(slots=True) +class ToolCall: + """来自模型输出的工具调用。""" + + call_id: str + func_name: str + args: Dict[str, Any] | None = None + + def __post_init__(self) -> None: + """执行工具调用的基础校验。 + + Raises: + ValueError: 当工具调用标识或函数名缺失时抛出。 """ - 初始化工具调用 - :param call_id: 工具调用ID - :param func_name: 要调用的函数名称 - :param args: 工具调用参数 - """ - self.call_id: str = call_id - self.func_name: str = func_name - self.args: dict | None = args + if not self.call_id: + raise ValueError("工具调用 ID 不能为空") + if not self.func_name: + raise ValueError("工具函数名称不能为空") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index c84a4f34..84af5052 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,29 +1,49 @@ -import re -import asyncio -import time -import random -import json - +from dataclasses import dataclass from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Set + +import asyncio +import random +import re +import time import traceback from src.common.logger import get_logger +from src.common.data_models.llm_service_data_models import ( + LLMAudioTranscriptionResult, + LLMEmbeddingResult, + LLMResponseResult, +) from src.config.config import config_manager from src.config.model_configs import APIProvider, ModelInfo, TaskConfig -from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat, RespFormatType -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType -from .model_client.base_client import BaseClient, APIResponse, client_registry -from .model_client import ensure_configured_clients_loaded -from .utils import compress_messages, llm_usage_recorder -from .exceptions import ( - NetworkConnectionError, - RespNotOkException, +from src.llm_models.exceptions import ( EmptyResponseException, ModelAttemptFailed, + NetworkConnectionError, + RespNotOkException, + RespParseException, ) +from src.llm_models.model_client import ensure_configured_clients_loaded +from src.llm_models.model_client.base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + ClientRequest, + EmbeddingRequest, + ResponseRequest, + client_registry, +) +from src.llm_models.payload_content.message import Message, MessageBuilder +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ( + ToolCall, + ToolDefinitionInput, + ToolOption, + normalize_tool_options, +) +from src.llm_models.utils import compress_messages, llm_usage_recorder install(extra_lines=3) @@ -38,106 +58,69 @@ class RequestType(Enum): AUDIO = "audio" -class LLMRequest: - """LLM请求类""" +@dataclass(slots=True) +class LLMExecutionResult: + """单次模型执行结果。""" - def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: - self.task_name = request_type - self.model_for_task = model_set + api_response: APIResponse + model_info: ModelInfo + + +class LLMOrchestrator: + """LLM 编排调度器。""" + + def __init__(self, task_name: str, request_type: str = "") -> None: + """初始化 LLM 请求调度器。 + + Args: + task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 + request_type: 当前请求的业务类型标识。 + """ + self.task_name = task_name.strip() self.request_type = request_type - self._task_config_signature = self._build_task_config_signature(model_set) - self._task_config_name = self._resolve_task_config_name(model_set) + self.model_for_task = self._get_task_config_or_raise() self.model_usage: Dict[str, Tuple[int, int, int]] = { model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" - @staticmethod - def _build_task_config_signature(model_set: TaskConfig) -> tuple: - return ( - tuple(model_set.model_list), - model_set.selection_strategy, - model_set.temperature, - model_set.max_tokens, - model_set.slow_threshold, - ) + def _get_task_config_or_raise(self) -> TaskConfig: + """获取当前任务名对应的最新任务配置。 - @staticmethod - def _iter_task_config_items(model_task_config: Any) -> list[tuple[str, TaskConfig]]: - cls = type(model_task_config) - if hasattr(cls, "model_fields"): - attrs = [name for name in cls.model_fields.keys() if not name.startswith("__")] - else: - attrs = [name for name in dir(model_task_config) if not name.startswith("__")] + Returns: + TaskConfig: 当前任务对应的最新任务配置对象。 - items: list[tuple[str, TaskConfig]] = [] - for attr in attrs: - value = getattr(model_task_config, attr, None) - if isinstance(value, TaskConfig): - items.append((attr, value)) - return items + Raises: + ValueError: 当任务名为空或对应配置不存在时抛出。 + """ + if not self.task_name: + raise ValueError("任务配置名称不能为空") - def _resolve_task_config_by_signature(self, model_set: TaskConfig) -> Optional[str]: - target_signature = self._build_task_config_signature(model_set) model_task_config = config_manager.get_model_config().model_task_config - return next( - ( - attr - for attr, value in self._iter_task_config_items(model_task_config) - if self._build_task_config_signature(value) == target_signature - ), - None, - ) - - def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]: - try: - model_task_config = config_manager.get_model_config().model_task_config - except Exception: - return None - for attr, value in self._iter_task_config_items(model_task_config): - if value is model_set: - return attr - try: - return self._resolve_task_config_by_signature(model_set) - except Exception: - return None - return None - - def _get_latest_task_config(self) -> TaskConfig: - if self._task_config_name: - try: - model_task_config = config_manager.get_model_config().model_task_config - value = getattr(model_task_config, self._task_config_name, None) - if isinstance(value, TaskConfig): - return value - except Exception: - return self.model_for_task - try: - if resolved_name := self._resolve_task_config_by_signature(self.model_for_task): - self._task_config_name = resolved_name - model_task_config = config_manager.get_model_config().model_task_config - value = getattr(model_task_config, resolved_name, None) - if isinstance(value, TaskConfig): - return value - except Exception: - return self.model_for_task - return self.model_for_task + task_config = getattr(model_task_config, self.task_name, None) + if not isinstance(task_config, TaskConfig): + raise ValueError(f"未找到名为 '{self.task_name}' 的任务配置") + return task_config def _refresh_task_config(self) -> TaskConfig: - latest = self._get_latest_task_config() + """刷新并同步任务配置缓存。 + + Returns: + TaskConfig: 刷新后的任务配置对象。 + """ + latest = self._get_task_config_or_raise() if latest is not self.model_for_task: self.model_for_task = latest - self._task_config_signature = self._build_task_config_signature(latest) if list(self.model_usage.keys()) != latest.model_list: self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list} return self.model_for_task def _check_slow_request(self, time_cost: float, model_name: str) -> None: - """检查请求是否过慢并输出警告日志 + """检查请求是否过慢并输出警告日志。 Args: - time_cost: 请求耗时(秒) - model_name: 使用的模型名称 + time_cost: 请求耗时(秒)。 + model_name: 使用的模型名称。 """ threshold = self.model_for_task.slow_threshold if time_cost > threshold: @@ -147,6 +130,31 @@ class LLMRequest: f" 如果你认为该警告出现得过于频繁,请调整model_config.toml中对应任务的slow_threshold至符合你实际情况的合理值" ) + @staticmethod + def _build_generation_result( + content: str, + reasoning_content: str, + model_name: str, + tool_calls: List[ToolCall] | None, + ) -> LLMResponseResult: + """构建统一的文本响应结果。 + + Args: + content: 模型返回的正文内容。 + reasoning_content: 模型返回的推理内容。 + model_name: 实际使用的模型名称。 + tool_calls: 模型返回的工具调用列表。 + + Returns: + LLMResponseResult: 统一文本响应结果对象。 + """ + return LLMResponseResult( + response=content, + reasoning=reasoning_content, + model_name=model_name, + tool_calls=tool_calls, + ) + async def generate_response_for_image( self, prompt: str, @@ -154,15 +162,20 @@ class LLMRequest: image_format: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 为图像生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """为图像生成响应。 + Args: - prompt (str): 提示词 - image_base64 (str): 图像的Base64编码字符串 - image_format (str): 图像格式(如 'png', 'jpeg' 等) + prompt: 文本提示词。 + image_base64: 图像的 Base64 编码字符串。 + image_format: 图像格式,例如 `png`、`jpeg`。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ self._refresh_task_config() start_time = time.time() @@ -175,12 +188,15 @@ class LLMRequest: ) return [message_builder.build()] - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info content = response.content or "" reasoning_content = response.reasoning_content or "" tool_calls = response.tool_calls @@ -198,44 +214,49 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time_cost, ) - return content, (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls) + + async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult: + """为语音生成转录响应。 - async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: - """ - 为语音生成响应 Args: - voice_base64 (str): 语音的Base64编码字符串 + voice_base64: 语音的 Base64 编码字符串。 + Returns: - (Optional[str]): 生成的文本描述或None + LLMAudioTranscriptionResult: 语音转写结果对象。 """ self._refresh_task_config() - response, _ = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.AUDIO, audio_base64=voice_base64, ) - return response.content or None + return LLMAudioTranscriptionResult(text=execution_result.api_response.content or None) async def generate_response_async( self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: List[ToolDefinitionInput] | None = None, response_format: RespFormat | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 异步生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """异步生成文本响应。 + Args: - prompt (str): 提示词 - temperature (float, optional): 温度参数 - max_tokens (int, optional): 最大token数 - tools (Optional[List[Dict[str, Any]]]): 工具列表 - response_format (RespFormat | None): 响应格式 - raise_when_empty (bool): 当响应为空时是否抛出异常 + prompt: 提示词。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + tools: 原始工具定义列表。 + response_format: 响应格式约束。 + raise_when_empty: 保留字段,当前版本暂未单独使用。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ + del raise_when_empty self._refresh_task_config() start_time = time.time() @@ -246,14 +267,17 @@ class LLMRequest: tool_built = self._build_tool_options(tools) - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, response_format=response_format, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info logger.debug(f"LLM请求总耗时: {time.time() - start_time}") logger.debug(f"LLM生成内容: {response}") @@ -273,54 +297,63 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time.time() - start_time, ) - return content or "", (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) async def generate_response_with_message_async( self, message_factory: Callable[[BaseClient], List[Message]], temperature: Optional[float] = None, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: List[ToolDefinitionInput] | None = None, response_format: RespFormat | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 异步生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """基于外部消息工厂异步生成响应。 + Args: - message_factory (Callable[[BaseClient], List[Message]]): 已构建好的消息工厂 - temperature (float, optional): 温度参数 - max_tokens (int, optional): 最大token数 - tools (Optional[List[Dict[str, Any]]]): 工具列表 - response_format (RespFormat | None): 响应格式 - raise_when_empty (bool): 当响应为空时是否抛出异常 + message_factory: 消息工厂,会根据客户端能力构建消息列表。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + tools: 原始工具定义列表。 + response_format: 响应格式约束。 + raise_when_empty: 保留字段,当前版本暂未单独使用。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ + del raise_when_empty self._refresh_task_config() start_time = time.time() if self.request_type.startswith("maisaka_"): logger.info( - f"LLMRequest[{self.request_type}] generate_response_with_message_async started " + f"LLMOrchestrator[{self.request_type}] generate_response_with_message_async started " f"(temperature={temperature}, max_tokens={max_tokens}, tools={len(tools or [])})" ) if self.request_type.startswith("maisaka_"): - logger.info(f"LLMRequest[{self.request_type}] building internal tool options from {len(tools or [])} tool(s)") + logger.info( + f"LLMOrchestrator[{self.request_type}] building internal tool options from {len(tools or [])} tool(s)" + ) tool_built = self._build_tool_options(tools) if self.request_type.startswith("maisaka_"): - logger.info(f"LLMRequest[{self.request_type}] built {len(tool_built or [])} internal tool option(s)") + logger.info(f"LLMOrchestrator[{self.request_type}] built {len(tool_built or [])} internal tool option(s)") - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, response_format=response_format, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info if self.request_type.startswith("maisaka_"): logger.info( - f"LLMRequest[{self.request_type}] generate_response_with_message_async finished " + f"LLMOrchestrator[{self.request_type}] generate_response_with_message_async finished " f"(model={model_info.name}, time_cost={time.time() - start_time:.2f}s)" ) @@ -344,116 +377,25 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time_cost, ) - return content or "", (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) - async def generate_structured_response_async( - self, - prompt: str, - schema: type | dict[str, Any], - fallback_result: dict[str, Any] | None = None, - temperature: Optional[float] = 0.0, - max_tokens: Optional[int] = None, - ) -> Tuple[dict[str, Any], Tuple[str, str, Optional[List[ToolCall]]], bool]: - """ - 结构化输出快速接口: - - 默认启用 JSON_SCHEMA 严格模式 - - 单模型单次尝试(不重试、不切换模型) - - 失败时立即返回 fallback_result + async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult: + """获取嵌入向量。 - Returns: - (结构化结果, (推理内容, 模型名, 工具调用), 是否成功) - """ - self._refresh_task_config() - start_time = time.time() - - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_list = [message_builder.build()] - - response_format = RespFormat(schema=schema, format_type=RespFormatType.JSON_SCHEMA) - if response_format.schema: - response_format.schema["strict"] = True - - model_info, api_provider, client = self._select_model() - fallback_data = fallback_result or {} - - try: - response = await self._attempt_request_on_model( - model_info=model_info, - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - message_list=message_list, - tool_options=None, - response_format=response_format, - stream_response_handler=None, - async_response_parser=None, - temperature=temperature, - max_tokens=max_tokens, - embedding_input=None, - audio_base64=None, - retry_limit=1, - ) - - time_cost = time.time() - start_time - self._check_slow_request(time_cost, model_info.name) - - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - parsed_result: dict[str, Any] | None = None - if response.content: - try: - parsed = json.loads(response.content) - if isinstance(parsed, dict): - parsed_result = parsed - except json.JSONDecodeError: - parsed_result = None - - if parsed_result is None: - logger.warning(f"结构化输出解析失败,使用降级结果。模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) - return fallback_data, (reasoning_content, model_info.name, tool_calls), False - - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - if response_usage := response.usage: - total_tokens += response_usage.total_tokens - llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=response_usage, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", - time_cost=time_cost, - ) - self.model_usage[model_info.name] = (total_tokens, penalty, max(usage_penalty - 1, 0)) - return parsed_result, (reasoning_content, model_info.name, tool_calls), True - - except Exception as e: - time_cost = time.time() - start_time - self._check_slow_request(time_cost, model_info.name) - logger.warning(f"结构化输出请求失败,直接降级。模型: {model_info.name}, 错误: {e}") - - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) - - return fallback_data, ("", model_info.name, None), False - - async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """ - 获取嵌入向量 Args: - embedding_input (str): 获取嵌入的目标 + embedding_input: 待编码的文本。 + Returns: - (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + LLMEmbeddingResult: 向量生成结果对象。 """ self._refresh_task_config() start_time = time.time() - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.EMBEDDING, embedding_input=embedding_input, ) + response = execution_result.api_response + model_info = execution_result.model_info embedding = response.embedding if usage := response.usage: llm_usage_recorder.record_usage_to_database( @@ -466,11 +408,207 @@ class LLMRequest: ) if not embedding: raise RuntimeError("获取embedding失败") - return embedding, model_info.name + return LLMEmbeddingResult(embedding=embedding, model_name=model_info.name) + + def _resolve_effective_temperature( + self, + model_info: ModelInfo, + temperature: Optional[float], + ) -> Optional[float]: + """解析响应请求最终使用的温度参数。 + + Args: + model_info: 当前模型信息。 + temperature: 调用方显式传入的温度。 + + Returns: + Optional[float]: 最终生效的温度参数。 + """ + if temperature is not None: + return temperature + if model_info.temperature is not None: + return model_info.temperature + if "temperature" in model_info.extra_params: + return model_info.extra_params["temperature"] + return self.model_for_task.temperature + + def _resolve_effective_max_tokens( + self, + model_info: ModelInfo, + max_tokens: Optional[int], + ) -> Optional[int]: + """解析响应请求最终使用的最大输出 token 数。 + + Args: + model_info: 当前模型信息。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + Optional[int]: 最终生效的最大 token 数。 + """ + if max_tokens is not None: + return max_tokens + if model_info.max_tokens is not None: + return model_info.max_tokens + if "max_tokens" in model_info.extra_params: + return model_info.extra_params["max_tokens"] + return self.model_for_task.max_tokens + + def _build_response_request( + self, + model_info: ModelInfo, + message_list: List[Message], + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + stream_response_handler: Optional[Callable[..., Any]], + async_response_parser: Optional[Callable[..., Any]], + interrupt_flag: asyncio.Event | None, + temperature: Optional[float], + max_tokens: Optional[int], + ) -> ResponseRequest: + """构建统一响应请求对象。 + + Args: + model_info: 当前模型信息。 + message_list: 请求消息列表。 + tool_options: 工具定义列表。 + response_format: 输出格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + interrupt_flag: 外部中断标记。 + temperature: 调用方显式传入的温度。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + ResponseRequest: 统一响应请求对象。 + """ + return ResponseRequest( + model_info=model_info, + message_list=list(message_list), + tool_options=None if tool_options is None else list(tool_options), + max_tokens=self._resolve_effective_max_tokens(model_info, max_tokens), + temperature=self._resolve_effective_temperature(model_info, temperature), + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + extra_params=dict(model_info.extra_params), + ) + + @staticmethod + def _build_embedding_request( + model_info: ModelInfo, + embedding_input: str, + ) -> EmbeddingRequest: + """构建统一嵌入请求对象。 + + Args: + model_info: 当前模型信息。 + embedding_input: 嵌入输入文本。 + + Returns: + EmbeddingRequest: 统一嵌入请求对象。 + """ + return EmbeddingRequest( + model_info=model_info, + embedding_input=embedding_input, + extra_params=dict(model_info.extra_params), + ) + + @staticmethod + def _build_audio_transcription_request( + model_info: ModelInfo, + audio_base64: str, + max_tokens: Optional[int] = None, + ) -> AudioTranscriptionRequest: + """构建统一音频转录请求对象。 + + Args: + model_info: 当前模型信息。 + audio_base64: Base64 编码的音频数据。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + AudioTranscriptionRequest: 统一音频转录请求对象。 + """ + return AudioTranscriptionRequest( + model_info=model_info, + audio_base64=audio_base64, + max_tokens=max_tokens, + extra_params=dict(model_info.extra_params), + ) + + def _build_client_request( + self, + request_type: RequestType, + model_info: ModelInfo, + message_list: List[Message], + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + stream_response_handler: Optional[Callable[..., Any]], + async_response_parser: Optional[Callable[..., Any]], + interrupt_flag: asyncio.Event | None, + temperature: Optional[float], + max_tokens: Optional[int], + embedding_input: str | None, + audio_base64: str | None, + ) -> ClientRequest: + """按请求类型构建统一客户端请求对象。 + + Args: + request_type: 请求类型。 + model_info: 当前模型信息。 + message_list: 请求消息列表。 + tool_options: 工具定义列表。 + response_format: 响应格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + interrupt_flag: 外部中断标记。 + temperature: 调用方显式传入的温度。 + max_tokens: 调用方显式传入的最大 token 数。 + embedding_input: 嵌入输入文本。 + audio_base64: Base64 编码的音频数据。 + + Returns: + ClientRequest: 对应请求类型的统一请求对象。 + + Raises: + ValueError: 请求类型未知或缺少必需字段时抛出。 + """ + if request_type == RequestType.RESPONSE: + return self._build_response_request( + model_info=model_info, + message_list=message_list, + tool_options=tool_options, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + temperature=temperature, + max_tokens=max_tokens, + ) + if request_type == RequestType.EMBEDDING: + if embedding_input is None: + raise ValueError("嵌入输入不能为空") + return self._build_embedding_request(model_info=model_info, embedding_input=embedding_input) + if request_type == RequestType.AUDIO: + if audio_base64 is None: + raise ValueError("音频 Base64 不能为空") + return self._build_audio_transcription_request( + model_info=model_info, + audio_base64=audio_base64, + max_tokens=max_tokens, + ) + raise ValueError(f"不支持的请求类型: {request_type}") def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据配置的策略选择模型:balance(负载均衡)或 random(随机选择) + """根据策略选择一个可用模型。 + + Args: + exclude_models: 本次请求中需要排除的模型名称集合。 + + Returns: + Tuple[ModelInfo, APIProvider, BaseClient]: 选中的模型、提供商与客户端实例。 """ self._refresh_task_config() available_models = { @@ -513,75 +651,38 @@ class LLMRequest: async def _attempt_request_on_model( self, - model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, - request_type: RequestType, - message_list: List[Message], - tool_options: list[ToolOption] | None, - response_format: RespFormat | None, - stream_response_handler: Optional[Callable[..., Any]], - async_response_parser: Optional[Callable[..., Any]], - temperature: Optional[float], - max_tokens: Optional[int], - embedding_input: str | None, - audio_base64: str | None, + request: ClientRequest, retry_limit: Optional[int] = None, ) -> APIResponse: - """ - 在单个模型上执行请求,包含针对临时错误的重试逻辑。 - 如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。 + """在单个模型上执行请求,并处理重试逻辑。 + + Args: + api_provider: 当前请求对应的 API 提供商配置。 + client: 已初始化的客户端实例。 + request: 统一客户端请求对象。 + retry_limit: 显式指定的重试次数;未指定时使用 Provider 配置。 + + Returns: + APIResponse: 统一响应对象。 + + Raises: + ModelAttemptFailed: 当当前模型重试耗尽或遇到硬错误时抛出。 """ retry_remain = retry_limit if retry_limit is not None else api_provider.max_retry retry_remain = max(1, retry_remain) - compressed_messages: Optional[List[Message]] = None + model_info = request.model_info + original_response_request = request if isinstance(request, ResponseRequest) else None + active_request: ClientRequest = request while retry_remain > 0: try: - if request_type == RequestType.RESPONSE: - # 温度优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 - effective_temperature = temperature - if effective_temperature is None: - effective_temperature = model_info.temperature - if effective_temperature is None: - effective_temperature = (model_info.extra_params or {}).get("temperature") - if effective_temperature is None: - effective_temperature = self.model_for_task.temperature - - # max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 - effective_max_tokens = max_tokens - if effective_max_tokens is None: - effective_max_tokens = model_info.max_tokens - if effective_max_tokens is None: - effective_max_tokens = (model_info.extra_params or {}).get("max_tokens") - if effective_max_tokens is None: - effective_max_tokens = self.model_for_task.max_tokens - - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=effective_max_tokens, - temperature=effective_temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input is not None, "嵌入输入不能为空" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "音频Base64不能为空" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) + if isinstance(active_request, ResponseRequest): + return await client.get_response(active_request) + if isinstance(active_request, EmbeddingRequest): + return await client.get_embedding(active_request) + return await client.get_audio_transcriptions(active_request) except EmptyResponseException as e: # 空回复:通常为临时问题,单独记录并重试 original_error_info = self._get_original_error_info(e) @@ -639,12 +740,19 @@ class LLMRequest: continue # 特殊处理413,尝试压缩 - if e.status_code == 413 and message_list and not compressed_messages: + if ( + e.status_code == 413 + and isinstance(active_request, ResponseRequest) + and active_request.message_list + and original_response_request is not None + and active_request.message_list == original_response_request.message_list + ): logger.warning( f"任务 '{task_display}' 的模型 '{model_info.name}' 返回413请求体过大,尝试压缩后重试..." ) # 压缩消息本身不消耗重试次数 - compressed_messages = compress_messages(message_list) + compressed_messages = compress_messages(active_request.message_list) + active_request = active_request.copy_with(message_list=compressed_messages) continue # 不可重试的HTTP错误 @@ -653,6 +761,22 @@ class LLMRequest: ) raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e + except RespParseException as e: + original_error_info = self._get_original_error_info(e) + retry_remain -= 1 + task_display = self.request_type or "未知任务" + if retry_remain <= 0: + logger.error( + f"任务 '{task_display}' 的模型 '{model_info.name}' 在响应解析多次失败后仍然失败。{original_error_info}" + ) + raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e + + logger.warning( + f"任务 '{task_display}' 的模型 '{model_info.name}' 返回内容解析失败(可重试): {str(e)}{original_error_info}。" + f"剩余重试次数: {retry_remain}" + ) + await asyncio.sleep(api_provider.retry_interval) + except Exception as e: logger.error(traceback.format_exc()) @@ -672,7 +796,7 @@ class LLMRequest: self, request_type: RequestType, message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, - tool_options: list[ToolOption] | None = None, + tool_options: List[ToolOption] | None = None, response_format: RespFormat | None = None, stream_response_handler: Optional[Callable[..., Any]] = None, async_response_parser: Optional[Callable[..., Any]] = None, @@ -680,9 +804,25 @@ class LLMRequest: max_tokens: Optional[int] = None, embedding_input: str | None = None, audio_base64: str | None = None, - ) -> Tuple[APIResponse, ModelInfo]: - """ - 调度器函数,负责模型选择、故障切换。 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMExecutionResult: + """执行一次完整的模型调度请求。 + + Args: + request_type: 请求类型。 + message_factory: 消息工厂,仅在响应请求中使用。 + tool_options: 工具定义列表。 + response_format: 响应格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + embedding_input: 嵌入输入文本。 + audio_base64: Base64 编码的音频数据。 + interrupt_flag: 外部中断标记。 + + Returns: + LLMExecutionResult: 单次模型执行结果对象。 """ failed_models_this_request: Set[str] = set() max_attempts = len(self.model_for_task.model_list) @@ -692,36 +832,30 @@ class LLMRequest: model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request) if self.request_type.startswith("maisaka_"): logger.info( - f"LLMRequest[{self.request_type}] selected model={model_info.name} " + f"LLMOrchestrator[{self.request_type}] selected model={model_info.name} " f"provider={api_provider.name} request_type={request_type.value}" ) message_list = [] if message_factory: if self.request_type.startswith("maisaka_"): - logger.info(f"LLMRequest[{self.request_type}] building message list via message_factory") + logger.info(f"LLMOrchestrator[{self.request_type}] building message list via message_factory") message_list = message_factory(client) if self.request_type.startswith("maisaka_"): logger.info( - f"LLMRequest[{self.request_type}] message_factory returned {len(message_list)} message(s)" + f"LLMOrchestrator[{self.request_type}] message_factory returned {len(message_list)} message(s)" ) try: - if self.request_type.startswith("maisaka_"): - logger.info( - f"LLMRequest[{self.request_type}] sending request to model={model_info.name} " - f"with tool_options={len(tool_options or [])}" - ) - response = await self._attempt_request_on_model( - model_info, - api_provider, - client, - request_type, + request = self._build_client_request( + request_type=request_type, + model_info=model_info, message_list=message_list, tool_options=tool_options, response_format=response_format, stream_response_handler=stream_response_handler, async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, temperature=temperature, max_tokens=max_tokens, embedding_input=embedding_input, @@ -729,13 +863,23 @@ class LLMRequest: ) if self.request_type.startswith("maisaka_"): logger.info( - f"LLMRequest[{self.request_type}] model={model_info.name} returned API response" + f"LLMOrchestrator[{self.request_type}] sending request to model={model_info.name} " + f"with tool_options={len(tool_options or [])}" + ) + response = await self._attempt_request_on_model( + api_provider, + client, + request=request, + ) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] model={model_info.name} returned API response" ) total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] if response_usage := response.usage: total_tokens += response_usage.total_tokens self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) - return response, model_info + return LLMExecutionResult(api_response=response, model_info=model_info) except ModelAttemptFailed as e: last_exception = e.original_exception or e @@ -753,46 +897,27 @@ class LLMRequest: raise last_exception raise RuntimeError("请求失败,所有可用模型均已尝试失败。") - def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - # sourcery skip: extract-method - """构建工具选项列表""" - if not tools: - return None - tool_options: List[ToolOption] = [] - for tool in tools: - tool_legal = True - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - try: - assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - tool_options_builder.add_param( - name=param[0], - param_type=param[1], - description=param[2], - required=param[3], - enum_values=param[4], - ) - except AssertionError as ae: - tool_legal = False - logger.error(f"{param[0]} 参数定义错误: {str(ae)}") - except Exception as e: - tool_legal = False - logger.error(f"构建工具参数失败: {str(e)}") - if tool_legal: - tool_options.append(tool_options_builder.build()) - return tool_options or None + def _build_tool_options(self, tools: List[ToolDefinitionInput] | None) -> List[ToolOption] | None: + """将任意输入工具定义列表规范化为内部工具选项。 + + Args: + tools: 原始工具定义列表。 + + Returns: + List[ToolOption] | None: 规范化后的工具选项列表。 + """ + return normalize_tool_options(tools) @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" + """提取 `` 思维链内容。 + + Args: + content: 原始模型输出文本。 + + Returns: + Tuple[str, str]: `(正文内容, 推理内容)`。 + """ match = re.search(r"(?:)?(.*?)", content, re.DOTALL) content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() reasoning = match[1].strip() if match else "" @@ -800,7 +925,14 @@ class LLMRequest: @staticmethod def _get_original_error_info(e: Exception) -> str: - """获取原始错误信息""" + """提取底层异常信息。 + + Args: + e: 当前捕获的异常对象。 + + Returns: + str: 可直接拼接到日志中的底层异常描述。 + """ if e.__cause__: original_error_type = type(e.__cause__).__name__ original_error_msg = str(e.__cause__) @@ -811,17 +943,16 @@ class LLMRequest: class TempMethodsLLMUtils: @staticmethod def get_model_info_by_name(model_name: str) -> ModelInfo: - """根据模型名称获取模型信息 + """根据模型名称获取模型信息。 Args: - model_config: ModelConfig实例 model_name: 模型名称 Returns: - ModelInfo: 模型信息 + ModelInfo: 模型信息。 Raises: - ValueError: 未找到指定模型 + ValueError: 未找到指定模型。 """ for model in config_manager.get_model_config().models: if model.name == model_name: @@ -830,17 +961,16 @@ class TempMethodsLLMUtils: @staticmethod def get_provider_by_name(provider_name: str) -> APIProvider: - """根据提供商名称获取提供商信息 + """根据提供商名称获取提供商信息。 Args: - model_config: ModelConfig实例 provider_name: 提供商名称 Returns: - APIProvider: API提供商信息 + APIProvider: API 提供商信息。 Raises: - ValueError: 未找到指定提供商 + ValueError: 未找到指定提供商。 """ for provider in config_manager.get_model_config().api_providers: if provider.name == provider_name: diff --git a/src/maisaka/cli.py b/src/maisaka/cli.py index f54ce8af..15bf694e 100644 --- a/src/maisaka/cli.py +++ b/src/maisaka/cli.py @@ -65,8 +65,8 @@ class BufferCLI: self._mcp_manager: Optional[MCPManager] = None self._init_llm() - def _init_llm(self): - """Initialize the LLM service from the main project config.""" + def _init_llm(self) -> None: + """从主项目配置初始化 LLM 服务。""" thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower() enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None @@ -77,7 +77,7 @@ class BufferCLI: enable_thinking=enable_thinking, ) - model_name = self.llm_service._model_name + model_name = self.llm_service.get_current_model_name() console.print(f"[success][OK] LLM service initialized[/success] [muted](model: {model_name})[/muted]") def _build_tool_context(self) -> ToolHandlerContext: diff --git a/src/maisaka/llm_service.py b/src/maisaka/llm_service.py index e955cb66..5ab2b3ad 100644 --- a/src/maisaka/llm_service.py +++ b/src/maisaka/llm_service.py @@ -1,6 +1,6 @@ -""" -MaiSaka LLM 服务 - 使用主项目 LLM 系统 -将主项目的 LLMRequest 适配为 MaiSaka 需要的接口 +"""MaiSaka LLM 服务。 + +该模块基于主项目服务层封装 MaiSaka 所需的对话与工具调用接口。 """ from base64 import b64decode @@ -8,7 +8,7 @@ from dataclasses import dataclass from datetime import datetime from io import BytesIO from time import perf_counter -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import asyncio import random @@ -23,9 +23,16 @@ from src.common.data_models.mai_message_data_model import MaiMessage from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt from src.config.config import config_manager, global_config +from src.llm_models.model_client.base_client import BaseClient from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType -from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, ToolParamType -from src.llm_models.utils_model import LLMRequest +from src.llm_models.payload_content.tool_option import ( + ToolCall, + ToolDefinitionInput, + ToolOption, + normalize_tool_options, +) +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from . import config from .config import console @@ -36,17 +43,15 @@ from .message_adapter import ( get_message_kind, get_message_role, get_message_text, - get_tool_call_id, get_tool_calls, - remove_last_perception, to_llm_message, ) logger = get_logger("maisaka_llm") -@dataclass +@dataclass(slots=True) class ChatResponse: - """LLM 对话循环单步响应""" + """LLM 对话循环单步响应。""" content: Optional[str] tool_calls: List[ToolCall] @@ -65,38 +70,34 @@ class MaiSakaLLMService: temperature: float = 0.5, max_tokens: int = 2048, enable_thinking: Optional[bool] = None, - ): - """ - 初始化 LLM 服务 + ) -> None: + """初始化 MaiSaka LLM 服务。 - 参数仅为兼容性保留,实际使用主项目配置 + Args: + api_key: 兼容旧接口保留的参数,当前不使用。 + base_url: 兼容旧接口保留的参数,当前不使用。 + model: 兼容旧接口保留的参数,当前不使用。 + chat_system_prompt: 可选的系统提示词覆盖值。 + temperature: 默认温度参数。 + max_tokens: 默认最大输出 token 数。 + enable_thinking: 是否启用思考模式。 """ + del api_key, base_url, model self._temperature = temperature self._max_tokens = max_tokens self._enable_thinking = enable_thinking - self._extra_tools: List[dict] = [] + self._extra_tools: List[ToolOption] = [] self._prompts_loaded = False self._prompt_load_lock = asyncio.Lock() - # 获取主项目模型配置 - try: - model_config = config_manager.get_model_config() - self._model_configs = model_config.model_task_config - except Exception: - # 如果配置加载失败,使用默认配置 - from src.config.model_configs import ModelTaskConfig - - self._model_configs = ModelTaskConfig() - logger.warning("无法加载主项目模型配置,使用默认配置") - - # 初始化 LLMRequest 实例(只使用 tool_use 和 replyer) - self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use") - # 主对话也使用 tool_use 模型(因为需要工具调用支持) - self._llm_planner = LLMRequest(model_set=self._model_configs.planner, request_type="maisaka_planner") + # 初始化服务层 LLM 门面(按任务名实时解析配置,确保热重载生效) + self._llm_tool_use = LLMServiceClient(task_name="tool_use", request_type="maisaka_tool_use") + # 主对话也使用 planner 模型 + self._llm_planner = LLMServiceClient(task_name="planner", request_type="maisaka_planner") self._llm_chat = self._llm_planner self._llm_utils = self._llm_tool_use # 回复生成使用 replyer 模型 - self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer") + self._llm_replyer = LLMServiceClient(task_name="replyer", request_type="maisaka_replyer") # 尝试修复数据库 schema(忽略错误) self._try_fix_database_schema() @@ -111,15 +112,30 @@ class MaiSakaLLMService: else: self._chat_system_prompt = chat_system_prompt - self._model_name = ( - self._model_configs.planner.model_list[0] if self._model_configs.planner.model_list else "未配置" - ) # 子模块提示词同样采用懒加载 self._emotion_prompt: Optional[str] = None self._cognition_prompt: Optional[str] = None + def get_current_model_name(self) -> str: + """获取当前 Maisaka 对话主模型名称。 + + Returns: + str: 当前 planner 任务的首选模型名;未配置时返回 ``未配置``。 + """ + try: + model_task_config = config_manager.get_model_config().model_task_config + if model_task_config.planner.model_list: + return model_task_config.planner.model_list[0] + except Exception as exc: + logger.warning(f"获取当前 Maisaka 模型名称失败: {exc}") + return "未配置" + def _try_fix_database_schema(self) -> None: - """尝试修复数据库 schema,添加缺失的列""" + """尝试修复数据库 schema。 + + Returns: + None: 该方法仅执行数据库修复副作用。 + """ try: from src.common.database.database_client import get_db_session from sqlalchemy import text @@ -139,7 +155,11 @@ class MaiSakaLLMService: pass def _build_personality_prompt(self) -> str: - """构建人设信息,参考 replyer 的做法""" + """构建当前人设提示词。 + + Returns: + str: 最终用于系统提示词的人设描述。 + """ try: bot_name = global_config.bot.nickname if global_config.bot.alias_names: @@ -169,60 +189,21 @@ class MaiSakaLLMService: # 返回默认人设 return "你的名字是麦麦,你是一个活泼可爱的AI助手。" - def set_extra_tools(self, tools: List[dict]) -> None: - """设置额外的工具定义(如 MCP 工具)""" - self._extra_tools = [self._normalize_extra_tool(tool) for tool in tools] - logger.info(f"Normalized {len(self._extra_tools)} extra tool(s) for Maisaka") + def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None: + """设置额外工具定义。 - @staticmethod - def _json_type_to_tool_param_type(json_type: str) -> ToolParamType: - normalized = (json_type or "").lower() - if normalized == "integer": - return ToolParamType.INTEGER - if normalized == "number": - return ToolParamType.FLOAT - if normalized == "boolean": - return ToolParamType.BOOLEAN - return ToolParamType.STRING - - @classmethod - def _normalize_extra_tool(cls, tool: dict) -> dict: - """Normalize external/OpenAI-style tool definitions into the internal tool schema.""" - if "name" in tool and "description" in tool: - return tool - - if tool.get("type") != "function": - return tool - - function_info = tool.get("function", {}) - parameters_schema = function_info.get("parameters", {}) or {} - required_names = set(parameters_schema.get("required", []) or []) - properties = parameters_schema.get("properties", {}) or {} - parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] - - for param_name, param_schema in properties.items(): - if not isinstance(param_schema, dict): - continue - enum_values = param_schema.get("enum") - normalized_enum = [str(value) for value in enum_values] if isinstance(enum_values, list) else None - parameters.append( - ( - str(param_name), - cls._json_type_to_tool_param_type(str(param_schema.get("type", "string"))), - str(param_schema.get("description", "")), - param_name in required_names, - normalized_enum, - ) - ) - - return { - "name": str(function_info.get("name", "")), - "description": str(function_info.get("description", "")), - "parameters": parameters, - } + Args: + tools: 外部传入的工具定义列表,例如 MCP 暴露的 OpenAI-compatible 工具。 + """ + self._extra_tools = normalize_tool_options(tools) or [] + logger.info(f"已为 Maisaka 加载 {len(self._extra_tools)} 个额外工具") async def _ensure_prompts_loaded(self) -> None: - """异步懒加载提示词,避免在运行中的事件循环里同步渲染 prompt。""" + """异步懒加载提示词。 + + Returns: + None: 该方法仅刷新内部提示词缓存。 + """ if self._prompts_loaded: return @@ -260,7 +241,14 @@ class MaiSakaLLMService: @staticmethod def _get_role_badge_style(role: str) -> str: - """为不同 role 返回不同的标签样式。""" + """为不同角色返回终端标签样式。 + + Args: + role: 消息角色名称。 + + Returns: + str: Rich 可识别的样式字符串。 + """ if role == "system": return "bold white on blue" if role == "user": @@ -273,7 +261,14 @@ class MaiSakaLLMService: @staticmethod def _build_terminal_image_preview(image_base64: str) -> Optional[str]: - """Build a low-resolution ASCII preview for terminals without inline-image support.""" + """构建终端 ASCII 图片预览。 + + Args: + image_base64: 图片的 Base64 数据。 + + Returns: + Optional[str]: 可渲染的 ASCII 预览文本;失败时返回 `None`。 + """ ascii_chars = " .:-=+*#%@" try: @@ -291,7 +286,7 @@ class MaiSakaLLMService: except Exception: return None - rows: list[str] = [] + rows: List[str] = [] for row_index in range(preview_height): row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width] row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels) @@ -301,12 +296,19 @@ class MaiSakaLLMService: @staticmethod def _render_message_content(content: Any) -> object: - """把消息内容转成适合 Rich 输出的 renderable。""" + """将消息内容转换为 Rich 可渲染对象。 + + Args: + content: 原始消息内容。 + + Returns: + object: Rich 可渲染对象。 + """ if isinstance(content, str): return Text(content) if isinstance(content, list): - parts: list[object] = [] + parts: List[object] = [] for item in content: if isinstance(item, str): parts.append(Text(item)) @@ -316,7 +318,7 @@ class MaiSakaLLMService: if isinstance(image_format, str) and isinstance(image_base64, str): approx_size = max(0, len(image_base64) * 3 // 4) size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B" - preview_parts: list[object] = [ + preview_parts: List[object] = [ Text(f"image/{image_format} {size_text}\nbase64 omitted", style="magenta") ] if config.TERMINAL_IMAGE_PREVIEW: @@ -343,8 +345,15 @@ class MaiSakaLLMService: return Pretty(content, expand_all=True) @staticmethod - def _format_tool_call_for_display(tool_call: Any) -> dict[str, Any]: - """将 tool call 转成适合 CLI 展示的结构。""" + def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]: + """将工具调用转换为 CLI 展示结构。 + + Args: + tool_call: 原始工具调用对象或字典。 + + Returns: + Dict[str, Any]: 统一后的展示字典。 + """ if isinstance(tool_call, dict): function_info = tool_call.get("function", {}) return { @@ -360,7 +369,16 @@ class MaiSakaLLMService: } def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel: - """Render assistant tool calls as standalone cards.""" + """渲染单个工具调用面板。 + + Args: + tool_call: 原始工具调用对象或字典。 + index: 当前工具调用在父消息中的序号。 + parent_index: 父消息在消息列表中的序号。 + + Returns: + Panel: 可直接打印的工具调用面板。 + """ title = Text.assemble( Text(" TOOL CALL ", style="bold white on magenta"), Text(f" #{parent_index}.{index}", style="muted"), @@ -373,16 +391,22 @@ class MaiSakaLLMService: ) def _render_message_panel(self, message: Any, index: int) -> Panel: - """渲染主循环 prompt 中的一条消息。""" + """渲染主循环 Prompt 中的一条消息。 + + Args: + message: 原始消息对象或字典。 + index: 当前消息序号。 + + Returns: + Panel: 可直接打印的消息面板。 + """ if isinstance(message, dict): raw_role = message.get("role", "unknown") content = message.get("content") - tool_calls = message.get("tool_calls") tool_call_id = message.get("tool_call_id") else: raw_role = getattr(message, "role", "unknown") content = getattr(message, "content", None) - tool_calls = getattr(message, "tool_calls", None) tool_call_id = getattr(message, "tool_call_id", None) role = raw_role.value if hasattr(raw_role, "value") else str(raw_role) @@ -391,7 +415,7 @@ class MaiSakaLLMService: Text(f" #{index}", style="muted"), ) - parts: list[object] = [] + parts: List[object] = [] if content not in (None, "", []): parts.append(Text(" message ", style="bold cyan")) parts.append(self._render_message_content(content)) @@ -415,30 +439,27 @@ class MaiSakaLLMService: padding=(0, 1), ) - @staticmethod - def _tool_option_to_dict(tool: "ToolOption") -> dict: - """将 ToolOption 对象转换为主项目期望的 dict 格式 + async def chat_loop_step(self, chat_history: List[MaiMessage]) -> ChatResponse: + """执行主对话循环的一步。 - 主项目的 _build_tool_options() 期望的格式: - { - "name": str, - "description": str, - "parameters": List[Tuple[name, ToolParamType, description, required, enum_values]] - } + Args: + chat_history: 当前对话历史。 + + Returns: + ChatResponse: 本轮对话生成结果。 """ - params = [] - if tool.params: - for param in tool.params: - params.append((param.name, param.param_type, param.description, param.required, param.enum_values)) - return {"name": tool.name, "description": tool.description, "parameters": params} - - async def chat_loop_step(self, chat_history: list[MaiMessage]) -> ChatResponse: - """执行对话循环的一步 - 使用 tool_use 模型""" await self._ensure_prompts_loaded() - def message_factory(client) -> list[Message]: - """将 MaiSaka 的 chat_history 转换为主项目的 Message 格式""" - messages: list[Message] = [] + def message_factory(_client: BaseClient) -> List[Message]: + """将 MaiSaka 对话历史转换为内部消息列表。 + + Args: + _client: 当前底层客户端实例。 + + Returns: + List[Message]: 规范化后的消息列表。 + """ + messages: List[Message] = [] # 首先添加系统提示词 system_msg = MessageBuilder().set_role(RoleType.System) @@ -454,15 +475,13 @@ class MaiSakaLLMService: return messages # 调用 LLM(使用带消息的接口) - # 合并内置工具和额外工具(将 ToolOption 对象转换为 dict) - all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + ( - self._extra_tools if self._extra_tools else [] - ) + # 合并内置工具和额外工具,统一交给底层规范化流程处理。 + all_tools = [*get_builtin_tools(), *self._extra_tools] # 打印消息列表 built_messages = message_factory(None) - ordered_panels: list[Panel] = [] + ordered_panels: List[Panel] = [] for index, msg in enumerate(built_messages, start=1): ordered_panels.append(self._render_message_panel(msg, index)) tool_calls = getattr(msg, "tool_calls", None) @@ -483,13 +502,18 @@ class MaiSakaLLMService: request_started_at = perf_counter() - logger.info("chat_loop_step calling planner model generate_response_with_message_async") - response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async( + logger.info("chat_loop_step calling planner model generate_response_with_messages") + generation_result = await self._llm_chat.generate_response_with_messages( message_factory=message_factory, - tools=all_tools if all_tools else None, - temperature=self._temperature, - max_tokens=self._max_tokens, + options=LLMGenerationOptions( + tool_options=all_tools if all_tools else None, + temperature=self._temperature, + max_tokens=self._max_tokens, + ), ) + response = generation_result.response + model = generation_result.model_name + tool_calls = generation_result.tool_calls elapsed = perf_counter() - request_started_at logger.info( f"chat_loop_step planner model returned in {elapsed:.2f}s " @@ -509,8 +533,15 @@ class MaiSakaLLMService: raw_message=raw_message, ) - def _filter_for_api(self, chat_history: list[MaiMessage]) -> str: - """过滤对话历史为 API 格式""" + def _filter_for_api(self, chat_history: List[MaiMessage]) -> str: + """将对话历史过滤为简单文本格式。 + + Args: + chat_history: 当前对话历史。 + + Returns: + str: 过滤后的文本上下文。 + """ parts = [] for msg in chat_history: role = get_message_role(msg) @@ -535,8 +566,15 @@ class MaiSakaLLMService: return "\n\n".join(parts) - def build_chat_context(self, user_text: str) -> list[MaiMessage]: - """构建对话上下文""" + def build_chat_context(self, user_text: str) -> List[MaiMessage]: + """构建新的对话上下文。 + + Args: + user_text: 用户输入文本。 + + Returns: + List[MaiMessage]: 初始对话上下文消息列表。 + """ return [ build_message( role=RoleType.User.value, @@ -547,8 +585,15 @@ class MaiSakaLLMService: # ──────── 分析模块(使用 utils 模型) ──────── - async def analyze_emotion(self, chat_history: list[MaiMessage]) -> str: - """情绪分析 - 使用 utils 模型""" + async def analyze_emotion(self, chat_history: List[MaiMessage]) -> str: + """执行情绪分析。 + + Args: + chat_history: 当前对话历史。 + + Returns: + str: 情绪分析文本。 + """ await self._ensure_prompts_loaded() filtered = [m for m in chat_history if get_message_kind(m) != "perception"] recent = filtered[-10:] if len(filtered) > 10 else filtered @@ -574,19 +619,26 @@ class MaiSakaLLMService: print("=" * 60 + "\n") try: - response, _ = await self._llm_utils.generate_response_async( + generation_result = await self._llm_utils.generate_response( prompt=prompt, - temperature=0.3, - max_tokens=512, + options=LLMGenerationOptions(temperature=0.3, max_tokens=512), ) + response = generation_result.response return response except Exception as e: logger.error(f"情绪分析 LLM 调用出错: {e}") return "" - async def analyze_cognition(self, chat_history: list[MaiMessage]) -> str: - """认知分析 - 使用 utils 模型""" + async def analyze_cognition(self, chat_history: List[MaiMessage]) -> str: + """执行认知分析。 + + Args: + chat_history: 当前对话历史。 + + Returns: + str: 认知分析文本。 + """ await self._ensure_prompts_loaded() filtered = [m for m in chat_history if get_message_kind(m) != "perception"] recent = filtered[-10:] if len(filtered) > 10 else filtered @@ -612,19 +664,27 @@ class MaiSakaLLMService: print("=" * 60 + "\n") try: - response, _ = await self._llm_utils.generate_response_async( + generation_result = await self._llm_utils.generate_response( prompt=prompt, - temperature=0.3, - max_tokens=512, + options=LLMGenerationOptions(temperature=0.3, max_tokens=512), ) + response = generation_result.response return response except Exception as e: logger.error(f"认知分析 LLM 调用出错: {e}") return "" - async def _removed_analyze_timing(self, chat_history: list[MaiMessage], timing_info: str) -> str: - """时间分析 - 使用 utils 模型""" + async def _removed_analyze_timing(self, chat_history: List[MaiMessage], timing_info: str) -> str: + """执行时间节奏分析。 + + Args: + chat_history: 当前对话历史。 + timing_info: 外部传入的时间信息摘要。 + + Returns: + str: 时间分析文本。 + """ await self._ensure_prompts_loaded() filtered = [ m @@ -653,11 +713,11 @@ class MaiSakaLLMService: print("=" * 60 + "\n") try: - response, _ = await self._llm_utils.generate_response_async( + generation_result = await self._llm_utils.generate_response( prompt=prompt, - temperature=0.3, - max_tokens=512, + options=LLMGenerationOptions(temperature=0.3, max_tokens=512), ) + response = generation_result.response return response except Exception as e: @@ -666,10 +726,15 @@ class MaiSakaLLMService: # ──────── 回复生成(使用 replyer 模型) ──────── - async def generate_reply(self, reason: str, chat_history: list[MaiMessage]) -> str: - """ - 生成回复 - 使用 replyer 模型 - 可供 Replyer 类直接调用 + async def generate_reply(self, reason: str, chat_history: List[MaiMessage]) -> str: + """生成最终回复文本。 + + Args: + reason: 当前轮次的内部想法或回复理由。 + chat_history: 当前对话历史。 + + Returns: + str: 最终回复文本。 """ await self._ensure_prompts_loaded() from datetime import datetime @@ -704,17 +769,12 @@ class MaiSakaLLMService: print("=" * 60 + "\n") try: - response, _ = await self._llm_replyer.generate_response_async( + generation_result = await self._llm_replyer.generate_response( prompt=messages, - temperature=0.8, - max_tokens=512, + options=LLMGenerationOptions(temperature=0.8, max_tokens=512), ) + response = generation_result.response return response.strip() if response else "..." except Exception as e: logger.error(f"回复生成 LLM 调用出错: {e}") return "..." - - - - - diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index b984c66d..3d18187d 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -16,8 +16,9 @@ from json_repair import repair_json from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.config.config import model_config, global_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.services import message_service as message_api from src.chat.utils.utils import is_bot_self from src.person_info.person_info import Person @@ -88,8 +89,8 @@ class ChatHistorySummarizer: # 注意:批次加载需要异步查询消息,所以在 start() 中调用 # LLM请求器,用于压缩聊天内容 - self.summarizer_llm = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer" + self.summarizer_llm = LLMServiceClient( + task_name="utils", request_type="chat_history_summarizer" ) # 后台循环相关 @@ -656,10 +657,11 @@ class ChatHistorySummarizer: prompt = await prompt_manager.render_prompt(prompt_template) try: - response, _ = await self.summarizer_llm.generate_response_async( + generation_result = await self.summarizer_llm.generate_response( prompt=prompt, - temperature=0.3, + options=LLMGenerationOptions(temperature=0.3), ) + response = generation_result.response logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") @@ -812,7 +814,8 @@ class ChatHistorySummarizer: prompt = await prompt_manager.render_prompt(prompt_template) try: - response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt) + generation_result = await self.summarizer_llm.generate_response(prompt=prompt) + response = generation_result.response # 解析JSON响应 json_str = response.strip() diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 982db166..41851408 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -5,7 +5,7 @@ import asyncio from datetime import datetime from typing import List, Dict, Any, Optional, Tuple, Callable from src.common.logger import get_logger -from src.config.config import global_config, model_config +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.services import llm_service as llm_api from sqlmodel import select, col @@ -269,18 +269,18 @@ async def _react_agent_solve_question( return messages message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues] - ( - success, - response, - reasoning_content, - model_name, - tool_calls, - ) = await llm_api.generate_with_model_with_tools_by_message_factory( - message_factory_fn, # type: ignore[arg-type] - model_config=model_config.model_task_config.tool_use, - tool_options=tool_definitions, - request_type="memory.react", + generation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="tool_use", + request_type="memory.react", + message_factory=message_factory_fn, # type: ignore[arg-type] + tool_options=tool_definitions, + ) ) + success = generation_result.success + response = generation_result.completion.response + reasoning_content = generation_result.completion.reasoning + tool_calls = generation_result.completion.tool_calls # logger.info( # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" @@ -679,18 +679,16 @@ async def _react_agent_solve_question( evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template) - ( - eval_success, - eval_response, - eval_reasoning_content, - eval_model_name, - eval_tool_calls, - ) = await llm_api.generate_with_model_with_tools( - evaluation_prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[], # 最终评估阶段不提供工具 - request_type="memory.react.final", + evaluation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="tool_use", + request_type="memory.react.final", + prompt=evaluation_prompt, + tool_options=[], + ) ) + eval_success = evaluation_result.success + eval_response = evaluation_result.completion.response if not eval_success: logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}") diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py index 1e1fa62b..f2dd1f0d 100644 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -1,11 +1,12 @@ -""" -工具注册系统 -提供统一的工具注册和管理接口 +"""工具注册系统。 + +提供统一的工具注册和管理接口。 """ -from typing import List, Dict, Any, Optional, Callable, Awaitable +from typing import Any, Awaitable, Callable, Dict, List, Optional + from src.common.logger import get_logger -from src.llm_models.payload_content.tool_option import ToolParamType +from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option logger = get_logger("memory_retrieval_tools") @@ -14,16 +15,19 @@ class MemoryRetrievalTool: """记忆检索工具基类""" def __init__( - self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] - ): - """ - 初始化工具 + self, + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]], + ) -> None: + """初始化工具。 Args: - name: 工具名称 - description: 工具描述 - parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}] - execute_func: 执行函数,必须是异步函数 + name: 工具名称。 + description: 工具描述。 + parameters: 参数定义列表。 + execute_func: 执行函数,必须是异步函数。 """ self.name = name self.description = description @@ -44,20 +48,17 @@ class MemoryRetrievalTool: params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" - async def execute(self, **kwargs) -> str: - """执行工具""" + async def execute(self, **kwargs: Any) -> str: + """执行工具。""" return await self.execute_func(**kwargs) def get_tool_definition(self) -> Dict[str, Any]: - """获取工具定义,用于LLM function calling + """获取规范化的工具定义。 Returns: - Dict[str, Any]: 工具定义字典,格式与BaseTool一致 - 格式: {"name": str, "description": str, "parameters": List[Tuple]} + Dict[str, Any]: 统一工具定义字典。 """ - # 转换参数格式为元组列表,格式与BaseTool一致 - # 格式: [("param_name", ToolParamType, "description", required, enum_values)] - param_tuples = [] + legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] for param in self.parameters: param_name = param.get("name", "") @@ -77,20 +78,27 @@ class MemoryRetrievalTool: } param_type = type_mapping.get(param_type_str, ToolParamType.STRING) - # 构建参数元组 - param_tuple = (param_name, param_type, param_desc, is_required, enum_values) - param_tuples.append(param_tuple) + legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values)) - # 构建工具定义,格式与BaseTool.get_tool_definition()一致 - tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples} - - return tool_def + normalized_option = normalize_tool_option( + { + "name": self.name, + "description": self.description, + "parameters": legacy_parameters, + } + ) + return { + "name": normalized_option.name, + "description": normalized_option.description, + "parameters_schema": normalized_option.parameters_schema, + } class MemoryRetrievalToolRegistry: """工具注册器""" - def __init__(self): + def __init__(self) -> None: + """初始化工具注册器。""" self.tools: Dict[str, MemoryRetrievalTool] = {} def register_tool(self, tool: MemoryRetrievalTool) -> None: @@ -137,15 +145,18 @@ _tool_registry = MemoryRetrievalToolRegistry() def register_memory_retrieval_tool( - name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]], ) -> None: - """注册记忆检索工具的便捷函数 + """注册记忆检索工具的便捷函数。 Args: - name: 工具名称 - description: 工具描述 - parameters: 参数定义列表 - execute_func: 执行函数 + name: 工具名称。 + description: 工具描述。 + parameters: 参数定义列表。 + execute_func: 执行函数。 """ tool = MemoryRetrievalTool(name, description, parameters, execute_func) _tool_registry.register_tool(tool) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 15ef0049..cf8143c6 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -17,14 +17,14 @@ from src.common.data_models.person_info_data_model import dump_group_cardname_re from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient logger = get_logger("person_info") -relation_selection_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="relation_selection" +relation_selection_model = LLMServiceClient( + task_name="tool_use", request_type="relation_selection" ) @@ -578,7 +578,8 @@ class Person: <分类1><分类2><分类3>...... 如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) + generation_result = await relation_selection_model.generate_response(prompt) + response = generation_result.response # print(prompt) # print(response) category_list = extract_categories_from_response(response) @@ -600,7 +601,8 @@ class Person: 例如: <分类1><分类2><分类3>...... 如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) + generation_result = await relation_selection_model.generate_response(prompt) + response = generation_result.response # print(prompt) # print(response) category_list = extract_categories_from_response(response) @@ -634,7 +636,9 @@ class Person: class PersonInfoManager: def __init__(self): self.person_name_list = {} - self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") + self.qv_name_llm = LLMServiceClient( + task_name="utils", request_type="relation.qv_name" + ) try: with get_db_session() as _: pass @@ -737,7 +741,8 @@ class PersonInfoManager: "nickname": "昵称", "reason": "理由" }""" - response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt) + generation_result = await self.qv_name_llm.generate_response(qv_name_prompt) + response = generation_result.response # logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response) diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index 9bb1755b..843b8ce0 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -1,33 +1,80 @@ -from typing import Any, Dict +from typing import Any, Dict, List from src.common.logger import get_logger from src.config.config import global_config -from src.llm_models.payload_content.tool_option import ToolCall logger = get_logger("plugin_runtime.integration") def _get_nested_config_value(source: Any, key: str, default: Any = None) -> Any: + """从嵌套对象或字典中读取配置值。 + + Args: + source: 配置对象或字典。 + key: 以点号分隔的路径。 + default: 未命中时返回的默认值。 + + Returns: + Any: 命中的值;读取失败时返回默认值。 + """ current = source try: for part in key.split("."): if isinstance(current, dict) and part in current: current = current[part] - elif hasattr(current, part): + continue + if hasattr(current, part): current = getattr(current, part) - else: - raise KeyError(part) + continue + raise KeyError(part) return current except Exception: return default +def _normalize_prompt_arg(prompt: Any) -> str | List[Dict[str, Any]]: + """校验并规范化插件传入的提示参数。 + + Args: + prompt: 原始提示参数。 + + Returns: + str | List[Dict[str, Any]]: 规范化后的提示输入。 + + Raises: + ValueError: 提示参数缺失或结构不受支持时抛出。 + """ + if isinstance(prompt, str): + if not prompt.strip(): + raise ValueError("缺少必要参数 prompt") + return prompt + if isinstance(prompt, list) and prompt: + for index, prompt_message in enumerate(prompt, start=1): + if not isinstance(prompt_message, dict): + raise ValueError(f"prompt 第 {index} 项必须为字典") + return prompt + raise ValueError("缺少必要参数 prompt") + + class RuntimeCoreCapabilityMixin: + """插件运行时的核心能力混入。""" + async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送文本消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - text: str = args.get("text", "") - stream_id: str = args.get("stream_id", "") + text = str(args.get("text", "")) + stream_id = str(args.get("stream_id", "")) if not text or not stream_id: return {"success": False, "error": "缺少必要参数 text 或 stream_id"} @@ -35,20 +82,31 @@ class RuntimeCoreCapabilityMixin: result = await send_api.text_to_stream( text=text, stream_id=stream_id, - typing=args.get("typing", False), - set_reply=args.get("set_reply", False), - storage_message=args.get("storage_message", True), + typing=bool(args.get("typing", False)), + set_reply=bool(args.get("set_reply", False)), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.text] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送表情图片。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - emoji_base64: str = args.get("emoji_base64", "") - stream_id: str = args.get("stream_id", "") + emoji_base64 = str(args.get("emoji_base64", "")) + stream_id = str(args.get("stream_id", "")) if not emoji_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} @@ -56,18 +114,29 @@ class RuntimeCoreCapabilityMixin: result = await send_api.emoji_to_stream( emoji_base64=emoji_base64, stream_id=stream_id, - storage_message=args.get("storage_message", True), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.emoji] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送图片。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - image_base64: str = args.get("image_base64", "") - stream_id: str = args.get("stream_id", "") + image_base64 = str(args.get("image_base64", "")) + stream_id = str(args.get("stream_id", "")) if not image_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} @@ -75,18 +144,29 @@ class RuntimeCoreCapabilityMixin: result = await send_api.image_to_stream( image_base64=image_base64, stream_id=stream_id, - storage_message=args.get("storage_message", True), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.image] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送命令消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - command = args.get("command", "") - stream_id: str = args.get("stream_id", "") + command = str(args.get("command", "")) + stream_id = str(args.get("stream_id", "")) if not command or not stream_id: return {"success": False, "error": "缺少必要参数 command 或 stream_id"} @@ -95,22 +175,33 @@ class RuntimeCoreCapabilityMixin: message_type="command", content=command, stream_id=stream_id, - storage_message=args.get("storage_message", True), - display_message=args.get("display_message", ""), + storage_message=bool(args.get("storage_message", True)), + display_message=str(args.get("display_message", "")), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.command] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送自定义消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - message_type: str = args.get("message_type", "") or args.get("custom_type", "") + message_type = str(args.get("message_type", "") or args.get("custom_type", "")) content = args.get("content") if content is None: content = args.get("data", "") - stream_id: str = args.get("stream_id", "") + stream_id = str(args.get("stream_id", "")) if not message_type or not stream_id: return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} @@ -119,114 +210,116 @@ class RuntimeCoreCapabilityMixin: message_type=message_type, content=content, stream_id=stream_id, - display_message=args.get("display_message", ""), - typing=args.get("typing", False), - storage_message=args.get("storage_message", True), + display_message=str(args.get("display_message", "")), + typing=bool(args.get("typing", False)), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.custom] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """执行无工具的 LLM 生成能力。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 标准化后的 LLM 响应结构。 + """ + del capability from src.services import llm_service as llm_api - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - model_name: str = args.get("model", "") or args.get("model_name", "") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") - try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, + prompt = _normalize_prompt_arg(args.get("prompt")) + task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", ""))) + result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name=task_name, + request_type=f"plugin.{plugin_id}", + prompt=prompt, + temperature=args.get("temperature"), + max_tokens=args.get("max_tokens"), + ) ) - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - } - except Exception as e: - logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + return result.to_capability_payload() + except Exception as exc: + logger.error(f"[cap.llm.generate] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """执行带工具的 LLM 生成能力。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 标准化后的 LLM 响应结构。 + """ + del capability from src.services import llm_service as llm_api - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - model_name: str = args.get("model", "") or args.get("model_name", "") tool_options = args.get("tools") or args.get("tool_options") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") + if tool_options is not None and not isinstance(tool_options, list): + return {"success": False, "error": "tools 必须为列表"} try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools( - prompt=prompt, - model_config=model_config, - tool_options=tool_options, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, + prompt = _normalize_prompt_arg(args.get("prompt")) + task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", ""))) + result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name=task_name, + request_type=f"plugin.{plugin_id}", + prompt=prompt, + tool_options=tool_options, + temperature=args.get("temperature"), + max_tokens=args.get("max_tokens"), + ) ) - serialized_tool_calls = None - if tool_calls: - serialized_tool_calls = [ - { - "id": tool_call.call_id, - "function": {"name": tool_call.func_name, "arguments": tool_call.args or {}}, - } - for tool_call in tool_calls - if isinstance(tool_call, ToolCall) - ] - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - "tool_calls": serialized_tool_calls, - } - except Exception as e: - logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + return result.to_capability_payload() + except Exception as exc: + logger.error(f"[cap.llm.generate_with_tools] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """获取当前宿主可用的模型任务列表。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 可用模型列表。 + """ + del plugin_id, capability, args from src.services import llm_service as llm_api try: models = llm_api.get_available_models() return {"success": True, "models": list(models.keys())} - except Exception as e: - logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.llm.get_available_models] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - key: str = args.get("key", "") + """读取宿主全局配置中的单个字段。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del plugin_id, capability + key = str(args.get("key", "")) default = args.get("default") if not key: return {"success": False, "value": None, "error": "缺少必要参数 key"} @@ -234,37 +327,57 @@ class RuntimeCoreCapabilityMixin: try: value = _get_nested_config_value(global_config, key, default) return {"success": True, "value": value} - except Exception as e: - return {"success": False, "value": None, "error": str(e)} + except Exception as exc: + return {"success": False, "value": None, "error": str(exc)} async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """读取指定插件的配置。 + + Args: + plugin_id: 当前插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del capability from src.plugin_runtime.component_query import component_query_service - plugin_name: str = args.get("plugin_name", plugin_id) - key: str = args.get("key", "") + plugin_name = str(args.get("plugin_name", plugin_id)) + key = str(args.get("key", "")) default = args.get("default") try: config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} - if key: value = _get_nested_config_value(config, key, default) return {"success": True, "value": value} - return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": default, "error": str(e)} + except Exception as exc: + return {"success": False, "value": default, "error": str(exc)} async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """读取指定插件的全部配置。 + + Args: + plugin_id: 当前插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del capability from src.plugin_runtime.component_query import component_query_service - plugin_name: str = args.get("plugin_name", plugin_id) + plugin_name = str(args.get("plugin_name", plugin_id)) try: config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": True, "value": {}} return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": {}, "error": str(e)} + except Exception as exc: + return {"success": False, "value": {}, "error": str(exc)} diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py index 7d23d202..5a6c39f5 100644 --- a/src/plugin_runtime/component_query.py +++ b/src/plugin_runtime/component_query.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl from src.common.logger import get_logger from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo -from src.llm_models.payload_content.tool_option import ToolParamType +from src.llm_models.payload_content.tool_option import normalize_tool_option if TYPE_CHECKING: from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry @@ -28,13 +28,6 @@ _HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = { ComponentType.COMMAND: "COMMAND", ComponentType.TOOL: "TOOL", } -_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "float": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "bool": ToolParamType.BOOLEAN, -} class ComponentQueryService: @@ -171,11 +164,9 @@ class ComponentQueryService: return ActionInfo( name=entry.name, - component_type=ComponentType.ACTION, description=str(metadata.get("description", "") or ""), enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=metadata, action_parameters=action_parameters, action_require=action_require, associated_types=associated_types, @@ -202,72 +193,48 @@ class ComponentQueryService: metadata = dict(entry.metadata) return CommandInfo( name=entry.name, - component_type=ComponentType.COMMAND, description=str(metadata.get("description", "") or ""), enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=metadata, - command_pattern=str(metadata.get("command_pattern", "") or ""), ) @staticmethod - def _coerce_tool_param_type(raw_value: Any) -> ToolParamType: - """规范化工具参数类型。 - - Args: - raw_value: 原始工具参数类型值。 - - Returns: - ToolParamType: 规范化后的工具参数类型。 - """ - - normalized_value = str(raw_value or "").strip().lower() - return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING) - - @staticmethod - def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]: - """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。 + def _build_tool_definition(entry: "ToolEntry") -> dict[str, Any]: + """将运行时 Tool 条目转换为原始工具定义字典。 Args: entry: 插件运行时中的 Tool 条目。 Returns: - list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。 + dict[str, Any]: 可交给 `normalize_tool_option()` 的原始工具定义。 """ + raw_definition: dict[str, Any] = { + "name": entry.name, + "description": entry.description, + } + if isinstance(entry.parameters_raw, dict) and entry.parameters_raw: + raw_definition["parameters_schema"] = entry.parameters_raw + return raw_definition + if isinstance(entry.parameters, list) and entry.parameters: + raw_definition["parameters"] = entry.parameters + return raw_definition + if isinstance(entry.parameters_raw, list) and entry.parameters_raw: + raw_definition["parameters"] = entry.parameters_raw + return raw_definition + return raw_definition - structured_parameters = entry.parameters if isinstance(entry.parameters, list) else [] - if not structured_parameters and isinstance(entry.parameters_raw, dict): - structured_parameters = [ - {"name": key, **value} - for key, value in entry.parameters_raw.items() - if isinstance(value, dict) - ] + @staticmethod + def _build_tool_parameters_schema(entry: "ToolEntry") -> dict[str, Any] | None: + """将运行时 Tool 条目转换为对象级参数 Schema。 - normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] - for parameter in structured_parameters: - if not isinstance(parameter, dict): - continue + Args: + entry: 插件运行时中的 Tool 条目。 - parameter_name = str(parameter.get("name", "") or "").strip() - if not parameter_name: - continue - - enum_values = parameter.get("enum") - normalized_enum_values = ( - [str(item) for item in enum_values if item is not None] - if isinstance(enum_values, list) - else None - ) - normalized_parameters.append( - ( - parameter_name, - ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")), - str(parameter.get("description", "") or ""), - bool(parameter.get("required", True)), - normalized_enum_values, - ) - ) - return normalized_parameters + Returns: + dict[str, Any] | None: 规范化后的对象级参数 Schema。 + """ + normalized_option = normalize_tool_option(ComponentQueryService._build_tool_definition(entry)) + return normalized_option.parameters_schema @staticmethod def _build_tool_info(entry: "ToolEntry") -> ToolInfo: @@ -282,13 +249,10 @@ class ComponentQueryService: return ToolInfo( name=entry.name, - component_type=ComponentType.TOOL, description=entry.description, enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=dict(entry.metadata), - tool_parameters=ComponentQueryService._build_tool_parameters(entry), - tool_description=entry.description, + parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry), ) @staticmethod diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 8f995e2a..bc1fbe6f 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -91,7 +91,7 @@ class ToolEntry(ComponentEntry): def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: self.description: str = metadata.get("description", "") self.parameters: List[Dict[str, Any]] = metadata.get("parameters", []) - self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", []) + self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {}) super().__init__(name, component_type, plugin_id, metadata) diff --git a/src/services/llm_service.py b/src/services/llm_service.py index 2927b5c1..de116507 100644 --- a/src/services/llm_service.py +++ b/src/services/llm_service.py @@ -1,191 +1,492 @@ -"""LLM 服务模块 +"""LLM 服务层。 -提供与 LLM 模型交互的核心功能。 +该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到 +`src.llm_models` 中的底层请求调度器。 """ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple +import json + +from src.common.data_models.llm_service_data_models import ( + LLMAudioTranscriptionResult, + LLMEmbeddingResult, + LLMGenerationOptions, + LLMImageOptions, + LLMResponseResult, + LLMServiceRequest, + LLMServiceResult, + MessageFactory, + PromptInput, + PromptMessage, +) from src.common.logger import get_logger from src.config.config import config_manager from src.config.model_configs import TaskConfig from src.llm_models.model_client.base_client import BaseClient -from src.llm_models.payload_content.message import Message +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType from src.llm_models.payload_content.tool_option import ToolCall -from src.llm_models.utils_model import LLMRequest +from src.llm_models.utils_model import LLMOrchestrator logger = get_logger("llm_service") +class LLMServiceClient: + """面向上层模块的 LLM 服务对象式门面。 -async def _generate_response( - model_config: TaskConfig, - request_type: str, - prompt: Optional[str] = None, - message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, - tool_options: Optional[List[Dict[str, Any]]] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[str, str, str, List[ToolCall] | None]: - llm_request = LLMRequest(model_set=model_config, request_type=request_type) + 当前推荐优先使用以下正式接口: + - `generate_response` + - `generate_response_with_messages` + - `generate_response_for_image` + - `transcribe_audio` + - `embed_text` + """ - if message_factory is not None: - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async( - message_factory=message_factory, - tools=tool_options, - temperature=temperature, - max_tokens=max_tokens, + def __init__(self, task_name: str, request_type: str = "") -> None: + """初始化 LLM 服务门面。 + + Args: + task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 + request_type: 当前请求的业务类型标识。 + """ + self.task_name = resolve_task_name(task_name) + self.request_type = request_type + self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type) + + @staticmethod + def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions: + """规范化文本生成选项。 + + Args: + options: 原始生成选项。 + + Returns: + LLMGenerationOptions: 可直接用于执行请求的完整选项对象。 + """ + if options is None: + return LLMGenerationOptions() + return options + + @staticmethod + def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions: + """规范化图像理解选项。 + + Args: + options: 原始图像理解选项。 + + Returns: + LLMImageOptions: 可直接用于执行请求的完整选项对象。 + """ + if options is None: + return LLMImageOptions() + return options + + async def generate_response( + self, + prompt: str, + options: LLMGenerationOptions | None = None, + ) -> LLMResponseResult: + """生成单轮文本响应。 + + Args: + prompt: 文本提示词。 + options: 文本生成选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_generation_options(options) + return await self._orchestrator.generate_response_async( + prompt=prompt, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + tools=active_options.tool_options, + response_format=active_options.response_format, + raise_when_empty=active_options.raise_when_empty, + interrupt_flag=active_options.interrupt_flag, ) - return response, reasoning_content, model_name, tool_call - if prompt is None: - raise ValueError("prompt 与 message_factory 不能同时为空") + async def generate_response_with_messages( + self, + message_factory: MessageFactory, + options: LLMGenerationOptions | None = None, + ) -> LLMResponseResult: + """基于消息工厂生成响应。 - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( - prompt, - tools=tool_options, - temperature=temperature, - max_tokens=max_tokens, - ) - return response, reasoning_content, model_name, tool_call + Args: + message_factory: 消息工厂,会根据客户端能力构建消息列表。 + options: 文本生成选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_generation_options(options) + return await self._orchestrator.generate_response_with_message_async( + message_factory=message_factory, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + tools=active_options.tool_options, + response_format=active_options.response_format, + raise_when_empty=active_options.raise_when_empty, + interrupt_flag=active_options.interrupt_flag, + ) + + async def generate_response_for_image( + self, + prompt: str, + image_base64: str, + image_format: str, + options: LLMImageOptions | None = None, + ) -> LLMResponseResult: + """为图像内容生成响应。 + + Args: + prompt: 文本提示词。 + image_base64: 图像的 Base64 编码字符串。 + image_format: 图像格式,例如 ``png``、``jpeg``。 + options: 图像理解选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_image_options(options) + return await self._orchestrator.generate_response_for_image( + prompt=prompt, + image_base64=image_base64, + image_format=image_format, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + interrupt_flag=active_options.interrupt_flag, + ) + + async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult: + """执行音频转写请求。 + + Args: + voice_base64: 音频的 Base64 编码字符串。 + + Returns: + LLMAudioTranscriptionResult: 音频转写结果对象。 + """ + return await self._orchestrator.generate_response_for_voice(voice_base64) + + async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult: + """生成文本嵌入向量。 + + Args: + embedding_input: 待编码的文本。 + + Returns: + LLMEmbeddingResult: 向量生成结果对象。 + """ + return await self._orchestrator.get_embedding(embedding_input) def get_available_models() -> Dict[str, TaskConfig]: - """获取所有可用的模型配置 + """获取所有可用模型配置。 Returns: - Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置 + Dict[str, TaskConfig]: 以模型任务名为键的配置映射。 """ try: models = config_manager.get_model_config().model_task_config - attrs = dir(models) - rets: Dict[str, TaskConfig] = {} - for attr in attrs: - if not attr.startswith("__"): - try: - value = getattr(models, attr) - if not callable(value) and isinstance(value, TaskConfig): - rets[attr] = value - except Exception as e: - logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}") - continue - return rets - - except Exception as e: - logger.error(f"[LLMService] 获取可用模型失败: {e}") + available_models: Dict[str, TaskConfig] = {} + for attr_name in dir(models): + if attr_name.startswith("__"): + continue + try: + attr_value = getattr(models, attr_name) + except Exception as exc: + logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}") + continue + if not callable(attr_value) and isinstance(attr_value, TaskConfig): + available_models[attr_name] = attr_value + return available_models + except Exception as exc: + logger.error(f"[LLMService] 获取可用模型失败: {exc}") return {} -async def generate_with_model( - prompt: str, - model_config: TaskConfig, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str]: - """使用指定模型生成内容 +def resolve_task_name(task_name: str = "") -> str: + """根据名称解析任务配置名。 Args: - prompt: 提示词 - model_config: 模型配置(从 get_available_models 获取的模型配置) - request_type: 请求类型标识 + task_name: 目标任务配置名;为空时返回首个可用任务名。 Returns: - Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + str: 解析得到的任务配置名。 + + Raises: + RuntimeError: 当前没有任何可用模型配置。 + ValueError: 指定名称不存在时抛出。 """ - try: - logger.debug(f"[LLMService] 完整提示词: {prompt}") - response, reasoning_content, model_name, _ = await _generate_response( - model_config=model_config, - request_type=request_type, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - ) - return True, response, reasoning_content, model_name - - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "" + models = get_available_models() + if not models: + raise RuntimeError("没有可用的模型配置") + normalized_task_name = task_name.strip() + if not normalized_task_name: + return next(iter(models.keys())) + if normalized_task_name not in models: + raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置") + return normalized_task_name -async def generate_with_model_with_tools( - prompt: str, - model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: - """使用指定模型和工具生成内容 +def _normalize_role(role_name: str) -> RoleType: + """将原始角色字符串转换为内部角色枚举。 Args: - prompt: 提示词 - model_config: 模型配置(从 get_available_models 获取的模型配置) - tool_options: 工具选项列表 - request_type: 请求类型标识 - temperature: 温度参数 - max_tokens: 最大token数 + role_name: 原始角色名称。 Returns: - Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表) + RoleType: 规范化后的角色枚举。 + + Raises: + ValueError: 角色类型不受支持时抛出。 """ + normalized_role_name = role_name.strip().lower() try: - model_name_list = model_config.model_list - logger.info(f"使用模型{model_name_list}生成内容") - logger.debug(f"完整提示词: {prompt}") - - response, reasoning_content, model_name, tool_call = await _generate_response( - model_config=model_config, - request_type=request_type, - prompt=prompt, - tool_options=tool_options, - temperature=temperature, - max_tokens=max_tokens, - ) - return True, response, reasoning_content, model_name, tool_call - - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "", None + return RoleType(normalized_role_name) + except ValueError as exc: + raise ValueError(f"不支持的消息角色: {role_name}") from exc -async def generate_with_model_with_tools_by_message_factory( - message_factory: Callable[[BaseClient], List[Message]], - model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: - """使用指定模型和工具生成内容(通过消息工厂构建消息列表) +def _parse_data_url_image(image_url: str) -> Tuple[str, str]: + """解析 Data URL 形式的图片内容。 Args: - message_factory: 消息工厂函数 - model_config: 模型配置 - tool_options: 工具选项列表 - request_type: 请求类型标识 - temperature: 温度参数 - max_tokens: 最大token数 + image_url: 图片 URL。 Returns: - Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表) + Tuple[str, str]: `(图片格式, Base64 数据)`。 + + Raises: + ValueError: 输入不是受支持的 Data URL 时抛出。 """ - try: - model_name_list = model_config.model_list - logger.info(f"使用模型 {model_name_list} 生成内容") + if not image_url.startswith("data:image/") or ";base64," not in image_url: + raise ValueError("仅支持 Data URL 形式的图片输入") + prefix, image_base64 = image_url.split(";base64,", maxsplit=1) + image_format = prefix.removeprefix("data:image/") + if not image_format or not image_base64: + raise ValueError("图片 Data URL 不完整") + return image_format, image_base64 - response, reasoning_content, model_name, tool_call = await _generate_response( - model_config=model_config, - request_type=request_type, - message_factory=message_factory, - tool_options=tool_options, - temperature=temperature, - max_tokens=max_tokens, + +def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None: + """将原始消息内容追加到内部消息构建器。 + + Args: + message_builder: 目标消息构建器。 + content: 原始消息内容。 + + Raises: + ValueError: 消息内容结构不受支持时抛出。 + """ + if isinstance(content, str): + message_builder.add_text_content(content) + return + + content_items: List[Any] + if isinstance(content, list): + content_items = content + elif isinstance(content, dict): + content_items = [content] + else: + raise ValueError("消息内容必须为字符串、字典或列表") + + for content_item in content_items: + if isinstance(content_item, str): + message_builder.add_text_content(content_item) + continue + if not isinstance(content_item, dict): + raise ValueError("消息内容列表中仅支持字符串或字典片段") + + part_type = str(content_item.get("type", "text")).strip().lower() + if part_type == "text": + text_content = content_item.get("text") + if not isinstance(text_content, str): + raise ValueError("文本片段缺少 `text` 字段") + message_builder.add_text_content(text_content) + continue + + if part_type in {"image", "image_url", "input_image"}: + image_url = content_item.get("image_url") + if isinstance(image_url, dict): + image_url = image_url.get("url") + if isinstance(image_url, str): + image_format, image_base64 = _parse_data_url_image(image_url) + message_builder.add_image_content(image_format=image_format, image_base64=image_base64) + continue + + image_format = content_item.get("image_format") + image_base64 = content_item.get("image_base64") + if isinstance(image_format, str) and isinstance(image_base64, str): + message_builder.add_image_content(image_format=image_format, image_base64=image_base64) + continue + raise ValueError("图片片段缺少可识别的图片数据") + + raise ValueError(f"不支持的消息片段类型: {part_type}") + + +def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None: + """将原始工具参数规范化为字典。 + + Args: + arguments: 原始工具参数。 + + Returns: + Dict[str, Any] | None: 规范化后的参数字典。 + """ + if arguments is None: + return None + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + stripped_arguments = arguments.strip() + if not stripped_arguments: + return {} + try: + parsed_arguments = json.loads(stripped_arguments) + except json.JSONDecodeError: + return {"raw_arguments": arguments} + if isinstance(parsed_arguments, dict): + return parsed_arguments + return {"value": parsed_arguments} + return {"value": arguments} + + +def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None: + """从原始消息中提取工具调用列表。 + + Args: + raw_tool_calls: 原始工具调用结构。 + + Returns: + List[ToolCall] | None: 规范化后的工具调用列表。 + + Raises: + ValueError: 工具调用结构缺失必要字段时抛出。 + """ + if raw_tool_calls is None: + return None + if not isinstance(raw_tool_calls, list): + raise ValueError("`tool_calls` 必须为列表") + + tool_calls: List[ToolCall] = [] + for raw_tool_call in raw_tool_calls: + if not isinstance(raw_tool_call, dict): + raise ValueError("工具调用项必须为字典") + + function_info = raw_tool_call.get("function") + if isinstance(function_info, dict): + func_name = function_info.get("name") + arguments = function_info.get("arguments") + else: + func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name") + arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args") + + call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id") + if not isinstance(call_id, str) or not isinstance(func_name, str): + raise ValueError("工具调用缺少 `id` 或函数名称") + + tool_calls.append( + ToolCall( + call_id=call_id, + func_name=func_name, + args=_normalize_tool_arguments(arguments), + ) ) - return True, response, reasoning_content, model_name, tool_call - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "", None + return tool_calls or None + + +def _build_message_from_dict(raw_message: PromptMessage) -> Message: + """将原始消息字典转换为内部消息对象。 + + Args: + raw_message: 原始消息字典。 + + Returns: + Message: 规范化后的消息对象。 + + Raises: + ValueError: 原始消息结构不合法时抛出。 + """ + raw_role = raw_message.get("role") + if not isinstance(raw_role, str): + raise ValueError("消息缺少字符串类型的 `role` 字段") + + role = _normalize_role(raw_role) + message_builder = MessageBuilder().set_role(role) + + tool_calls = _build_tool_calls(raw_message.get("tool_calls")) + if tool_calls is not None: + message_builder.set_tool_calls(tool_calls) + + tool_call_id = raw_message.get("tool_call_id") + if isinstance(tool_call_id, str) and role == RoleType.Tool: + message_builder.set_tool_call_id(tool_call_id) + + if "content" in raw_message and raw_message["content"] not in (None, "", []): + _append_content_parts(message_builder, raw_message["content"]) + + return message_builder.build() + + +def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory: + """将统一提示输入转换为消息工厂。 + + Args: + prompt: 原始提示输入。 + + Returns: + MessageFactory: 惰性构建消息列表的工厂函数。 + """ + if isinstance(prompt, str): + def build_messages(_: BaseClient) -> List[Message]: + """构建单条用户消息。""" + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + return [message_builder.build()] + + return build_messages + + def build_messages(_: BaseClient) -> List[Message]: + """构建多消息对话输入。""" + return [_build_message_from_dict(raw_message) for raw_message in prompt] + + return build_messages + + +async def generate(request: LLMServiceRequest) -> LLMServiceResult: + """执行统一的 LLM 服务请求。 + + Args: + request: 服务层统一请求对象。 + + Returns: + LLMServiceResult: 统一响应对象;失败时 `success=False`。 + """ + llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type) + if request.message_factory is not None: + active_message_factory = request.message_factory + else: + prompt = request.prompt + if prompt is None: + raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个") + active_message_factory = _build_prompt_message_factory(prompt) + + try: + generation_result = await llm_client.generate_response_with_messages( + message_factory=active_message_factory, + options=LLMGenerationOptions( + temperature=request.temperature, + max_tokens=request.max_tokens, + tool_options=request.tool_options, + response_format=request.response_format, + interrupt_flag=request.interrupt_flag, + ), + ) + return LLMServiceResult.from_response_result(generation_result) + except Exception as exc: + error_message = f"生成内容时出错: {exc}" + logger.error(f"[LLMService] {error_message}") + return LLMServiceResult.from_error(error_message, str(exc)) diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 2f67aca5..fad701ba 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -13,6 +13,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query from src.common.logger import get_logger from src.config.config import CONFIG_DIR +from src.config.model_configs import APIProvider +from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url from src.webui.dependencies import require_auth from src.webui.utils.network_security import validate_public_url @@ -35,8 +37,8 @@ MODEL_FETCHER_CONFIG = { def _normalize_url(url: str) -> str: - """规范化 URL(去掉尾部斜杠)""" - return url.rstrip("/") if url else "" + """规范化 URL(去掉尾部斜杠)。""" + return normalize_openai_base_url(url) if url else "" def _parse_openai_response(data: Dict) -> List[Dict]: @@ -89,19 +91,30 @@ async def _fetch_models_from_provider( endpoint: str, parser: str, client_type: str = "openai", + auth_type: str = "bearer", + auth_header_name: str = "Authorization", + auth_header_prefix: str = "Bearer", + auth_query_name: str = "api_key", + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, str]] = None, ) -> List[Dict]: - """ - 从提供商 API 获取模型列表 + """从提供商 API 获取模型列表。 Args: - base_url: 提供商的基础 URL - api_key: API 密钥 - endpoint: 获取模型列表的端点 - parser: 响应解析器类型 ('openai' | 'gemini') - client_type: 客户端类型 ('openai' | 'gemini') + base_url: 提供商的基础 URL。 + api_key: API 密钥。 + endpoint: 获取模型列表的端点。 + parser: 响应解析器类型。 + client_type: 客户端类型。 + auth_type: OpenAI 兼容接口的鉴权方式。 + auth_header_name: Header 鉴权时使用的请求头名称。 + auth_header_prefix: Header 鉴权时使用的请求头前缀。 + auth_query_name: Query 鉴权时使用的查询参数名称。 + default_headers: 默认附带的请求头。 + default_query: 默认附带的查询参数。 Returns: - 模型列表 + List[Dict]: 解析后的模型列表。 """ try: base_url = validate_public_url(_normalize_url(base_url)) @@ -118,8 +131,21 @@ async def _fetch_models_from_provider( # Gemini 使用 URL 参数传递 API Key params["key"] = api_key else: - # OpenAI 兼容格式使用 Authorization 头 - headers["Authorization"] = f"Bearer {api_key}" + provider = APIProvider( + name="webui-openai-compatible-fetcher", + base_url=base_url, + api_key=api_key, + client_type="openai", + auth_type=auth_type, + auth_header_name=auth_header_name, + auth_header_prefix=auth_header_prefix, + auth_query_name=auth_query_name, + default_headers=default_headers or {}, + default_query=default_query or {}, + ) + client_config = build_openai_compatible_client_config(provider) + headers.update(client_config.default_headers) + params.update(client_config.default_query) try: async with httpx.AsyncClient(timeout=30.0) as client: @@ -186,10 +212,9 @@ async def get_provider_models( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), ): - """ - 获取指定提供商的可用模型列表 + """获取指定提供商的可用模型列表。 - 通过提供商名称查找配置,然后请求对应的模型列表端点 + 通过提供商名称查找配置,然后请求对应的模型列表端点。 """ # 获取提供商配置 provider_config = _get_provider_config(provider_name) @@ -205,13 +230,21 @@ async def get_provider_models( if not api_key: raise HTTPException(status_code=400, detail="提供商配置缺少 api_key") + resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint + # 获取模型列表 models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, - endpoint=endpoint, + endpoint=resolved_endpoint, parser=parser, client_type=client_type, + auth_type=provider_config.get("auth_type", "bearer"), + auth_header_name=provider_config.get("auth_header_name", "Authorization"), + auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"), + auth_query_name=provider_config.get("auth_query_name", "api_key"), + default_headers=provider_config.get("default_headers", {}), + default_query=provider_config.get("default_query", {}), ) return { @@ -229,16 +262,22 @@ async def get_models_by_url( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), + auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"), + auth_header_name: str = Query("Authorization", description="Header 鉴权名称"), + auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"), + auth_query_name: str = Query("api_key", description="Query 鉴权参数名"), ): - """ - 通过 URL 直接获取模型列表(用于自定义提供商) - """ + """通过 URL 直接获取模型列表。""" models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, endpoint=endpoint, parser=parser, client_type=client_type, + auth_type=auth_type, + auth_header_name=auth_header_name, + auth_header_prefix=auth_header_prefix, + auth_query_name=auth_query_name, ) return {