feat:改为单planner,并解析多个动作

This commit is contained in:
SengokuCola
2025-09-11 14:25:02 +08:00
parent 8ed94d1f26
commit a4285673aa
7 changed files with 1284 additions and 886 deletions

View File

@@ -0,0 +1,613 @@
#!/usr/bin/env python3
"""
基于Embedding的兴趣度计算测试脚本
使用MaiBot-Core的EmbeddingStore计算兴趣描述与目标文本的关联度
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import List, Dict, Tuple, Optional
import time
import json
import asyncio
from src.chat.knowledge.embedding_store import EmbeddingStore, cosine_similarity
from src.chat.knowledge.embedding_store import EMBEDDING_DATA_DIR_STR
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
class InterestScorer:
"""基于Embedding的兴趣度计算器"""
def __init__(self, namespace: str = "interest_test"):
"""初始化兴趣度计算器"""
self.embedding_store = EmbeddingStore(namespace, EMBEDDING_DATA_DIR_STR)
async def get_embedding(self, text: str) -> Tuple[Optional[List[float]], float]:
"""获取文本的嵌入向量"""
start_time = time.time()
try:
# 直接使用异步方式获取嵌入
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
embedding, _ = await llm.get_embedding(text)
end_time = time.time()
elapsed = end_time - start_time
if embedding and len(embedding) > 0:
return embedding, elapsed
return None, elapsed
except Exception as e:
print(f"获取嵌入向量失败: {e}")
return None, 0.0
async def calculate_similarity(self, text1: str, text2: str) -> Tuple[float, float, float]:
"""计算两段文本的余弦相似度,返回(相似度, 文本1耗时, 文本2耗时)"""
emb1, time1 = await self.get_embedding(text1)
emb2, time2 = await self.get_embedding(text2)
if emb1 is None or emb2 is None:
return 0.0, time1, time2
return cosine_similarity(emb1, emb2), time1, time2
async def calculate_interest_score(self, interest_text: str, target_text: str) -> Dict:
"""
计算兴趣度分数
Args:
interest_text: 兴趣描述文本
target_text: 目标文本
Returns:
包含各种分数的字典
"""
# 只计算语义相似度(嵌入分数)
semantic_score, interest_time, target_time = await self.calculate_similarity(interest_text, target_text)
# 直接使用语义相似度作为最终分数
final_score = semantic_score
return {
"final_score": final_score,
"semantic_score": semantic_score,
"timing": {
"interest_embedding_time": interest_time,
"target_embedding_time": target_time,
"total_time": interest_time + target_time
}
}
async def batch_calculate(self, interest_text: str, target_texts: List[str]) -> List[Dict]:
"""批量计算兴趣度"""
results = []
total_start_time = time.time()
print(f"开始批量计算兴趣度...")
print(f"兴趣文本: {interest_text}")
print(f"目标文本数量: {len(target_texts)}")
# 获取兴趣文本的嵌入向量(只需要一次)
interest_embedding, interest_time = await self.get_embedding(interest_text)
if interest_embedding is None:
print("无法获取兴趣文本的嵌入向量")
return []
print(f"兴趣文本嵌入计算耗时: {interest_time:.3f}")
total_target_time = 0.0
for i, target_text in enumerate(target_texts):
print(f"处理第 {i+1}/{len(target_texts)} 个文本...")
# 获取目标文本的嵌入向量
target_embedding, target_time = await self.get_embedding(target_text)
total_target_time += target_time
if target_embedding is None:
semantic_score = 0.0
else:
semantic_score = cosine_similarity(interest_embedding, target_embedding)
# 直接使用语义相似度作为最终分数
final_score = semantic_score
results.append({
"target_text": target_text,
"final_score": final_score,
"semantic_score": semantic_score,
"timing": {
"target_embedding_time": target_time,
"item_total_time": target_time
}
})
# 按分数排序
results.sort(key=lambda x: x["final_score"], reverse=True)
total_time = time.time() - total_start_time
avg_target_time = total_target_time / len(target_texts) if target_texts else 0
print(f"\n=== 性能统计 ===")
print(f"兴趣文本嵌入计算耗时: {interest_time:.3f}")
print(f"目标文本嵌入计算总耗时: {total_target_time:.3f}")
print(f"目标文本嵌入计算平均耗时: {avg_target_time:.3f}")
print(f"总耗时: {total_time:.3f}")
print(f"平均每个目标文本处理耗时: {total_time / len(target_texts):.3f}")
return results
async def generate_paraphrases(self, original_text: str, num_sentences: int = 5) -> List[str]:
"""
使用LLM生成近义句子
Args:
original_text: 原始文本
num_sentences: 生成句子数量
Returns:
近义句子列表
"""
try:
# 创建LLM请求实例
llm_request = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="paraphrase_generator"
)
# 构建生成近义句子的提示词
prompt = f"""请为以下兴趣描述生成{num_sentences}个意义相近但表达不同的句子:
原始兴趣描述:{original_text}
要求:
1. 保持原意不变,但尽量自由发挥,使用不同的表达方式,内容也可以有差异
2. 句子结构要有所变化
3. 可以适当调整语气和重点
4. 每个句子都要完整且自然
5. 只返回句子,不要编号,每行一个句子
生成的近义句子:"""
print(f"正在生成近义句子...")
content, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(prompt)
# 解析生成的句子
sentences = []
for line in content.strip().split('\n'):
line = line.strip()
if line and not line.startswith('生成') and not line.startswith('近义'):
sentences.append(line)
# 确保返回指定数量的句子
sentences = sentences[:num_sentences]
print(f"成功生成 {len(sentences)} 个近义句子")
print(f"使用的模型: {model_name}")
return sentences
except Exception as e:
print(f"生成近义句子失败: {e}")
return []
async def evaluate_all_paraphrases(self, original_text: str, target_texts: List[str], num_sentences: int = 5) -> Dict:
"""
评估原始文本和所有近义句子的兴趣度
Args:
original_text: 原始兴趣描述文本
target_texts: 目标文本列表
num_sentences: 生成近义句子数量
Returns:
包含所有评估结果的字典
"""
print(f"\n=== 开始近义句子兴趣度评估 ===")
print(f"原始兴趣描述: {original_text}")
print(f"目标文本数量: {len(target_texts)}")
print(f"生成近义句子数量: {num_sentences}")
# 生成近义句子
paraphrases = await self.generate_paraphrases(original_text, num_sentences)
if not paraphrases:
print("生成近义句子失败,使用原始文本进行评估")
paraphrases = []
# 所有待评估的文本(原始文本 + 近义句子)
all_texts = [original_text] + paraphrases
# 对每个文本进行兴趣度评估
evaluation_results = {}
for i, text in enumerate(all_texts):
text_type = "原始文本" if i == 0 else f"近义句子{i}"
print(f"\n--- 评估 {text_type} ---")
print(f"文本内容: {text}")
# 计算兴趣度
results = await self.batch_calculate(text, target_texts)
evaluation_results[text_type] = {
"text": text,
"results": results,
"top_score": results[0]["final_score"] if results else 0.0,
"average_score": sum(r["final_score"] for r in results) / len(results) if results else 0.0
}
return {
"original_text": original_text,
"paraphrases": paraphrases,
"evaluations": evaluation_results,
"summary": self._generate_summary(evaluation_results, target_texts)
}
def _generate_summary(self, evaluation_results: Dict, target_texts: List[str]) -> Dict:
"""生成评估摘要 - 关注目标句子的表现"""
summary = {
"best_performer": None,
"worst_performer": None,
"average_scores": {},
"max_scores": {},
"rankings": [],
"target_stats": {},
"target_rankings": []
}
scores = []
for text_type, data in evaluation_results.items():
scores.append({
"text_type": text_type,
"text": data["text"],
"top_score": data["top_score"],
"average_score": data["average_score"]
})
# 按top_score排序
scores.sort(key=lambda x: x["top_score"], reverse=True)
summary["rankings"] = scores
summary["best_performer"] = scores[0] if scores else None
summary["worst_performer"] = scores[-1] if scores else None
# 计算原始文本统计
original_score = next((s for s in scores if s["text_type"] == "原始文本"), None)
if original_score:
summary["average_scores"]["original"] = original_score["average_score"]
summary["max_scores"]["original"] = original_score["top_score"]
# 计算目标句子的统计信息
target_stats = {}
for i, target_text in enumerate(target_texts):
target_key = f"目标{i+1}"
scores_for_target = []
# 收集所有兴趣描述对该目标文本的分数
for text_type, data in evaluation_results.items():
for result in data["results"]:
if result["target_text"] == target_text:
scores_for_target.append(result["final_score"])
if scores_for_target:
target_stats[target_key] = {
"target_text": target_text,
"scores": scores_for_target,
"average": sum(scores_for_target) / len(scores_for_target),
"max": max(scores_for_target),
"min": min(scores_for_target),
"std": (sum((x - sum(scores_for_target) / len(scores_for_target)) ** 2 for x in scores_for_target) / len(scores_for_target)) ** 0.5
}
summary["target_stats"] = target_stats
# 按平均分对目标文本排序
target_rankings = []
for target_key, stats in target_stats.items():
target_rankings.append({
"target_key": target_key,
"target_text": stats["target_text"],
"average_score": stats["average"],
"max_score": stats["max"],
"min_score": stats["min"],
"std_score": stats["std"]
})
target_rankings.sort(key=lambda x: x["average_score"], reverse=True)
summary["target_rankings"] = target_rankings
# 计算目标文本的整体统计
if target_rankings:
all_target_averages = [t["average_score"] for t in target_rankings]
all_target_scores = []
for stats in target_stats.values():
all_target_scores.extend(stats["scores"])
summary["target_overall"] = {
"avg_of_averages": sum(all_target_averages) / len(all_target_averages),
"overall_max": max(all_target_scores),
"overall_min": min(all_target_scores),
"best_target": target_rankings[0]["target_text"],
"worst_target": target_rankings[-1]["target_text"]
}
return summary
async def run_single_test():
"""运行单个测试"""
print("单个兴趣度测试")
print("=" * 40)
# 输入兴趣文本
# interest_text = input("请输入兴趣描述文本: ").strip()
# if not interest_text:
# print("兴趣描述不能为空")
# return
interest_text ="对技术相关话题,游戏和动漫相关话题感兴趣,也对日常话题感兴趣,不喜欢太过沉重严肃的话题"
# 输入目标文本
print("请输入目标文本 (输入空行结束):")
import random
target_texts = [
"AveMujica非常好看你看了吗",
"明日方舟这个游戏挺好玩的",
"你能不能说点正经的",
"明日方舟挺好玩的",
"你的名字非常好看,你看了吗",
"《你的名字》非常好看,你看了吗",
"我们来聊聊苏联政治吧",
"轻音少女非常好看,你看了吗",
"我还挺喜欢打游戏的",
"我嘞个原神玩家啊",
"我心买了PlayStation5",
"直接Steam",
"有没有R"
]
random.shuffle(target_texts)
# while True:
# line = input().strip()
# if not line:
# break
# target_texts.append(line)
# if not target_texts:
# print("目标文本不能为空")
# return
# 计算兴趣度
scorer = InterestScorer()
results = await scorer.batch_calculate(interest_text, target_texts)
# 显示结果
print(f"\n兴趣度排序结果:")
print("-" * 80)
print(f"{'排名':<4} {'最终分数':<10} {'语义分数':<10} {'耗时(秒)':<10} {'目标文本'}")
print("-" * 80)
for j, result in enumerate(results):
target_text = result['target_text']
if len(target_text) > 40:
target_text = target_text[:37] + "..."
timing = result.get('timing', {})
item_time = timing.get('item_total_time', 0.0)
print(f"{j+1:<4} {result['final_score']:<10.3f} {result['semantic_score']:<10.3f} "
f"{item_time:<10.3f} {target_text}")
async def run_paraphrase_test():
"""运行近义句子测试"""
print("近义句子兴趣度对比测试")
print("=" * 40)
# 输入兴趣文本
interest_text = "对技术相关话题,游戏和动漫相关话题感兴趣,比如明日方舟和原神,也对日常话题感兴趣,不喜欢太过沉重严肃的话题"
# 输入目标文本
print("请输入目标文本 (输入空行结束):")
# target_texts = []
# while True:
# line = input().strip()
# if not line:
# break
# target_texts.append(line)
target_texts = [
"AveMujica非常好看你看了吗",
"明日方舟这个游戏挺好玩的",
"你能不能说点正经的",
"明日方舟挺好玩的",
"你的名字非常好看,你看了吗",
"《你的名字》非常好看,你看了吗",
"我们来聊聊苏联政治吧",
"轻音少女非常好看,你看了吗",
"我还挺喜欢打游戏的",
"刚加好友就视奸空间14条",
"可乐老大加我好友,我先日一遍空间",
"鸟一茬茬的",
"可乐可以是m群友可以是s"
]
if not target_texts:
print("目标文本不能为空")
return
# 创建评估器
scorer = InterestScorer()
# 运行评估
result = await scorer.evaluate_all_paraphrases(interest_text, target_texts, num_sentences=5)
# 显示结果
display_paraphrase_results(result, target_texts)
def display_paraphrase_results(result: Dict, target_texts: List[str]):
"""显示近义句子评估结果"""
print("\n" + "=" * 80)
print("近义句子兴趣度评估结果")
print("=" * 80)
# 显示目标文本
print(f"\n📋 目标文本列表:")
print("-" * 40)
for i, target in enumerate(target_texts):
print(f"{i+1}. {target}")
# 显示生成的近义句子
print(f"\n📝 生成的近义句子 (作为兴趣描述):")
print("-" * 40)
for i, paraphrase in enumerate(result["paraphrases"]):
print(f"{i+1}. {paraphrase}")
# 显示摘要
summary = result["summary"]
print(f"\n📊 评估摘要:")
print("-" * 40)
if summary["best_performer"]:
print(f"最佳表现: {summary['best_performer']['text_type']} (最高分: {summary['best_performer']['top_score']:.3f})")
if summary["worst_performer"]:
print(f"最差表现: {summary['worst_performer']['text_type']} (最高分: {summary['worst_performer']['top_score']:.3f})")
print(f"原始文本平均分: {summary['average_scores'].get('original', 0):.3f}")
# 显示目标文本的整体统计
if "target_overall" in summary:
overall = summary["target_overall"]
print(f"\n📈 目标文本整体统计:")
print("-" * 40)
print(f"目标文本数量: {len(summary['target_rankings'])}")
print(f"平均分的平均值: {overall['avg_of_averages']:.3f}")
print(f"所有匹配中的最高分: {overall['overall_max']:.3f}")
print(f"所有匹配中的最低分: {overall['overall_min']:.3f}")
print(f"最佳匹配目标: {overall['best_target'][:50]}...")
print(f"最差匹配目标: {overall['worst_target'][:50]}...")
# 显示目标文本排名
if "target_rankings" in summary and summary["target_rankings"]:
print(f"\n🏆 目标文本排名 (按平均分):")
print("-" * 80)
print(f"{'排名':<4} {'平均分':<8} {'最高分':<8} {'最低分':<8} {'标准差':<8} {'目标文本'}")
print("-" * 80)
for i, target in enumerate(summary["target_rankings"]):
target_text = target["target_text"][:40] + "..." if len(target["target_text"]) > 40 else target["target_text"]
print(f"{i+1:<4} {target['average_score']:<8.3f} {target['max_score']:<8.3f} {target['min_score']:<8.3f} {target['std_score']:<8.3f} {target_text}")
# 显示每个目标文本的详细分数分布
if "target_stats" in summary:
print(f"\n📊 目标文本详细分数分布:")
print("-" * 80)
for target_key, stats in summary["target_stats"].items():
print(f"\n{target_key}: {stats['target_text']}")
print(f" 平均分: {stats['average']:.3f}")
print(f" 最高分: {stats['max']:.3f}")
print(f" 最低分: {stats['min']:.3f}")
print(f" 标准差: {stats['std']:.3f}")
print(f" 所有分数: {[f'{s:.3f}' for s in stats['scores']]}")
# 显示最佳和最差兴趣描述的目标表现对比
if summary["best_performer"] and summary["worst_performer"]:
print(f"\n🔍 最佳 vs 最差兴趣描述对比:")
print("-" * 80)
best_data = result["evaluations"][summary["best_performer"]["text_type"]]
worst_data = result["evaluations"][summary["worst_performer"]["text_type"]]
print(f"最佳兴趣描述: {summary['best_performer']['text']}")
print(f"最差兴趣描述: {summary['worst_performer']['text']}")
print(f"")
print(f"{'目标文本':<30} {'最佳分数':<10} {'最差分数':<10} {'差值'}")
print("-" * 60)
for best_result, worst_result in zip(best_data["results"], worst_data["results"]):
if best_result["target_text"] == worst_result["target_text"]:
diff = best_result["final_score"] - worst_result["final_score"]
target_text = best_result["target_text"][:27] + "..." if len(best_result["target_text"]) > 30 else best_result["target_text"]
print(f"{target_text:<30} {best_result['final_score']:<10.3f} {worst_result['final_score']:<10.3f} {diff:+.3f}")
# 显示排名
print(f"\n🏆 兴趣描述性能排名:")
print("-" * 80)
print(f"{'排名':<4} {'文本类型':<10} {'最高分':<8} {'平均分':<8} {'兴趣描述内容'}")
print("-" * 80)
for i, item in enumerate(summary["rankings"]):
text_content = item["text"][:40] + "..." if len(item["text"]) > 40 else item["text"]
print(f"{i+1:<4} {item['text_type']:<10} {item['top_score']:<8.3f} {item['average_score']:<8.3f} {text_content}")
# 显示每个兴趣描述的详细结果
print(f"\n🔍 详细结果:")
print("-" * 80)
for text_type, data in result["evaluations"].items():
print(f"\n--- {text_type} ---")
print(f"兴趣描述: {data['text']}")
print(f"最高分: {data['top_score']:.3f}")
print(f"平均分: {data['average_score']:.3f}")
# 显示前3个匹配结果
top_results = data["results"][:3]
print(f"前3个匹配的目标文本:")
for j, result_item in enumerate(top_results):
print(f" {j+1}. 分数: {result_item['final_score']:.3f} - {result_item['target_text']}")
# 显示对比表格
print(f"\n📈 兴趣描述对比表格:")
print("-" * 100)
header = f"{'兴趣描述':<20}"
for i, target in enumerate(target_texts):
target_name = f"目标{i+1}"
header += f" {target_name:<12}"
print(header)
print("-" * 100)
# 原始文本行
original_line = f"{'原始文本':<20}"
original_data = result["evaluations"]["原始文本"]["results"]
for i in range(len(target_texts)):
if i < len(original_data):
original_line += f" {original_data[i]['final_score']:<12.3f}"
else:
original_line += f" {'-':<12}"
print(original_line)
# 近义句子行
for i, paraphrase in enumerate(result["paraphrases"]):
text_type = f"近义句子{i+1}"
line = f"{text_type:<20}"
paraphrase_data = result["evaluations"][text_type]["results"]
for j in range(len(target_texts)):
if j < len(paraphrase_data):
line += f" {paraphrase_data[j]['final_score']:<12.3f}"
else:
line += f" {'-':<12}"
print(line)
def main():
"""主函数"""
print("基于Embedding的兴趣度计算测试工具")
print("1. 单个兴趣度测试")
print("2. 近义句子兴趣度对比测试")
choice = input("\n请选择 (1/2): ").strip()
if choice == "1":
asyncio.run(run_single_test())
elif choice == "2":
asyncio.run(run_paraphrase_test())
else:
print("无效选择")
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,9 @@
from typing import Optional from typing import Optional
from datetime import datetime, timedelta
import statistics
from src.config.config import global_config from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
from src.common.database.database_model import Messages
def get_config_base_talk_frequency(chat_id: Optional[str] = None) -> float: def get_config_base_talk_frequency(chat_id: Optional[str] = None) -> float:
@@ -124,3 +127,146 @@ def get_global_frequency() -> Optional[float]:
return get_time_based_frequency(config_item[1:]) return get_time_based_frequency(config_item[1:])
return None return None
def get_weekly_hourly_message_stats(chat_id: str):
"""
计算指定聊天最近一周每个小时的消息数量和用户数量
Args:
chat_id: 聊天ID对应 Messages 表的 chat_id 字段)
Returns:
dict: 包含24个小时统计数据格式为:
{
"0": {"message_count": [5, 8, 3, 12, 6, 9, 7], "message_std_dev": 2.1},
"1": {"message_count": [10, 15, 8, 20, 12, 18, 14], "message_std_dev": 3.2},
...
}
"""
# 计算一周前的时间戳
one_week_ago = datetime.now() - timedelta(days=7)
one_week_ago_timestamp = one_week_ago.timestamp()
# 初始化数据结构:按小时存储每天的消息计数
hourly_data = {}
for hour in range(24):
hourly_data[f"hour_{hour}"] = {"daily_counts": []}
try:
# 查询指定聊天最近一周的消息
messages = Messages.select().where(
(Messages.time >= one_week_ago_timestamp) &
(Messages.chat_id == chat_id)
)
# 统计每个小时的数据
for message in messages:
# 将时间戳转换为datetime
msg_time = datetime.fromtimestamp(message.time)
hour = msg_time.hour
# 记录每天的消息计数(按日期分组)
day_key = msg_time.strftime("%Y-%m-%d")
hour_key = f"{hour}"
# 为该小时添加当天的消息计数
found = False
for day_count in hourly_data[hour_key]["daily_counts"]:
if day_count["date"] == day_key:
day_count["count"] += 1
found = True
break
if not found:
hourly_data[hour_key]["daily_counts"].append({"date": day_key, "count": 1})
except Exception as e:
# 如果查询失败,返回空的统计结果
print(f"Error getting weekly hourly message stats for chat {chat_id}: {e}")
hourly_stats = {}
for hour in range(24):
hourly_stats[f"hour_{hour}"] = {
"message_count": [],
"message_std_dev": 0.0
}
return hourly_stats
# 计算每个小时的统计结果
hourly_stats = {}
for hour in range(24):
hour_key = f"hour_{hour}"
daily_counts = [day["count"] for day in hourly_data[hour_key]["daily_counts"]]
# 计算总消息数
total_messages = sum(daily_counts)
# 计算标准差
message_std_dev = 0.0
if len(daily_counts) > 1:
message_std_dev = statistics.stdev(daily_counts)
elif len(daily_counts) == 1:
message_std_dev = 0.0
# 按日期排序每日消息计数
daily_counts_sorted = sorted(hourly_data[hour_key]["daily_counts"], key=lambda x: x["date"])
hourly_stats[hour_key] = {
"message_count": [day["count"] for day in daily_counts_sorted],
"message_std_dev": message_std_dev
}
return hourly_stats
def get_recent_15min_stats(chat_id: str):
"""
获取最近15分钟指定聊天的消息数量和发言人数
Args:
chat_id: 聊天ID对应 Messages 表的 chat_id 字段)
Returns:
dict: 包含消息数量和发言人数,格式为:
{
"message_count": 25,
"user_count": 8,
"time_range": "2025-01-01 14:30:00 - 2025-01-01 14:45:00"
}
"""
# 计算15分钟前的时间戳
fifteen_min_ago = datetime.now() - timedelta(minutes=15)
fifteen_min_ago_timestamp = fifteen_min_ago.timestamp()
current_time = datetime.now()
# 初始化统计结果
message_count = 0
user_set = set()
try:
# 查询最近15分钟的消息
messages = Messages.select().where(
(Messages.time >= fifteen_min_ago_timestamp) &
(Messages.chat_id == chat_id)
)
# 统计消息数量和用户
for message in messages:
message_count += 1
if message.user_id:
user_set.add(message.user_id)
except Exception as e:
# 如果查询失败,返回空结果
print(f"Error getting recent 15min stats for chat {chat_id}: {e}")
return {
"message_count": 0,
"user_count": 0,
"time_range": f"{fifteen_min_ago.strftime('%Y-%m-%d %H:%M:%S')} - {current_time.strftime('%Y-%m-%d %H:%M:%S')}"
}
return {
"message_count": message_count,
"user_count": len(user_set),
"time_range": f"{fifteen_min_ago.strftime('%Y-%m-%d %H:%M:%S')} - {current_time.strftime('%Y-%m-%d %H:%M:%S')}"
}

View File

@@ -18,7 +18,6 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.chat.express.expression_learner import expression_learner_manager from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
@@ -52,6 +51,16 @@ ERROR_LOOP_INFO = {
}, },
} }
# ?什么时候发言:
# 1.聊天频率较低:与过去该时段发言频率进行比较
# 2.感兴趣的话题暂时使用Emb计算
# 3.感兴趣的人:认识次数
# 4.直接提及
# 什么时候不发言:
# 1.敏感话题:判断较难
# 2.发言频率太高:近时判断,例如发言频率> 1/人数 *2
# 3.明确被拒绝planner判断
install(extra_lines=3) install(extra_lines=3)
@@ -85,8 +94,6 @@ class HeartFChatting:
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.frequency_control = frequency_control_manager.get_or_create_frequency_control(self.stream_id)
self.action_manager = ActionManager() self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager) self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id) self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
@@ -101,6 +108,10 @@ class HeartFChatting:
self._current_cycle_detail: CycleDetail = None # type: ignore self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 10 self.last_read_time = time.time() - 10
self.no_reply_until_call = False
async def start(self): async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。""" """检查是否需要启动主循环,如果未激活则启动。"""
@@ -156,38 +167,12 @@ class HeartFChatting:
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}" formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}") timer_strings.append(f"{name}: {formatted_time}")
# 获取动作类型,兼容新旧格式
# 移除无用代码
# action_type = "未知动作"
# if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
# loop_plan_info = self._current_cycle_detail.loop_plan_info
# if isinstance(loop_plan_info, dict):
# action_result = loop_plan_info.get("action_result", {})
# if isinstance(action_result, dict):
# # 旧格式action_result是字典
# action_type = action_result.get("action_type", "未知动作")
# elif isinstance(action_result, list) and action_result:
# # 新格式action_result是actions列表
# # TODO: 把这里写明白
# action_type = action_result[0].action_type or "未知动作"
# elif isinstance(loop_plan_info, list) and loop_plan_info:
# # 直接是actions列表的情况
# action_type = loop_plan_info[0].get("action_type", "未知动作")
logger.info( logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考," f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}" # type: ignore f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}" # type: ignore
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
) )
async def calculate_interest_value(self, recent_messages_list: List["DatabaseMessages"]) -> float:
total_interest = 0.0
for msg in recent_messages_list:
interest_value = msg.interest_value
if interest_value is not None and msg.processed_plain_text:
total_interest += float(interest_value)
return total_interest / len(recent_messages_list)
async def _loopbody(self): async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat( recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id, chat_id=self.stream_id,
@@ -200,9 +185,22 @@ class HeartFChatting:
) )
if recent_messages_list: if recent_messages_list:
# !处理no_reply_until_call逻辑
if self.no_reply_until_call:
for message in recent_messages_list:
if message.is_mentioned or message.is_at:
self.no_reply_until_call = False
break
# 没有提到,继续保持沉默
if self.no_reply_until_call:
logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
await asyncio.sleep(1)
return True
self.last_read_time = time.time() self.last_read_time = time.time()
await self._observe( await self._observe(
interest_value=await self.calculate_interest_value(recent_messages_list),
recent_messages_list=recent_messages_list, recent_messages_list=recent_messages_list,
) )
else: else:
@@ -262,190 +260,140 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers return loop_info, reply_text, cycle_timers
async def _observe( async def _observe(
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None
) -> bool: ) -> bool:
if recent_messages_list is None: if recent_messages_list is None:
recent_messages_list = [] recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率
# 当interest_value为0时概率接近0使用Focus模式
# 当interest_value很高时概率接近1使用Normal模式
def calculate_normal_mode_probability(interest_val: float) -> float:
# 使用sigmoid函数调整参数使概率分布更合理
# 当interest_value = 0时概率约为0.1
# 当interest_value = 1时概率约为0.5
# 当interest_value = 2时概率约为0.8
# 当interest_value = 3时概率约为0.95
k = 2.0 # 控制曲线陡峭程度
x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = (
calculate_normal_mode_probability(interest_value) * 2 * self.frequency_control.get_final_talk_frequency()
)
# 对呼唤名字进行增幅
for msg in recent_messages_list:
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
normal_mode_probability += msg.reply_probability_boost
if global_config.chat.mentioned_bot_reply and msg.is_mentioned:
normal_mode_probability += global_config.chat.mentioned_bot_reply
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
# 根据概率决定使用直接回复
interest_triggered = False
focus_triggered = False
if random.random() < normal_mode_probability:
interest_triggered = True
logger.info(f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复")
if s4u_config.enable_s4u: if s4u_config.enable_s4u:
await send_typing() await send_typing()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat() await self.expression_learner.trigger_learning_for_chat()
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
# 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {} available_actions: Dict[str, ActionInfo] = {}
try:
await self.action_modifier.modify_actions()
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 如果兴趣度不足以激活 # 执行planner
if not interest_triggered: is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
# 看看专注值够不够
if random.random() < self.frequency_control.get_final_focus_value():
# 专注值足够,仍然进入正式思考
focus_triggered = True # 都没触发,路边
# 任意一种触发都行 message_list_before_now = get_raw_msg_before_timestamp_with_chat(
if interest_triggered or focus_triggered: chat_id=self.stream_id,
# 进入正式思考模式 timestamp=time.time(),
cycle_timers, thinking_id = self.start_cycle() limit=int(global_config.chat.max_context_size * 0.6),
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") )
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
# 第一步:动作检查 prompt_info = await self.action_planner.build_planner_prompt(
try: is_group_chat=is_group_chat,
await self.action_modifier.modify_actions() chat_target_info=chat_target_info,
available_actions = self.action_manager.get_using_actions() current_available_actions=available_actions,
except Exception as e: chat_content_block=chat_content_block,
logger.error(f"{self.log_prefix} 动作修改失败: {e}") message_id_list=message_id_list,
interest=global_config.personality.interest,
# 执行planner )
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
message_list_before_now = get_raw_msg_before_timestamp_with_chat( )
chat_id=self.stream_id, if not continue_flag:
timestamp=time.time(), return False
limit=int(global_config.chat.max_context_size * 0.6), if modified_message and modified_message._modify_flags.modify_llm_prompt:
) prompt_info = (modified_message.llm_prompt, prompt_info[1])
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD", with Timer("规划器", cycle_timers):
read_mark=self.action_planner.last_obs_time_mark, action_to_use_info, _ = await self.action_planner.plan(
truncate=True, loop_start_time=self.last_read_time,
show_actions=True, available_actions=available_actions,
) )
prompt_info = await self.action_planner.build_planner_prompt( # 3. 并行执行所有动作
is_group_chat=is_group_chat, action_tasks = [
chat_target_info=chat_target_info, asyncio.create_task(
# current_available_actions=planner_info[2], self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
chat_content_block=chat_content_block,
# actions_before_now_block=actions_before_now_block,
message_id_list=message_id_list,
) )
continue_flag, modified_message = await events_manager.handle_mai_events( for action in action_to_use_info
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id ]
)
if not continue_flag: # 并行执行所有任务
return False results = await asyncio.gather(*action_tasks, return_exceptions=True)
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1]) # 处理执行结果
with Timer("规划器", cycle_timers): reply_loop_info = None
# 根据不同触发进入不同plan reply_text_from_reply = ""
if focus_triggered: action_success = False
mode = ChatMode.FOCUS action_reply_text = ""
action_command = ""
for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
_cur_action = action_to_use_info[i]
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
action_command = result.get("command", "")
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["reply_text"]
else: else:
mode = ChatMode.NORMAL logger.warning(f"{self.log_prefix} 回复动作执行失败")
action_to_use_info, _ = await self.action_planner.plan( # 构建最终的循环信息
mode=mode, if reply_loop_info:
loop_start_time=self.last_read_time, # 如果有回复信息使用回复的loop_info作为基础
available_actions=available_actions, loop_info = reply_loop_info
) # 更新动作执行信息
loop_info["loop_action_info"].update(
# 3. 并行执行所有动作 {
action_tasks = [ "action_taken": action_success,
asyncio.create_task( "command": action_command,
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) "taken_time": time.time(),
)
for action in action_to_use_info
]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
action_command = ""
for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
_cur_action = action_to_use_info[i]
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
action_command = result.get("command", "")
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["reply_text"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"command": action_command,
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"command": action_command,
"taken_time": time.time(),
},
} }
reply_text = action_reply_text )
reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"command": action_command,
"taken_time": time.time(),
},
}
reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers) self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers) self.print_cycle_info(cycle_timers)
"""S4U内容暂时保留""" """S4U内容暂时保留"""
if s4u_config.enable_s4u: if s4u_config.enable_s4u:
await stop_typing() await stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text) await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
"""S4U内容暂时保留""" """S4U内容暂时保留"""
return True return True
@@ -579,7 +527,7 @@ class HeartFChatting:
): ):
"""执行单个动作的通用函数""" """执行单个动作的通用函数"""
try: try:
if action_planner_info.action_type == "no_action": if action_planner_info.action_type == "no_reply":
# 直接处理no_action逻辑不再通过动作系统 # 直接处理no_action逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复" reason = action_planner_info.reasoning or "选择不回复"
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
@@ -594,26 +542,19 @@ class HeartFChatting:
action_data={"reason": reason}, action_data={"reason": reason},
action_name="no_action", action_name="no_action",
) )
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type != "reply":
# 执行普通动作 elif action_planner_info.action_type == "wait_time":
with Timer("动作执行", cycle_timers): logger.info(f"{self.log_prefix} 等待{action_planner_info.action_data['time']}秒后回复")
success, reply_text, command = await self._handle_action( await asyncio.sleep(action_planner_info.action_data["time"])
action_planner_info.action_type, return {"action_type": "wait_time", "success": True, "reply_text": "", "command": ""}
action_planner_info.reasoning or "",
action_planner_info.action_data or {}, elif action_planner_info.action_type == "no_reply_until_call":
cycle_timers, logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
thinking_id, self.no_reply_until_call = True
action_planner_info.action_message, return {"action_type": "no_reply_until_call", "success": True, "reply_text": "", "command": ""}
)
return { elif action_planner_info.action_type == "reply":
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try: try:
success, llm_response = await generator_api.generate_reply( success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
@@ -652,6 +593,26 @@ class HeartFChatting:
"reply_text": reply_text, "reply_text": reply_text,
"loop_info": loop_info, "loop_info": loop_info,
} }
# 其他动作
else:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
)
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}") logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")

File diff suppressed because it is too large Load Diff

View File

@@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/ # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.3-snapshot.1" MMC_VERSION = "0.10.3-snapshot.2"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):

View File

@@ -277,10 +277,12 @@ async def _default_stream_response_handler(
# 空 choices / usage-only 帧的防御 # 空 choices / usage-only 帧的防御
if not hasattr(event, "choices") or not event.choices: if not hasattr(event, "choices") or not event.choices:
if hasattr(event, "usage") and event.usage: if hasattr(event, "usage") and event.usage:
# 安全地获取usage属性处理不同API版本的差异
usage_obj = event.usage
_usage_record = ( _usage_record = (
event.usage.prompt_tokens or 0, getattr(usage_obj, 'prompt_tokens', 0) or 0,
event.usage.completion_tokens or 0, getattr(usage_obj, 'completion_tokens', 0) or 0,
event.usage.total_tokens or 0, getattr(usage_obj, 'total_tokens', 0) or 0,
) )
continue # 跳过本帧,避免访问 choices[0] continue # 跳过本帧,避免访问 choices[0]
delta = event.choices[0].delta # 获取当前块的delta内容 delta = event.choices[0].delta # 获取当前块的delta内容
@@ -300,10 +302,12 @@ async def _default_stream_response_handler(
if event.usage: if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中 # 如果有使用情况则将其存储在APIResponse对象中
# 安全地获取usage属性处理不同API版本的差异
usage_obj = event.usage
_usage_record = ( _usage_record = (
event.usage.prompt_tokens or 0, getattr(usage_obj, 'prompt_tokens', 0) or 0,
event.usage.completion_tokens or 0, getattr(usage_obj, 'completion_tokens', 0) or 0,
event.usage.total_tokens or 0, getattr(usage_obj, 'total_tokens', 0) or 0,
) )
try: try:
@@ -370,10 +374,12 @@ def _default_normal_response_parser(
# 提取Usage信息 # 提取Usage信息
if resp.usage: if resp.usage:
# 安全地获取usage属性处理不同API版本的差异
usage_obj = resp.usage
_usage_record = ( _usage_record = (
resp.usage.prompt_tokens or 0, getattr(usage_obj, 'prompt_tokens', 0) or 0,
resp.usage.completion_tokens or 0, getattr(usage_obj, 'completion_tokens', 0) or 0,
resp.usage.total_tokens or 0, getattr(usage_obj, 'total_tokens', 0) or 0,
) )
else: else:
_usage_record = None _usage_record = None

View File

@@ -5,6 +5,7 @@ import time
from enum import Enum from enum import Enum
from rich.traceback import install from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any from typing import Tuple, List, Dict, Optional, Callable, Any
import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import model_config from src.config.config import model_config
@@ -391,6 +392,7 @@ class LLMRequest:
logger.debug(f"附加内容: {str(e.ext_info)}") logger.debug(f"附加内容: {str(e.ext_info)}")
return -1, None # 不再重试请求该模型 return -1, None # 不再重试请求该模型
else: else:
print(traceback.format_exc())
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
return -1, None # 不再重试请求该模型 return -1, None # 不再重试请求该模型