ruff reformatted
This commit is contained in:
@@ -27,6 +27,7 @@ chat_config = LogConfig(
|
||||
|
||||
logger = get_module_logger("reasoning_chat", config=chat_config)
|
||||
|
||||
|
||||
class ReasoningChat:
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
@@ -224,13 +225,13 @@ class ReasoningChat:
|
||||
do_reply = False
|
||||
if random() < reply_probability:
|
||||
do_reply = True
|
||||
|
||||
|
||||
# 创建思考消息
|
||||
timer1 = time.time()
|
||||
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
|
||||
timer2 = time.time()
|
||||
timing_results["创建思考消息"] = timer2 - timer1
|
||||
|
||||
|
||||
# 生成回复
|
||||
timer1 = time.time()
|
||||
response_set = await self.gpt.generate_response(message)
|
||||
|
||||
@@ -40,7 +40,7 @@ class ResponseGenerator:
|
||||
|
||||
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
#从global_config中获取模型概率值并选择模型
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
if random.random() < global_config.MODEL_R1_PROBABILITY:
|
||||
self.current_model_type = "深深地"
|
||||
current_model = self.model_reasoning
|
||||
@@ -51,7 +51,6 @@ class ResponseGenerator:
|
||||
logger.info(
|
||||
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||
) # noqa: E501
|
||||
|
||||
|
||||
model_response = await self._generate_response_with_model(message, current_model)
|
||||
|
||||
@@ -189,4 +188,4 @@ class ResponseGenerator:
|
||||
|
||||
# print(f"得到了处理后的llm返回{processed_response}")
|
||||
|
||||
return processed_response
|
||||
return processed_response
|
||||
|
||||
@@ -24,35 +24,32 @@ class PromptBuilder:
|
||||
async def _build_prompt(
|
||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
) -> tuple[str, str]:
|
||||
|
||||
# 开始构建prompt
|
||||
prompt_personality = "你"
|
||||
#person
|
||||
# person
|
||||
individuality = Individuality.get_instance()
|
||||
|
||||
|
||||
personality_core = individuality.personality.personality_core
|
||||
prompt_personality += personality_core
|
||||
|
||||
|
||||
personality_sides = individuality.personality.personality_sides
|
||||
random.shuffle(personality_sides)
|
||||
prompt_personality += f",{personality_sides[0]}"
|
||||
|
||||
|
||||
identity_detail = individuality.identity.identity_detail
|
||||
random.shuffle(identity_detail)
|
||||
prompt_personality += f",{identity_detail[0]}"
|
||||
|
||||
|
||||
|
||||
|
||||
# 关系
|
||||
who_chat_in_group = [(chat_stream.user_info.platform,
|
||||
chat_stream.user_info.user_id,
|
||||
chat_stream.user_info.user_nickname)]
|
||||
who_chat_in_group = [
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
]
|
||||
who_chat_in_group += get_recent_group_speaker(
|
||||
stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
|
||||
limit=global_config.MAX_CONTEXT_SIZE,
|
||||
)
|
||||
|
||||
|
||||
relation_prompt = ""
|
||||
for person in who_chat_in_group:
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||
@@ -67,7 +64,7 @@ class PromptBuilder:
|
||||
mood_prompt = mood_manager.get_prompt()
|
||||
|
||||
# logger.info(f"心情prompt: {mood_prompt}")
|
||||
|
||||
|
||||
# 调取记忆
|
||||
memory_prompt = ""
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
@@ -84,7 +81,7 @@ class PromptBuilder:
|
||||
# print(f"相关记忆:{related_memory_info}")
|
||||
|
||||
# 日程构建
|
||||
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
|
||||
schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
|
||||
|
||||
# 获取聊天上下文
|
||||
chat_in_group = True
|
||||
@@ -143,7 +140,7 @@ class PromptBuilder:
|
||||
涉及政治敏感以及违法违规的内容请规避。"""
|
||||
|
||||
logger.info("开始构建prompt")
|
||||
|
||||
|
||||
prompt = f"""
|
||||
{relation_prompt_all}
|
||||
{memory_prompt}
|
||||
@@ -165,7 +162,7 @@ class PromptBuilder:
|
||||
start_time = time.time()
|
||||
related_info = ""
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
|
||||
|
||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||
topics = []
|
||||
# try:
|
||||
@@ -173,7 +170,7 @@ class PromptBuilder:
|
||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||
|
||||
|
||||
# # 提取关键词
|
||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||
# if not topics:
|
||||
@@ -184,7 +181,7 @@ class PromptBuilder:
|
||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
# if topic.strip()
|
||||
# ]
|
||||
|
||||
|
||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||
@@ -192,7 +189,7 @@ class PromptBuilder:
|
||||
# words = jieba.cut(message)
|
||||
# topics = [word for word in words if len(word) > 1][:5]
|
||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||
|
||||
|
||||
# 如果无法提取到主题,直接使用整个消息
|
||||
if not topics:
|
||||
logger.info("未能提取到任何主题,使用整个消息进行查询")
|
||||
@@ -200,26 +197,26 @@ class PromptBuilder:
|
||||
if not embedding:
|
||||
logger.error("获取消息嵌入向量失败")
|
||||
return ""
|
||||
|
||||
|
||||
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||
return related_info
|
||||
|
||||
|
||||
# 2. 对每个主题进行知识库查询
|
||||
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||
|
||||
|
||||
# 优化:批量获取嵌入向量,减少API调用
|
||||
embeddings = {}
|
||||
topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||
if message: # 确保消息非空
|
||||
topics_batch.append(message)
|
||||
|
||||
|
||||
# 批量获取嵌入向量
|
||||
embed_start_time = time.time()
|
||||
for text in topics_batch:
|
||||
if not text or len(text.strip()) == 0:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
embedding = await get_embedding(text, request_type="prompt_build")
|
||||
if embedding:
|
||||
@@ -228,17 +225,17 @@ class PromptBuilder:
|
||||
logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||
except Exception as e:
|
||||
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||
|
||||
|
||||
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||
|
||||
|
||||
if not embeddings:
|
||||
logger.error("所有嵌入向量获取失败")
|
||||
return ""
|
||||
|
||||
|
||||
# 3. 对每个主题进行知识库查询
|
||||
all_results = []
|
||||
query_start_time = time.time()
|
||||
|
||||
|
||||
# 首先添加原始消息的查询结果
|
||||
if message in embeddings:
|
||||
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||
@@ -247,12 +244,12 @@ class PromptBuilder:
|
||||
result["topic"] = "原始消息"
|
||||
all_results.extend(original_results)
|
||||
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||
|
||||
|
||||
# 然后添加每个主题的查询结果
|
||||
for topic in topics:
|
||||
if not topic or topic not in embeddings:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||
if topic_results:
|
||||
@@ -263,9 +260,9 @@ class PromptBuilder:
|
||||
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||
|
||||
|
||||
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||
|
||||
|
||||
# 4. 去重和过滤
|
||||
process_start_time = time.time()
|
||||
unique_contents = set()
|
||||
@@ -275,14 +272,16 @@ class PromptBuilder:
|
||||
if content not in unique_contents:
|
||||
unique_contents.add(content)
|
||||
filtered_results.append(result)
|
||||
|
||||
|
||||
# 5. 按相似度排序
|
||||
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
|
||||
# 6. 限制总数量(最多10条)
|
||||
filtered_results = filtered_results[:10]
|
||||
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
|
||||
|
||||
logger.info(
|
||||
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
||||
)
|
||||
|
||||
# 7. 格式化输出
|
||||
if filtered_results:
|
||||
format_start_time = time.time()
|
||||
@@ -292,7 +291,7 @@ class PromptBuilder:
|
||||
if topic not in grouped_results:
|
||||
grouped_results[topic] = []
|
||||
grouped_results[topic].append(result)
|
||||
|
||||
|
||||
# 按主题组织输出
|
||||
for topic, results in grouped_results.items():
|
||||
related_info += f"【主题: {topic}】\n"
|
||||
@@ -303,13 +302,15 @@ class PromptBuilder:
|
||||
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||
related_info += f"{content}\n"
|
||||
related_info += "\n"
|
||||
|
||||
|
||||
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||
|
||||
|
||||
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||
return related_info
|
||||
|
||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
|
||||
def get_info_from_db(
|
||||
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
# 使用余弦相似度计算
|
||||
|
||||
Reference in New Issue
Block a user