feat: Enhance OpenAI compatibility and introduce unified LLM service data models
- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs. - Introduced new data models for LLM service requests and responses to standardize interactions across layers. - Added an adapter base class for unified request execution across different providers. - Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message>=0.6.2",
|
||||
"maibot-plugin-sdk>=2.0.0",
|
||||
"maibot-plugin-sdk>=2.1.0",
|
||||
"mcp",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
|
||||
@@ -2,8 +2,8 @@ import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
@@ -109,8 +109,8 @@ class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner",
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
@@ -398,7 +398,8 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
@@ -427,7 +428,8 @@ class ActionPlanner:
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
end_generation_result = await self.llm.generate_response(end_decision_prompt)
|
||||
end_content = end_generation_result.response # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
@@ -43,7 +43,9 @@ class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner", request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
@@ -157,7 +159,8 @@ class GoalAnalyzer:
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
@@ -271,7 +274,8 @@ class GoalAnalyzer:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
|
||||
@@ -3,8 +3,7 @@ from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.chat.knowledge import qa_manager
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
@@ -14,7 +13,7 @@ class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.llm = LLMServiceClient(task_name="utils")
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
|
||||
@@ -2,8 +2,8 @@ import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
@@ -14,7 +14,7 @@ class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.llm = LLMServiceClient(task_name="utils", request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
@@ -137,7 +137,8 @@ class ReplyChecker:
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
@@ -87,8 +87,8 @@ class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
@@ -223,7 +223,8 @@ class ReplyGenerator:
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
@@ -17,9 +17,9 @@ from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
@@ -43,8 +43,8 @@ class BrainPlanner:
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
@@ -412,7 +412,9 @@ class BrainPlanner:
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
|
||||
@@ -38,8 +39,10 @@ def _ensure_directories() -> None:
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMServiceClient(
|
||||
task_name="utils", request_type="emoji"
|
||||
)
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
@@ -461,9 +464,11 @@ class EmojiManager:
|
||||
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
|
||||
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
|
||||
|
||||
decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emoji_replace_prompt, temperature=0.8, max_tokens=600
|
||||
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emoji_replace_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.8, max_tokens=600),
|
||||
)
|
||||
decision = decision_result.response
|
||||
logger.info(f"[决策] 结果: {decision}")
|
||||
|
||||
# 解析决策结果
|
||||
@@ -524,24 +529,36 @@ class EmojiManager:
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.5
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
"jpg",
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.5
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
|
||||
# 表情包审查
|
||||
if global_config.emoji.content_filtration:
|
||||
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
|
||||
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
|
||||
llm_response, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt, image_base64, image_format, temperature=0.3
|
||||
filtration_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.3),
|
||||
)
|
||||
llm_response = filtration_result.response
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
@@ -567,9 +584,11 @@ class EmojiManager:
|
||||
emotion_prompt_template.add_context("description", target_emoji.description)
|
||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||
# 调用LLM生成情感标签
|
||||
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=200
|
||||
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emotion_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
|
||||
)
|
||||
emotion_result = emotion_generation_result.response
|
||||
|
||||
# 解析情感标签结果
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
@@ -11,8 +11,9 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.data_models.image_data_model import MaiImage
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +28,7 @@ def _ensure_image_dir_exists():
|
||||
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
vlm = LLMServiceClient(task_name="vlm", request_type="image")
|
||||
|
||||
|
||||
class ImageManager:
|
||||
@@ -260,7 +261,13 @@ class ImageManager:
|
||||
prompt = global_config.personality.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4)
|
||||
generation_result = await vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.4),
|
||||
)
|
||||
description = generation_result.response
|
||||
if not description:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
return description or ""
|
||||
|
||||
@@ -139,14 +139,14 @@ class EmbeddingStore:
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
@@ -195,13 +195,12 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
@@ -209,7 +208,8 @@ class EmbeddingStore:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from . import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
from . import INVALID_ENTITY
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]:
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
"""从文本中提取 JSON 数据。
|
||||
|
||||
Args:
|
||||
text: 原始模型输出文本。
|
||||
|
||||
Returns:
|
||||
List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。
|
||||
"""
|
||||
if text is None:
|
||||
logger.error("输入文本为None")
|
||||
return []
|
||||
@@ -46,20 +55,30 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
"""对单段文本执行实体提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取实体的原始段落文本。
|
||||
|
||||
Returns:
|
||||
List[str]: 提取出的实体列表。
|
||||
"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(entity_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"LLM返回的原始响应: {response}")
|
||||
@@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
def _rdf_triple_extract(
|
||||
llm_req: LLMServiceClient,
|
||||
paragraph: str,
|
||||
entities: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""对单段文本执行 RDF 三元组提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取的原始段落文本。
|
||||
entities: 已识别出的实体列表。
|
||||
|
||||
Returns:
|
||||
List[List[str]]: 提取出的三元组列表。
|
||||
"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
@@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||
@@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
llm_client_for_ner: LLMServiceClient,
|
||||
llm_client_for_rdf: LLMServiceClient,
|
||||
paragraph: str,
|
||||
) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]:
|
||||
"""从文本中提取实体与三元组信息。
|
||||
|
||||
Args:
|
||||
llm_client_for_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。
|
||||
paragraph: 原始段落文本。
|
||||
|
||||
Returns:
|
||||
Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回
|
||||
``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。
|
||||
"""
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -176,17 +223,30 @@ def info_extract_from_str(
|
||||
|
||||
|
||||
class IEProcess:
|
||||
"""
|
||||
信息抽取处理器类,提供更方便的批次处理接口。
|
||||
"""
|
||||
"""信息抽取处理器。"""
|
||||
|
||||
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_ner: LLMServiceClient,
|
||||
llm_rdf: LLMServiceClient | None = None,
|
||||
) -> None:
|
||||
"""初始化信息抽取处理器。
|
||||
|
||||
Args:
|
||||
llm_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。
|
||||
"""
|
||||
self.llm_ner = llm_ner
|
||||
self.llm_rdf = llm_rdf or llm_ner
|
||||
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
|
||||
"""
|
||||
异步处理多个段落。
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]:
|
||||
"""异步处理多个段落。
|
||||
|
||||
Args:
|
||||
paragraphs: 待处理的段落列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 每个成功段落对应的抽取结果。
|
||||
"""
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
@@ -91,13 +91,14 @@ class LPMMOperations:
|
||||
|
||||
# 2. 实体与三元组抽取 (内部调用大模型)
|
||||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
llm_ner = LLMServiceClient(
|
||||
task_name="lpmm_entity_extract", request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMServiceClient(
|
||||
task_name="lpmm_rdf_build", request_type="lpmm.rdf_build"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
|
||||
@@ -149,7 +149,7 @@ class ActionModifier:
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
activation_type = action_info.activation_type or action_info.focus_activation_type
|
||||
activation_type = action_info.activation_type
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
@@ -19,9 +19,9 @@ from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -46,8 +46,8 @@ class ActionPlanner:
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
@@ -725,7 +725,9 @@ class ActionPlanner:
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -56,7 +56,9 @@ class DefaultReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
|
||||
@@ -1158,9 +1160,11 @@ class DefaultReplyer:
|
||||
# else:
|
||||
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
@@ -1200,11 +1204,15 @@ class DefaultReplyer:
|
||||
template_prompt.add_context("sender", sender)
|
||||
template_prompt.add_context("target_message", target)
|
||||
prompt = await prompt_manager.render_prompt(template_prompt)
|
||||
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
generation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="tool_use",
|
||||
request_type="replyer.lpmm_knowledge",
|
||||
prompt=prompt,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
)
|
||||
)
|
||||
tool_calls = generation_result.completion.tool_calls
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -52,7 +52,9 @@ class PrivateReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
# self.memory_activator = MemoryActivator()
|
||||
@@ -997,9 +999,11 @@ class PrivateReplyer:
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
content = content.strip()
|
||||
|
||||
|
||||
@@ -4,16 +4,18 @@
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
@@ -33,7 +35,9 @@ class ToolExecutor:
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="tool_use", request_type="tool_executor"
|
||||
)
|
||||
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
@@ -69,9 +73,11 @@ class ToolExecutor:
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
generation_result = await self.llm_model.generate_response(
|
||||
prompt=prompt,
|
||||
options=LLMGenerationOptions(tool_options=tools, raise_when_empty=False),
|
||||
)
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
@@ -85,11 +91,15 @@ class ToolExecutor:
|
||||
return tool_results, used_tools, prompt
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
def _get_tool_definitions(self) -> List[ToolDefinitionInput]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_query_service.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
return [
|
||||
cast(ToolDefinitionInput, info.get_llm_definition())
|
||||
for name, info in all_tools.items()
|
||||
if name not in user_disabled_tools
|
||||
]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用列表"""
|
||||
|
||||
@@ -13,8 +13,8 @@ import jieba
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
|
||||
187
src/common/data_models/llm_service_data_models.py
Normal file
187
src/common/data_models/llm_service_data_models.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""LLM 服务层与编排层共享数据模型。
|
||||
|
||||
该模块集中定义 LLM 服务层与底层编排器共同使用的请求、选项与结果对象,
|
||||
用于替代散落在各层之间的复杂元组返回值。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeAlias
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.data_models import BaseDataModel
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
|
||||
|
||||
PromptMessage: TypeAlias = Dict[str, Any]
|
||||
"""统一的原始提示消息结构。"""
|
||||
|
||||
PromptInput: TypeAlias = str | List[PromptMessage]
|
||||
"""统一的提示输入类型。"""
|
||||
|
||||
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
|
||||
"""统一的消息工厂类型。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMServiceRequest(BaseDataModel):
|
||||
"""LLM 服务层统一请求对象。"""
|
||||
|
||||
task_name: str
|
||||
request_type: str
|
||||
prompt: PromptInput | None = None
|
||||
message_factory: MessageFactory | None = None
|
||||
tool_options: List[ToolDefinitionInput] | None = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
response_format: RespFormat | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""校验请求对象的必要字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 `task_name` 为空,或 `prompt` 与 `message_factory`
|
||||
的组合非法时抛出。
|
||||
"""
|
||||
self.task_name = self.task_name.strip()
|
||||
if not self.task_name:
|
||||
raise ValueError("`task_name` 不能为空")
|
||||
has_prompt = self.prompt is not None
|
||||
has_message_factory = self.message_factory is not None
|
||||
if has_prompt == has_message_factory:
|
||||
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMResponseResult(BaseDataModel):
|
||||
"""单次 LLM 响应结果。"""
|
||||
|
||||
response: str = field(default_factory=str)
|
||||
reasoning: str = field(default_factory=str)
|
||||
model_name: str = field(default_factory=str)
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMServiceResult(BaseDataModel):
|
||||
"""LLM 服务层统一响应对象。"""
|
||||
|
||||
success: bool = False
|
||||
completion: LLMResponseResult = field(default_factory=LLMResponseResult)
|
||||
error: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_response_result(cls, completion: LLMResponseResult) -> "LLMServiceResult":
|
||||
"""从单次 LLM 响应结果构建服务响应。
|
||||
|
||||
Args:
|
||||
completion: 单次 LLM 响应结果。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 标记为成功的服务响应对象。
|
||||
"""
|
||||
return cls(
|
||||
success=True,
|
||||
completion=completion,
|
||||
error=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_error(cls, error_message: str, error_detail: str | None = None) -> "LLMServiceResult":
|
||||
"""构建失败的服务响应对象。
|
||||
|
||||
Args:
|
||||
error_message: 对上层展示的错误消息。
|
||||
error_detail: 底层错误详情。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 标记为失败的服务响应对象。
|
||||
"""
|
||||
return cls(
|
||||
success=False,
|
||||
completion=LLMResponseResult(response=error_message),
|
||||
error=error_detail or error_message,
|
||||
)
|
||||
|
||||
def to_capability_payload(self) -> Dict[str, Any]:
|
||||
"""转换为插件能力层可直接返回的结构。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的能力返回值。
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"success": self.success,
|
||||
"response": self.completion.response,
|
||||
"reasoning": self.completion.reasoning,
|
||||
"model_name": self.completion.model_name,
|
||||
}
|
||||
if self.completion.tool_calls is not None:
|
||||
payload["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": tool_call.args or {},
|
||||
},
|
||||
}
|
||||
for tool_call in self.completion.tool_calls
|
||||
]
|
||||
if self.error:
|
||||
payload["error"] = self.error
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMGenerationOptions(BaseDataModel):
|
||||
"""LLM 文本生成选项。"""
|
||||
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
tool_options: List[ToolDefinitionInput] | None = None
|
||||
response_format: RespFormat | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
raise_when_empty: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMImageOptions(BaseDataModel):
|
||||
"""LLM 图像理解选项。"""
|
||||
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMAudioTranscriptionResult(BaseDataModel):
|
||||
"""LLM 音频转写结果。"""
|
||||
|
||||
text: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMEmbeddingResult(BaseDataModel):
|
||||
"""LLM 向量生成结果。"""
|
||||
|
||||
embedding: List[float] = field(default_factory=list)
|
||||
model_name: str = field(default_factory=str)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLMAudioTranscriptionResult",
|
||||
"LLMEmbeddingResult",
|
||||
"LLMGenerationOptions",
|
||||
"LLMImageOptions",
|
||||
"LLMResponseResult",
|
||||
"LLMServiceRequest",
|
||||
"LLMServiceResult",
|
||||
"MessageFactory",
|
||||
"PromptInput",
|
||||
"PromptMessage",
|
||||
]
|
||||
@@ -4,16 +4,15 @@ from typing import Optional
|
||||
import base64
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("voice_utils")
|
||||
|
||||
# TODO: 在LLMRequest重构后修改这里
|
||||
asr_model = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
|
||||
asr_model = LLMServiceClient(task_name="voice", request_type="audio")
|
||||
|
||||
|
||||
async def get_voice_text(voice_bytes: bytes) -> Optional[str]:
|
||||
@@ -30,7 +29,8 @@ async def get_voice_text(voice_bytes: bytes) -> Optional[str]:
|
||||
return None
|
||||
try:
|
||||
voice_base64 = base64.b64encode(voice_bytes).decode("utf-8")
|
||||
text = await asr_model.generate_response_for_voice(voice_base64)
|
||||
transcription_result = await asr_model.transcribe_audio(voice_base64)
|
||||
text = transcription_result.text
|
||||
if not text:
|
||||
logger.warning("语音转文字结果为空")
|
||||
|
||||
|
||||
@@ -1,7 +1,35 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .config_base import ConfigBase, Field
|
||||
from src.common.i18n import t
|
||||
from .config_base import ConfigBase, Field
|
||||
|
||||
|
||||
class OpenAICompatibleAuthType(str, Enum):
|
||||
"""OpenAI 兼容接口的鉴权方式。"""
|
||||
|
||||
BEARER = "bearer"
|
||||
HEADER = "header"
|
||||
QUERY = "query"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class ReasoningParseMode(str, Enum):
|
||||
"""推理内容解析策略。"""
|
||||
|
||||
AUTO = "auto"
|
||||
NATIVE = "native"
|
||||
THINK_TAG = "think_tag"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class ToolArgumentParseMode(str, Enum):
|
||||
"""工具调用参数的解析策略。"""
|
||||
|
||||
AUTO = "auto"
|
||||
STRICT = "strict"
|
||||
REPAIR = "repair"
|
||||
DOUBLE_DECODE = "double_decode"
|
||||
|
||||
|
||||
class APIProvider(ConfigBase):
|
||||
@@ -33,7 +61,7 @@ class APIProvider(ConfigBase):
|
||||
"x-icon": "key",
|
||||
},
|
||||
)
|
||||
"""API密钥"""
|
||||
"""API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。"""
|
||||
|
||||
client_type: str = Field(
|
||||
default="openai",
|
||||
@@ -44,6 +72,105 @@ class APIProvider(ConfigBase):
|
||||
)
|
||||
"""客户端类型 (可选: openai/google, 默认为openai)"""
|
||||
|
||||
auth_type: str = Field(
|
||||
default=OpenAICompatibleAuthType.BEARER.value,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "shield",
|
||||
},
|
||||
)
|
||||
"""OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。"""
|
||||
|
||||
auth_header_name: str = Field(
|
||||
default="Authorization",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "header",
|
||||
},
|
||||
)
|
||||
"""当 `auth_type` 为 `header` 时使用的请求头名称。"""
|
||||
|
||||
auth_header_prefix: str = Field(
|
||||
default="Bearer",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "shield-check",
|
||||
},
|
||||
)
|
||||
"""当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。"""
|
||||
|
||||
auth_query_name: str = Field(
|
||||
default="api_key",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""当 `auth_type` 为 `query` 时使用的查询参数名称。"""
|
||||
|
||||
default_headers: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "header",
|
||||
},
|
||||
)
|
||||
"""所有请求默认附带的 HTTP Header。"""
|
||||
|
||||
default_query: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "list-filter",
|
||||
},
|
||||
)
|
||||
"""所有请求默认附带的查询参数。"""
|
||||
|
||||
organization: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "building-2",
|
||||
},
|
||||
)
|
||||
"""OpenAI 官方接口可选的 `organization`。"""
|
||||
|
||||
project: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder-kanban",
|
||||
},
|
||||
)
|
||||
"""OpenAI 官方接口可选的 `project`。"""
|
||||
|
||||
model_list_endpoint: str = Field(
|
||||
default="/models",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list",
|
||||
},
|
||||
)
|
||||
"""模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。"""
|
||||
|
||||
reasoning_parse_mode: str = Field(
|
||||
default=ReasoningParseMode.AUTO.value,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "brain",
|
||||
},
|
||||
)
|
||||
"""推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。"""
|
||||
|
||||
tool_argument_parse_mode: str = Field(
|
||||
default=ToolArgumentParseMode.AUTO.value,
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "braces",
|
||||
},
|
||||
)
|
||||
"""工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。"""
|
||||
|
||||
max_retry: int = Field(
|
||||
default=2,
|
||||
ge=0,
|
||||
@@ -76,15 +203,26 @@ class APIProvider(ConfigBase):
|
||||
)
|
||||
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
"""确保api_key在repr中不被显示"""
|
||||
if not self.api_key:
|
||||
def model_post_init(self, context: Any = None) -> None:
|
||||
"""执行 API 提供商配置的后置校验。
|
||||
|
||||
Args:
|
||||
context: Pydantic 传入的上下文对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 当配置项缺失或组合不合法时抛出。
|
||||
"""
|
||||
if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key:
|
||||
raise ValueError(t("config.api_key_empty"))
|
||||
if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url
|
||||
raise ValueError(t("config.api_base_url_empty"))
|
||||
if not self.name:
|
||||
raise ValueError(t("config.api_provider_name_empty"))
|
||||
return super().model_post_init(context)
|
||||
if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip():
|
||||
raise ValueError("当 auth_type=header 时,auth_header_name 不能为空")
|
||||
if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip():
|
||||
raise ValueError("当 auth_type=query 时,auth_query_name 不能为空")
|
||||
super().model_post_init(context)
|
||||
|
||||
|
||||
class ModelInfo(ConfigBase):
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
# from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType
|
||||
# from src.common.data_models.message_data_model import ReplyContent as ReplyContent
|
||||
# from src.common.data_models.message_data_model import ForwardNode as ForwardNode
|
||||
@@ -15,49 +16,42 @@ from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
"""Host 内部使用的组件类型枚举。"""
|
||||
|
||||
ACTION = "action" # 动作组件
|
||||
COMMAND = "command" # 命令组件
|
||||
TOOL = "tool" # 服务组件(预留)
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||
TOOL = "tool" # 工具组件
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回枚举值字符串。
|
||||
|
||||
Returns:
|
||||
str: 当前组件类型对应的字符串值。
|
||||
"""
|
||||
return self.value
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType(Enum):
|
||||
"""动作激活类型枚举"""
|
||||
"""动作激活类型枚举。"""
|
||||
|
||||
NEVER = "never" # 从不激活(默认关闭)
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
def __str__(self) -> str:
|
||||
"""返回枚举值字符串。
|
||||
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode(Enum):
|
||||
"""聊天模式枚举"""
|
||||
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
PRIORITY = "priority" # 优先级聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
def __str__(self):
|
||||
Returns:
|
||||
str: 当前激活类型对应的字符串值。
|
||||
"""
|
||||
return self.value
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
事件类型枚举类
|
||||
"""
|
||||
"""事件类型枚举。"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||
@@ -72,185 +66,96 @@ class EventType(Enum):
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回枚举值字符串。
|
||||
|
||||
Returns:
|
||||
str: 当前事件类型对应的字符串值。
|
||||
"""
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class PythonDependency:
|
||||
"""Python包依赖信息"""
|
||||
|
||||
package_name: str # 包名称
|
||||
version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本
|
||||
optional: bool = False # 是否为可选依赖
|
||||
description: str = "" # 依赖描述
|
||||
install_name: str = "" # 安装时的包名(如果与import名不同)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.install_name:
|
||||
self.install_name = self.package_name
|
||||
|
||||
def get_pip_requirement(self) -> str:
|
||||
"""获取pip安装格式的依赖字符串"""
|
||||
if self.version:
|
||||
return f"{self.install_name}{self.version}"
|
||||
return self.install_name
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ComponentInfo:
|
||||
"""组件信息"""
|
||||
"""Host 内部使用的组件信息快照。"""
|
||||
|
||||
name: str # 组件名称
|
||||
component_type: ComponentType # 组件类型
|
||||
description: str = "" # 组件描述
|
||||
enabled: bool = True # 是否启用
|
||||
plugin_name: str = "" # 所属插件名称
|
||||
is_built_in: bool = False # 是否为内置组件
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
name: str
|
||||
"""组件名称。"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
description: str = ""
|
||||
"""组件描述。"""
|
||||
|
||||
enabled: bool = True
|
||||
"""组件是否启用。"""
|
||||
|
||||
plugin_name: str = ""
|
||||
"""所属插件 ID。"""
|
||||
|
||||
component_type: ComponentType = field(init=False)
|
||||
"""组件类型。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ActionInfo(ComponentInfo):
|
||||
"""动作组件信息"""
|
||||
"""供 Planner 与回复链使用的动作信息快照。"""
|
||||
|
||||
action_parameters: Dict[str, str] = field(
|
||||
default_factory=dict
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
keyword_case_sensitive: bool = False
|
||||
# 模式和并行设置
|
||||
parallel_action: bool = False
|
||||
component_type: ComponentType = field(init=False, default=ComponentType.ACTION)
|
||||
"""组件类型。"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.activation_keywords is None:
|
||||
self.activation_keywords = []
|
||||
if self.action_parameters is None:
|
||||
self.action_parameters = {}
|
||||
if self.action_require is None:
|
||||
self.action_require = []
|
||||
if self.associated_types is None:
|
||||
self.associated_types = []
|
||||
self.component_type = ComponentType.ACTION
|
||||
def __post_init__(self) -> None:
|
||||
"""归一化动作快照中的集合字段。"""
|
||||
self.action_parameters = dict(self.action_parameters or {})
|
||||
self.action_require = list(self.action_require or [])
|
||||
self.associated_types = list(self.associated_types or [])
|
||||
self.activation_keywords = list(self.activation_keywords or [])
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class CommandInfo(ComponentInfo):
|
||||
"""命令组件信息"""
|
||||
"""供命令处理链使用的命令信息快照。"""
|
||||
|
||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.COMMAND
|
||||
component_type: ComponentType = field(init=False, default=ComponentType.COMMAND)
|
||||
"""组件类型。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
"""供工具执行链使用的工具信息快照。"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
parameters_schema: Dict[str, Any] | None = None
|
||||
"""对象级工具参数 Schema。"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.TOOL
|
||||
component_type: ComponentType = field(init=False, default=ComponentType.TOOL)
|
||||
"""组件类型。"""
|
||||
|
||||
def get_llm_definition(self) -> dict:
|
||||
"""生成 LLM function-calling 所需的工具定义"""
|
||||
return {
|
||||
def get_llm_definition(self) -> Dict[str, Any]:
|
||||
"""生成供 LLM 使用的规范化工具定义。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一工具定义字典。
|
||||
"""
|
||||
definition: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.tool_description,
|
||||
"parameters": self.tool_parameters,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
definition["parameters_schema"] = copy.deepcopy(self.parameters_schema)
|
||||
return definition
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandlerInfo(ComponentInfo):
|
||||
"""事件处理器组件信息"""
|
||||
|
||||
event_type: EventType | str = EventType.ON_MESSAGE # 监听事件类型
|
||||
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
||||
weight: int = 0 # 事件处理器权重,决定执行顺序
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
display_name: str # 插件显示名称
|
||||
name: str # 插件名称
|
||||
description: str # 插件描述
|
||||
version: str = "1.0.0" # 插件版本
|
||||
author: str = "" # 插件作者
|
||||
enabled: bool = True # 是否启用
|
||||
is_built_in: bool = False # 是否为内置插件
|
||||
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
|
||||
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
|
||||
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
|
||||
config_file: str = "" # 配置文件路径
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
# 新增:manifest相关信息
|
||||
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
|
||||
license: str = "" # 插件许可证
|
||||
homepage_url: str = "" # 插件主页
|
||||
repository_url: str = "" # 插件仓库地址
|
||||
keywords: List[str] = field(default_factory=list) # 插件关键词
|
||||
categories: List[str] = field(default_factory=list) # 插件分类
|
||||
min_host_version: str = "" # 最低主机版本要求
|
||||
max_host_version: str = "" # 最高主机版本要求
|
||||
|
||||
def __post_init__(self):
|
||||
if self.components is None:
|
||||
self.components = []
|
||||
if self.dependencies is None:
|
||||
self.dependencies = []
|
||||
if self.python_dependencies is None:
|
||||
self.python_dependencies = []
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
if self.manifest_data is None:
|
||||
self.manifest_data = {}
|
||||
if self.keywords is None:
|
||||
self.keywords = []
|
||||
if self.categories is None:
|
||||
self.categories = []
|
||||
|
||||
def get_missing_packages(self) -> List[PythonDependency]:
|
||||
"""检查缺失的Python包"""
|
||||
missing = []
|
||||
for dep in self.python_dependencies:
|
||||
try:
|
||||
__import__(dep.package_name)
|
||||
except ImportError:
|
||||
if not dep.optional:
|
||||
missing.append(dep)
|
||||
return missing
|
||||
|
||||
def get_pip_requirements(self) -> List[str]:
|
||||
"""获取所有pip安装格式的依赖"""
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ModifyFlag:
|
||||
"""消息修改标记集合。"""
|
||||
|
||||
modify_message_segments: bool = False
|
||||
modify_plain_text: bool = False
|
||||
modify_llm_prompt: bool = False
|
||||
@@ -258,9 +163,9 @@ class ModifyFlag:
|
||||
modify_llm_response_reasoning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
"""核心事件系统使用的统一消息模型。"""
|
||||
|
||||
message_segments: List[Seg] = field(default_factory=list)
|
||||
"""消息段列表,支持多段消息"""
|
||||
@@ -306,11 +211,17 @@ class MaiMessages:
|
||||
|
||||
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""归一化消息段列表。"""
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
|
||||
def deepcopy(self):
|
||||
def deepcopy(self) -> "MaiMessages":
|
||||
"""深拷贝当前消息对象。
|
||||
|
||||
Returns:
|
||||
MaiMessages: 深拷贝后的消息对象。
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to_transport_dict(self) -> Dict[str, Any]:
|
||||
@@ -347,6 +258,14 @@ class MaiMessages:
|
||||
|
||||
@staticmethod
|
||||
def _serialize_transport_value(value: Any) -> Any:
|
||||
"""递归序列化字段值为可传输结构。
|
||||
|
||||
Args:
|
||||
value: 任意字段值。
|
||||
|
||||
Returns:
|
||||
Any: 可用于 IPC 传输的纯 Python 值。
|
||||
"""
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, Enum):
|
||||
@@ -367,13 +286,22 @@ class MaiMessages:
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
|
||||
"""反序列化特定字段的传输值。
|
||||
|
||||
Args:
|
||||
field_name: 字段名称。
|
||||
value: 传输层返回的字段值。
|
||||
|
||||
Returns:
|
||||
Any: 反序列化后的字段值。
|
||||
"""
|
||||
if field_name == "message_segments" and isinstance(value, list):
|
||||
deserialized_segments: List[Seg] = []
|
||||
for segment in value:
|
||||
if isinstance(segment, Seg):
|
||||
deserialized_segments.append(segment)
|
||||
elif isinstance(segment, dict) and "type" in segment:
|
||||
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
|
||||
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data", "")))
|
||||
return deserialized_segments
|
||||
|
||||
if field_name == "llm_response_tool_call" and isinstance(value, list):
|
||||
@@ -393,15 +321,15 @@ class MaiMessages:
|
||||
|
||||
return value
|
||||
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||
"""
|
||||
修改消息段列表
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False) -> None:
|
||||
"""修改消息段列表。
|
||||
|
||||
Warning:
|
||||
在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致
|
||||
在生成了 ``plain_text`` 的情况下调用此方法,可能会导致文本与消息段不一致。
|
||||
|
||||
Args:
|
||||
new_segments (List[Seg]): 新的消息段列表
|
||||
new_segments: 新的消息段列表。
|
||||
suppress_warning: 是否抑制潜在不一致警告。
|
||||
"""
|
||||
if self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
@@ -412,15 +340,15 @@ class MaiMessages:
|
||||
self.message_segments = new_segments
|
||||
self._modify_flags.modify_message_segments = True
|
||||
|
||||
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改LLM提示词
|
||||
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False) -> None:
|
||||
"""修改 LLM 提示词。
|
||||
|
||||
Warning:
|
||||
在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效
|
||||
在没有生成 ``llm_prompt`` 的情况下调用此方法,可能会导致修改无效。
|
||||
|
||||
Args:
|
||||
new_prompt (str): 新的提示词内容
|
||||
new_prompt: 新的提示词内容。
|
||||
suppress_warning: 是否抑制潜在无效修改警告。
|
||||
"""
|
||||
if self.llm_prompt is None and not suppress_warning:
|
||||
warnings.warn(
|
||||
@@ -431,15 +359,15 @@ class MaiMessages:
|
||||
self.llm_prompt = new_prompt
|
||||
self._modify_flags.modify_llm_prompt = True
|
||||
|
||||
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的plain_text内容
|
||||
def modify_plain_text(self, new_text: str, suppress_warning: bool = False) -> None:
|
||||
"""修改生成的纯文本内容。
|
||||
|
||||
Warning:
|
||||
在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效
|
||||
在未生成 ``plain_text`` 的情况下调用此方法,可能会导致修改无效。
|
||||
|
||||
Args:
|
||||
new_text (str): 新的纯文本内容
|
||||
new_text: 新的纯文本内容。
|
||||
suppress_warning: 是否抑制潜在无效修改警告。
|
||||
"""
|
||||
if not self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
@@ -450,15 +378,15 @@ class MaiMessages:
|
||||
self.plain_text = new_text
|
||||
self._modify_flags.modify_plain_text = True
|
||||
|
||||
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_content内容
|
||||
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False) -> None:
|
||||
"""修改生成的 LLM 响应正文。
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效
|
||||
在未生成 ``llm_response_content`` 的情况下调用此方法,可能会导致修改无效。
|
||||
|
||||
Args:
|
||||
new_content (str): 新的LLM响应内容
|
||||
new_content: 新的 LLM 响应内容。
|
||||
suppress_warning: 是否抑制潜在无效修改警告。
|
||||
"""
|
||||
if not self.llm_response_content and not suppress_warning:
|
||||
warnings.warn(
|
||||
@@ -469,15 +397,15 @@ class MaiMessages:
|
||||
self.llm_response_content = new_content
|
||||
self._modify_flags.modify_llm_response_content = True
|
||||
|
||||
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_reasoning内容
|
||||
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False) -> None:
|
||||
"""修改生成的 LLM 推理内容。
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效
|
||||
在未生成 ``llm_response_reasoning`` 的情况下调用此方法,可能会导致修改无效。
|
||||
|
||||
Args:
|
||||
new_reasoning (str): 新的LLM响应推理内容
|
||||
new_reasoning: 新的 LLM 推理内容。
|
||||
suppress_warning: 是否抑制潜在无效修改警告。
|
||||
"""
|
||||
if not self.llm_response_reasoning and not suppress_warning:
|
||||
warnings.warn(
|
||||
@@ -487,10 +415,3 @@ class MaiMessages:
|
||||
)
|
||||
self.llm_response_reasoning = new_reasoning
|
||||
self._modify_flags.modify_llm_response_reasoning = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomEventHandlerResult:
|
||||
message: str = ""
|
||||
timestamp: float = 0.0
|
||||
extra_info: Optional[Dict] = None
|
||||
|
||||
@@ -20,8 +20,8 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
logger = get_logger("expression_auto_check_task")
|
||||
@@ -76,7 +76,7 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
|
||||
judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check")
|
||||
|
||||
|
||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
|
||||
@@ -94,10 +94,11 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
generation_result = await judge_llm.generate_response(
|
||||
prompt=prompt,
|
||||
options=LLMGenerationOptions(temperature=0.6, max_tokens=1024),
|
||||
)
|
||||
|
||||
response = generation_result.response
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
|
||||
@@ -7,8 +7,9 @@ import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
@@ -26,10 +27,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner")
|
||||
summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary")
|
||||
check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check")
|
||||
express_learn_model = LLMServiceClient(
|
||||
task_name="utils", request_type="expression.learner"
|
||||
)
|
||||
summary_model = LLMServiceClient(task_name="tool_use", request_type="expression.summary")
|
||||
check_model = LLMServiceClient(task_name="tool_use", request_type="expression.check")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
@@ -74,7 +76,10 @@ class ExpressionLearner:
|
||||
|
||||
# 调用 LLM 学习表达方式
|
||||
try:
|
||||
response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
generation_result = await express_learn_model.generate_response(
|
||||
prompt, options=LLMGenerationOptions(temperature=0.3)
|
||||
)
|
||||
response = generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错:{e}")
|
||||
return
|
||||
@@ -413,7 +418,10 @@ class ExpressionLearner:
|
||||
"只输出概括内容。"
|
||||
)
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary_result = await summary_model.generate_response(
|
||||
prompt, options=LLMGenerationOptions(temperature=0.2)
|
||||
)
|
||||
summary = summary_result.response
|
||||
if summary := summary.strip():
|
||||
return summary
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,10 +4,11 @@ import time
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
@@ -17,8 +18,8 @@ logger = get_logger("expression_selector")
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="tool_use", request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
@@ -383,8 +384,8 @@ class ExpressionSelector:
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
generation_result = await self.llm_model.generate_response(prompt=prompt)
|
||||
content = generation_result.response
|
||||
# print(prompt)
|
||||
# print(content)
|
||||
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
from json_repair import repair_json
|
||||
from typing import Tuple, Optional, List
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_utils")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
|
||||
judge_llm = LLMServiceClient(task_name="tool_use", request_type="expression_check")
|
||||
|
||||
|
||||
def _normalize_repair_json_result(repaired_result: Any) -> str:
|
||||
"""将 repair_json 的返回值规范化为 JSON 字符串。
|
||||
|
||||
Args:
|
||||
repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。
|
||||
|
||||
Returns:
|
||||
str: 可供 `json.loads` 继续解析的 JSON 字符串。
|
||||
|
||||
Raises:
|
||||
TypeError: 当返回值无法规范化为字符串时抛出。
|
||||
"""
|
||||
if isinstance(repaired_result, str):
|
||||
return repaired_result
|
||||
if isinstance(repaired_result, tuple) and repaired_result:
|
||||
first_item = repaired_result[0]
|
||||
if isinstance(first_item, str):
|
||||
return first_item
|
||||
return json.dumps(first_item, ensure_ascii=False)
|
||||
raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}")
|
||||
|
||||
|
||||
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
|
||||
@@ -51,7 +72,11 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
|
||||
logger.info(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024)
|
||||
generation_result = await judge_llm.generate_response(
|
||||
prompt=prompt,
|
||||
options=LLMGenerationOptions(temperature=0.6, max_tokens=1024),
|
||||
)
|
||||
response = generation_result.response
|
||||
|
||||
logger.debug(f"评估结果: {response}")
|
||||
|
||||
@@ -59,7 +84,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
response_repaired = repair_json(response)
|
||||
response_repaired = _normalize_repair_json_result(repair_json(response))
|
||||
evaluation = json.loads(response_repaired)
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
|
||||
@@ -74,7 +99,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
return False, f"评估结果格式错误: {e}", str(e)
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
def fix_chinese_quotes_in_json(text: str) -> str:
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
@@ -201,12 +226,12 @@ def is_single_char_jargon(content: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _try_parse(text):
|
||||
def _try_parse(text: str) -> Any:
|
||||
try:
|
||||
return json.loads(text)
|
||||
except Exception:
|
||||
try:
|
||||
repaired = repair_json(text)
|
||||
repaired = _normalize_repair_json_result(repair_json(text))
|
||||
return json.loads(repaired)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -4,8 +4,9 @@ from typing import List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.jargon_miner_old import search_jargon
|
||||
from src.learners.learner_utils_old import (
|
||||
@@ -23,8 +24,8 @@ class JargonExplainer:
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="tool_use",
|
||||
request_type="jargon.explain",
|
||||
)
|
||||
|
||||
@@ -206,7 +207,10 @@ class JargonExplainer:
|
||||
prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text)
|
||||
summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize)
|
||||
|
||||
summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3)
|
||||
summary_result = await self.llm.generate_response(
|
||||
summarize_prompt, options=LLMGenerationOptions(temperature=0.3)
|
||||
)
|
||||
summary = summary_result.response
|
||||
if not summary:
|
||||
# 如果LLM概括失败,直接返回原始解释
|
||||
return f"上下文中的黑话解释:\n{explanations_text}"
|
||||
|
||||
@@ -12,17 +12,17 @@ from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
from .expression_utils import is_single_char_jargon
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
|
||||
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
|
||||
llm_extract = LLMServiceClient(task_name="utils", request_type="jargon.extract")
|
||||
llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference")
|
||||
|
||||
|
||||
class JargonEntry(TypedDict):
|
||||
@@ -100,7 +100,10 @@ class JargonMiner:
|
||||
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
|
||||
prompt1 = await prompt_manager.render_prompt(prompt1_template)
|
||||
|
||||
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||||
generation_result_1 = await llm_inference.generate_response(
|
||||
prompt1, options=LLMGenerationOptions(temperature=0.3)
|
||||
)
|
||||
llm_response_1 = generation_result_1.response
|
||||
if not llm_response_1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
@@ -129,7 +132,10 @@ class JargonMiner:
|
||||
prompt2_template.add_context("content", content)
|
||||
prompt2 = await prompt_manager.render_prompt(prompt2_template)
|
||||
|
||||
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||||
generation_result_2 = await llm_inference.generate_response(
|
||||
prompt2, options=LLMGenerationOptions(temperature=0.3)
|
||||
)
|
||||
llm_response_2 = generation_result_2.response
|
||||
if not llm_response_2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
@@ -153,7 +159,10 @@ class JargonMiner:
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||||
generation_result_3 = await llm_inference.generate_response(
|
||||
prompt3, options=LLMGenerationOptions(temperature=0.3)
|
||||
)
|
||||
llm_response_3 = generation_result_3.response
|
||||
if not llm_response_3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
259
src/llm_models/model_client/adapter_base.py
Normal file
259
src/llm_models/model_client/adapter_base.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.config.model_configs import ModelInfo
|
||||
|
||||
from .base_client import (
|
||||
APIResponse,
|
||||
AudioTranscriptionRequest,
|
||||
BaseClient,
|
||||
EmbeddingRequest,
|
||||
ResponseRequest,
|
||||
UsageRecord,
|
||||
UsageTuple,
|
||||
)
|
||||
|
||||
RawStreamT = TypeVar("RawStreamT")
|
||||
"""流式原始响应类型变量。"""
|
||||
|
||||
RawResponseT = TypeVar("RawResponseT")
|
||||
"""非流式原始响应类型变量。"""
|
||||
|
||||
TaskResultT = TypeVar("TaskResultT")
|
||||
"""异步任务返回值类型变量。"""
|
||||
|
||||
ProviderStreamResponseHandler = Callable[
|
||||
[RawStreamT, asyncio.Event | None],
|
||||
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
|
||||
]
|
||||
"""Provider 专用流式响应处理函数类型。"""
|
||||
|
||||
ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]]
|
||||
"""Provider 专用非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
async def await_task_with_interrupt(
|
||||
task: asyncio.Task[TaskResultT],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
*,
|
||||
interval_seconds: float = 0.1,
|
||||
) -> TaskResultT:
|
||||
"""在支持外部中断的前提下等待异步任务完成。
|
||||
|
||||
Args:
|
||||
task: 待等待的异步任务。
|
||||
interrupt_flag: 外部中断标记。
|
||||
interval_seconds: 轮询检查间隔,单位秒。
|
||||
|
||||
Returns:
|
||||
TaskResultT: 任务执行结果。
|
||||
|
||||
Raises:
|
||||
ReqAbortException: 等待期间收到外部中断信号时抛出。
|
||||
"""
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
|
||||
while not task.done():
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(interval_seconds)
|
||||
return await task
|
||||
|
||||
|
||||
class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
|
||||
"""提供统一请求执行骨架的 Provider 适配基类。"""
|
||||
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一响应对象。
|
||||
"""
|
||||
stream_response_handler = self._resolve_stream_response_handler(request)
|
||||
response_parser = self._resolve_response_parser(request)
|
||||
response, usage_record = await self._execute_response_request(
|
||||
request,
|
||||
stream_response_handler,
|
||||
response_parser,
|
||||
)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一嵌入响应。
|
||||
"""
|
||||
response, usage_record = await self._execute_embedding_request(request)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一音频转录响应。
|
||||
"""
|
||||
response, usage_record = await self._execute_audio_transcription_request(request)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
def _resolve_stream_response_handler(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderStreamResponseHandler[RawStreamT]:
|
||||
"""解析实际使用的流式响应处理器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。
|
||||
"""
|
||||
if request.stream_response_handler is not None:
|
||||
return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler)
|
||||
return self._build_default_stream_response_handler(request)
|
||||
|
||||
def _resolve_response_parser(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderResponseParser[RawResponseT]:
|
||||
"""解析实际使用的非流式响应解析器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderResponseParser[RawResponseT]: 非流式响应解析器。
|
||||
"""
|
||||
if request.async_response_parser is not None:
|
||||
return cast(ProviderResponseParser[RawResponseT], request.async_response_parser)
|
||||
return self._build_default_response_parser(request)
|
||||
|
||||
@staticmethod
|
||||
def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord:
|
||||
"""根据统一使用量三元组构建 `UsageRecord`。
|
||||
|
||||
Args:
|
||||
model_info: 模型信息。
|
||||
usage_record: 使用量三元组。
|
||||
|
||||
Returns:
|
||||
UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。
|
||||
"""
|
||||
return UsageRecord(
|
||||
model_name=model_info.name,
|
||||
provider_name=model_info.api_provider,
|
||||
prompt_tokens=usage_record[0],
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
def _attach_usage_record(
|
||||
self,
|
||||
response: APIResponse,
|
||||
model_info: ModelInfo,
|
||||
usage_record: UsageTuple | None,
|
||||
) -> APIResponse:
|
||||
"""在响应对象上附加统一使用量信息。
|
||||
|
||||
Args:
|
||||
response: 已解析的统一响应对象。
|
||||
model_info: 模型信息。
|
||||
usage_record: 可选的使用量三元组。
|
||||
|
||||
Returns:
|
||||
APIResponse: 附加使用量后的响应对象。
|
||||
"""
|
||||
if usage_record is not None:
|
||||
response.usage = self._build_usage_record(model_info, usage_record)
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def _build_default_stream_response_handler(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderStreamResponseHandler[RawStreamT]:
|
||||
"""构建默认流式响应处理器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _build_default_response_parser(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderResponseParser[RawResponseT]:
|
||||
"""构建默认非流式响应解析器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderResponseParser[RawResponseT]: 默认非流式解析器。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_response_request(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
stream_response_handler: ProviderStreamResponseHandler[RawStreamT],
|
||||
response_parser: ProviderResponseParser[RawResponseT],
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的文本/多模态响应请求。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
stream_response_handler: 流式响应处理器。
|
||||
response_parser: 非流式响应解析器。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的嵌入请求。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_audio_transcription_request(
|
||||
self,
|
||||
request: AudioTranscriptionRequest,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的音频转录请求。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
logger = get_logger("model_client_registry")
|
||||
|
||||
@@ -47,10 +48,10 @@ class APIResponse:
|
||||
reasoning_content: str | None = None
|
||||
"""推理内容"""
|
||||
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
"""工具调用 [(工具名称, 工具参数), ...]"""
|
||||
|
||||
embedding: list[float] | None = None
|
||||
embedding: List[float] | None = None
|
||||
"""嵌入向量"""
|
||||
|
||||
usage: UsageRecord | None = None
|
||||
@@ -60,6 +61,82 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
UsageTuple = Tuple[int, int, int]
|
||||
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
|
||||
|
||||
StreamResponseHandler = Callable[
|
||||
[Any, asyncio.Event | None],
|
||||
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
|
||||
]
|
||||
"""统一的流式响应处理函数类型。"""
|
||||
|
||||
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
|
||||
"""统一的非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ResponseRequest:
|
||||
"""统一的文本/多模态响应请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
message_list: List[Message]
|
||||
tool_options: List[ToolOption] | None = None
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
response_format: RespFormat | None = None
|
||||
stream_response_handler: StreamResponseHandler | None = None
|
||||
async_response_parser: ResponseParser | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def copy_with(self, **changes: Any) -> "ResponseRequest":
|
||||
"""基于当前请求创建一个带局部变更的新请求。
|
||||
|
||||
Args:
|
||||
**changes: 需要覆盖的字段值。
|
||||
|
||||
Returns:
|
||||
ResponseRequest: 复制后的请求对象。
|
||||
"""
|
||||
payload = {
|
||||
"model_info": self.model_info,
|
||||
"message_list": list(self.message_list),
|
||||
"tool_options": None if self.tool_options is None else list(self.tool_options),
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"response_format": self.response_format,
|
||||
"stream_response_handler": self.stream_response_handler,
|
||||
"async_response_parser": self.async_response_parser,
|
||||
"interrupt_flag": self.interrupt_flag,
|
||||
"extra_params": dict(self.extra_params),
|
||||
}
|
||||
payload.update(changes)
|
||||
return ResponseRequest(**payload)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EmbeddingRequest:
|
||||
"""统一的嵌入请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
embedding_input: str
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AudioTranscriptionRequest:
|
||||
"""统一的音频转录请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
audio_base64: str
|
||||
max_tokens: int | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
|
||||
"""统一客户端请求类型。"""
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
基础客户端
|
||||
@@ -67,97 +144,82 @@ class BaseClient(ABC):
|
||||
|
||||
api_provider: APIProvider
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
def __init__(self, api_provider: APIProvider) -> None:
|
||||
"""初始化基础客户端。
|
||||
|
||||
Args:
|
||||
api_provider: API 提供商配置。
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
|
||||
@abstractmethod
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
:param model_info: 模型信息
|
||||
:param message_list: 对话体
|
||||
:param tool_options: 工具选项(可选,默认为None)
|
||||
:param max_tokens: 最大token数(可选,默认为1024)
|
||||
:param temperature: 温度(可选,默认为0.7)
|
||||
:param response_format: 响应格式(可选,默认为 NotGiven )
|
||||
:param stream_response_handler: 流式响应处理函数(可选)
|
||||
:param async_response_parser: 响应解析函数(可选)
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 统一响应对象。
|
||||
"""
|
||||
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取文本嵌入
|
||||
:param model_info: 模型信息
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 嵌入响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录
|
||||
:param model_info: 模型信息
|
||||
:param audio_base64: base64编码的音频数据
|
||||
:extra_params: 附加的请求参数
|
||||
:return: 音频转录响应
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 音频转录响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
def get_support_image_formats(self) -> List[str]:
|
||||
"""获取支持的图片格式。
|
||||
|
||||
Returns:
|
||||
List[str]: 支持的图片格式列表。
|
||||
"""
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
"""客户端注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||
"""初始化注册表并绑定配置重载回调。"""
|
||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||
"""注册 API 客户端类。
|
||||
|
||||
Args:
|
||||
client_class: API客户端类
|
||||
client_type: 客户端类型标识。
|
||||
|
||||
Returns:
|
||||
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
@@ -165,14 +227,15 @@ class ClientRegistry:
|
||||
|
||||
return decorator
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||
"""获取注册的 API 客户端实例。
|
||||
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
api_provider: APIProvider 实例。
|
||||
force_new: 是否强制创建新实例。
|
||||
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
BaseClient: 注册的 API 客户端实例。
|
||||
"""
|
||||
from . import ensure_client_type_loaded
|
||||
|
||||
@@ -194,6 +257,7 @@ class ClientRegistry:
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
def clear_client_instance_cache(self) -> None:
|
||||
"""清空客户端实例缓存。"""
|
||||
self.client_instance_cache.clear()
|
||||
logger.info("检测到配置重载,已清空LLM客户端实例缓存")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
140
src/llm_models/openai_compat.py
Normal file
140
src/llm_models/openai_compat.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping
|
||||
|
||||
from src.config.model_configs import APIProvider, OpenAICompatibleAuthType
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OpenAICompatibleClientConfig:
|
||||
"""OpenAI 兼容客户端的基础配置。"""
|
||||
|
||||
api_key: str
|
||||
base_url: str
|
||||
default_headers: dict[str, str] = field(default_factory=dict)
|
||||
default_query: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OpenAICompatibleRequestOverrides:
|
||||
"""单次请求级别的附加配置。"""
|
||||
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
extra_query: dict[str, object] = field(default_factory=dict)
|
||||
extra_body: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def normalize_openai_base_url(base_url: str) -> str:
|
||||
"""规范化 OpenAI 兼容接口的基础地址。
|
||||
|
||||
Args:
|
||||
base_url: 原始基础地址。
|
||||
|
||||
Returns:
|
||||
str: 去掉尾部斜杠后的地址。
|
||||
"""
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
def _build_auth_header_value(prefix: str, api_key: str) -> str:
|
||||
"""构造鉴权请求头的值。
|
||||
|
||||
Args:
|
||||
prefix: 请求头前缀。
|
||||
api_key: 实际密钥。
|
||||
|
||||
Returns:
|
||||
str: 拼接完成的请求头值。
|
||||
"""
|
||||
normalized_prefix = prefix.strip()
|
||||
if not normalized_prefix:
|
||||
return api_key
|
||||
return f"{normalized_prefix} {api_key}"
|
||||
|
||||
|
||||
def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig:
|
||||
"""构建 OpenAI 兼容客户端配置。
|
||||
|
||||
Args:
|
||||
api_provider: API 提供商配置。
|
||||
|
||||
Returns:
|
||||
OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。
|
||||
"""
|
||||
default_headers = dict(api_provider.default_headers)
|
||||
default_query: dict[str, object] = dict(api_provider.default_query)
|
||||
client_api_key = api_provider.api_key
|
||||
|
||||
if api_provider.auth_type == OpenAICompatibleAuthType.BEARER:
|
||||
if (
|
||||
api_provider.auth_header_name != "Authorization"
|
||||
or api_provider.auth_header_prefix.strip() != "Bearer"
|
||||
):
|
||||
client_api_key = ""
|
||||
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
|
||||
prefix=api_provider.auth_header_prefix,
|
||||
api_key=api_provider.api_key,
|
||||
)
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER:
|
||||
client_api_key = ""
|
||||
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
|
||||
prefix=api_provider.auth_header_prefix,
|
||||
api_key=api_provider.api_key,
|
||||
)
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY:
|
||||
client_api_key = ""
|
||||
default_query[api_provider.auth_query_name] = api_provider.api_key
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.NONE:
|
||||
client_api_key = ""
|
||||
|
||||
return OpenAICompatibleClientConfig(
|
||||
api_key=client_api_key,
|
||||
base_url=normalize_openai_base_url(api_provider.base_url),
|
||||
default_headers=default_headers,
|
||||
default_query=default_query,
|
||||
)
|
||||
|
||||
|
||||
def _extract_mapping(value: Any) -> dict[str, Any]:
|
||||
"""将任意映射值规范化为普通字典。
|
||||
|
||||
Args:
|
||||
value: 原始输入值。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 规范化后的字典。非映射值时返回空字典。
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
return {str(key): item for key, item in value.items()}
|
||||
return {}
|
||||
|
||||
|
||||
def split_openai_request_overrides(
|
||||
extra_params: Mapping[str, Any] | None,
|
||||
*,
|
||||
reserved_body_keys: set[str] | None = None,
|
||||
) -> OpenAICompatibleRequestOverrides:
|
||||
"""拆分单次请求中的头、查询参数和请求体扩展字段。
|
||||
|
||||
Args:
|
||||
extra_params: 模型级别或请求级别的附加参数。
|
||||
reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。
|
||||
|
||||
Returns:
|
||||
OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。
|
||||
"""
|
||||
raw_params = dict(extra_params or {})
|
||||
extra_headers = _extract_mapping(raw_params.pop("headers", None))
|
||||
extra_query = _extract_mapping(raw_params.pop("query", None))
|
||||
extra_body = _extract_mapping(raw_params.pop("body", None))
|
||||
blocked_body_keys = reserved_body_keys or set()
|
||||
|
||||
for key, value in raw_params.items():
|
||||
if key in blocked_body_keys:
|
||||
continue
|
||||
extra_body[key] = value
|
||||
|
||||
return OpenAICompatibleRequestOverrides(
|
||||
extra_headers={key: str(value) for key, value in extra_headers.items()},
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
@@ -1,133 +1,280 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List, Tuple
|
||||
|
||||
from .tool_option import ToolCall
|
||||
|
||||
|
||||
# 设计这系列类的目的是为未来可能的扩展做准备
|
||||
class RoleType(str, Enum):
|
||||
"""消息角色类型。"""
|
||||
|
||||
|
||||
class RoleType(Enum):
|
||||
System = "system"
|
||||
User = "user"
|
||||
Assistant = "assistant"
|
||||
Tool = "tool"
|
||||
|
||||
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
|
||||
"""默认支持的图片格式列表。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TextMessagePart:
|
||||
"""文本消息片段。"""
|
||||
|
||||
text: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行文本片段的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当文本为空时抛出。
|
||||
"""
|
||||
if self.text == "":
|
||||
raise ValueError("文本消息片段不能为空字符串")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageMessagePart:
|
||||
"""Base64 图片消息片段。"""
|
||||
|
||||
image_format: str
|
||||
image_base64: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行图片片段的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当图片格式或 Base64 数据无效时抛出。
|
||||
"""
|
||||
if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
if not self.image_base64:
|
||||
raise ValueError("图片的 base64 编码不能为空")
|
||||
|
||||
@property
|
||||
def normalized_image_format(self) -> str:
|
||||
"""获取规范化后的图片格式。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。
|
||||
"""
|
||||
image_format = self.image_format.lower()
|
||||
if image_format in {"jpg", "jpeg"}:
|
||||
return "jpeg"
|
||||
return image_format
|
||||
|
||||
|
||||
MessagePart = TextMessagePart | ImageMessagePart
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
role: RoleType,
|
||||
content: str | list[tuple[str, str] | str],
|
||||
tool_call_id: str | None = None,
|
||||
tool_calls: Optional[List[ToolCall]] = None,
|
||||
):
|
||||
"""统一消息模型。"""
|
||||
|
||||
role: RoleType
|
||||
parts: List[MessagePart] = field(default_factory=list)
|
||||
tool_call_id: str | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行消息对象的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当消息内容或工具调用信息不完整时抛出。
|
||||
"""
|
||||
初始化消息对象
|
||||
(不应直接修改Message类,而应使用MessageBuilder类来构建对象)
|
||||
if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls):
|
||||
raise ValueError("消息内容不能为空")
|
||||
if self.role == RoleType.Tool and not self.tool_call_id:
|
||||
raise ValueError("Tool 角色的工具调用 ID 不能为空")
|
||||
|
||||
@property
|
||||
def content(self) -> str | List[Tuple[str, str] | str]:
|
||||
"""获取兼容旧逻辑的内容视图。
|
||||
|
||||
Returns:
|
||||
str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串,
|
||||
否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。
|
||||
"""
|
||||
self.role: RoleType = role
|
||||
self.content: str | list[tuple[str, str] | str] = content
|
||||
self.tool_call_id: str | None = tool_call_id
|
||||
self.tool_calls: Optional[List[ToolCall]] = tool_calls
|
||||
if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart):
|
||||
return self.parts[0].text
|
||||
content_items: List[Tuple[str, str] | str] = []
|
||||
for part in self.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
content_items.append(part.text)
|
||||
else:
|
||||
content_items.append((part.image_format, part.image_base64))
|
||||
return content_items
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
"""提取消息中的所有文本片段。
|
||||
|
||||
Returns:
|
||||
str: 以原始顺序拼接后的文本内容。
|
||||
"""
|
||||
return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""生成便于调试的字符串表示。
|
||||
|
||||
Returns:
|
||||
str: 当前消息对象的可读摘要。
|
||||
"""
|
||||
return (
|
||||
f"Role: {self.role}, Content: {self.content}, "
|
||||
f"Role: {self.role}, Parts: {self.parts}, "
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
def __init__(self):
|
||||
"""消息构建器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化构建器。"""
|
||||
self.__role: RoleType = RoleType.User
|
||||
self.__content: list[tuple[str, str] | str] = []
|
||||
self.__parts: List[MessagePart] = []
|
||||
self.__tool_call_id: str | None = None
|
||||
self.__tool_calls: Optional[List[ToolCall]] = None
|
||||
self.__tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||
"""
|
||||
设置角色(默认为User)
|
||||
:param role: 角色
|
||||
:return: MessageBuilder对象
|
||||
"""设置消息角色。
|
||||
|
||||
Args:
|
||||
role: 目标角色,默认为 `user`。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
self.__role = role
|
||||
return self
|
||||
|
||||
def add_text_part(self, text: str) -> "MessageBuilder":
|
||||
"""追加文本片段。
|
||||
|
||||
Args:
|
||||
text: 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
self.__parts.append(TextMessagePart(text=text))
|
||||
return self
|
||||
|
||||
def add_text_content(self, text: str) -> "MessageBuilder":
|
||||
"""追加文本片段。
|
||||
|
||||
Args:
|
||||
text: 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
添加文本内容
|
||||
:param text: 文本内容
|
||||
:return: MessageBuilder对象
|
||||
return self.add_text_part(text)
|
||||
|
||||
def add_image_base64_part(
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
|
||||
) -> "MessageBuilder":
|
||||
"""追加 Base64 图片片段。
|
||||
|
||||
Args:
|
||||
image_format: 图片格式。
|
||||
image_base64: 图片的 Base64 编码。
|
||||
support_formats: 允许的图片格式列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当图片格式不被支持时抛出。
|
||||
"""
|
||||
self.__content.append(text)
|
||||
if image_format.lower() not in support_formats:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64))
|
||||
return self
|
||||
|
||||
def add_image_content(
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
|
||||
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
|
||||
) -> "MessageBuilder":
|
||||
"""
|
||||
添加图片内容
|
||||
:param image_format: 图片格式
|
||||
:param image_base64: 图片的base64编码
|
||||
:return: MessageBuilder对象
|
||||
"""
|
||||
if image_format.lower() not in support_formats:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
if not image_base64:
|
||||
raise ValueError("图片的base64编码不能为空")
|
||||
self.__content.append((image_format, image_base64))
|
||||
return self
|
||||
"""追加 Base64 图片片段。
|
||||
|
||||
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
|
||||
Args:
|
||||
image_format: 图片格式。
|
||||
image_base64: 图片的 Base64 编码。
|
||||
support_formats: 允许的图片格式列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
添加工具调用指令(调用时请确保已设置为Tool角色)
|
||||
:param tool_call_id: 工具调用指令的id
|
||||
:return: MessageBuilder对象
|
||||
return self.add_image_base64_part(
|
||||
image_format=image_format,
|
||||
image_base64=image_base64,
|
||||
support_formats=support_formats,
|
||||
)
|
||||
|
||||
def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder":
|
||||
"""设置工具结果消息引用的工具调用 ID。
|
||||
|
||||
Args:
|
||||
tool_call_id: 工具调用 ID。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。
|
||||
"""
|
||||
if self.__role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID")
|
||||
if not tool_call_id:
|
||||
raise ValueError("工具调用ID不能为空")
|
||||
raise ValueError("工具调用 ID 不能为空")
|
||||
self.__tool_call_id = tool_call_id
|
||||
return self
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
|
||||
"""设置工具结果消息引用的工具调用 ID。
|
||||
|
||||
Args:
|
||||
tool_call_id: 工具调用 ID。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
设置助手消息的工具调用列表
|
||||
:param tool_calls: 工具调用列表
|
||||
:return: MessageBuilder对象
|
||||
return self.set_tool_call_id(tool_call_id)
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
"""设置助手消息中的工具调用列表。
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。
|
||||
"""
|
||||
if self.__role != RoleType.Assistant:
|
||||
raise ValueError("仅当角色为Assistant时才能设置工具调用列表")
|
||||
raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表")
|
||||
if not tool_calls:
|
||||
raise ValueError("工具调用列表不能为空")
|
||||
self.__tool_calls = tool_calls
|
||||
self.__tool_calls = list(tool_calls)
|
||||
return self
|
||||
|
||||
def build(self) -> Message:
|
||||
"""
|
||||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
"""构建消息对象。
|
||||
|
||||
Returns:
|
||||
Message: 构建完成的消息对象。
|
||||
"""
|
||||
return Message(
|
||||
role=self.__role,
|
||||
content=(
|
||||
self.__content[0]
|
||||
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
|
||||
else self.__content
|
||||
),
|
||||
parts=list(self.__parts),
|
||||
tool_call_id=self.__tool_call_id,
|
||||
tool_calls=self.__tool_calls,
|
||||
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -1,51 +1,40 @@
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Dict, List, Mapping, Optional, Type, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict, Required
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RespFormatType(Enum):
|
||||
TEXT = "text" # 文本
|
||||
JSON_OBJ = "json_object" # JSON
|
||||
JSON_SCHEMA = "json_schema" # JSON Schema
|
||||
"""响应格式类型。"""
|
||||
|
||||
TEXT = "text"
|
||||
JSON_OBJ = "json_object"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
|
||||
|
||||
class JsonSchema(TypedDict, total=False):
|
||||
"""内部使用的 JSON Schema 包装结构。"""
|
||||
|
||||
name: Required[str]
|
||||
"""
|
||||
The name of the response format.
|
||||
|
||||
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
|
||||
of 64.
|
||||
"""
|
||||
|
||||
description: Optional[str]
|
||||
"""
|
||||
A description of what the response format is for, used by the model to determine
|
||||
how to respond in the format.
|
||||
"""
|
||||
|
||||
schema: dict[str, object]
|
||||
"""
|
||||
The schema for the response format, described as a JSON Schema object. Learn how
|
||||
to build JSON schemas [here](https://json-schema.org/).
|
||||
"""
|
||||
|
||||
schema: Dict[str, Any]
|
||||
strict: Optional[bool]
|
||||
"""
|
||||
Whether to enable strict schema adherence when generating the output. If set to
|
||||
true, the model will always follow the exact schema defined in the `schema`
|
||||
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
|
||||
learn more, read the
|
||||
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
"""
|
||||
|
||||
|
||||
def _json_schema_type_check(instance) -> str | None:
|
||||
def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None:
|
||||
"""检查 JSON Schema 包装结构是否合法。
|
||||
|
||||
Args:
|
||||
instance: 待检查的 JSON Schema 包装字典。
|
||||
|
||||
Returns:
|
||||
str | None: 不合法时返回错误信息,合法时返回 `None`。
|
||||
"""
|
||||
if "name" not in instance:
|
||||
return "schema必须包含'name'字段"
|
||||
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
if not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
return "schema的'name'字段必须是非空字符串"
|
||||
if "description" in instance and (
|
||||
not isinstance(instance["description"], str) or instance["description"].strip() == ""
|
||||
@@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None:
|
||||
return "schema的'description'字段只能填入非空字符串"
|
||||
if "schema" not in instance:
|
||||
return "schema必须包含'schema'字段"
|
||||
elif not isinstance(instance["schema"], dict):
|
||||
if not isinstance(instance["schema"], dict):
|
||||
return "schema的'schema'字段必须是字典,详见https://json-schema.org/"
|
||||
if "strict" in instance and not isinstance(instance["strict"], bool):
|
||||
return "schema的'strict'字段只能填入布尔值"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
|
||||
"""
|
||||
递归移除JSON Schema中的title字段
|
||||
def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
|
||||
"""递归移除 JSON Schema 中的 `title` 字段。
|
||||
|
||||
Args:
|
||||
schema: 待处理的 Schema 树。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||
for idx, item in enumerate(schema):
|
||||
for index, item in enumerate(schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
schema[idx] = _remove_title(item)
|
||||
elif isinstance(schema, dict):
|
||||
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
schema[index] = _remove_title(item)
|
||||
return schema
|
||||
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
return schema
|
||||
|
||||
|
||||
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
链接JSON Schema中的definitions字段
|
||||
def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""展开 Schema 中的本地 `$defs`/`$ref` 引用。
|
||||
|
||||
Args:
|
||||
schema: 待处理的根 Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 展开后的 Schema。
|
||||
"""
|
||||
|
||||
def link_definitions_recursive(
|
||||
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
递归链接JSON Schema中的definitions字段
|
||||
:param path: 当前路径
|
||||
:param sub_schema: 子Schema
|
||||
:param defs: Schema定义集
|
||||
:return:
|
||||
path: str,
|
||||
sub_schema: Dict[str, Any] | List[Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
"""递归展开局部定义。
|
||||
|
||||
Args:
|
||||
path: 当前递归路径。
|
||||
sub_schema: 当前子 Schema。
|
||||
definitions: 已收集的定义字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 展开后的子 Schema。
|
||||
"""
|
||||
if isinstance(sub_schema, list):
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
# 如果当前Schema有$def字段,则将其添加到defs中
|
||||
key_prefix = f"{path}/$defs/"
|
||||
for key, value in sub_schema["$defs"].items():
|
||||
def_key = key_prefix + key
|
||||
if def_key not in defs:
|
||||
defs[def_key] = value
|
||||
del sub_schema["$defs"]
|
||||
if "$ref" in sub_schema:
|
||||
# 如果当前Schema有$ref字段,则将其替换为defs中的定义
|
||||
def_key = sub_schema["$ref"]
|
||||
if def_key in defs:
|
||||
sub_schema = defs[def_key]
|
||||
else:
|
||||
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
|
||||
# 遍历键值对
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
# 如果当前值是字典或列表,则递归调用
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
|
||||
for index, item in enumerate(sub_schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions)
|
||||
return sub_schema
|
||||
|
||||
if "$defs" in sub_schema:
|
||||
key_prefix = f"{path}/$defs/"
|
||||
defs_payload = cast(Dict[str, Any], sub_schema["$defs"])
|
||||
for key, value in defs_payload.items():
|
||||
definition_key = key_prefix + key
|
||||
if definition_key not in definitions:
|
||||
definitions[definition_key] = value
|
||||
del sub_schema["$defs"]
|
||||
|
||||
if "$ref" in sub_schema:
|
||||
definition_key = cast(str, sub_schema["$ref"])
|
||||
if definition_key in definitions:
|
||||
return definitions[definition_key]
|
||||
raise ValueError(f"Schema中引用的定义'{definition_key}'不存在")
|
||||
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions)
|
||||
return sub_schema
|
||||
|
||||
return link_definitions_recursive("#", schema, {})
|
||||
return cast(Dict[str, Any], link_definitions_recursive("#", schema, {}))
|
||||
|
||||
|
||||
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
递归移除JSON Schema中的$defs字段
|
||||
def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
|
||||
"""递归移除 JSON Schema 中的 `$defs` 字段。
|
||||
|
||||
Args:
|
||||
schema: 待处理的 Schema 树。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||
for idx, item in enumerate(schema):
|
||||
for index, item in enumerate(schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
schema[idx] = _remove_title(item)
|
||||
elif isinstance(schema, dict):
|
||||
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||
if "$defs" in schema:
|
||||
del schema["$defs"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
schema[index] = _remove_defs(item)
|
||||
return schema
|
||||
|
||||
if "$defs" in schema:
|
||||
del schema["$defs"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_defs(value)
|
||||
return schema
|
||||
|
||||
|
||||
class RespFormat:
|
||||
"""
|
||||
响应格式
|
||||
"""
|
||||
"""统一响应格式定义。"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema_from_model(schema):
|
||||
json_schema = {
|
||||
"name": schema.__name__,
|
||||
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
|
||||
def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema:
|
||||
"""从 Pydantic 模型生成内部 JSON Schema 包装结构。
|
||||
|
||||
Args:
|
||||
schema_model: Pydantic 模型类。
|
||||
|
||||
Returns:
|
||||
JsonSchema: 内部统一 JSON Schema 包装结构。
|
||||
"""
|
||||
schema_tree = deepcopy(schema_model.model_json_schema())
|
||||
json_schema: JsonSchema = {
|
||||
"name": schema_model.__name__,
|
||||
"schema": cast(
|
||||
Dict[str, Any],
|
||||
_remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))),
|
||||
),
|
||||
"strict": False,
|
||||
}
|
||||
if schema.__doc__:
|
||||
json_schema["description"] = schema.__doc__
|
||||
if schema_model.__doc__:
|
||||
json_schema["description"] = schema_model.__doc__
|
||||
return json_schema
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
format_type: RespFormatType = RespFormatType.TEXT,
|
||||
schema: type | JsonSchema | None = None,
|
||||
):
|
||||
"""
|
||||
响应格式
|
||||
:param format_type: 响应格式类型(默认为文本)
|
||||
:param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效)
|
||||
schema: Type[BaseModel] | JsonSchema | None = None,
|
||||
) -> None:
|
||||
"""初始化响应格式对象。
|
||||
|
||||
Args:
|
||||
format_type: 响应格式类型。
|
||||
schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。
|
||||
"""
|
||||
self.format_type: RespFormatType = format_type
|
||||
self.schema_source: Type[BaseModel] | JsonSchema | None = schema
|
||||
self.schema: JsonSchema | None = None
|
||||
|
||||
if format_type == RespFormatType.JSON_SCHEMA:
|
||||
if schema is None:
|
||||
raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
|
||||
if isinstance(schema, dict):
|
||||
if check_msg := _json_schema_type_check(schema):
|
||||
raise ValueError(f"schema格式不正确,{check_msg}")
|
||||
if format_type != RespFormatType.JSON_SCHEMA:
|
||||
return
|
||||
if schema is None:
|
||||
raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
|
||||
if isinstance(schema, dict):
|
||||
if check_msg := _json_schema_type_check(schema):
|
||||
raise ValueError(f"schema格式不正确,{check_msg}")
|
||||
self.schema = cast(JsonSchema, deepcopy(schema))
|
||||
return
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
try:
|
||||
self.schema = self._generate_schema_from_model(schema)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
|
||||
f"{schema.__name__}:\n"
|
||||
) from exc
|
||||
return
|
||||
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
|
||||
|
||||
self.schema = schema
|
||||
elif issubclass(schema, BaseModel):
|
||||
try:
|
||||
json_schema = self._generate_schema_from_model(schema)
|
||||
def get_schema_object(self) -> Dict[str, Any] | None:
|
||||
"""获取内部包装中的对象级 JSON Schema。
|
||||
|
||||
self.schema = json_schema
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
|
||||
f"{schema.__name__}:\n"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
|
||||
else:
|
||||
self.schema = None
|
||||
|
||||
def to_dict(self):
|
||||
Returns:
|
||||
Dict[str, Any] | None: 对象级 JSON Schema;不存在时返回 `None`。
|
||||
"""
|
||||
将响应格式转换为字典
|
||||
:return: 字典
|
||||
if self.schema is None:
|
||||
return None
|
||||
schema_payload = self.schema.get("schema")
|
||||
if isinstance(schema_payload, dict):
|
||||
return cast(Dict[str, Any], deepcopy(schema_payload))
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""将响应格式转换为字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的响应格式字典。
|
||||
"""
|
||||
if self.schema:
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
"schema": self.schema,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
}
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
}
|
||||
|
||||
@@ -1,83 +1,368 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple, TypeAlias, cast
|
||||
|
||||
|
||||
class ToolParamType(Enum):
|
||||
class ToolParamType(str, Enum):
|
||||
"""工具参数类型。"""
|
||||
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
FLOAT = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None]
|
||||
"""旧版工具参数元组格式。"""
|
||||
|
||||
|
||||
def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType:
|
||||
"""将任意输入值规范化为内部工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。
|
||||
"""
|
||||
工具调用参数类型
|
||||
if isinstance(raw_value, ToolParamType):
|
||||
return raw_value
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
if normalized_value in {"integer", "int"}:
|
||||
return ToolParamType.INTEGER
|
||||
if normalized_value in {"number", "float"}:
|
||||
return ToolParamType.NUMBER
|
||||
if normalized_value in {"boolean", "bool"}:
|
||||
return ToolParamType.BOOLEAN
|
||||
if normalized_value == "array":
|
||||
return ToolParamType.ARRAY
|
||||
if normalized_value == "object":
|
||||
return ToolParamType.OBJECT
|
||||
return ToolParamType.STRING
|
||||
|
||||
|
||||
def _is_object_schema(schema: Dict[str, Any]) -> bool:
|
||||
"""判断输入字典是否已经是对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
schema: 待判断的字典。
|
||||
|
||||
Returns:
|
||||
bool: 为对象级 JSON Schema 时返回 `True`。
|
||||
"""
|
||||
|
||||
STRING = "string" # 字符串
|
||||
INTEGER = "integer" # 整型
|
||||
FLOAT = "float" # 浮点型
|
||||
BOOLEAN = "bool" # 布尔型
|
||||
return schema.get("type") == "object" or "properties" in schema or "required" in schema
|
||||
|
||||
|
||||
def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将属性映射转换为对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
property_map: 仅包含属性定义的映射。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 对象级 JSON Schema。
|
||||
"""
|
||||
required_names: List[str] = []
|
||||
normalized_properties: Dict[str, Any] = {}
|
||||
for property_name, property_schema in property_map.items():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
|
||||
property_schema_copy = deepcopy(property_schema)
|
||||
is_required = bool(property_schema_copy.pop("required", False))
|
||||
if is_required:
|
||||
required_names.append(str(property_name))
|
||||
normalized_properties[str(property_name)] = property_schema_copy
|
||||
|
||||
parameters_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": normalized_properties,
|
||||
}
|
||||
if required_names:
|
||||
parameters_schema["required"] = required_names
|
||||
return parameters_schema
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolParam:
|
||||
"""
|
||||
工具调用参数
|
||||
"""
|
||||
"""工具参数定义。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str
|
||||
param_type: ToolParamType
|
||||
description: str
|
||||
required: bool
|
||||
enum_values: List[Any] | None = None
|
||||
items_schema: Dict[str, Any] | None = None
|
||||
properties: Dict[str, Dict[str, Any]] | None = None
|
||||
required_properties: List[str] = field(default_factory=list)
|
||||
additional_properties: bool | Dict[str, Any] | None = None
|
||||
default: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行参数定义的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当参数名称或复杂类型定义不合法时抛出。
|
||||
"""
|
||||
if not self.name:
|
||||
raise ValueError("参数名称不能为空")
|
||||
if self.param_type == ToolParamType.ARRAY and self.items_schema is None:
|
||||
raise ValueError("数组参数必须提供 items_schema")
|
||||
if self.param_type == ToolParamType.OBJECT and self.properties is None:
|
||||
self.properties = {}
|
||||
|
||||
@classmethod
|
||||
def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam":
|
||||
"""从旧版五元组参数定义构建工具参数。
|
||||
|
||||
Args:
|
||||
parameter: 旧版参数元组。
|
||||
|
||||
Returns:
|
||||
ToolParam: 规范化后的工具参数对象。
|
||||
"""
|
||||
return cls(
|
||||
name=parameter[0],
|
||||
param_type=parameter[1],
|
||||
description=parameter[2],
|
||||
required=parameter[3],
|
||||
enum_values=parameter[4],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
name: str,
|
||||
param_type: ToolParamType,
|
||||
description: str,
|
||||
required: bool,
|
||||
enum_values: list[str] | None = None,
|
||||
):
|
||||
parameter_schema: Dict[str, Any],
|
||||
*,
|
||||
required: bool = False,
|
||||
) -> "ToolParam":
|
||||
"""从属性级 JSON Schema 或结构化参数字典构建工具参数。
|
||||
|
||||
Args:
|
||||
name: 参数名称。
|
||||
parameter_schema: 参数对应的 Schema 或结构化定义。
|
||||
required: 参数是否必填。
|
||||
|
||||
Returns:
|
||||
ToolParam: 规范化后的工具参数对象。
|
||||
"""
|
||||
初始化工具调用参数
|
||||
(不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象)
|
||||
:param name: 参数名称
|
||||
:param param_type: 参数类型
|
||||
:param description: 参数描述
|
||||
:param required: 是否必填
|
||||
raw_required_properties = parameter_schema.get("required_properties")
|
||||
if raw_required_properties is None and isinstance(parameter_schema.get("required"), list):
|
||||
raw_required_properties = parameter_schema.get("required")
|
||||
return cls(
|
||||
name=name,
|
||||
param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")),
|
||||
description=str(parameter_schema.get("description", "") or ""),
|
||||
required=required,
|
||||
enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")),
|
||||
items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")),
|
||||
properties=deepcopy(parameter_schema.get("properties")),
|
||||
required_properties=list(raw_required_properties or []),
|
||||
additional_properties=deepcopy(
|
||||
parameter_schema["additional_properties"]
|
||||
if "additional_properties" in parameter_schema
|
||||
else parameter_schema.get("additionalProperties")
|
||||
),
|
||||
default=deepcopy(parameter_schema.get("default")),
|
||||
)
|
||||
|
||||
def to_json_schema(self) -> Dict[str, Any]:
|
||||
"""将参数定义转换为 JSON Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 参数对应的 JSON Schema 片段。
|
||||
"""
|
||||
self.name: str = name
|
||||
self.param_type: ToolParamType = param_type
|
||||
self.description: str = description
|
||||
self.required: bool = required
|
||||
self.enum_values: list[str] | None = enum_values
|
||||
schema: Dict[str, Any] = {
|
||||
"type": self.param_type.value,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.enum_values:
|
||||
schema["enum"] = list(self.enum_values)
|
||||
if self.default is not None:
|
||||
schema["default"] = deepcopy(self.default)
|
||||
if self.param_type == ToolParamType.ARRAY and self.items_schema is not None:
|
||||
schema["items"] = deepcopy(self.items_schema)
|
||||
if self.param_type == ToolParamType.OBJECT:
|
||||
schema["properties"] = deepcopy(self.properties or {})
|
||||
if self.required_properties:
|
||||
schema["required"] = list(self.required_properties)
|
||||
if self.additional_properties is not None:
|
||||
schema["additionalProperties"] = deepcopy(self.additional_properties)
|
||||
return schema
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolOption:
|
||||
"""
|
||||
工具调用项
|
||||
"""
|
||||
"""工具定义。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params: list[ToolParam] | None = None,
|
||||
):
|
||||
name: str
|
||||
description: str
|
||||
params: List[ToolParam] | None = None
|
||||
parameters_schema_override: Dict[str, Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行工具定义的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。
|
||||
"""
|
||||
初始化工具调用项
|
||||
(不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象)
|
||||
:param name: 工具名称
|
||||
:param description: 工具描述
|
||||
:param params: 工具参数列表
|
||||
if not self.name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
if not self.description:
|
||||
raise ValueError("工具描述不能为空")
|
||||
if self.parameters_schema_override is not None:
|
||||
schema_type = self.parameters_schema_override.get("type")
|
||||
if schema_type != "object":
|
||||
raise ValueError("工具参数 Schema 必须是 object 类型")
|
||||
|
||||
@classmethod
|
||||
def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption":
|
||||
"""从任意支持的工具定义字典构建内部工具对象。
|
||||
|
||||
支持以下输入形状:
|
||||
- `{"name", "description", "parameters_schema"}`
|
||||
- `{"name", "description", "parameters"}`
|
||||
- OpenAI function tool:`{"type": "function", "function": {...}}`
|
||||
- 仅属性映射的对象参数定义:`{"query": {"type": "string"}}`
|
||||
|
||||
Args:
|
||||
definition: 原始工具定义字典。
|
||||
|
||||
Returns:
|
||||
ToolOption: 规范化后的工具定义对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具定义缺少必要字段时抛出。
|
||||
"""
|
||||
self.name: str = name
|
||||
self.description: str = description
|
||||
self.params: list[ToolParam] | None = params
|
||||
if definition.get("type") == "function" and isinstance(definition.get("function"), dict):
|
||||
function_definition = cast(Dict[str, Any], definition["function"])
|
||||
return cls.from_definition(
|
||||
{
|
||||
"name": function_definition.get("name", ""),
|
||||
"description": function_definition.get("description", ""),
|
||||
"parameters_schema": function_definition.get("parameters"),
|
||||
}
|
||||
)
|
||||
|
||||
name = str(definition.get("name", "") or "").strip()
|
||||
description = str(definition.get("description", "") or "").strip()
|
||||
if not name:
|
||||
raise ValueError("工具定义缺少 name")
|
||||
if not description:
|
||||
description = f"工具 {name}"
|
||||
|
||||
parameters_schema = definition.get("parameters_schema")
|
||||
if isinstance(parameters_schema, dict):
|
||||
normalized_schema = deepcopy(parameters_schema)
|
||||
if not _is_object_schema(normalized_schema):
|
||||
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=None,
|
||||
parameters_schema_override=normalized_schema,
|
||||
)
|
||||
|
||||
raw_parameters = definition.get("parameters")
|
||||
if isinstance(raw_parameters, dict):
|
||||
normalized_schema = deepcopy(raw_parameters)
|
||||
if not _is_object_schema(normalized_schema):
|
||||
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=None,
|
||||
parameters_schema_override=normalized_schema,
|
||||
)
|
||||
|
||||
if isinstance(raw_parameters, list):
|
||||
params: List[ToolParam] = []
|
||||
for raw_parameter in raw_parameters:
|
||||
if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5:
|
||||
params.append(ToolParam.from_legacy_tuple(raw_parameter))
|
||||
continue
|
||||
if isinstance(raw_parameter, dict):
|
||||
parameter_name = str(raw_parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
params.append(
|
||||
ToolParam.from_dict(
|
||||
parameter_name,
|
||||
raw_parameter,
|
||||
required=bool(raw_parameter.get("required", False)),
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=params or None,
|
||||
parameters_schema_override=None,
|
||||
)
|
||||
|
||||
return cls(name=name, description=description, params=None, parameters_schema_override=None)
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> Dict[str, Any] | None:
|
||||
"""获取工具参数的对象级 JSON Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。
|
||||
"""
|
||||
if self.parameters_schema_override is not None:
|
||||
return deepcopy(self.parameters_schema_override)
|
||||
if not self.params:
|
||||
return None
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {param.name: param.to_json_schema() for param in self.params},
|
||||
"required": [param.name for param in self.params if param.required],
|
||||
}
|
||||
|
||||
def to_openai_function_schema(self) -> Dict[str, Any]:
|
||||
"""转换为 OpenAI function calling 结构。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OpenAI 兼容的工具定义。
|
||||
"""
|
||||
function_schema: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
function_schema["parameters"] = self.parameters_schema
|
||||
return {
|
||||
"type": "function",
|
||||
"function": function_schema,
|
||||
}
|
||||
|
||||
|
||||
class ToolOptionBuilder:
|
||||
"""
|
||||
工具调用项构建器
|
||||
"""
|
||||
"""工具定义构建器。"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化构建器。"""
|
||||
self.__name: str = ""
|
||||
self.__description: str = ""
|
||||
self.__params: list[ToolParam] = []
|
||||
self.__params: List[ToolParam] = []
|
||||
self.__parameters_schema_override: Dict[str, Any] | None = None
|
||||
|
||||
def set_name(self, name: str) -> "ToolOptionBuilder":
|
||||
"""
|
||||
设置工具名称
|
||||
:param name: 工具名称
|
||||
:return: ToolBuilder实例
|
||||
"""设置工具名称。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当名称为空时抛出。
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
@@ -85,35 +370,76 @@ class ToolOptionBuilder:
|
||||
return self
|
||||
|
||||
def set_description(self, description: str) -> "ToolOptionBuilder":
|
||||
"""
|
||||
设置工具描述
|
||||
:param description: 工具描述
|
||||
:return: ToolBuilder实例
|
||||
"""设置工具描述。
|
||||
|
||||
Args:
|
||||
description: 工具描述。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当描述为空时抛出。
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("工具描述不能为空")
|
||||
self.__description = description
|
||||
return self
|
||||
|
||||
def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder":
|
||||
"""直接设置完整的参数对象 Schema。
|
||||
|
||||
Args:
|
||||
schema: 完整的对象级 JSON Schema。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 schema 不是 object 类型时抛出。
|
||||
"""
|
||||
if schema.get("type") != "object":
|
||||
raise ValueError("工具参数 Schema 必须是 object 类型")
|
||||
self.__parameters_schema_override = deepcopy(schema)
|
||||
self.__params.clear()
|
||||
return self
|
||||
|
||||
def add_param(
|
||||
self,
|
||||
name: str,
|
||||
param_type: ToolParamType,
|
||||
description: str,
|
||||
required: bool = False,
|
||||
enum_values: list[str] | None = None,
|
||||
enum_values: List[Any] | None = None,
|
||||
*,
|
||||
items_schema: Dict[str, Any] | None = None,
|
||||
properties: Dict[str, Dict[str, Any]] | None = None,
|
||||
required_properties: List[str] | None = None,
|
||||
additional_properties: bool | Dict[str, Any] | None = None,
|
||||
default: Any = None,
|
||||
) -> "ToolOptionBuilder":
|
||||
"""
|
||||
添加工具参数
|
||||
:param name: 参数名称
|
||||
:param param_type: 参数类型
|
||||
:param description: 参数描述
|
||||
:param required: 是否必填(默认为False)
|
||||
:return: ToolBuilder实例
|
||||
"""
|
||||
if not name or not description:
|
||||
raise ValueError("参数名称/描述不能为空")
|
||||
"""添加一个参数定义。
|
||||
|
||||
Args:
|
||||
name: 参数名称。
|
||||
param_type: 参数类型。
|
||||
description: 参数描述。
|
||||
required: 参数是否必填。
|
||||
enum_values: 可选的枚举值列表。
|
||||
items_schema: 数组参数的元素 Schema。
|
||||
properties: 对象参数的属性定义。
|
||||
required_properties: 对象参数内部的必填字段。
|
||||
additional_properties: 对象参数是否允许额外字段。
|
||||
default: 参数默认值。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当构建器已经设置完整 Schema 时抛出。
|
||||
"""
|
||||
if self.__parameters_schema_override is not None:
|
||||
raise ValueError("已设置完整参数 Schema,不能再逐项添加参数")
|
||||
self.__params.append(
|
||||
ToolParam(
|
||||
name=name,
|
||||
@@ -121,43 +447,83 @@ class ToolOptionBuilder:
|
||||
description=description,
|
||||
required=required,
|
||||
enum_values=enum_values,
|
||||
items_schema=deepcopy(items_schema),
|
||||
properties=deepcopy(properties),
|
||||
required_properties=list(required_properties or []),
|
||||
additional_properties=deepcopy(additional_properties),
|
||||
default=deepcopy(default),
|
||||
)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
"""
|
||||
构建工具调用项
|
||||
:return: 工具调用项
|
||||
"""
|
||||
if self.__name == "" or self.__description == "":
|
||||
raise ValueError("工具名称/描述不能为空")
|
||||
def build(self) -> ToolOption:
|
||||
"""构建工具定义。
|
||||
|
||||
Returns:
|
||||
ToolOption: 构建完成的工具定义。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具名称或描述缺失时抛出。
|
||||
"""
|
||||
if not self.__name or not self.__description:
|
||||
raise ValueError("工具名称和描述不能为空")
|
||||
return ToolOption(
|
||||
name=self.__name,
|
||||
description=self.__description,
|
||||
params=None if len(self.__params) == 0 else self.__params,
|
||||
params=None if not self.__params else list(self.__params),
|
||||
parameters_schema_override=deepcopy(self.__parameters_schema_override),
|
||||
)
|
||||
|
||||
|
||||
class ToolCall:
|
||||
"""
|
||||
来自模型反馈的工具调用
|
||||
"""
|
||||
ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any]
|
||||
"""统一的工具定义输入类型。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_id: str,
|
||||
func_name: str,
|
||||
args: dict | None = None,
|
||||
):
|
||||
|
||||
def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption:
|
||||
"""将任意支持的工具输入规范化为内部 `ToolOption`。
|
||||
|
||||
Args:
|
||||
tool_definition: 原始工具定义输入。
|
||||
|
||||
Returns:
|
||||
ToolOption: 规范化后的工具定义对象。
|
||||
"""
|
||||
if isinstance(tool_definition, ToolOption):
|
||||
return tool_definition
|
||||
return ToolOption.from_definition(tool_definition)
|
||||
|
||||
|
||||
def normalize_tool_options(
|
||||
tool_definitions: List[ToolDefinitionInput] | None,
|
||||
) -> List[ToolOption] | None:
|
||||
"""批量规范化工具定义列表。
|
||||
|
||||
Args:
|
||||
tool_definitions: 原始工具定义列表。
|
||||
|
||||
Returns:
|
||||
List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。
|
||||
"""
|
||||
if not tool_definitions:
|
||||
return None
|
||||
return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCall:
|
||||
"""来自模型输出的工具调用。"""
|
||||
|
||||
call_id: str
|
||||
func_name: str
|
||||
args: Dict[str, Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行工具调用的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具调用标识或函数名缺失时抛出。
|
||||
"""
|
||||
初始化工具调用
|
||||
:param call_id: 工具调用ID
|
||||
:param func_name: 要调用的函数名称
|
||||
:param args: 工具调用参数
|
||||
"""
|
||||
self.call_id: str = call_id
|
||||
self.func_name: str = func_name
|
||||
self.args: dict | None = args
|
||||
if not self.call_id:
|
||||
raise ValueError("工具调用 ID 不能为空")
|
||||
if not self.func_name:
|
||||
raise ValueError("工具函数名称不能为空")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -65,8 +65,8 @@ class BufferCLI:
|
||||
self._mcp_manager: Optional[MCPManager] = None
|
||||
self._init_llm()
|
||||
|
||||
def _init_llm(self):
|
||||
"""Initialize the LLM service from the main project config."""
|
||||
def _init_llm(self) -> None:
|
||||
"""从主项目配置初始化 LLM 服务。"""
|
||||
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
||||
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
|
||||
|
||||
@@ -77,7 +77,7 @@ class BufferCLI:
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
|
||||
model_name = self.llm_service._model_name
|
||||
model_name = self.llm_service.get_current_model_name()
|
||||
console.print(f"[success][OK] LLM service initialized[/success] [muted](model: {model_name})[/muted]")
|
||||
|
||||
def _build_tool_context(self) -> ToolHandlerContext:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
MaiSaka LLM 服务 - 使用主项目 LLM 系统
|
||||
将主项目的 LLMRequest 适配为 MaiSaka 需要的接口
|
||||
"""MaiSaka LLM 服务。
|
||||
|
||||
该模块基于主项目服务层封装 MaiSaka 所需的对话与工具调用接口。
|
||||
"""
|
||||
|
||||
from base64 import b64decode
|
||||
@@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
@@ -23,9 +23,16 @@ from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, ToolParamType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content.tool_option import (
|
||||
ToolCall,
|
||||
ToolDefinitionInput,
|
||||
ToolOption,
|
||||
normalize_tool_options,
|
||||
)
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from . import config
|
||||
from .config import console
|
||||
@@ -36,17 +43,15 @@ from .message_adapter import (
|
||||
get_message_kind,
|
||||
get_message_role,
|
||||
get_message_text,
|
||||
get_tool_call_id,
|
||||
get_tool_calls,
|
||||
remove_last_perception,
|
||||
to_llm_message,
|
||||
)
|
||||
|
||||
logger = get_logger("maisaka_llm")
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class ChatResponse:
|
||||
"""LLM 对话循环单步响应"""
|
||||
"""LLM 对话循环单步响应。"""
|
||||
|
||||
content: Optional[str]
|
||||
tool_calls: List[ToolCall]
|
||||
@@ -65,38 +70,34 @@ class MaiSakaLLMService:
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 2048,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
初始化 LLM 服务
|
||||
) -> None:
|
||||
"""初始化 MaiSaka LLM 服务。
|
||||
|
||||
参数仅为兼容性保留,实际使用主项目配置
|
||||
Args:
|
||||
api_key: 兼容旧接口保留的参数,当前不使用。
|
||||
base_url: 兼容旧接口保留的参数,当前不使用。
|
||||
model: 兼容旧接口保留的参数,当前不使用。
|
||||
chat_system_prompt: 可选的系统提示词覆盖值。
|
||||
temperature: 默认温度参数。
|
||||
max_tokens: 默认最大输出 token 数。
|
||||
enable_thinking: 是否启用思考模式。
|
||||
"""
|
||||
del api_key, base_url, model
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._enable_thinking = enable_thinking
|
||||
self._extra_tools: List[dict] = []
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._prompts_loaded = False
|
||||
self._prompt_load_lock = asyncio.Lock()
|
||||
|
||||
# 获取主项目模型配置
|
||||
try:
|
||||
model_config = config_manager.get_model_config()
|
||||
self._model_configs = model_config.model_task_config
|
||||
except Exception:
|
||||
# 如果配置加载失败,使用默认配置
|
||||
from src.config.model_configs import ModelTaskConfig
|
||||
|
||||
self._model_configs = ModelTaskConfig()
|
||||
logger.warning("无法加载主项目模型配置,使用默认配置")
|
||||
|
||||
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer)
|
||||
self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use")
|
||||
# 主对话也使用 tool_use 模型(因为需要工具调用支持)
|
||||
self._llm_planner = LLMRequest(model_set=self._model_configs.planner, request_type="maisaka_planner")
|
||||
# 初始化服务层 LLM 门面(按任务名实时解析配置,确保热重载生效)
|
||||
self._llm_tool_use = LLMServiceClient(task_name="tool_use", request_type="maisaka_tool_use")
|
||||
# 主对话也使用 planner 模型
|
||||
self._llm_planner = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
|
||||
self._llm_chat = self._llm_planner
|
||||
self._llm_utils = self._llm_tool_use
|
||||
# 回复生成使用 replyer 模型
|
||||
self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer")
|
||||
self._llm_replyer = LLMServiceClient(task_name="replyer", request_type="maisaka_replyer")
|
||||
|
||||
# 尝试修复数据库 schema(忽略错误)
|
||||
self._try_fix_database_schema()
|
||||
@@ -111,15 +112,30 @@ class MaiSakaLLMService:
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
|
||||
self._model_name = (
|
||||
self._model_configs.planner.model_list[0] if self._model_configs.planner.model_list else "未配置"
|
||||
)
|
||||
# 子模块提示词同样采用懒加载
|
||||
self._emotion_prompt: Optional[str] = None
|
||||
self._cognition_prompt: Optional[str] = None
|
||||
|
||||
def get_current_model_name(self) -> str:
|
||||
"""获取当前 Maisaka 对话主模型名称。
|
||||
|
||||
Returns:
|
||||
str: 当前 planner 任务的首选模型名;未配置时返回 ``未配置``。
|
||||
"""
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
if model_task_config.planner.model_list:
|
||||
return model_task_config.planner.model_list[0]
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取当前 Maisaka 模型名称失败: {exc}")
|
||||
return "未配置"
|
||||
|
||||
def _try_fix_database_schema(self) -> None:
|
||||
"""尝试修复数据库 schema,添加缺失的列"""
|
||||
"""尝试修复数据库 schema。
|
||||
|
||||
Returns:
|
||||
None: 该方法仅执行数据库修复副作用。
|
||||
"""
|
||||
try:
|
||||
from src.common.database.database_client import get_db_session
|
||||
from sqlalchemy import text
|
||||
@@ -139,7 +155,11 @@ class MaiSakaLLMService:
|
||||
pass
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建人设信息,参考 replyer 的做法"""
|
||||
"""构建当前人设提示词。
|
||||
|
||||
Returns:
|
||||
str: 最终用于系统提示词的人设描述。
|
||||
"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
@@ -169,60 +189,21 @@ class MaiSakaLLMService:
|
||||
# 返回默认人设
|
||||
return "你的名字是麦麦,你是一个活泼可爱的AI助手。"
|
||||
|
||||
def set_extra_tools(self, tools: List[dict]) -> None:
|
||||
"""设置额外的工具定义(如 MCP 工具)"""
|
||||
self._extra_tools = [self._normalize_extra_tool(tool) for tool in tools]
|
||||
logger.info(f"Normalized {len(self._extra_tools)} extra tool(s) for Maisaka")
|
||||
def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None:
|
||||
"""设置额外工具定义。
|
||||
|
||||
@staticmethod
|
||||
def _json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||
normalized = (json_type or "").lower()
|
||||
if normalized == "integer":
|
||||
return ToolParamType.INTEGER
|
||||
if normalized == "number":
|
||||
return ToolParamType.FLOAT
|
||||
if normalized == "boolean":
|
||||
return ToolParamType.BOOLEAN
|
||||
return ToolParamType.STRING
|
||||
|
||||
@classmethod
|
||||
def _normalize_extra_tool(cls, tool: dict) -> dict:
|
||||
"""Normalize external/OpenAI-style tool definitions into the internal tool schema."""
|
||||
if "name" in tool and "description" in tool:
|
||||
return tool
|
||||
|
||||
if tool.get("type") != "function":
|
||||
return tool
|
||||
|
||||
function_info = tool.get("function", {})
|
||||
parameters_schema = function_info.get("parameters", {}) or {}
|
||||
required_names = set(parameters_schema.get("required", []) or [])
|
||||
properties = parameters_schema.get("properties", {}) or {}
|
||||
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
|
||||
for param_name, param_schema in properties.items():
|
||||
if not isinstance(param_schema, dict):
|
||||
continue
|
||||
enum_values = param_schema.get("enum")
|
||||
normalized_enum = [str(value) for value in enum_values] if isinstance(enum_values, list) else None
|
||||
parameters.append(
|
||||
(
|
||||
str(param_name),
|
||||
cls._json_type_to_tool_param_type(str(param_schema.get("type", "string"))),
|
||||
str(param_schema.get("description", "")),
|
||||
param_name in required_names,
|
||||
normalized_enum,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"name": str(function_info.get("name", "")),
|
||||
"description": str(function_info.get("description", "")),
|
||||
"parameters": parameters,
|
||||
}
|
||||
Args:
|
||||
tools: 外部传入的工具定义列表,例如 MCP 暴露的 OpenAI-compatible 工具。
|
||||
"""
|
||||
self._extra_tools = normalize_tool_options(tools) or []
|
||||
logger.info(f"已为 Maisaka 加载 {len(self._extra_tools)} 个额外工具")
|
||||
|
||||
async def _ensure_prompts_loaded(self) -> None:
|
||||
"""异步懒加载提示词,避免在运行中的事件循环里同步渲染 prompt。"""
|
||||
"""异步懒加载提示词。
|
||||
|
||||
Returns:
|
||||
None: 该方法仅刷新内部提示词缓存。
|
||||
"""
|
||||
if self._prompts_loaded:
|
||||
return
|
||||
|
||||
@@ -260,7 +241,14 @@ class MaiSakaLLMService:
|
||||
|
||||
@staticmethod
|
||||
def _get_role_badge_style(role: str) -> str:
|
||||
"""为不同 role 返回不同的标签样式。"""
|
||||
"""为不同角色返回终端标签样式。
|
||||
|
||||
Args:
|
||||
role: 消息角色名称。
|
||||
|
||||
Returns:
|
||||
str: Rich 可识别的样式字符串。
|
||||
"""
|
||||
if role == "system":
|
||||
return "bold white on blue"
|
||||
if role == "user":
|
||||
@@ -273,7 +261,14 @@ class MaiSakaLLMService:
|
||||
|
||||
@staticmethod
|
||||
def _build_terminal_image_preview(image_base64: str) -> Optional[str]:
|
||||
"""Build a low-resolution ASCII preview for terminals without inline-image support."""
|
||||
"""构建终端 ASCII 图片预览。
|
||||
|
||||
Args:
|
||||
image_base64: 图片的 Base64 数据。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 可渲染的 ASCII 预览文本;失败时返回 `None`。
|
||||
"""
|
||||
ascii_chars = " .:-=+*#%@"
|
||||
|
||||
try:
|
||||
@@ -291,7 +286,7 @@ class MaiSakaLLMService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
rows: list[str] = []
|
||||
rows: List[str] = []
|
||||
for row_index in range(preview_height):
|
||||
row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width]
|
||||
row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels)
|
||||
@@ -301,12 +296,19 @@ class MaiSakaLLMService:
|
||||
|
||||
@staticmethod
|
||||
def _render_message_content(content: Any) -> object:
|
||||
"""把消息内容转成适合 Rich 输出的 renderable。"""
|
||||
"""将消息内容转换为 Rich 可渲染对象。
|
||||
|
||||
Args:
|
||||
content: 原始消息内容。
|
||||
|
||||
Returns:
|
||||
object: Rich 可渲染对象。
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return Text(content)
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: list[object] = []
|
||||
parts: List[object] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(Text(item))
|
||||
@@ -316,7 +318,7 @@ class MaiSakaLLMService:
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
approx_size = max(0, len(image_base64) * 3 // 4)
|
||||
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
|
||||
preview_parts: list[object] = [
|
||||
preview_parts: List[object] = [
|
||||
Text(f"image/{image_format} {size_text}\nbase64 omitted", style="magenta")
|
||||
]
|
||||
if config.TERMINAL_IMAGE_PREVIEW:
|
||||
@@ -343,8 +345,15 @@ class MaiSakaLLMService:
|
||||
return Pretty(content, expand_all=True)
|
||||
|
||||
@staticmethod
|
||||
def _format_tool_call_for_display(tool_call: Any) -> dict[str, Any]:
|
||||
"""将 tool call 转成适合 CLI 展示的结构。"""
|
||||
def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]:
|
||||
"""将工具调用转换为 CLI 展示结构。
|
||||
|
||||
Args:
|
||||
tool_call: 原始工具调用对象或字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一后的展示字典。
|
||||
"""
|
||||
if isinstance(tool_call, dict):
|
||||
function_info = tool_call.get("function", {})
|
||||
return {
|
||||
@@ -360,7 +369,16 @@ class MaiSakaLLMService:
|
||||
}
|
||||
|
||||
def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel:
|
||||
"""Render assistant tool calls as standalone cards."""
|
||||
"""渲染单个工具调用面板。
|
||||
|
||||
Args:
|
||||
tool_call: 原始工具调用对象或字典。
|
||||
index: 当前工具调用在父消息中的序号。
|
||||
parent_index: 父消息在消息列表中的序号。
|
||||
|
||||
Returns:
|
||||
Panel: 可直接打印的工具调用面板。
|
||||
"""
|
||||
title = Text.assemble(
|
||||
Text(" TOOL CALL ", style="bold white on magenta"),
|
||||
Text(f" #{parent_index}.{index}", style="muted"),
|
||||
@@ -373,16 +391,22 @@ class MaiSakaLLMService:
|
||||
)
|
||||
|
||||
def _render_message_panel(self, message: Any, index: int) -> Panel:
|
||||
"""渲染主循环 prompt 中的一条消息。"""
|
||||
"""渲染主循环 Prompt 中的一条消息。
|
||||
|
||||
Args:
|
||||
message: 原始消息对象或字典。
|
||||
index: 当前消息序号。
|
||||
|
||||
Returns:
|
||||
Panel: 可直接打印的消息面板。
|
||||
"""
|
||||
if isinstance(message, dict):
|
||||
raw_role = message.get("role", "unknown")
|
||||
content = message.get("content")
|
||||
tool_calls = message.get("tool_calls")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
else:
|
||||
raw_role = getattr(message, "role", "unknown")
|
||||
content = getattr(message, "content", None)
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
|
||||
role = raw_role.value if hasattr(raw_role, "value") else str(raw_role)
|
||||
@@ -391,7 +415,7 @@ class MaiSakaLLMService:
|
||||
Text(f" #{index}", style="muted"),
|
||||
)
|
||||
|
||||
parts: list[object] = []
|
||||
parts: List[object] = []
|
||||
if content not in (None, "", []):
|
||||
parts.append(Text(" message ", style="bold cyan"))
|
||||
parts.append(self._render_message_content(content))
|
||||
@@ -415,30 +439,27 @@ class MaiSakaLLMService:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tool_option_to_dict(tool: "ToolOption") -> dict:
|
||||
"""将 ToolOption 对象转换为主项目期望的 dict 格式
|
||||
async def chat_loop_step(self, chat_history: List[MaiMessage]) -> ChatResponse:
|
||||
"""执行主对话循环的一步。
|
||||
|
||||
主项目的 _build_tool_options() 期望的格式:
|
||||
{
|
||||
"name": str,
|
||||
"description": str,
|
||||
"parameters": List[Tuple[name, ToolParamType, description, required, enum_values]]
|
||||
}
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
ChatResponse: 本轮对话生成结果。
|
||||
"""
|
||||
params = []
|
||||
if tool.params:
|
||||
for param in tool.params:
|
||||
params.append((param.name, param.param_type, param.description, param.required, param.enum_values))
|
||||
return {"name": tool.name, "description": tool.description, "parameters": params}
|
||||
|
||||
async def chat_loop_step(self, chat_history: list[MaiMessage]) -> ChatResponse:
|
||||
"""执行对话循环的一步 - 使用 tool_use 模型"""
|
||||
await self._ensure_prompts_loaded()
|
||||
|
||||
def message_factory(client) -> list[Message]:
|
||||
"""将 MaiSaka 的 chat_history 转换为主项目的 Message 格式"""
|
||||
messages: list[Message] = []
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
"""将 MaiSaka 对话历史转换为内部消息列表。
|
||||
|
||||
Args:
|
||||
_client: 当前底层客户端实例。
|
||||
|
||||
Returns:
|
||||
List[Message]: 规范化后的消息列表。
|
||||
"""
|
||||
messages: List[Message] = []
|
||||
|
||||
# 首先添加系统提示词
|
||||
system_msg = MessageBuilder().set_role(RoleType.System)
|
||||
@@ -454,15 +475,13 @@ class MaiSakaLLMService:
|
||||
return messages
|
||||
|
||||
# 调用 LLM(使用带消息的接口)
|
||||
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict)
|
||||
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (
|
||||
self._extra_tools if self._extra_tools else []
|
||||
)
|
||||
# 合并内置工具和额外工具,统一交给底层规范化流程处理。
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
# 打印消息列表
|
||||
built_messages = message_factory(None)
|
||||
|
||||
ordered_panels: list[Panel] = []
|
||||
ordered_panels: List[Panel] = []
|
||||
for index, msg in enumerate(built_messages, start=1):
|
||||
ordered_panels.append(self._render_message_panel(msg, index))
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
@@ -483,13 +502,18 @@ class MaiSakaLLMService:
|
||||
|
||||
|
||||
request_started_at = perf_counter()
|
||||
logger.info("chat_loop_step calling planner model generate_response_with_message_async")
|
||||
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
|
||||
logger.info("chat_loop_step calling planner model generate_response_with_messages")
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
tools=all_tools if all_tools else None,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
options=LLMGenerationOptions(
|
||||
tool_options=all_tools if all_tools else None,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
),
|
||||
)
|
||||
response = generation_result.response
|
||||
model = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
elapsed = perf_counter() - request_started_at
|
||||
logger.info(
|
||||
f"chat_loop_step planner model returned in {elapsed:.2f}s "
|
||||
@@ -509,8 +533,15 @@ class MaiSakaLLMService:
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
def _filter_for_api(self, chat_history: list[MaiMessage]) -> str:
|
||||
"""过滤对话历史为 API 格式"""
|
||||
def _filter_for_api(self, chat_history: List[MaiMessage]) -> str:
|
||||
"""将对话历史过滤为简单文本格式。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
str: 过滤后的文本上下文。
|
||||
"""
|
||||
parts = []
|
||||
for msg in chat_history:
|
||||
role = get_message_role(msg)
|
||||
@@ -535,8 +566,15 @@ class MaiSakaLLMService:
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def build_chat_context(self, user_text: str) -> list[MaiMessage]:
|
||||
"""构建对话上下文"""
|
||||
def build_chat_context(self, user_text: str) -> List[MaiMessage]:
|
||||
"""构建新的对话上下文。
|
||||
|
||||
Args:
|
||||
user_text: 用户输入文本。
|
||||
|
||||
Returns:
|
||||
List[MaiMessage]: 初始对话上下文消息列表。
|
||||
"""
|
||||
return [
|
||||
build_message(
|
||||
role=RoleType.User.value,
|
||||
@@ -547,8 +585,15 @@ class MaiSakaLLMService:
|
||||
|
||||
# ──────── 分析模块(使用 utils 模型) ────────
|
||||
|
||||
async def analyze_emotion(self, chat_history: list[MaiMessage]) -> str:
|
||||
"""情绪分析 - 使用 utils 模型"""
|
||||
async def analyze_emotion(self, chat_history: List[MaiMessage]) -> str:
|
||||
"""执行情绪分析。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
str: 情绪分析文本。
|
||||
"""
|
||||
await self._ensure_prompts_loaded()
|
||||
filtered = [m for m in chat_history if get_message_kind(m) != "perception"]
|
||||
recent = filtered[-10:] if len(filtered) > 10 else filtered
|
||||
@@ -574,19 +619,26 @@ class MaiSakaLLMService:
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
generation_result = await self._llm_utils.generate_response(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
|
||||
)
|
||||
response = generation_result.response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"情绪分析 LLM 调用出错: {e}")
|
||||
return ""
|
||||
|
||||
async def analyze_cognition(self, chat_history: list[MaiMessage]) -> str:
|
||||
"""认知分析 - 使用 utils 模型"""
|
||||
async def analyze_cognition(self, chat_history: List[MaiMessage]) -> str:
|
||||
"""执行认知分析。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
str: 认知分析文本。
|
||||
"""
|
||||
await self._ensure_prompts_loaded()
|
||||
filtered = [m for m in chat_history if get_message_kind(m) != "perception"]
|
||||
recent = filtered[-10:] if len(filtered) > 10 else filtered
|
||||
@@ -612,19 +664,27 @@ class MaiSakaLLMService:
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
generation_result = await self._llm_utils.generate_response(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
|
||||
)
|
||||
response = generation_result.response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"认知分析 LLM 调用出错: {e}")
|
||||
return ""
|
||||
|
||||
async def _removed_analyze_timing(self, chat_history: list[MaiMessage], timing_info: str) -> str:
|
||||
"""时间分析 - 使用 utils 模型"""
|
||||
async def _removed_analyze_timing(self, chat_history: List[MaiMessage], timing_info: str) -> str:
|
||||
"""执行时间节奏分析。
|
||||
|
||||
Args:
|
||||
chat_history: 当前对话历史。
|
||||
timing_info: 外部传入的时间信息摘要。
|
||||
|
||||
Returns:
|
||||
str: 时间分析文本。
|
||||
"""
|
||||
await self._ensure_prompts_loaded()
|
||||
filtered = [
|
||||
m
|
||||
@@ -653,11 +713,11 @@ class MaiSakaLLMService:
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
generation_result = await self._llm_utils.generate_response(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=512),
|
||||
)
|
||||
response = generation_result.response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
@@ -666,10 +726,15 @@ class MaiSakaLLMService:
|
||||
|
||||
# ──────── 回复生成(使用 replyer 模型) ────────
|
||||
|
||||
async def generate_reply(self, reason: str, chat_history: list[MaiMessage]) -> str:
|
||||
"""
|
||||
生成回复 - 使用 replyer 模型
|
||||
可供 Replyer 类直接调用
|
||||
async def generate_reply(self, reason: str, chat_history: List[MaiMessage]) -> str:
|
||||
"""生成最终回复文本。
|
||||
|
||||
Args:
|
||||
reason: 当前轮次的内部想法或回复理由。
|
||||
chat_history: 当前对话历史。
|
||||
|
||||
Returns:
|
||||
str: 最终回复文本。
|
||||
"""
|
||||
await self._ensure_prompts_loaded()
|
||||
from datetime import datetime
|
||||
@@ -704,17 +769,12 @@ class MaiSakaLLMService:
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_replyer.generate_response_async(
|
||||
generation_result = await self._llm_replyer.generate_response(
|
||||
prompt=messages,
|
||||
temperature=0.8,
|
||||
max_tokens=512,
|
||||
options=LLMGenerationOptions(temperature=0.8, max_tokens=512),
|
||||
)
|
||||
response = generation_result.response
|
||||
return response.strip() if response else "..."
|
||||
except Exception as e:
|
||||
logger.error(f"回复生成 LLM 调用出错: {e}")
|
||||
return "..."
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,9 @@ from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config, global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.services import message_service as message_api
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.person_info.person_info import Person
|
||||
@@ -88,8 +89,8 @@ class ChatHistorySummarizer:
|
||||
# 注意:批次加载需要异步查询消息,所以在 start() 中调用
|
||||
|
||||
# LLM请求器,用于压缩聊天内容
|
||||
self.summarizer_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
|
||||
self.summarizer_llm = LLMServiceClient(
|
||||
task_name="utils", request_type="chat_history_summarizer"
|
||||
)
|
||||
|
||||
# 后台循环相关
|
||||
@@ -656,10 +657,11 @@ class ChatHistorySummarizer:
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
generation_result = await self.summarizer_llm.generate_response(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
options=LLMGenerationOptions(temperature=0.3),
|
||||
)
|
||||
response = generation_result.response
|
||||
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
|
||||
@@ -812,7 +814,8 @@ class ChatHistorySummarizer:
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt)
|
||||
generation_result = await self.summarizer_llm.generate_response(prompt=prompt)
|
||||
response = generation_result.response
|
||||
|
||||
# 解析JSON响应
|
||||
json_str = response.strip()
|
||||
|
||||
@@ -5,7 +5,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services import llm_service as llm_api
|
||||
from sqlmodel import select, col
|
||||
@@ -269,18 +269,18 @@ async def _react_agent_solve_question(
|
||||
return messages
|
||||
|
||||
message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues]
|
||||
(
|
||||
success,
|
||||
response,
|
||||
reasoning_content,
|
||||
model_name,
|
||||
tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory_fn, # type: ignore[arg-type]
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_definitions,
|
||||
request_type="memory.react",
|
||||
generation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="tool_use",
|
||||
request_type="memory.react",
|
||||
message_factory=message_factory_fn, # type: ignore[arg-type]
|
||||
tool_options=tool_definitions,
|
||||
)
|
||||
)
|
||||
success = generation_result.success
|
||||
response = generation_result.completion.response
|
||||
reasoning_content = generation_result.completion.reasoning
|
||||
tool_calls = generation_result.completion.tool_calls
|
||||
|
||||
# logger.info(
|
||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
@@ -679,18 +679,16 @@ async def _react_agent_solve_question(
|
||||
evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
|
||||
evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template)
|
||||
|
||||
(
|
||||
eval_success,
|
||||
eval_response,
|
||||
eval_reasoning_content,
|
||||
eval_model_name,
|
||||
eval_tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools(
|
||||
evaluation_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[], # 最终评估阶段不提供工具
|
||||
request_type="memory.react.final",
|
||||
evaluation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="tool_use",
|
||||
request_type="memory.react.final",
|
||||
prompt=evaluation_prompt,
|
||||
tool_options=[],
|
||||
)
|
||||
)
|
||||
eval_success = evaluation_result.success
|
||||
eval_response = evaluation_result.completion.response
|
||||
|
||||
if not eval_success:
|
||||
logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}")
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
工具注册系统
|
||||
提供统一的工具注册和管理接口
|
||||
"""工具注册系统。
|
||||
|
||||
提供统一的工具注册和管理接口。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
@@ -14,16 +15,19 @@ class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
):
|
||||
"""
|
||||
初始化工具
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""初始化工具。
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
|
||||
execute_func: 执行函数,必须是异步函数
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数,必须是异步函数。
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
@@ -44,20 +48,17 @@ class MemoryRetrievalTool:
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""执行工具。"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM function calling
|
||||
"""获取规范化的工具定义。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 工具定义字典,格式与BaseTool一致
|
||||
格式: {"name": str, "description": str, "parameters": List[Tuple]}
|
||||
Dict[str, Any]: 统一工具定义字典。
|
||||
"""
|
||||
# 转换参数格式为元组列表,格式与BaseTool一致
|
||||
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
|
||||
param_tuples = []
|
||||
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
@@ -77,20 +78,27 @@ class MemoryRetrievalTool:
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
# 构建参数元组
|
||||
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
|
||||
param_tuples.append(param_tuple)
|
||||
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
|
||||
|
||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||
|
||||
return tool_def
|
||||
normalized_option = normalize_tool_option(
|
||||
{
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": legacy_parameters,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"name": normalized_option.name,
|
||||
"description": normalized_option.description,
|
||||
"parameters_schema": normalized_option.parameters_schema,
|
||||
}
|
||||
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化工具注册器。"""
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
@@ -137,15 +145,18 @@ _tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数
|
||||
"""注册记忆检索工具的便捷函数。
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表
|
||||
execute_func: 执行函数
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数。
|
||||
"""
|
||||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||
_tool_registry.register_tool(tool)
|
||||
|
||||
@@ -17,14 +17,14 @@ from src.common.data_models.person_info_data_model import dump_group_cardname_re
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
relation_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use, request_type="relation_selection"
|
||||
relation_selection_model = LLMServiceClient(
|
||||
task_name="tool_use", request_type="relation_selection"
|
||||
)
|
||||
|
||||
|
||||
@@ -578,7 +578,8 @@ class Person:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
generation_result = await relation_selection_model.generate_response(prompt)
|
||||
response = generation_result.response
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
@@ -600,7 +601,8 @@ class Person:
|
||||
例如:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
generation_result = await relation_selection_model.generate_response(prompt)
|
||||
response = generation_result.response
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
@@ -634,7 +636,9 @@ class Person:
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
self.qv_name_llm = LLMServiceClient(
|
||||
task_name="utils", request_type="relation.qv_name"
|
||||
)
|
||||
try:
|
||||
with get_db_session() as _:
|
||||
pass
|
||||
@@ -737,7 +741,8 @@ class PersonInfoManager:
|
||||
"nickname": "昵称",
|
||||
"reason": "理由"
|
||||
}"""
|
||||
response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
|
||||
generation_result = await self.qv_name_llm.generate_response(qv_name_prompt)
|
||||
response = generation_result.response
|
||||
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||||
result = self._extract_json_from_text(response)
|
||||
|
||||
|
||||
@@ -1,33 +1,80 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
|
||||
def _get_nested_config_value(source: Any, key: str, default: Any = None) -> Any:
|
||||
"""从嵌套对象或字典中读取配置值。
|
||||
|
||||
Args:
|
||||
source: 配置对象或字典。
|
||||
key: 以点号分隔的路径。
|
||||
default: 未命中时返回的默认值。
|
||||
|
||||
Returns:
|
||||
Any: 命中的值;读取失败时返回默认值。
|
||||
"""
|
||||
current = source
|
||||
try:
|
||||
for part in key.split("."):
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
elif hasattr(current, part):
|
||||
continue
|
||||
if hasattr(current, part):
|
||||
current = getattr(current, part)
|
||||
else:
|
||||
raise KeyError(part)
|
||||
continue
|
||||
raise KeyError(part)
|
||||
return current
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_prompt_arg(prompt: Any) -> str | List[Dict[str, Any]]:
|
||||
"""校验并规范化插件传入的提示参数。
|
||||
|
||||
Args:
|
||||
prompt: 原始提示参数。
|
||||
|
||||
Returns:
|
||||
str | List[Dict[str, Any]]: 规范化后的提示输入。
|
||||
|
||||
Raises:
|
||||
ValueError: 提示参数缺失或结构不受支持时抛出。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
if not prompt.strip():
|
||||
raise ValueError("缺少必要参数 prompt")
|
||||
return prompt
|
||||
if isinstance(prompt, list) and prompt:
|
||||
for index, prompt_message in enumerate(prompt, start=1):
|
||||
if not isinstance(prompt_message, dict):
|
||||
raise ValueError(f"prompt 第 {index} 项必须为字典")
|
||||
return prompt
|
||||
raise ValueError("缺少必要参数 prompt")
|
||||
|
||||
|
||||
class RuntimeCoreCapabilityMixin:
|
||||
"""插件运行时的核心能力混入。"""
|
||||
|
||||
async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送文本消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
text: str = args.get("text", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
text = str(args.get("text", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not text or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
|
||||
|
||||
@@ -35,20 +82,31 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.text_to_stream(
|
||||
text=text,
|
||||
stream_id=stream_id,
|
||||
typing=args.get("typing", False),
|
||||
set_reply=args.get("set_reply", False),
|
||||
storage_message=args.get("storage_message", True),
|
||||
typing=bool(args.get("typing", False)),
|
||||
set_reply=bool(args.get("set_reply", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.text] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送表情图片。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
emoji_base64: str = args.get("emoji_base64", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
emoji_base64 = str(args.get("emoji_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not emoji_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
|
||||
|
||||
@@ -56,18 +114,29 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.emoji_to_stream(
|
||||
emoji_base64=emoji_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.emoji] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送图片。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
image_base64: str = args.get("image_base64", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
image_base64 = str(args.get("image_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not image_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
|
||||
|
||||
@@ -75,18 +144,29 @@ class RuntimeCoreCapabilityMixin:
|
||||
result = await send_api.image_to_stream(
|
||||
image_base64=image_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.image] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送命令消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
command = args.get("command", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
command = str(args.get("command", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not command or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
|
||||
|
||||
@@ -95,22 +175,33 @@ class RuntimeCoreCapabilityMixin:
|
||||
message_type="command",
|
||||
content=command,
|
||||
stream_id=stream_id,
|
||||
storage_message=args.get("storage_message", True),
|
||||
display_message=args.get("display_message", ""),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.command] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""向指定流发送自定义消息。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 能力执行结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
from src.services import send_service as send_api
|
||||
|
||||
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
|
||||
message_type = str(args.get("message_type", "") or args.get("custom_type", ""))
|
||||
content = args.get("content")
|
||||
if content is None:
|
||||
content = args.get("data", "")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
if not message_type or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||
|
||||
@@ -119,114 +210,116 @@ class RuntimeCoreCapabilityMixin:
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=stream_id,
|
||||
display_message=args.get("display_message", ""),
|
||||
typing=args.get("typing", False),
|
||||
storage_message=args.get("storage_message", True),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
typing=bool(args.get("typing", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.send.custom] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""执行无工具的 LLM 生成能力。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 标准化后的 LLM 响应结构。
|
||||
"""
|
||||
del capability
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
prompt: str = args.get("prompt", "")
|
||||
if not prompt:
|
||||
return {"success": False, "error": "缺少必要参数 prompt"}
|
||||
|
||||
model_name: str = args.get("model", "") or args.get("model_name", "")
|
||||
temperature = args.get("temperature")
|
||||
max_tokens = args.get("max_tokens")
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
if model_name and model_name in models:
|
||||
model_config = models[model_name]
|
||||
else:
|
||||
if not models:
|
||||
return {"success": False, "error": "没有可用的模型配置"}
|
||||
model_config = next(iter(models.values()))
|
||||
|
||||
success, response, reasoning, used_model = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
prompt = _normalize_prompt_arg(args.get("prompt"))
|
||||
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
prompt=prompt,
|
||||
temperature=args.get("temperature"),
|
||||
max_tokens=args.get("max_tokens"),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"success": success,
|
||||
"response": response,
|
||||
"reasoning": reasoning,
|
||||
"model_name": used_model,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
return result.to_capability_payload()
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.generate] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""执行带工具的 LLM 生成能力。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 标准化后的 LLM 响应结构。
|
||||
"""
|
||||
del capability
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
prompt: str = args.get("prompt", "")
|
||||
if not prompt:
|
||||
return {"success": False, "error": "缺少必要参数 prompt"}
|
||||
|
||||
model_name: str = args.get("model", "") or args.get("model_name", "")
|
||||
tool_options = args.get("tools") or args.get("tool_options")
|
||||
temperature = args.get("temperature")
|
||||
max_tokens = args.get("max_tokens")
|
||||
if tool_options is not None and not isinstance(tool_options, list):
|
||||
return {"success": False, "error": "tools 必须为列表"}
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
if model_name and model_name in models:
|
||||
model_config = models[model_name]
|
||||
else:
|
||||
if not models:
|
||||
return {"success": False, "error": "没有可用的模型配置"}
|
||||
model_config = next(iter(models.values()))
|
||||
|
||||
success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
tool_options=tool_options,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
prompt = _normalize_prompt_arg(args.get("prompt"))
|
||||
task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", "")))
|
||||
result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name=task_name,
|
||||
request_type=f"plugin.{plugin_id}",
|
||||
prompt=prompt,
|
||||
tool_options=tool_options,
|
||||
temperature=args.get("temperature"),
|
||||
max_tokens=args.get("max_tokens"),
|
||||
)
|
||||
)
|
||||
serialized_tool_calls = None
|
||||
if tool_calls:
|
||||
serialized_tool_calls = [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, ToolCall)
|
||||
]
|
||||
return {
|
||||
"success": success,
|
||||
"response": response,
|
||||
"reasoning": reasoning,
|
||||
"model_name": used_model,
|
||||
"tool_calls": serialized_tool_calls,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
return result.to_capability_payload()
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.generate_with_tools] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取当前宿主可用的模型任务列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 可用模型列表。
|
||||
"""
|
||||
del plugin_id, capability, args
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
try:
|
||||
models = llm_api.get_available_models()
|
||||
return {"success": True, "models": list(models.keys())}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.llm.get_available_models] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
key: str = args.get("key", "")
|
||||
"""读取宿主全局配置中的单个字段。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del plugin_id, capability
|
||||
key = str(args.get("key", ""))
|
||||
default = args.get("default")
|
||||
if not key:
|
||||
return {"success": False, "value": None, "error": "缺少必要参数 key"}
|
||||
@@ -234,37 +327,57 @@ class RuntimeCoreCapabilityMixin:
|
||||
try:
|
||||
value = _get_nested_config_value(global_config, key, default)
|
||||
return {"success": True, "value": value}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": None, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": None, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取指定插件的配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del capability
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
key: str = args.get("key", "")
|
||||
plugin_name = str(args.get("plugin_name", plugin_id))
|
||||
key = str(args.get("key", ""))
|
||||
default = args.get("default")
|
||||
|
||||
try:
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
|
||||
|
||||
if key:
|
||||
value = _get_nested_config_value(config, key, default)
|
||||
return {"success": True, "value": value}
|
||||
|
||||
return {"success": True, "value": config}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": default, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": default, "error": str(exc)}
|
||||
|
||||
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取指定插件的全部配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前插件标识。
|
||||
capability: 能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 配置读取结果。
|
||||
"""
|
||||
del capability
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
plugin_name = str(args.get("plugin_name", plugin_id))
|
||||
try:
|
||||
config = component_query_service.get_plugin_config(plugin_name)
|
||||
if config is None:
|
||||
return {"success": True, "value": {}}
|
||||
return {"success": True, "value": config}
|
||||
except Exception as e:
|
||||
return {"success": False, "value": {}, "error": str(e)}
|
||||
except Exception as exc:
|
||||
return {"success": False, "value": {}, "error": str(exc)}
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import normalize_tool_option
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
|
||||
@@ -28,13 +28,6 @@ _HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
|
||||
ComponentType.COMMAND: "COMMAND",
|
||||
ComponentType.TOOL: "TOOL",
|
||||
}
|
||||
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
|
||||
|
||||
class ComponentQueryService:
|
||||
@@ -171,11 +164,9 @@ class ComponentQueryService:
|
||||
|
||||
return ActionInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
action_parameters=action_parameters,
|
||||
action_require=action_require,
|
||||
associated_types=associated_types,
|
||||
@@ -202,72 +193,48 @@ class ComponentQueryService:
|
||||
metadata = dict(entry.metadata)
|
||||
return CommandInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=str(metadata.get("description", "") or ""),
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=metadata,
|
||||
command_pattern=str(metadata.get("command_pattern", "") or ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
|
||||
"""规范化工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始工具参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的工具参数类型。
|
||||
"""
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
|
||||
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
|
||||
def _build_tool_definition(entry: "ToolEntry") -> dict[str, Any]:
|
||||
"""将运行时 Tool 条目转换为原始工具定义字典。
|
||||
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
Returns:
|
||||
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
|
||||
dict[str, Any]: 可交给 `normalize_tool_option()` 的原始工具定义。
|
||||
"""
|
||||
raw_definition: dict[str, Any] = {
|
||||
"name": entry.name,
|
||||
"description": entry.description,
|
||||
}
|
||||
if isinstance(entry.parameters_raw, dict) and entry.parameters_raw:
|
||||
raw_definition["parameters_schema"] = entry.parameters_raw
|
||||
return raw_definition
|
||||
if isinstance(entry.parameters, list) and entry.parameters:
|
||||
raw_definition["parameters"] = entry.parameters
|
||||
return raw_definition
|
||||
if isinstance(entry.parameters_raw, list) and entry.parameters_raw:
|
||||
raw_definition["parameters"] = entry.parameters_raw
|
||||
return raw_definition
|
||||
return raw_definition
|
||||
|
||||
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
|
||||
if not structured_parameters and isinstance(entry.parameters_raw, dict):
|
||||
structured_parameters = [
|
||||
{"name": key, **value}
|
||||
for key, value in entry.parameters_raw.items()
|
||||
if isinstance(value, dict)
|
||||
]
|
||||
@staticmethod
|
||||
def _build_tool_parameters_schema(entry: "ToolEntry") -> dict[str, Any] | None:
|
||||
"""将运行时 Tool 条目转换为对象级参数 Schema。
|
||||
|
||||
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
for parameter in structured_parameters:
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
Args:
|
||||
entry: 插件运行时中的 Tool 条目。
|
||||
|
||||
parameter_name = str(parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
enum_values = parameter.get("enum")
|
||||
normalized_enum_values = (
|
||||
[str(item) for item in enum_values if item is not None]
|
||||
if isinstance(enum_values, list)
|
||||
else None
|
||||
)
|
||||
normalized_parameters.append(
|
||||
(
|
||||
parameter_name,
|
||||
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
|
||||
str(parameter.get("description", "") or ""),
|
||||
bool(parameter.get("required", True)),
|
||||
normalized_enum_values,
|
||||
)
|
||||
)
|
||||
return normalized_parameters
|
||||
Returns:
|
||||
dict[str, Any] | None: 规范化后的对象级参数 Schema。
|
||||
"""
|
||||
normalized_option = normalize_tool_option(ComponentQueryService._build_tool_definition(entry))
|
||||
return normalized_option.parameters_schema
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
|
||||
@@ -282,13 +249,10 @@ class ComponentQueryService:
|
||||
|
||||
return ToolInfo(
|
||||
name=entry.name,
|
||||
component_type=ComponentType.TOOL,
|
||||
description=entry.description,
|
||||
enabled=bool(entry.enabled),
|
||||
plugin_name=entry.plugin_id,
|
||||
metadata=dict(entry.metadata),
|
||||
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
|
||||
tool_description=entry.description,
|
||||
parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -91,7 +91,7 @@ class ToolEntry(ComponentEntry):
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.description: str = metadata.get("description", "")
|
||||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||||
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
|
||||
self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {})
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
|
||||
@@ -1,191 +1,492 @@
|
||||
"""LLM 服务模块
|
||||
"""LLM 服务层。
|
||||
|
||||
提供与 LLM 模型交互的核心功能。
|
||||
该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到
|
||||
`src.llm_models` 中的底层请求调度器。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
|
||||
from src.common.data_models.llm_service_data_models import (
|
||||
LLMAudioTranscriptionResult,
|
||||
LLMEmbeddingResult,
|
||||
LLMGenerationOptions,
|
||||
LLMImageOptions,
|
||||
LLMResponseResult,
|
||||
LLMServiceRequest,
|
||||
LLMServiceResult,
|
||||
MessageFactory,
|
||||
PromptInput,
|
||||
PromptMessage,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
|
||||
logger = get_logger("llm_service")
|
||||
|
||||
class LLMServiceClient:
|
||||
"""面向上层模块的 LLM 服务对象式门面。
|
||||
|
||||
async def _generate_response(
|
||||
model_config: TaskConfig,
|
||||
request_type: str,
|
||||
prompt: Optional[str] = None,
|
||||
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
|
||||
tool_options: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, str, str, List[ToolCall] | None]:
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
当前推荐优先使用以下正式接口:
|
||||
- `generate_response`
|
||||
- `generate_response_with_messages`
|
||||
- `generate_response_for_image`
|
||||
- `transcribe_audio`
|
||||
- `embed_text`
|
||||
"""
|
||||
|
||||
if message_factory is not None:
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
"""初始化 LLM 服务门面。
|
||||
|
||||
Args:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions:
|
||||
"""规范化文本生成选项。
|
||||
|
||||
Args:
|
||||
options: 原始生成选项。
|
||||
|
||||
Returns:
|
||||
LLMGenerationOptions: 可直接用于执行请求的完整选项对象。
|
||||
"""
|
||||
if options is None:
|
||||
return LLMGenerationOptions()
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions:
|
||||
"""规范化图像理解选项。
|
||||
|
||||
Args:
|
||||
options: 原始图像理解选项。
|
||||
|
||||
Returns:
|
||||
LLMImageOptions: 可直接用于执行请求的完整选项对象。
|
||||
"""
|
||||
if options is None:
|
||||
return LLMImageOptions()
|
||||
return options
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
options: LLMGenerationOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""生成单轮文本响应。
|
||||
|
||||
Args:
|
||||
prompt: 文本提示词。
|
||||
options: 文本生成选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
return response, reasoning_content, model_name, tool_call
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError("prompt 与 message_factory 不能同时为空")
|
||||
async def generate_response_with_messages(
|
||||
self,
|
||||
message_factory: MessageFactory,
|
||||
options: LLMGenerationOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""基于消息工厂生成响应。
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response, reasoning_content, model_name, tool_call
|
||||
Args:
|
||||
message_factory: 消息工厂,会根据客户端能力构建消息列表。
|
||||
options: 文本生成选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
prompt: str,
|
||||
image_base64: str,
|
||||
image_format: str,
|
||||
options: LLMImageOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""为图像内容生成响应。
|
||||
|
||||
Args:
|
||||
prompt: 文本提示词。
|
||||
image_base64: 图像的 Base64 编码字符串。
|
||||
image_format: 图像格式,例如 ``png``、``jpeg``。
|
||||
options: 图像理解选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_image_options(options)
|
||||
return await self._orchestrator.generate_response_for_image(
|
||||
prompt=prompt,
|
||||
image_base64=image_base64,
|
||||
image_format=image_format,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
|
||||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||
"""执行音频转写请求。
|
||||
|
||||
Args:
|
||||
voice_base64: 音频的 Base64 编码字符串。
|
||||
|
||||
Returns:
|
||||
LLMAudioTranscriptionResult: 音频转写结果对象。
|
||||
"""
|
||||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
|
||||
"""生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本。
|
||||
|
||||
Returns:
|
||||
LLMEmbeddingResult: 向量生成结果对象。
|
||||
"""
|
||||
return await self._orchestrator.get_embedding(embedding_input)
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
"""获取所有可用模型配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {e}")
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
"""根据名称解析任务配置名。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
task_name: 目标任务配置名;为空时返回首个可用任务名。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
str: 解析得到的任务配置名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置。
|
||||
ValueError: 指定名称不存在时抛出。
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"[LLMService] 完整提示词: {prompt}")
|
||||
response, reasoning_content, model_name, _ = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
def _normalize_role(role_name: str) -> RoleType:
|
||||
"""将原始角色字符串转换为内部角色枚举。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
role_name: 原始角色名称。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
RoleType: 规范化后的角色枚举。
|
||||
|
||||
Raises:
|
||||
ValueError: 角色类型不受支持时抛出。
|
||||
"""
|
||||
normalized_role_name = role_name.strip().lower()
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型{model_name_list}生成内容")
|
||||
logger.debug(f"完整提示词: {prompt}")
|
||||
|
||||
response, reasoning_content, model_name, tool_call = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
prompt=prompt,
|
||||
tool_options=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
return RoleType(normalized_role_name)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"不支持的消息角色: {role_name}") from exc
|
||||
|
||||
|
||||
async def generate_with_model_with_tools_by_message_factory(
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
|
||||
def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||||
"""解析 Data URL 形式的图片内容。
|
||||
|
||||
Args:
|
||||
message_factory: 消息工厂函数
|
||||
model_config: 模型配置
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
image_url: 图片 URL。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
Tuple[str, str]: `(图片格式, Base64 数据)`。
|
||||
|
||||
Raises:
|
||||
ValueError: 输入不是受支持的 Data URL 时抛出。
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型 {model_name_list} 生成内容")
|
||||
if not image_url.startswith("data:image/") or ";base64," not in image_url:
|
||||
raise ValueError("仅支持 Data URL 形式的图片输入")
|
||||
prefix, image_base64 = image_url.split(";base64,", maxsplit=1)
|
||||
image_format = prefix.removeprefix("data:image/")
|
||||
if not image_format or not image_base64:
|
||||
raise ValueError("图片 Data URL 不完整")
|
||||
return image_format, image_base64
|
||||
|
||||
response, reasoning_content, model_name, tool_call = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
message_factory=message_factory,
|
||||
tool_options=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||||
"""将原始消息内容追加到内部消息构建器。
|
||||
|
||||
Args:
|
||||
message_builder: 目标消息构建器。
|
||||
content: 原始消息内容。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息内容结构不受支持时抛出。
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
message_builder.add_text_content(content)
|
||||
return
|
||||
|
||||
content_items: List[Any]
|
||||
if isinstance(content, list):
|
||||
content_items = content
|
||||
elif isinstance(content, dict):
|
||||
content_items = [content]
|
||||
else:
|
||||
raise ValueError("消息内容必须为字符串、字典或列表")
|
||||
|
||||
for content_item in content_items:
|
||||
if isinstance(content_item, str):
|
||||
message_builder.add_text_content(content_item)
|
||||
continue
|
||||
if not isinstance(content_item, dict):
|
||||
raise ValueError("消息内容列表中仅支持字符串或字典片段")
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type == "text":
|
||||
text_content = content_item.get("text")
|
||||
if not isinstance(text_content, str):
|
||||
raise ValueError("文本片段缺少 `text` 字段")
|
||||
message_builder.add_text_content(text_content)
|
||||
continue
|
||||
|
||||
if part_type in {"image", "image_url", "input_image"}:
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||||
|
||||
|
||||
def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None:
|
||||
"""将原始工具参数规范化为字典。
|
||||
|
||||
Args:
|
||||
arguments: 原始工具参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 规范化后的参数字典。
|
||||
"""
|
||||
if arguments is None:
|
||||
return None
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
if isinstance(arguments, str):
|
||||
stripped_arguments = arguments.strip()
|
||||
if not stripped_arguments:
|
||||
return {}
|
||||
try:
|
||||
parsed_arguments = json.loads(stripped_arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_arguments": arguments}
|
||||
if isinstance(parsed_arguments, dict):
|
||||
return parsed_arguments
|
||||
return {"value": parsed_arguments}
|
||||
return {"value": arguments}
|
||||
|
||||
|
||||
def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None:
|
||||
"""从原始消息中提取工具调用列表。
|
||||
|
||||
Args:
|
||||
raw_tool_calls: 原始工具调用结构。
|
||||
|
||||
Returns:
|
||||
List[ToolCall] | None: 规范化后的工具调用列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 工具调用结构缺失必要字段时抛出。
|
||||
"""
|
||||
if raw_tool_calls is None:
|
||||
return None
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
raise ValueError("`tool_calls` 必须为列表")
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
raise ValueError("工具调用项必须为字典")
|
||||
|
||||
function_info = raw_tool_call.get("function")
|
||||
if isinstance(function_info, dict):
|
||||
func_name = function_info.get("name")
|
||||
arguments = function_info.get("arguments")
|
||||
else:
|
||||
func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name")
|
||||
arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args")
|
||||
|
||||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||
if not isinstance(call_id, str) or not isinstance(func_name, str):
|
||||
raise ValueError("工具调用缺少 `id` 或函数名称")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=func_name,
|
||||
args=_normalize_tool_arguments(arguments),
|
||||
)
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
return tool_calls or None
|
||||
|
||||
|
||||
def _build_message_from_dict(raw_message: PromptMessage) -> Message:
|
||||
"""将原始消息字典转换为内部消息对象。
|
||||
|
||||
Args:
|
||||
raw_message: 原始消息字典。
|
||||
|
||||
Returns:
|
||||
Message: 规范化后的消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 原始消息结构不合法时抛出。
|
||||
"""
|
||||
raw_role = raw_message.get("role")
|
||||
if not isinstance(raw_role, str):
|
||||
raise ValueError("消息缺少字符串类型的 `role` 字段")
|
||||
|
||||
role = _normalize_role(raw_role)
|
||||
message_builder = MessageBuilder().set_role(role)
|
||||
|
||||
tool_calls = _build_tool_calls(raw_message.get("tool_calls"))
|
||||
if tool_calls is not None:
|
||||
message_builder.set_tool_calls(tool_calls)
|
||||
|
||||
tool_call_id = raw_message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and role == RoleType.Tool:
|
||||
message_builder.set_tool_call_id(tool_call_id)
|
||||
|
||||
if "content" in raw_message and raw_message["content"] not in (None, "", []):
|
||||
_append_content_parts(message_builder, raw_message["content"])
|
||||
|
||||
return message_builder.build()
|
||||
|
||||
|
||||
def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory:
|
||||
"""将统一提示输入转换为消息工厂。
|
||||
|
||||
Args:
|
||||
prompt: 原始提示输入。
|
||||
|
||||
Returns:
|
||||
MessageFactory: 惰性构建消息列表的工厂函数。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
def build_messages(_: BaseClient) -> List[Message]:
|
||||
"""构建单条用户消息。"""
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
return [message_builder.build()]
|
||||
|
||||
return build_messages
|
||||
|
||||
def build_messages(_: BaseClient) -> List[Message]:
|
||||
"""构建多消息对话输入。"""
|
||||
return [_build_message_from_dict(raw_message) for raw_message in prompt]
|
||||
|
||||
return build_messages
|
||||
|
||||
|
||||
async def generate(request: LLMServiceRequest) -> LLMServiceResult:
|
||||
"""执行统一的 LLM 服务请求。
|
||||
|
||||
Args:
|
||||
request: 服务层统一请求对象。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 统一响应对象;失败时 `success=False`。
|
||||
"""
|
||||
llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type)
|
||||
if request.message_factory is not None:
|
||||
active_message_factory = request.message_factory
|
||||
else:
|
||||
prompt = request.prompt
|
||||
if prompt is None:
|
||||
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
|
||||
active_message_factory = _build_prompt_message_factory(prompt)
|
||||
|
||||
try:
|
||||
generation_result = await llm_client.generate_response_with_messages(
|
||||
message_factory=active_message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
tool_options=request.tool_options,
|
||||
response_format=request.response_format,
|
||||
interrupt_flag=request.interrupt_flag,
|
||||
),
|
||||
)
|
||||
return LLMServiceResult.from_response_result(generation_result)
|
||||
except Exception as exc:
|
||||
error_message = f"生成内容时出错: {exc}"
|
||||
logger.error(f"[LLMService] {error_message}")
|
||||
return LLMServiceResult.from_error(error_message, str(exc))
|
||||
|
||||
@@ -13,6 +13,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.config.model_configs import APIProvider
|
||||
from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.network_security import validate_public_url
|
||||
|
||||
@@ -35,8 +37,8 @@ MODEL_FETCHER_CONFIG = {
|
||||
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
return url.rstrip("/") if url else ""
|
||||
"""规范化 URL(去掉尾部斜杠)。"""
|
||||
return normalize_openai_base_url(url) if url else ""
|
||||
|
||||
|
||||
def _parse_openai_response(data: Dict) -> List[Dict]:
|
||||
@@ -89,19 +91,30 @@ async def _fetch_models_from_provider(
|
||||
endpoint: str,
|
||||
parser: str,
|
||||
client_type: str = "openai",
|
||||
auth_type: str = "bearer",
|
||||
auth_header_name: str = "Authorization",
|
||||
auth_header_prefix: str = "Bearer",
|
||||
auth_query_name: str = "api_key",
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
default_query: Optional[Dict[str, str]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
"""从提供商 API 获取模型列表。
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
base_url: 提供商的基础 URL。
|
||||
api_key: API 密钥。
|
||||
endpoint: 获取模型列表的端点。
|
||||
parser: 响应解析器类型。
|
||||
client_type: 客户端类型。
|
||||
auth_type: OpenAI 兼容接口的鉴权方式。
|
||||
auth_header_name: Header 鉴权时使用的请求头名称。
|
||||
auth_header_prefix: Header 鉴权时使用的请求头前缀。
|
||||
auth_query_name: Query 鉴权时使用的查询参数名称。
|
||||
default_headers: 默认附带的请求头。
|
||||
default_query: 默认附带的查询参数。
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
List[Dict]: 解析后的模型列表。
|
||||
"""
|
||||
try:
|
||||
base_url = validate_public_url(_normalize_url(base_url))
|
||||
@@ -118,8 +131,21 @@ async def _fetch_models_from_provider(
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
provider = APIProvider(
|
||||
name="webui-openai-compatible-fetcher",
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
client_type="openai",
|
||||
auth_type=auth_type,
|
||||
auth_header_name=auth_header_name,
|
||||
auth_header_prefix=auth_header_prefix,
|
||||
auth_query_name=auth_query_name,
|
||||
default_headers=default_headers or {},
|
||||
default_query=default_query or {},
|
||||
)
|
||||
client_config = build_openai_compatible_client_config(provider)
|
||||
headers.update(client_config.default_headers)
|
||||
params.update(client_config.default_query)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
@@ -186,10 +212,9 @@ async def get_provider_models(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
"""获取指定提供商的可用模型列表。
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点。
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
@@ -205,13 +230,21 @@ async def get_provider_models(
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
auth_type=provider_config.get("auth_type", "bearer"),
|
||||
auth_header_name=provider_config.get("auth_header_name", "Authorization"),
|
||||
auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"),
|
||||
auth_query_name=provider_config.get("auth_query_name", "api_key"),
|
||||
default_headers=provider_config.get("default_headers", {}),
|
||||
default_query=provider_config.get("default_query", {}),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -229,16 +262,22 @@ async def get_models_by_url(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"),
|
||||
auth_header_name: str = Query("Authorization", description="Header 鉴权名称"),
|
||||
auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"),
|
||||
auth_query_name: str = Query("api_key", description="Query 鉴权参数名"),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
"""
|
||||
"""通过 URL 直接获取模型列表。"""
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
auth_type=auth_type,
|
||||
auth_header_name=auth_header_name,
|
||||
auth_header_prefix=auth_header_prefix,
|
||||
auth_query_name=auth_query_name,
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user