Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
|
||||
"""根据哈希值获取已注册表情包的情感标签列表
|
||||
|
||||
@@ -731,7 +731,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(',')
|
||||
return emoji_record.emotion.split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
|
||||
@@ -77,10 +77,10 @@ class ExpressionSelector:
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
@@ -123,9 +123,7 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -200,7 +198,7 @@ class ExpressionSelector:
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
@@ -208,7 +206,7 @@ class ExpressionSelector:
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
@@ -248,7 +246,6 @@ class ExpressionSelector:
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
@@ -295,7 +292,6 @@ class ExpressionSelector:
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]:
|
||||
return get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]:
|
||||
return get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -34,4 +34,4 @@ def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -261,7 +261,11 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||
async def _observe(
|
||||
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None
|
||||
) -> bool:
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
|
||||
@@ -3,14 +3,16 @@ from typing import Any, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
|
||||
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
@@ -18,7 +20,7 @@ class Heartflow:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
new_chat = HeartFChatting(chat_id = chat_id)
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
@@ -27,4 +29,5 @@ class Heartflow:
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
@@ -34,14 +35,14 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""
|
||||
if message.is_picid or message.is_emoji:
|
||||
return 0.0, []
|
||||
|
||||
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth= 4,
|
||||
max_depth=4,
|
||||
fast_retrieval=global_config.chat.interest_rate_mode == "fast",
|
||||
)
|
||||
message.key_words = keywords
|
||||
@@ -51,7 +52,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
@@ -75,16 +76,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
|
||||
message.interest_value = base_interest
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
|
||||
return base_interest, keywords
|
||||
|
||||
|
||||
@@ -115,14 +115,13 @@ class HeartFCMessageReceiver:
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
interested_rate, keywords = await _calculate_interest(message)
|
||||
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
if global_config.mood.enable_mood:
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
@@ -132,7 +131,7 @@ class HeartFCMessageReceiver:
|
||||
# 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
picid_list = re.findall(picid_pattern, message.processed_plain_text)
|
||||
|
||||
|
||||
# 创建替换后的文本
|
||||
processed_text = message.processed_plain_text
|
||||
if picid_list:
|
||||
@@ -145,18 +144,20 @@ class HeartFCMessageReceiver:
|
||||
# 如果没有找到图片描述,则移除[picid:xxxx]标记
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
|
||||
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references(
|
||||
processed_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
|
||||
|
||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
nickname=userinfo.user_nickname,
|
||||
) # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理失败: {e}")
|
||||
|
||||
@@ -124,6 +124,7 @@ async def send_typing():
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
@@ -135,4 +136,4 @@ async def stop_typing():
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
||||
@@ -32,11 +32,11 @@ install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
dir_path: str,
|
||||
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
@@ -103,12 +109,16 @@ class EmbeddingStore:
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||
|
||||
|
||||
# 如果配置值被调整,记录日志
|
||||
if self.max_workers != max_workers:
|
||||
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||
logger.warning(
|
||||
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
|
||||
)
|
||||
if self.chunk_size != chunk_size:
|
||||
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||
logger.warning(
|
||||
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
|
||||
)
|
||||
|
||||
self.store = {}
|
||||
|
||||
@@ -120,23 +130,23 @@ class EmbeddingStore:
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
@@ -147,43 +157,45 @@ class EmbeddingStore:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
|
||||
# 分块
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i:i + chunk_size]
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
@@ -193,25 +205,25 @@ class EmbeddingStore:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
@@ -220,14 +232,14 @@ class EmbeddingStore:
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
return chunk_results
|
||||
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
@@ -241,7 +253,7 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
|
||||
# 按原始顺序返回结果
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
@@ -250,7 +262,7 @@ class EmbeddingStore:
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
|
||||
return ordered_results
|
||||
|
||||
def get_test_file_path(self):
|
||||
@@ -259,14 +271,14 @@ class EmbeddingStore:
|
||||
def save_embedding_test_vectors(self):
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 构建测试向量字典
|
||||
test_vectors = {}
|
||||
for idx, (s, embedding) in enumerate(embedding_results):
|
||||
@@ -276,10 +288,10 @@ class EmbeddingStore:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info("测试字符串嵌入向量保存完成")
|
||||
|
||||
def load_embedding_test_vectors(self):
|
||||
@@ -297,35 +309,35 @@ class EmbeddingStore:
|
||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
# 检查本地向量完整性
|
||||
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||
if local_vectors.get(str(idx)) is None:
|
||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 检查一致性
|
||||
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||
local_emb = local_vectors.get(str(idx))
|
||||
if not new_emb:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
return False
|
||||
|
||||
|
||||
sim = cosine_similarity(local_emb, new_emb)
|
||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||
return False
|
||||
|
||||
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
@@ -333,22 +345,22 @@ class EmbeddingStore:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
|
||||
|
||||
total = len(strs)
|
||||
|
||||
|
||||
# 过滤已存在的字符串
|
||||
new_strs = []
|
||||
for s in strs:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash not in self.store:
|
||||
new_strs.append(s)
|
||||
|
||||
|
||||
if not new_strs:
|
||||
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
@@ -362,31 +374,39 @@ class EmbeddingStore:
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||
|
||||
|
||||
# 首先更新已存在项的进度
|
||||
already_processed = total - len(new_strs)
|
||||
if already_processed > 0:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
|
||||
if new_strs:
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
self.max_workers,
|
||||
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
|
||||
)
|
||||
|
||||
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||
|
||||
|
||||
# 定义进度更新回调函数
|
||||
def update_progress(count):
|
||||
progress.update(task, advance=count)
|
||||
|
||||
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
@@ -519,7 +539,7 @@ class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
|
||||
Args:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
|
||||
@@ -426,9 +426,7 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score)
|
||||
for node_key, score in ppr_res.items()
|
||||
if node_key.startswith("paragraph")
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||
from .lpmmconfig import global_config
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from .lpmmconfig import global_config # noqa
|
||||
from .embedding_store import EmbeddingManager # noqa
|
||||
from .llm_client import LLMClient # noqa
|
||||
from .utils.dyn_topk import dyn_select_top_k # noqa
|
||||
|
||||
|
||||
class MemoryActiveManager:
|
||||
|
||||
@@ -8,7 +8,7 @@ def dyn_select_top_k(
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
return []
|
||||
|
||||
|
||||
# 按照分数排序(降序)
|
||||
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class MessageStorage:
|
||||
if isinstance(keywords, list):
|
||||
return json.dumps(keywords, ensure_ascii=False)
|
||||
return "[]"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
@@ -85,7 +85,7 @@ class MessageStorage:
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
selected_expressions = ""
|
||||
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
|
||||
@@ -124,4 +124,4 @@ class ActionManager:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
@@ -103,25 +103,23 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
|
||||
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
# if chat_content is not None:
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
@@ -131,9 +129,7 @@ class ActionModifier:
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
|
||||
@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||
|
||||
|
||||
for item_name in stats[period_key][category]:
|
||||
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||
if time_costs:
|
||||
# 计算平均耗时
|
||||
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||
|
||||
|
||||
# 计算标准差
|
||||
if len(time_costs) > 1:
|
||||
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||
std_time_cost = variance ** 0.5
|
||||
std_time_cost = variance**0.5
|
||||
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||
else:
|
||||
stats[period_key][std_key][item_name] = 0.0
|
||||
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[MSG_CNT_BY_CHAT]
|
||||
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 生成HTML
|
||||
return f"""
|
||||
|
||||
@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
|
||||
# 这部分怎么处理啊啊啊啊
|
||||
#我觉得可以给消息加一个 reply_probability_boost字段
|
||||
# 我觉得可以给消息加一个 reply_probability_boost字段
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
@@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
return [keywords_str] if keywords_str else []
|
||||
|
||||
|
||||
|
||||
|
||||
def cut_key_words(concept_name: str) -> list[str]:
|
||||
"""对概念名称进行jieba分词,并过滤掉关键词列表中的关键词"""
|
||||
concept_name_tokens = list(jieba.cut(concept_name))
|
||||
|
||||
# 定义常见连词、停用词与标点
|
||||
conjunctions = {
|
||||
"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"
|
||||
}
|
||||
conjunctions = {"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"}
|
||||
stop_words = {
|
||||
"的", "了", "呢", "吗", "吧", "啊", "哦", "恩", "嗯", "呀", "嘛", "哇",
|
||||
"在", "是", "很", "也", "又", "就", "都", "还", "更", "最", "被", "把",
|
||||
"给", "对", "和", "与", "及", "跟", "并", "而且", "或者", "或", "以及"
|
||||
"的",
|
||||
"了",
|
||||
"呢",
|
||||
"吗",
|
||||
"吧",
|
||||
"啊",
|
||||
"哦",
|
||||
"恩",
|
||||
"嗯",
|
||||
"呀",
|
||||
"嘛",
|
||||
"哇",
|
||||
"在",
|
||||
"是",
|
||||
"很",
|
||||
"也",
|
||||
"又",
|
||||
"就",
|
||||
"都",
|
||||
"还",
|
||||
"更",
|
||||
"最",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"对",
|
||||
"和",
|
||||
"与",
|
||||
"及",
|
||||
"跟",
|
||||
"并",
|
||||
"而且",
|
||||
"或者",
|
||||
"或",
|
||||
"以及",
|
||||
}
|
||||
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
|
||||
|
||||
@@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
left = merged_tokens[-1]
|
||||
right = cleaned_tokens[i + 1]
|
||||
# 左右都需要是有效词
|
||||
if left and right \
|
||||
and left not in conjunctions and right not in conjunctions \
|
||||
and left not in stop_words and right not in stop_words \
|
||||
and not all(ch in chinese_punctuations for ch in left) \
|
||||
and not all(ch in chinese_punctuations for ch in right):
|
||||
if (
|
||||
left
|
||||
and right
|
||||
and left not in conjunctions
|
||||
and right not in conjunctions
|
||||
and left not in stop_words
|
||||
and right not in stop_words
|
||||
and not all(ch in chinese_punctuations for ch in left)
|
||||
and not all(ch in chinese_punctuations for ch in right)
|
||||
):
|
||||
# 合并为一个新词,并替换掉左侧与跳过右侧
|
||||
combined = f"{left}{tok}{right}"
|
||||
merged_tokens[-1] = combined
|
||||
@@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
if tok in stop_words:
|
||||
continue
|
||||
# if tok in ban_words:
|
||||
# continue
|
||||
# continue
|
||||
if all(ch in chinese_punctuations for ch in tok):
|
||||
continue
|
||||
if tok.strip() == "":
|
||||
@@ -899,4 +932,4 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
result_tokens.append(tok)
|
||||
|
||||
filtered_concept_name_tokens = result_tokens
|
||||
return filtered_concept_name_tokens
|
||||
return filtered_concept_name_tokens
|
||||
|
||||
@@ -91,9 +91,10 @@ class ImageManager:
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
@@ -120,6 +121,7 @@ class ImageManager:
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||
if tags:
|
||||
|
||||
Reference in New Issue
Block a user