feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑
This commit is contained in:
@@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
@@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
return is_mentioned, is_at, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]:
|
||||
"""获取文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
text: 待编码的文本内容。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。
|
||||
"""
|
||||
embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding_result = await embedding_client.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user