Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev
This commit is contained in:
@@ -205,7 +205,7 @@ class HeartFChatting:
|
||||
# TODO: Planner逻辑
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||
self._end_cycle(current_cycle_detail)
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -133,19 +134,20 @@ class EmbeddingStore:
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
"""以同步方式获取单条字符串的嵌入向量。
|
||||
|
||||
Args:
|
||||
s: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
List[float]: 嵌入向量;失败时返回空列表。
|
||||
"""
|
||||
try:
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_result = embedding_client.embed_text_sync(s)
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
@@ -157,17 +159,15 @@ class EmbeddingStore:
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
self,
|
||||
strs: List[str],
|
||||
chunk_size: int = 10,
|
||||
max_workers: int = 10,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
"""使用多线程批量获取嵌入向量。
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
@@ -190,53 +190,42 @@ class EmbeddingStore:
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
|
||||
"""处理单个数据块。
|
||||
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
Args:
|
||||
chunk_data: 数据块起始索引与字符串列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
|
||||
"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results: List[Tuple[int, str, List[float]]] = []
|
||||
|
||||
try:
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_results = embedding_client.embed_texts_sync(
|
||||
chunk_strs,
|
||||
max_concurrent=1,
|
||||
)
|
||||
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
|
||||
embedding = embedding_result.embedding
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding))
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
logger.error(f"创建 EmbeddingService 实例失败: {e}")
|
||||
for i, s in enumerate(chunk_strs):
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
@@ -1135,9 +1135,6 @@ class DefaultReplyer:
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
del message
|
||||
del sender
|
||||
del target
|
||||
return ""
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1058,12 +1058,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
|
||||
|
||||
if chat_id in _stat_chat_manager.sessions:
|
||||
session = _stat_chat_manager.sessions[chat_id]
|
||||
name = _stat_chat_manager.get_session_name(chat_id)
|
||||
if name and name.strip():
|
||||
return name.strip()
|
||||
if user_name and user_name.strip():
|
||||
return user_name.strip()
|
||||
|
||||
# 如果从chat_stream获取失败,尝试解析chat_id格式
|
||||
if chat_id.startswith("g"):
|
||||
|
||||
@@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
@@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
return is_mentioned, is_at, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]:
|
||||
"""获取文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
text: 待编码的文本内容。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。
|
||||
"""
|
||||
embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding_result = await embedding_client.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user