feat: Enhance OpenAI compatibility and introduce unified LLM service data models
- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs. - Introduced new data models for LLM service requests and responses to standardize interactions across layers. - Added an adapter base class for unified request execution across different providers. - Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
@@ -2,8 +2,8 @@ import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
@@ -109,8 +109,8 @@ class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner",
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
@@ -398,7 +398,8 @@ class ActionPlanner:
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
@@ -427,7 +428,8 @@ class ActionPlanner:
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
end_generation_result = await self.llm.generate_response(end_decision_prompt)
|
||||
end_content = end_generation_result.response # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
@@ -43,7 +43,9 @@ class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="planner", request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
@@ -157,7 +159,8 @@ class GoalAnalyzer:
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
@@ -271,7 +274,8 @@ class GoalAnalyzer:
|
||||
}}"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
|
||||
@@ -3,8 +3,7 @@ from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.chat.knowledge import qa_manager
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
@@ -14,7 +13,7 @@ class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.llm = LLMServiceClient(task_name="utils")
|
||||
self.private_name = private_name
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
|
||||
@@ -2,8 +2,8 @@ import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
@@ -14,7 +14,7 @@ class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.llm = LLMServiceClient(task_name="utils", request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
@@ -137,7 +137,8 @@ class ReplyChecker:
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
@@ -87,8 +87,8 @@ class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
self.llm = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
@@ -223,7 +223,8 @@ class ReplyGenerator:
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
generation_result = await self.llm.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
@@ -17,9 +17,9 @@ from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
@@ -43,8 +43,8 @@ class BrainPlanner:
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
@@ -412,7 +412,9 @@ class BrainPlanner:
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
|
||||
@@ -38,8 +39,10 @@ def _ensure_directories() -> None:
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMServiceClient(
|
||||
task_name="utils", request_type="emoji"
|
||||
)
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
@@ -461,9 +464,11 @@ class EmojiManager:
|
||||
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
|
||||
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
|
||||
|
||||
decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emoji_replace_prompt, temperature=0.8, max_tokens=600
|
||||
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emoji_replace_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.8, max_tokens=600),
|
||||
)
|
||||
decision = decision_result.response
|
||||
logger.info(f"[决策] 结果: {decision}")
|
||||
|
||||
# 解析决策结果
|
||||
@@ -524,24 +529,36 @@ class EmojiManager:
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.5
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
"jpg",
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.5
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
|
||||
# 表情包审查
|
||||
if global_config.emoji.content_filtration:
|
||||
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
|
||||
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
|
||||
llm_response, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt, image_base64, image_format, temperature=0.3
|
||||
filtration_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.3),
|
||||
)
|
||||
llm_response = filtration_result.response
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
@@ -567,9 +584,11 @@ class EmojiManager:
|
||||
emotion_prompt_template.add_context("description", target_emoji.description)
|
||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||
# 调用LLM生成情感标签
|
||||
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=200
|
||||
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emotion_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
|
||||
)
|
||||
emotion_result = emotion_generation_result.response
|
||||
|
||||
# 解析情感标签结果
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
@@ -11,8 +11,9 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.data_models.image_data_model import MaiImage
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +28,7 @@ def _ensure_image_dir_exists():
|
||||
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
vlm = LLMServiceClient(task_name="vlm", request_type="image")
|
||||
|
||||
|
||||
class ImageManager:
|
||||
@@ -260,7 +261,13 @@ class ImageManager:
|
||||
prompt = global_config.personality.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4)
|
||||
generation_result = await vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.4),
|
||||
)
|
||||
description = generation_result.response
|
||||
if not description:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
return description or ""
|
||||
|
||||
@@ -139,14 +139,14 @@ class EmbeddingStore:
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
@@ -195,13 +195,12 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
@@ -209,7 +208,8 @@ class EmbeddingStore:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from . import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
from . import INVALID_ENTITY
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]:
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
"""从文本中提取 JSON 数据。
|
||||
|
||||
Args:
|
||||
text: 原始模型输出文本。
|
||||
|
||||
Returns:
|
||||
List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。
|
||||
"""
|
||||
if text is None:
|
||||
logger.error("输入文本为None")
|
||||
return []
|
||||
@@ -46,20 +55,30 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
"""对单段文本执行实体提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取实体的原始段落文本。
|
||||
|
||||
Returns:
|
||||
List[str]: 提取出的实体列表。
|
||||
"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(entity_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"LLM返回的原始响应: {response}")
|
||||
@@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
def _rdf_triple_extract(
|
||||
llm_req: LLMServiceClient,
|
||||
paragraph: str,
|
||||
entities: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""对单段文本执行 RDF 三元组提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取的原始段落文本。
|
||||
entities: 已识别出的实体列表。
|
||||
|
||||
Returns:
|
||||
List[List[str]]: 提取出的三元组列表。
|
||||
"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
@@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||
@@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
llm_client_for_ner: LLMServiceClient,
|
||||
llm_client_for_rdf: LLMServiceClient,
|
||||
paragraph: str,
|
||||
) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]:
|
||||
"""从文本中提取实体与三元组信息。
|
||||
|
||||
Args:
|
||||
llm_client_for_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。
|
||||
paragraph: 原始段落文本。
|
||||
|
||||
Returns:
|
||||
Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回
|
||||
``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。
|
||||
"""
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -176,17 +223,30 @@ def info_extract_from_str(
|
||||
|
||||
|
||||
class IEProcess:
|
||||
"""
|
||||
信息抽取处理器类,提供更方便的批次处理接口。
|
||||
"""
|
||||
"""信息抽取处理器。"""
|
||||
|
||||
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_ner: LLMServiceClient,
|
||||
llm_rdf: LLMServiceClient | None = None,
|
||||
) -> None:
|
||||
"""初始化信息抽取处理器。
|
||||
|
||||
Args:
|
||||
llm_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。
|
||||
"""
|
||||
self.llm_ner = llm_ner
|
||||
self.llm_rdf = llm_rdf or llm_ner
|
||||
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
|
||||
"""
|
||||
异步处理多个段落。
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]:
|
||||
"""异步处理多个段落。
|
||||
|
||||
Args:
|
||||
paragraphs: 待处理的段落列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 每个成功段落对应的抽取结果。
|
||||
"""
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
@@ -91,13 +91,14 @@ class LPMMOperations:
|
||||
|
||||
# 2. 实体与三元组抽取 (内部调用大模型)
|
||||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
llm_ner = LLMServiceClient(
|
||||
task_name="lpmm_entity_extract", request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMServiceClient(
|
||||
task_name="lpmm_rdf_build", request_type="lpmm.rdf_build"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
|
||||
@@ -149,7 +149,7 @@ class ActionModifier:
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
activation_type = action_info.activation_type or action_info.focus_activation_type
|
||||
activation_type = action_info.activation_type
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
@@ -19,9 +19,9 @@ from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
@@ -46,8 +46,8 @@ class ActionPlanner:
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
self.planner_llm = LLMServiceClient(
|
||||
task_name="planner", request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
@@ -725,7 +725,9 @@ class ActionPlanner:
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
generation_result = await self.planner_llm.generate_response(prompt=prompt)
|
||||
llm_content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -56,7 +56,9 @@ class DefaultReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
|
||||
@@ -1158,9 +1160,11 @@ class DefaultReplyer:
|
||||
# else:
|
||||
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
@@ -1200,11 +1204,15 @@ class DefaultReplyer:
|
||||
template_prompt.add_context("sender", sender)
|
||||
template_prompt.add_context("target_message", target)
|
||||
prompt = await prompt_manager.render_prompt(template_prompt)
|
||||
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
generation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="tool_use",
|
||||
request_type="replyer.lpmm_knowledge",
|
||||
prompt=prompt,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
)
|
||||
)
|
||||
tool_calls = generation_result.completion.tool_calls
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -52,7 +52,9 @@ class PrivateReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
# self.memory_activator = MemoryActivator()
|
||||
@@ -997,9 +999,11 @@ class PrivateReplyer:
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
content = content.strip()
|
||||
|
||||
|
||||
@@ -4,16 +4,18 @@
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
@@ -33,7 +35,9 @@ class ToolExecutor:
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="tool_use", request_type="tool_executor"
|
||||
)
|
||||
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
@@ -69,9 +73,11 @@ class ToolExecutor:
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
generation_result = await self.llm_model.generate_response(
|
||||
prompt=prompt,
|
||||
options=LLMGenerationOptions(tool_options=tools, raise_when_empty=False),
|
||||
)
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
@@ -85,11 +91,15 @@ class ToolExecutor:
|
||||
return tool_results, used_tools, prompt
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
def _get_tool_definitions(self) -> List[ToolDefinitionInput]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_query_service.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
return [
|
||||
cast(ToolDefinitionInput, info.get_llm_definition())
|
||||
for name, info in all_tools.items()
|
||||
if name not in user_disabled_tools
|
||||
]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用列表"""
|
||||
|
||||
@@ -13,8 +13,8 @@ import jieba
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
|
||||
Reference in New Issue
Block a user