634 lines
22 KiB
Python
634 lines
22 KiB
Python
"""LLM 服务层。
|
||
|
||
该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到
|
||
`src.llm_models` 中的底层请求调度器。
|
||
"""
|
||
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
import hashlib
|
||
import inspect
|
||
import json
|
||
|
||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||
from src.common.data_models.llm_service_data_models import (
|
||
LLMAudioTranscriptionResult,
|
||
LLMGenerationOptions,
|
||
LLMImageOptions,
|
||
LLMResponseResult,
|
||
LLMServiceRequest,
|
||
LLMServiceResult,
|
||
MessageFactory,
|
||
PromptInput,
|
||
PromptMessage,
|
||
)
|
||
from src.common.logger import get_logger
|
||
from src.llm_models.model_client.base_client import BaseClient
|
||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||
from src.llm_models.payload_content.tool_option import ToolCall
|
||
from src.llm_models.utils_model import LLMOrchestrator
|
||
from src.services.embedding_service import EmbeddingServiceClient
|
||
from src.services.llm_cache_stats import record_llm_cache_usage
|
||
from src.services.service_task_resolver import (
|
||
get_available_models as _get_available_models,
|
||
resolve_task_name as _resolve_task_name,
|
||
resolve_task_name_from_model_config as _resolve_task_name_from_model_config,
|
||
)
|
||
|
||
logger = get_logger("llm_service")
|
||
|
||
|
||
class LLMServiceClient:
|
||
"""面向上层模块的 LLM 服务对象式门面。
|
||
|
||
当前推荐优先使用以下正式接口:
|
||
- `generate_response`
|
||
- `generate_response_with_messages`
|
||
- `generate_response_for_image`
|
||
- `transcribe_audio`
|
||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||
"""
|
||
|
||
def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
|
||
"""初始化 LLM 服务门面。
|
||
|
||
Args:
|
||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||
request_type: 当前请求的业务类型标识。
|
||
"""
|
||
self.task_name = _resolve_task_name(task_name)
|
||
self.request_type = request_type
|
||
self.session_id = str(session_id or "").strip()
|
||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||
|
||
@staticmethod
|
||
def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions:
|
||
"""规范化文本生成选项。
|
||
|
||
Args:
|
||
options: 原始生成选项。
|
||
|
||
Returns:
|
||
LLMGenerationOptions: 可直接用于执行请求的完整选项对象。
|
||
"""
|
||
if options is None:
|
||
return LLMGenerationOptions()
|
||
return options
|
||
|
||
@staticmethod
|
||
def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions:
|
||
"""规范化图像理解选项。
|
||
|
||
Args:
|
||
options: 原始图像理解选项。
|
||
|
||
Returns:
|
||
LLMImageOptions: 可直接用于执行请求的完整选项对象。
|
||
"""
|
||
if options is None:
|
||
return LLMImageOptions()
|
||
return options
|
||
|
||
@staticmethod
|
||
def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
|
||
parts: list[dict[str, Any]] = []
|
||
for part in message.parts:
|
||
if hasattr(part, "text"):
|
||
parts.append({"type": "text", "text": part.text})
|
||
continue
|
||
|
||
image_base64 = getattr(part, "image_base64", "")
|
||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||
parts.append(
|
||
{
|
||
"type": "image",
|
||
"format": getattr(part, "image_format", ""),
|
||
"size": len(image_base64),
|
||
"sha256": image_digest,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"role": str(message.role.value if hasattr(message.role, "value") else message.role),
|
||
"parts": parts,
|
||
"tool_call_id": message.tool_call_id,
|
||
"tool_name": message.tool_name,
|
||
"tool_calls": [
|
||
{
|
||
"id": tool_call.call_id,
|
||
"name": tool_call.func_name,
|
||
"arguments": tool_call.args,
|
||
"extra_content": tool_call.extra_content,
|
||
}
|
||
for tool_call in (message.tool_calls or [])
|
||
],
|
||
}
|
||
|
||
@classmethod
|
||
def _build_cache_stats_prompt_text(
|
||
cls,
|
||
*,
|
||
messages: List[Message],
|
||
tool_options: Any,
|
||
response_format: Any,
|
||
) -> str:
|
||
payload = {
|
||
"messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
|
||
"tool_options": tool_options or [],
|
||
"response_format": response_format,
|
||
}
|
||
return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
|
||
|
||
def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
|
||
"""记录当前调用的 prompt cache 统计。"""
|
||
|
||
record_llm_cache_usage(
|
||
task_name=self.task_name,
|
||
request_type=self.request_type,
|
||
model_name=result.model_name,
|
||
session_id=self.session_id,
|
||
prompt_tokens=result.prompt_tokens,
|
||
prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
|
||
prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
|
||
prompt_text=prompt_text,
|
||
)
|
||
|
||
async def generate_response(
|
||
self,
|
||
prompt: str,
|
||
options: LLMGenerationOptions | None = None,
|
||
) -> LLMResponseResult:
|
||
"""生成单轮文本响应。
|
||
|
||
Args:
|
||
prompt: 文本提示词。
|
||
options: 文本生成选项。
|
||
|
||
Returns:
|
||
LLMResponseResult: 统一文本生成结果。
|
||
"""
|
||
active_options = self._normalize_generation_options(options)
|
||
prompt_text = self._build_cache_stats_prompt_text(
|
||
messages=[MessageBuilder().add_text_content(prompt).build()],
|
||
tool_options=active_options.tool_options,
|
||
response_format=active_options.response_format,
|
||
)
|
||
result = await self._orchestrator.generate_response_async(
|
||
prompt=prompt,
|
||
temperature=active_options.temperature,
|
||
max_tokens=active_options.max_tokens,
|
||
tools=active_options.tool_options,
|
||
response_format=active_options.response_format,
|
||
raise_when_empty=active_options.raise_when_empty,
|
||
interrupt_flag=active_options.interrupt_flag,
|
||
)
|
||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||
return result
|
||
|
||
async def generate_response_with_messages(
|
||
self,
|
||
message_factory: MessageFactory,
|
||
options: LLMGenerationOptions | None = None,
|
||
) -> LLMResponseResult:
|
||
"""基于消息工厂生成响应。
|
||
|
||
Args:
|
||
message_factory: 消息工厂,会根据客户端能力构建消息列表。
|
||
options: 文本生成选项。
|
||
|
||
Returns:
|
||
LLMResponseResult: 统一文本生成结果。
|
||
"""
|
||
active_options = self._normalize_generation_options(options)
|
||
prompt_text_holder: dict[str, str] = {}
|
||
|
||
def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
|
||
if len(inspect.signature(message_factory).parameters) >= 2:
|
||
messages = message_factory(client, model_info)
|
||
else:
|
||
messages = message_factory(client)
|
||
prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
|
||
messages=messages,
|
||
tool_options=active_options.tool_options,
|
||
response_format=active_options.response_format,
|
||
)
|
||
return messages
|
||
|
||
result = await self._orchestrator.generate_response_with_message_async(
|
||
message_factory=cache_stats_message_factory,
|
||
temperature=active_options.temperature,
|
||
max_tokens=active_options.max_tokens,
|
||
tools=active_options.tool_options,
|
||
response_format=active_options.response_format,
|
||
raise_when_empty=active_options.raise_when_empty,
|
||
interrupt_flag=active_options.interrupt_flag,
|
||
)
|
||
self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
|
||
return result
|
||
|
||
async def generate_response_for_image(
|
||
self,
|
||
prompt: str,
|
||
image_base64: str,
|
||
image_format: str,
|
||
options: LLMImageOptions | None = None,
|
||
) -> LLMResponseResult:
|
||
"""为图像内容生成响应。
|
||
|
||
Args:
|
||
prompt: 文本提示词。
|
||
image_base64: 图像的 Base64 编码字符串。
|
||
image_format: 图像格式,例如 ``png``、``jpeg``。
|
||
options: 图像理解选项。
|
||
|
||
Returns:
|
||
LLMResponseResult: 统一文本生成结果。
|
||
"""
|
||
active_options = self._normalize_image_options(options)
|
||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||
prompt_text = json.dumps(
|
||
{
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"parts": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image",
|
||
"format": image_format,
|
||
"size": len(image_base64),
|
||
"sha256": image_digest,
|
||
},
|
||
],
|
||
}
|
||
],
|
||
"tool_options": [],
|
||
"response_format": None,
|
||
},
|
||
ensure_ascii=False,
|
||
sort_keys=True,
|
||
)
|
||
result = await self._orchestrator.generate_response_for_image(
|
||
prompt=prompt,
|
||
image_base64=image_base64,
|
||
image_format=image_format,
|
||
temperature=active_options.temperature,
|
||
max_tokens=active_options.max_tokens,
|
||
interrupt_flag=active_options.interrupt_flag,
|
||
)
|
||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||
return result
|
||
|
||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||
"""执行音频转写请求。
|
||
|
||
Args:
|
||
voice_base64: 音频的 Base64 编码字符串。
|
||
|
||
Returns:
|
||
LLMAudioTranscriptionResult: 音频转写结果对象。
|
||
"""
|
||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||
|
||
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||
"""兼容旧调用的文本嵌入入口。
|
||
|
||
Args:
|
||
embedding_input: 待编码的文本。
|
||
|
||
Returns:
|
||
EmbeddingResult: 向量生成结果对象。
|
||
"""
|
||
embedding_client = EmbeddingServiceClient(
|
||
task_name=self.task_name,
|
||
request_type=self.request_type,
|
||
)
|
||
return await embedding_client.embed_text(embedding_input)
|
||
|
||
|
||
def get_available_models() -> Dict[str, Any]:
|
||
"""获取所有可用模型配置。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 以模型任务名为键的配置映射。
|
||
"""
|
||
return _get_available_models()
|
||
|
||
|
||
def resolve_task_name(task_name: str = "") -> str:
|
||
"""根据名称解析任务配置名。
|
||
|
||
Args:
|
||
task_name: 目标任务配置名;为空时返回首个可用任务名。
|
||
|
||
Returns:
|
||
str: 解析得到的任务配置名。
|
||
"""
|
||
return _resolve_task_name(task_name)
|
||
|
||
|
||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
|
||
|
||
Args:
|
||
model_config: 旧调用方持有的任务配置对象。
|
||
preferred_task_name: 候选任务名(可选)。
|
||
|
||
Returns:
|
||
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
|
||
"""
|
||
return _resolve_task_name_from_model_config(
|
||
model_config=model_config,
|
||
preferred_task_name=preferred_task_name,
|
||
)
|
||
|
||
|
||
def _normalize_role(role_name: str) -> RoleType:
|
||
"""将原始角色字符串转换为内部角色枚举。
|
||
|
||
Args:
|
||
role_name: 原始角色名称。
|
||
|
||
Returns:
|
||
RoleType: 规范化后的角色枚举。
|
||
|
||
Raises:
|
||
ValueError: 角色类型不受支持时抛出。
|
||
"""
|
||
normalized_role_name = role_name.strip().lower()
|
||
try:
|
||
return RoleType(normalized_role_name)
|
||
except ValueError as exc:
|
||
raise ValueError(f"不支持的消息角色: {role_name}") from exc
|
||
|
||
|
||
def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||
"""解析 Data URL 形式的图片内容。
|
||
|
||
Args:
|
||
image_url: 图片 URL。
|
||
|
||
Returns:
|
||
Tuple[str, str]: `(图片格式, Base64 数据)`。
|
||
|
||
Raises:
|
||
ValueError: 输入不是受支持的 Data URL 时抛出。
|
||
"""
|
||
if not image_url.startswith("data:image/") or ";base64," not in image_url:
|
||
raise ValueError("仅支持 Data URL 形式的图片输入")
|
||
prefix, image_base64 = image_url.split(";base64,", maxsplit=1)
|
||
image_format = prefix.removeprefix("data:image/")
|
||
if not image_format or not image_base64:
|
||
raise ValueError("图片 Data URL 不完整")
|
||
return image_format, image_base64
|
||
|
||
|
||
def _append_image_content(message_builder: MessageBuilder, content_item: Any) -> bool:
|
||
"""向消息构建器追加图片片段。
|
||
|
||
兼容两种输入格式:
|
||
1. 旧序列化格式中的 `(image_format, image_base64)` 元组。
|
||
2. 标准字典片段中的 Data URL 或 `image_format`/`image_base64` 字段。
|
||
"""
|
||
|
||
if isinstance(content_item, (tuple, list)) and len(content_item) == 2:
|
||
image_format, image_base64 = content_item
|
||
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||
raise ValueError("图片元组片段必须包含字符串类型的 image_format 和 image_base64")
|
||
|
||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||
return True
|
||
|
||
if not isinstance(content_item, dict):
|
||
return False
|
||
|
||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||
if part_type not in {"image", "image_url", "input_image"}:
|
||
return False
|
||
|
||
image_url = content_item.get("image_url")
|
||
if isinstance(image_url, dict):
|
||
image_url = image_url.get("url")
|
||
if isinstance(image_url, str):
|
||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||
return True
|
||
|
||
image_format = content_item.get("image_format")
|
||
image_base64 = content_item.get("image_base64")
|
||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||
return True
|
||
|
||
raise ValueError("图片片段缺少可识别的图片数据")
|
||
|
||
|
||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||
"""将原始消息内容追加到内部消息构建器。
|
||
|
||
Args:
|
||
message_builder: 目标消息构建器。
|
||
content: 原始消息内容。
|
||
|
||
Raises:
|
||
ValueError: 消息内容结构不受支持时抛出。
|
||
"""
|
||
if isinstance(content, str):
|
||
message_builder.add_text_content(content)
|
||
return
|
||
|
||
content_items: List[Any]
|
||
if isinstance(content, list):
|
||
content_items = content
|
||
elif isinstance(content, dict):
|
||
content_items = [content]
|
||
else:
|
||
raise ValueError("消息内容必须为字符串、字典或列表")
|
||
|
||
for content_item in content_items:
|
||
if isinstance(content_item, str):
|
||
message_builder.add_text_content(content_item)
|
||
continue
|
||
if _append_image_content(message_builder, content_item):
|
||
continue
|
||
if not isinstance(content_item, dict):
|
||
raise ValueError("消息内容列表中仅支持字符串、图片元组或字典片段")
|
||
|
||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||
if part_type == "text":
|
||
text_content = content_item.get("text")
|
||
if not isinstance(text_content, str):
|
||
raise ValueError("文本片段缺少 `text` 字段")
|
||
message_builder.add_text_content(text_content)
|
||
continue
|
||
|
||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||
|
||
|
||
def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None:
|
||
"""将原始工具参数规范化为字典。
|
||
|
||
Args:
|
||
arguments: 原始工具参数。
|
||
|
||
Returns:
|
||
Dict[str, Any] | None: 规范化后的参数字典。
|
||
"""
|
||
if arguments is None:
|
||
return None
|
||
if isinstance(arguments, dict):
|
||
return arguments
|
||
if isinstance(arguments, str):
|
||
stripped_arguments = arguments.strip()
|
||
if not stripped_arguments:
|
||
return {}
|
||
try:
|
||
parsed_arguments = json.loads(stripped_arguments)
|
||
except json.JSONDecodeError:
|
||
return {"raw_arguments": arguments}
|
||
if isinstance(parsed_arguments, dict):
|
||
return parsed_arguments
|
||
return {"value": parsed_arguments}
|
||
return {"value": arguments}
|
||
|
||
|
||
def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None:
|
||
"""从原始消息中提取工具调用列表。
|
||
|
||
Args:
|
||
raw_tool_calls: 原始工具调用结构。
|
||
|
||
Returns:
|
||
List[ToolCall] | None: 规范化后的工具调用列表。
|
||
|
||
Raises:
|
||
ValueError: 工具调用结构缺失必要字段时抛出。
|
||
"""
|
||
if raw_tool_calls is None:
|
||
return None
|
||
if not isinstance(raw_tool_calls, list):
|
||
raise ValueError("`tool_calls` 必须为列表")
|
||
|
||
tool_calls: List[ToolCall] = []
|
||
for raw_tool_call in raw_tool_calls:
|
||
if not isinstance(raw_tool_call, dict):
|
||
raise ValueError("工具调用项必须为字典")
|
||
|
||
function_info = raw_tool_call.get("function")
|
||
if isinstance(function_info, dict):
|
||
func_name = function_info.get("name")
|
||
arguments = function_info.get("arguments")
|
||
else:
|
||
func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name")
|
||
arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args")
|
||
|
||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||
if not isinstance(call_id, str) or not isinstance(func_name, str):
|
||
raise ValueError("工具调用缺少 `id` 或函数名称")
|
||
|
||
extra_content = raw_tool_call.get("extra_content")
|
||
tool_calls.append(
|
||
ToolCall(
|
||
call_id=call_id,
|
||
func_name=func_name,
|
||
args=_normalize_tool_arguments(arguments),
|
||
extra_content=extra_content if isinstance(extra_content, dict) else None,
|
||
)
|
||
)
|
||
|
||
return tool_calls or None
|
||
|
||
|
||
def _build_message_from_dict(raw_message: PromptMessage) -> Message:
|
||
"""将原始消息字典转换为内部消息对象。
|
||
|
||
Args:
|
||
raw_message: 原始消息字典。
|
||
|
||
Returns:
|
||
Message: 规范化后的消息对象。
|
||
|
||
Raises:
|
||
ValueError: 原始消息结构不合法时抛出。
|
||
"""
|
||
raw_role = raw_message.get("role")
|
||
if not isinstance(raw_role, str):
|
||
raise ValueError("消息缺少字符串类型的 `role` 字段")
|
||
|
||
role = _normalize_role(raw_role)
|
||
message_builder = MessageBuilder().set_role(role)
|
||
|
||
tool_calls = _build_tool_calls(raw_message.get("tool_calls"))
|
||
if tool_calls is not None:
|
||
message_builder.set_tool_calls(tool_calls)
|
||
|
||
tool_call_id = raw_message.get("tool_call_id")
|
||
if isinstance(tool_call_id, str) and role == RoleType.Tool:
|
||
message_builder.set_tool_call_id(tool_call_id)
|
||
|
||
if "content" in raw_message and raw_message["content"] not in (None, "", []):
|
||
_append_content_parts(message_builder, raw_message["content"])
|
||
|
||
return message_builder.build()
|
||
|
||
|
||
def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory:
|
||
"""将统一提示输入转换为消息工厂。
|
||
|
||
Args:
|
||
prompt: 原始提示输入。
|
||
|
||
Returns:
|
||
MessageFactory: 惰性构建消息列表的工厂函数。
|
||
"""
|
||
if isinstance(prompt, str):
|
||
def build_messages(_: BaseClient) -> List[Message]:
|
||
"""构建单条用户消息。"""
|
||
message_builder = MessageBuilder()
|
||
message_builder.add_text_content(prompt)
|
||
return [message_builder.build()]
|
||
|
||
return build_messages
|
||
|
||
def build_messages(_: BaseClient) -> List[Message]:
|
||
"""构建多消息对话输入。"""
|
||
return [_build_message_from_dict(raw_message) for raw_message in prompt]
|
||
|
||
return build_messages
|
||
|
||
|
||
async def generate(request: LLMServiceRequest) -> LLMServiceResult:
|
||
"""执行统一的 LLM 服务请求。
|
||
|
||
Args:
|
||
request: 服务层统一请求对象。
|
||
|
||
Returns:
|
||
LLMServiceResult: 统一响应对象;失败时 `success=False`。
|
||
"""
|
||
llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type)
|
||
if request.message_factory is not None:
|
||
active_message_factory = request.message_factory
|
||
else:
|
||
prompt = request.prompt
|
||
if prompt is None:
|
||
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
|
||
active_message_factory = _build_prompt_message_factory(prompt)
|
||
|
||
try:
|
||
generation_result = await llm_client.generate_response_with_messages(
|
||
message_factory=active_message_factory,
|
||
options=LLMGenerationOptions(
|
||
temperature=request.temperature,
|
||
max_tokens=request.max_tokens,
|
||
tool_options=request.tool_options,
|
||
response_format=request.response_format,
|
||
interrupt_flag=request.interrupt_flag,
|
||
),
|
||
)
|
||
return LLMServiceResult.from_response_result(generation_result)
|
||
except Exception as exc:
|
||
error_message = f"生成内容时出错: {exc}"
|
||
logger.error(f"[LLMService] {error_message}")
|
||
return LLMServiceResult.from_error(error_message, str(exc))
|