Merge remote-tracking branch 'upstream/r-dev' into sync/pr-1564-upstream-20260331
# Conflicts: # src/chat/brain_chat/PFC/conversation.py # src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py # src/chat/knowledge/lpmm_ops.py
This commit is contained in:
@@ -4,14 +4,14 @@ import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord
|
||||
from src.common.database.database_model import ToolRecord
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database_service")
|
||||
@@ -65,7 +65,7 @@ async def db_save(
|
||||
record = None
|
||||
if key_field and key_value is not None:
|
||||
key_column = _get_model_field(model_class, key_field)
|
||||
record = session.exec(select(model_class).where(key_column == key_value)).first()
|
||||
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
|
||||
|
||||
if record is None:
|
||||
record = model_class(**data)
|
||||
@@ -99,7 +99,7 @@ async def db_get(
|
||||
statement = _apply_order_by(statement, model_class, order_by)
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
results = session.exec(statement).all()
|
||||
results = session.exec(cast(Any, statement)).all()
|
||||
data = [_to_dict(item) for item in results]
|
||||
if single_result:
|
||||
return data[0] if data else None
|
||||
@@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
||||
statement = select(model_class)
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
records = session.exec(statement).all()
|
||||
records = session.exec(cast(Any, statement)).all()
|
||||
for record in records:
|
||||
for field_name, value in data.items():
|
||||
_get_model_field(model_class, field_name)
|
||||
@@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
statement = select(func.count()).select_from(model_class)
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement).one()
|
||||
result = session.exec(cast(Any, statement)).one()
|
||||
return int(result or 0)
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
|
||||
@@ -157,6 +157,39 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
return 0
|
||||
|
||||
|
||||
async def store_tool_info(
|
||||
chat_stream: BotChatSession,
|
||||
builtin_prompt: Optional[str] = None,
|
||||
display_prompt: str = "",
|
||||
tool_id: str = "",
|
||||
tool_data: Optional[dict[str, Any]] = None,
|
||||
tool_name: str = "",
|
||||
tool_reasoning: str = "",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
record_data = {
|
||||
"tool_id": tool_id or str(int(time.time() * 1000000)),
|
||||
"timestamp": datetime.now(),
|
||||
"session_id": chat_stream.session_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_data": json.dumps(tool_data or {}, ensure_ascii=False),
|
||||
"tool_reasoning": tool_reasoning,
|
||||
"tool_builtin_prompt": builtin_prompt,
|
||||
"tool_display_prompt": display_prompt,
|
||||
}
|
||||
|
||||
saved_record = await db_save(ToolRecord, data=record_data, key_field="tool_id", key_value=record_data["tool_id"])
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseService] 成功存储工具信息: {tool_name} (ID: {record_data['tool_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseService] 存储工具信息失败: {tool_name}")
|
||||
return saved_record
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 存储工具信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream: BotChatSession,
|
||||
builtin_prompt: Optional[str] = None,
|
||||
@@ -166,27 +199,13 @@ async def store_action_info(
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"timestamp": datetime.now(),
|
||||
"session_id": chat_stream.session_id,
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_reasoning": action_reasoning,
|
||||
"action_builtin_prompt": builtin_prompt,
|
||||
"action_display_prompt": display_prompt,
|
||||
}
|
||||
|
||||
saved_record = await db_save(
|
||||
ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
|
||||
return saved_record
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
"""兼容旧接口,内部转发到 ``store_tool_info``。"""
|
||||
return await store_tool_info(
|
||||
chat_stream=chat_stream,
|
||||
builtin_prompt=builtin_prompt,
|
||||
display_prompt=display_prompt,
|
||||
tool_id=thinking_id,
|
||||
tool_data=action_data,
|
||||
tool_name=action_name,
|
||||
tool_reasoning=action_reasoning,
|
||||
)
|
||||
|
||||
@@ -1,191 +1,492 @@
|
||||
"""LLM 服务模块
|
||||
"""LLM 服务层。
|
||||
|
||||
提供与 LLM 模型交互的核心功能。
|
||||
该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到
|
||||
`src.llm_models` 中的底层请求调度器。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
|
||||
from src.common.data_models.llm_service_data_models import (
|
||||
LLMAudioTranscriptionResult,
|
||||
LLMEmbeddingResult,
|
||||
LLMGenerationOptions,
|
||||
LLMImageOptions,
|
||||
LLMResponseResult,
|
||||
LLMServiceRequest,
|
||||
LLMServiceResult,
|
||||
MessageFactory,
|
||||
PromptInput,
|
||||
PromptMessage,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
|
||||
logger = get_logger("llm_service")
|
||||
|
||||
class LLMServiceClient:
|
||||
"""面向上层模块的 LLM 服务对象式门面。
|
||||
|
||||
async def _generate_response(
|
||||
model_config: TaskConfig,
|
||||
request_type: str,
|
||||
prompt: Optional[str] = None,
|
||||
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
|
||||
tool_options: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, str, str, List[ToolCall] | None]:
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
当前推荐优先使用以下正式接口:
|
||||
- `generate_response`
|
||||
- `generate_response_with_messages`
|
||||
- `generate_response_for_image`
|
||||
- `transcribe_audio`
|
||||
- `embed_text`
|
||||
"""
|
||||
|
||||
if message_factory is not None:
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
"""初始化 LLM 服务门面。
|
||||
|
||||
Args:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions:
|
||||
"""规范化文本生成选项。
|
||||
|
||||
Args:
|
||||
options: 原始生成选项。
|
||||
|
||||
Returns:
|
||||
LLMGenerationOptions: 可直接用于执行请求的完整选项对象。
|
||||
"""
|
||||
if options is None:
|
||||
return LLMGenerationOptions()
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions:
|
||||
"""规范化图像理解选项。
|
||||
|
||||
Args:
|
||||
options: 原始图像理解选项。
|
||||
|
||||
Returns:
|
||||
LLMImageOptions: 可直接用于执行请求的完整选项对象。
|
||||
"""
|
||||
if options is None:
|
||||
return LLMImageOptions()
|
||||
return options
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
options: LLMGenerationOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""生成单轮文本响应。
|
||||
|
||||
Args:
|
||||
prompt: 文本提示词。
|
||||
options: 文本生成选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
return response, reasoning_content, model_name, tool_call
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError("prompt 与 message_factory 不能同时为空")
|
||||
async def generate_response_with_messages(
|
||||
self,
|
||||
message_factory: MessageFactory,
|
||||
options: LLMGenerationOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""基于消息工厂生成响应。
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response, reasoning_content, model_name, tool_call
|
||||
Args:
|
||||
message_factory: 消息工厂,会根据客户端能力构建消息列表。
|
||||
options: 文本生成选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
prompt: str,
|
||||
image_base64: str,
|
||||
image_format: str,
|
||||
options: LLMImageOptions | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""为图像内容生成响应。
|
||||
|
||||
Args:
|
||||
prompt: 文本提示词。
|
||||
image_base64: 图像的 Base64 编码字符串。
|
||||
image_format: 图像格式,例如 ``png``、``jpeg``。
|
||||
options: 图像理解选项。
|
||||
|
||||
Returns:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_image_options(options)
|
||||
return await self._orchestrator.generate_response_for_image(
|
||||
prompt=prompt,
|
||||
image_base64=image_base64,
|
||||
image_format=image_format,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
|
||||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||
"""执行音频转写请求。
|
||||
|
||||
Args:
|
||||
voice_base64: 音频的 Base64 编码字符串。
|
||||
|
||||
Returns:
|
||||
LLMAudioTranscriptionResult: 音频转写结果对象。
|
||||
"""
|
||||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
|
||||
"""生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本。
|
||||
|
||||
Returns:
|
||||
LLMEmbeddingResult: 向量生成结果对象。
|
||||
"""
|
||||
return await self._orchestrator.get_embedding(embedding_input)
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
"""获取所有可用模型配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {e}")
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
"""根据名称解析任务配置名。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
task_name: 目标任务配置名;为空时返回首个可用任务名。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
str: 解析得到的任务配置名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置。
|
||||
ValueError: 指定名称不存在时抛出。
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"[LLMService] 完整提示词: {prompt}")
|
||||
response, reasoning_content, model_name, _ = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
def _normalize_role(role_name: str) -> RoleType:
|
||||
"""将原始角色字符串转换为内部角色枚举。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
role_name: 原始角色名称。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
RoleType: 规范化后的角色枚举。
|
||||
|
||||
Raises:
|
||||
ValueError: 角色类型不受支持时抛出。
|
||||
"""
|
||||
normalized_role_name = role_name.strip().lower()
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型{model_name_list}生成内容")
|
||||
logger.debug(f"完整提示词: {prompt}")
|
||||
|
||||
response, reasoning_content, model_name, tool_call = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
prompt=prompt,
|
||||
tool_options=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
return RoleType(normalized_role_name)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"不支持的消息角色: {role_name}") from exc
|
||||
|
||||
|
||||
async def generate_with_model_with_tools_by_message_factory(
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
|
||||
def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||||
"""解析 Data URL 形式的图片内容。
|
||||
|
||||
Args:
|
||||
message_factory: 消息工厂函数
|
||||
model_config: 模型配置
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
image_url: 图片 URL。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
Tuple[str, str]: `(图片格式, Base64 数据)`。
|
||||
|
||||
Raises:
|
||||
ValueError: 输入不是受支持的 Data URL 时抛出。
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型 {model_name_list} 生成内容")
|
||||
if not image_url.startswith("data:image/") or ";base64," not in image_url:
|
||||
raise ValueError("仅支持 Data URL 形式的图片输入")
|
||||
prefix, image_base64 = image_url.split(";base64,", maxsplit=1)
|
||||
image_format = prefix.removeprefix("data:image/")
|
||||
if not image_format or not image_base64:
|
||||
raise ValueError("图片 Data URL 不完整")
|
||||
return image_format, image_base64
|
||||
|
||||
response, reasoning_content, model_name, tool_call = await _generate_response(
|
||||
model_config=model_config,
|
||||
request_type=request_type,
|
||||
message_factory=message_factory,
|
||||
tool_options=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||||
"""将原始消息内容追加到内部消息构建器。
|
||||
|
||||
Args:
|
||||
message_builder: 目标消息构建器。
|
||||
content: 原始消息内容。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息内容结构不受支持时抛出。
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
message_builder.add_text_content(content)
|
||||
return
|
||||
|
||||
content_items: List[Any]
|
||||
if isinstance(content, list):
|
||||
content_items = content
|
||||
elif isinstance(content, dict):
|
||||
content_items = [content]
|
||||
else:
|
||||
raise ValueError("消息内容必须为字符串、字典或列表")
|
||||
|
||||
for content_item in content_items:
|
||||
if isinstance(content_item, str):
|
||||
message_builder.add_text_content(content_item)
|
||||
continue
|
||||
if not isinstance(content_item, dict):
|
||||
raise ValueError("消息内容列表中仅支持字符串或字典片段")
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type == "text":
|
||||
text_content = content_item.get("text")
|
||||
if not isinstance(text_content, str):
|
||||
raise ValueError("文本片段缺少 `text` 字段")
|
||||
message_builder.add_text_content(text_content)
|
||||
continue
|
||||
|
||||
if part_type in {"image", "image_url", "input_image"}:
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||||
|
||||
|
||||
def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None:
|
||||
"""将原始工具参数规范化为字典。
|
||||
|
||||
Args:
|
||||
arguments: 原始工具参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 规范化后的参数字典。
|
||||
"""
|
||||
if arguments is None:
|
||||
return None
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
if isinstance(arguments, str):
|
||||
stripped_arguments = arguments.strip()
|
||||
if not stripped_arguments:
|
||||
return {}
|
||||
try:
|
||||
parsed_arguments = json.loads(stripped_arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_arguments": arguments}
|
||||
if isinstance(parsed_arguments, dict):
|
||||
return parsed_arguments
|
||||
return {"value": parsed_arguments}
|
||||
return {"value": arguments}
|
||||
|
||||
|
||||
def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None:
|
||||
"""从原始消息中提取工具调用列表。
|
||||
|
||||
Args:
|
||||
raw_tool_calls: 原始工具调用结构。
|
||||
|
||||
Returns:
|
||||
List[ToolCall] | None: 规范化后的工具调用列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 工具调用结构缺失必要字段时抛出。
|
||||
"""
|
||||
if raw_tool_calls is None:
|
||||
return None
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
raise ValueError("`tool_calls` 必须为列表")
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
raise ValueError("工具调用项必须为字典")
|
||||
|
||||
function_info = raw_tool_call.get("function")
|
||||
if isinstance(function_info, dict):
|
||||
func_name = function_info.get("name")
|
||||
arguments = function_info.get("arguments")
|
||||
else:
|
||||
func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name")
|
||||
arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args")
|
||||
|
||||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||
if not isinstance(call_id, str) or not isinstance(func_name, str):
|
||||
raise ValueError("工具调用缺少 `id` 或函数名称")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=func_name,
|
||||
args=_normalize_tool_arguments(arguments),
|
||||
)
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
return tool_calls or None
|
||||
|
||||
|
||||
def _build_message_from_dict(raw_message: PromptMessage) -> Message:
|
||||
"""将原始消息字典转换为内部消息对象。
|
||||
|
||||
Args:
|
||||
raw_message: 原始消息字典。
|
||||
|
||||
Returns:
|
||||
Message: 规范化后的消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 原始消息结构不合法时抛出。
|
||||
"""
|
||||
raw_role = raw_message.get("role")
|
||||
if not isinstance(raw_role, str):
|
||||
raise ValueError("消息缺少字符串类型的 `role` 字段")
|
||||
|
||||
role = _normalize_role(raw_role)
|
||||
message_builder = MessageBuilder().set_role(role)
|
||||
|
||||
tool_calls = _build_tool_calls(raw_message.get("tool_calls"))
|
||||
if tool_calls is not None:
|
||||
message_builder.set_tool_calls(tool_calls)
|
||||
|
||||
tool_call_id = raw_message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and role == RoleType.Tool:
|
||||
message_builder.set_tool_call_id(tool_call_id)
|
||||
|
||||
if "content" in raw_message and raw_message["content"] not in (None, "", []):
|
||||
_append_content_parts(message_builder, raw_message["content"])
|
||||
|
||||
return message_builder.build()
|
||||
|
||||
|
||||
def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory:
|
||||
"""将统一提示输入转换为消息工厂。
|
||||
|
||||
Args:
|
||||
prompt: 原始提示输入。
|
||||
|
||||
Returns:
|
||||
MessageFactory: 惰性构建消息列表的工厂函数。
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
def build_messages(_: BaseClient) -> List[Message]:
|
||||
"""构建单条用户消息。"""
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
return [message_builder.build()]
|
||||
|
||||
return build_messages
|
||||
|
||||
def build_messages(_: BaseClient) -> List[Message]:
|
||||
"""构建多消息对话输入。"""
|
||||
return [_build_message_from_dict(raw_message) for raw_message in prompt]
|
||||
|
||||
return build_messages
|
||||
|
||||
|
||||
async def generate(request: LLMServiceRequest) -> LLMServiceResult:
|
||||
"""执行统一的 LLM 服务请求。
|
||||
|
||||
Args:
|
||||
request: 服务层统一请求对象。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 统一响应对象;失败时 `success=False`。
|
||||
"""
|
||||
llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type)
|
||||
if request.message_factory is not None:
|
||||
active_message_factory = request.message_factory
|
||||
else:
|
||||
prompt = request.prompt
|
||||
if prompt is None:
|
||||
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
|
||||
active_message_factory = _build_prompt_message_factory(prompt)
|
||||
|
||||
try:
|
||||
generation_result = await llm_client.generate_response_with_messages(
|
||||
message_factory=active_message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
tool_options=request.tool_options,
|
||||
response_format=request.response_format,
|
||||
interrupt_flag=request.interrupt_flag,
|
||||
),
|
||||
)
|
||||
return LLMServiceResult.from_response_result(generation_result)
|
||||
except Exception as exc:
|
||||
error_message = f"生成内容时出错: {exc}"
|
||||
logger.error(f"[LLMService] {error_message}")
|
||||
return LLMServiceResult.from_error(error_message, str(exc))
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import List, Optional, Tuple
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.action_record_data_model import MaiActionRecord
|
||||
from src.common.data_models.tool_record_data_model import MaiToolRecord
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord, Images, ImageType
|
||||
from src.common.database.database_model import Images, ImageType, ToolRecord
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.common.utils.math_utils import translate_timestamp_to_human_readable
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
@@ -238,18 +238,18 @@ def get_actions_by_timestamp_with_chat(
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[MaiActionRecord]:
|
||||
) -> List[MaiToolRecord]:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.session_id) == chat_id)
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
select(ToolRecord)
|
||||
.where(col(ToolRecord.session_id) == chat_id)
|
||||
.where(col(ToolRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ToolRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
||||
.order_by(col(ToolRecord.timestamp))
|
||||
)
|
||||
if limit is not None:
|
||||
statement = statement.limit(limit)
|
||||
return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()]
|
||||
return [MaiToolRecord.from_db_instance(item) for item in session.exec(statement).all()]
|
||||
|
||||
|
||||
def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str:
|
||||
|
||||
@@ -281,6 +281,26 @@ def _build_processed_plain_text(message: SessionMessage) -> str:
|
||||
return " ".join(part for part in processed_parts if part)
|
||||
|
||||
|
||||
def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160) -> str:
|
||||
"""构造出站消息的日志预览文本。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
max_length: 预览文本最大长度。
|
||||
|
||||
Returns:
|
||||
str: 适用于日志展示的消息摘要。
|
||||
"""
|
||||
preview_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
if not preview_text:
|
||||
preview_text = f"[{_describe_message_sequence(message.raw_message)}]"
|
||||
|
||||
normalized_preview = " ".join(preview_text.split())
|
||||
if len(normalized_preview) <= max_length:
|
||||
return normalized_preview
|
||||
return f"{normalized_preview[:max_length]}..."
|
||||
|
||||
|
||||
def _build_outbound_session_message(
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
@@ -424,11 +444,7 @@ def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
|
||||
f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}"
|
||||
for receipt in delivery_batch.failed_receipts
|
||||
) or "未命中任何发送路由"
|
||||
logger.warning(
|
||||
"[SendService] Platform IO 发送失败: platform=%s %s",
|
||||
delivery_batch.route_key.platform,
|
||||
failed_details,
|
||||
)
|
||||
logger.warning(f"[SendService] Platform IO 发送失败: platform={delivery_batch.route_key.platform} {failed_details}")
|
||||
|
||||
|
||||
async def _send_via_platform_io(
|
||||
@@ -493,9 +509,9 @@ async def _send_via_platform_io(
|
||||
for receipt in delivery_batch.sent_receipts
|
||||
]
|
||||
logger.info(
|
||||
"[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)",
|
||||
route_key.platform,
|
||||
", ".join(successful_driver_ids),
|
||||
f"[SendService] 已通过 Platform IO 将消息发往平台 '{route_key.platform}' "
|
||||
f"(drivers: {', '.join(successful_driver_ids)}) "
|
||||
f"message={_build_outbound_log_preview(message)}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user