超级Ruff

This commit is contained in:
墨梓柒
2025-09-09 19:25:12 +08:00
parent ac2936d5fc
commit 163dbb6b90
68 changed files with 1092 additions and 1043 deletions

3
bot.py
View File

@@ -62,9 +62,10 @@ def easter_egg():
async def graceful_shutdown(): # sourcery skip: use-named-expression
try:
logger.info("正在优雅关闭麦麦...")
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType
# 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)

View File

@@ -5,12 +5,11 @@ from typing import Dict, List
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
@@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
@@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
'0-1天': 0,
'1-3天': 0,
'3-7天': 0,
'7-14天': 0,
'14-30天': 0,
'30-60天': 0,
'60-90天': 0,
'90+天': 0
"0-1天": 0,
"1-3天": 0,
"3-7天": 0,
"7-14天": 0,
"14-30天": 0,
"30-60天": 0,
"60-90天": 0,
"90+天": 0,
}
for expr in expressions:
diff_days = (now - expr.last_active_time) / (24*3600)
diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1:
distribution['0-1天'] += 1
distribution["0-1天"] += 1
elif diff_days < 3:
distribution['1-3天'] += 1
distribution["1-3天"] += 1
elif diff_days < 7:
distribution['3-7天'] += 1
distribution["3-7天"] += 1
elif diff_days < 14:
distribution['7-14天'] += 1
distribution["7-14天"] += 1
elif diff_days < 30:
distribution['14-30天'] += 1
distribution["14-30天"] += 1
elif diff_days < 60:
distribution['30-60天'] += 1
distribution["30-60天"] += 1
elif diff_days < 90:
distribution['60-90天'] += 1
distribution["60-90天"] += 1
else:
distribution['90+天'] += 1
distribution["90+天"] += 1
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values"""
distribution = {
'0-1': 0,
'1-2': 0,
'2-3': 0,
'3-4': 0,
'4-5': 0,
'5-10': 0,
'10+': 0
}
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
cnt = expr.count
if cnt < 1:
distribution['0-1'] += 1
distribution["0-1"] += 1
elif cnt < 2:
distribution['1-2'] += 1
distribution["1-2"] += 1
elif cnt < 3:
distribution['2-3'] += 1
distribution["2-3"] += 1
elif cnt < 4:
distribution['3-4'] += 1
distribution["3-4"] += 1
elif cnt < 5:
distribution['4-5'] += 1
distribution["4-5"] += 1
elif cnt < 10:
distribution['5-10'] += 1
distribution["5-10"] += 1
else:
distribution['10+'] += 1
distribution["10+"] += 1
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return (Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.count.desc())
.limit(top_n))
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics"""
time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}")
print("\n上次激活时间分布:")
for period, count in time_dist.items():
print(f"{period}: {count} ({count/total*100:.2f}%)")
print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:")
for range_, count in count_dist.items():
print(f"{range_}: {count} ({count/total*100:.2f}%)")
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}")
if chat_total == 0:
print("该聊天没有表达式数据")
return
# Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:")
for period, count in time_dist.items():
if count > 0:
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:")
for range_, count in count_dist.items():
if count > 0:
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions
print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10)
@@ -163,32 +151,32 @@ def interactive_menu() -> None:
if not expressions:
print("数据库中没有找到表达式")
return
total = len(expressions)
# Get unique chat_ids and their names
chat_ids = list(set(expr.chat_id for expr in expressions))
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True:
print("\n" + "="*50)
print("\n" + "=" * 50)
print("表达式统计分析")
print("="*50)
print("=" * 50)
print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == 'q':
if choice.lower() == "q":
print("再见!")
break
try:
choice_num = int(choice)
if choice_num == 0:
@@ -200,9 +188,9 @@ def interactive_menu() -> None:
print("无效的选择,请重新输入")
except ValueError:
print("请输入有效的数字")
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()
interactive_menu()

View File

@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入")
def ensure_openie_dir():
"""确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR):
@@ -253,7 +254,7 @@ def main():
# 没有运行的事件循环,创建新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步主函数
loop.run_until_complete(main_async())

View File

@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR):
@@ -48,6 +50,7 @@ def ensure_dirs():
os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"

View File

@@ -3,12 +3,11 @@ import sys
import os
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str:
@@ -17,7 +16,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
@@ -39,66 +38,62 @@ def format_timestamp(timestamp: float) -> str:
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
'0.000-0.010': 0,
'0.010-0.050': 0,
'0.050-0.100': 0,
'0.100-0.500': 0,
'0.500-1.000': 0,
'1.000-2.000': 0,
'2.000-5.000': 0,
'5.000-10.000': 0,
'10.000+': 0
"0.000-0.010": 0,
"0.010-0.050": 0,
"0.050-0.100": 0,
"0.100-0.500": 0,
"0.500-1.000": 0,
"1.000-2.000": 0,
"2.000-5.000": 0,
"5.000-10.000": 0,
"10.000+": 0,
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution['0.000-0.010'] += 1
distribution["0.000-0.010"] += 1
elif value < 0.050:
distribution['0.010-0.050'] += 1
distribution["0.010-0.050"] += 1
elif value < 0.100:
distribution['0.050-0.100'] += 1
distribution["0.050-0.100"] += 1
elif value < 0.500:
distribution['0.100-0.500'] += 1
distribution["0.100-0.500"] += 1
elif value < 1.000:
distribution['0.500-1.000'] += 1
distribution["0.500-1.000"] += 1
elif value < 2.000:
distribution['1.000-2.000'] += 1
distribution["1.000-2.000"] += 1
elif value < 5.000:
distribution['2.000-5.000'] += 1
distribution["2.000-5.000"] += 1
elif value < 10.000:
distribution['5.000-10.000'] += 1
distribution["5.000-10.000"] += 1
else:
distribution['10.000+'] += 1
distribution["10.000+"] += 1
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values:
return {
'count': 0,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
values.sort()
count = len(values)
return {
'count': count,
'min': min(values),
'max': max(values),
'avg': sum(values) / count,
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
"count": count,
"min": min(values),
"max": max(values),
"avg": sum(values) / count,
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
}
@@ -109,20 +104,24 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
).count()
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
@@ -135,30 +134,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
return now - 24 * 3600, now
elif choice == "2":
return now - 3*24*3600, now
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7*24*3600, now
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30*24*3600, now
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
@@ -170,41 +169,40 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
return None, None
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where(
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
)
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
@@ -213,16 +211,16 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats['count']
total = stats["count"]
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
@@ -231,34 +229,34 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "="*50)
print("\n" + "=" * 50)
print("Interest Value 分析工具")
print("="*50)
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
@@ -269,19 +267,19 @@ def interactive_menu() -> None:
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()
interactive_menu()

View File

@@ -828,7 +828,7 @@ class LogViewer:
parts, tags = self.formatter.format_log_entry(log_entry)
line_text = " ".join(parts)
log_lines.append(line_text)
with open(filename, "w", encoding="utf-8") as f:
f.write("\n".join(log_lines))
messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
@@ -1188,15 +1188,16 @@ class LogViewer:
line_count += 1
except json.JSONDecodeError:
continue
# 如果发现了新模块,在主线程中更新模块集合
if new_modules:
def update_modules():
self.modules.update(new_modules)
self.update_module_list()
self.root.after(0, update_modules)
return new_entries
def append_new_logs(self, new_entries):
@@ -1424,4 +1425,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -2,6 +2,7 @@ import os
from pathlib import Path
import sys # 新增系统模块导入
from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
all_paragraphs.extend(paragraphs)
return all_paragraphs
def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件
@@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]:
raw_data.append(item)
logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data
return sha256_list, raw_data

View File

@@ -4,21 +4,22 @@ import os
import re
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
from src.common.database.database_model import Messages, ChatStreams # noqa
def contains_emoji_or_image_tags(text: str) -> bool:
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
if not text:
return False
# 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r'\[表情包[^\]]*\]'
image_pattern = r'\[图片[^\]]*\]'
emoji_pattern = r"\[表情包[^\]]*\]"
image_pattern = r"\[图片[^\]]*\]"
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
@@ -26,14 +27,14 @@ def clean_reply_text(text: str) -> str:
"""Remove reply references like [回复 xxxx...] from text"""
if not text:
return text
# 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
# 去除多余的空白字符
cleaned_text = cleaned_text.strip()
return cleaned_text
@@ -43,7 +44,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
@@ -65,63 +66,63 @@ def format_timestamp(timestamp: float) -> str:
def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
'0': 0, # 空文本
'1-5': 0, # 极短文本
'6-10': 0, # 很短文本
'11-20': 0, # 短文本
'21-30': 0, # 较短文本
'31-50': 0, # 中短文本
'51-70': 0, # 中等文本
'71-100': 0, # 较长文本
'101-150': 0, # 长文本
'151-200': 0, # 很长文本
'201-300': 0, # 超长文本
'301-500': 0, # 极长文本
'501-1000': 0, # 巨长文本
'1000+': 0 # 超巨长文本
"0": 0, # 空文本
"1-5": 0, # 极短文本
"6-10": 0, # 很短文本
"11-20": 0, # 短文本
"21-30": 0, # 较短文本
"31-50": 0, # 中短文本
"51-70": 0, # 中等文本
"71-100": 0, # 较长文本
"101-150": 0, # 长文本
"151-200": 0, # 很长文本
"201-300": 0, # 超长文本
"301-500": 0, # 极长文本
"501-1000": 0, # 巨长文本
"1000+": 0, # 超巨长文本
}
for msg in messages:
if msg.processed_plain_text is None:
continue
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
if length == 0:
distribution['0'] += 1
distribution["0"] += 1
elif length <= 5:
distribution['1-5'] += 1
distribution["1-5"] += 1
elif length <= 10:
distribution['6-10'] += 1
distribution["6-10"] += 1
elif length <= 20:
distribution['11-20'] += 1
distribution["11-20"] += 1
elif length <= 30:
distribution['21-30'] += 1
distribution["21-30"] += 1
elif length <= 50:
distribution['31-50'] += 1
distribution["31-50"] += 1
elif length <= 70:
distribution['51-70'] += 1
distribution["51-70"] += 1
elif length <= 100:
distribution['71-100'] += 1
distribution["71-100"] += 1
elif length <= 150:
distribution['101-150'] += 1
distribution["101-150"] += 1
elif length <= 200:
distribution['151-200'] += 1
distribution["151-200"] += 1
elif length <= 300:
distribution['201-300'] += 1
distribution["201-300"] += 1
elif length <= 500:
distribution['301-500'] += 1
distribution["301-500"] += 1
elif length <= 1000:
distribution['501-1000'] += 1
distribution["501-1000"] += 1
else:
distribution['1000+'] += 1
distribution["1000+"] += 1
return distribution
@@ -130,7 +131,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
lengths = []
null_count = 0
excluded_count = 0 # 被排除的消息数量
for msg in messages:
if msg.processed_plain_text is None:
null_count += 1
@@ -141,29 +142,29 @@ def get_text_length_stats(messages) -> Dict[str, float]:
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
lengths.append(len(cleaned_text))
if not lengths:
return {
'count': 0,
'null_count': null_count,
'excluded_count': excluded_count,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
"count": 0,
"null_count": null_count,
"excluded_count": excluded_count,
"min": 0,
"max": 0,
"avg": 0,
"median": 0,
}
lengths.sort()
count = len(lengths)
return {
'count': count,
'null_count': null_count,
'excluded_count': excluded_count,
'min': min(lengths),
'max': max(lengths),
'avg': sum(lengths) / count,
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
"count": count,
"null_count": null_count,
"excluded_count": excluded_count,
"min": min(lengths),
"max": max(lengths),
"avg": sum(lengths) / count,
"median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
}
@@ -174,21 +175,25 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
).count()
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.is_emoji != 1)
& (Messages.is_picid != 1)
& (Messages.is_command != 1)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
@@ -201,30 +206,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
return now - 24 * 3600, now
elif choice == "2":
return now - 3*24*3600, now
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7*24*3600, now
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30*24*3600, now
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
@@ -239,13 +244,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
for msg in messages:
if msg.processed_plain_text is not None:
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
@@ -254,42 +259,40 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
# 截取前100个字符作为预览
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
message_lengths.append((chat_name, length, time_str, preview))
# 按长度排序取前N个
message_lengths.sort(key=lambda x: x[1], reverse=True)
return message_lengths[:top_n]
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息
query = Messages.select().where(
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
)
query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_text_length_distribution(messages)
stats = get_text_length_stats(messages)
top_longest = get_top_longest_messages(messages, 10)
# 显示结果
print("\n=== Processed Plain Text 长度分析结果 ===")
print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)")
@@ -297,7 +300,7 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
@@ -306,26 +309,26 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"总消息数量: {len(messages)}")
print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}")
if stats['count'] > 0:
if stats["count"] > 0:
print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:")
total = stats['count']
total = stats["count"]
if total > 0:
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
# 显示最长的消息
if top_longest:
print(f"\n最长的 {len(top_longest)} 条消息:")
@@ -338,34 +341,34 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
def interactive_menu() -> None:
"""Interactive menu for text length analysis"""
while True:
print("\n" + "="*50)
print("\n" + "=" * 50)
print("Processed Plain Text 长度分析工具")
print("="*50)
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到聊天数据")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
@@ -376,19 +379,19 @@ def interactive_menu() -> None:
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_text_lengths(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()
interactive_menu()

View File

@@ -708,7 +708,7 @@ class EmojiManager:
if not emoji.is_deleted and emoji.hash == emoji_hash:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
"""根据哈希值获取已注册表情包的情感标签列表
@@ -731,7 +731,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',')
return emoji_record.emotion.split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")

View File

@@ -77,10 +77,10 @@ class ExpressionSelector:
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
@@ -123,9 +123,7 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int
) -> List[Dict[str, Any]]:
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -200,7 +198,7 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
@@ -208,7 +206,7 @@ class ExpressionSelector:
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 10)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
@@ -248,7 +246,6 @@ class ExpressionSelector:
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
@@ -295,7 +292,6 @@ class ExpressionSelector:
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()

View File

@@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]:
return get_time_based_focus_value(config_item[1:])
return None

View File

@@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]:
return get_time_based_frequency(config_item[1:])
return None

View File

@@ -34,4 +34,4 @@ def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None
return None

View File

@@ -155,21 +155,21 @@ class HeartFChatting:
timer_strings.append(f"{name}: {formatted_time}")
# 获取动作类型,兼容新旧格式
action_type = "未知动作"
_action_type = "未知动作"
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
loop_plan_info = self._current_cycle_detail.loop_plan_info
if isinstance(loop_plan_info, dict):
action_result = loop_plan_info.get("action_result", {})
if isinstance(action_result, dict):
# 旧格式action_result是字典
action_type = action_result.get("action_type", "未知动作")
_action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
# TODO: 把这里写明白
action_type = action_result[0].action_type or "未知动作"
_action_type = action_result[0].action_type or "未知动作"
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get("action_type", "未知动作")
_action_type = loop_plan_info[0].get("action_type", "未知动作")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
@@ -258,7 +258,11 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
async def _observe(
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None
) -> bool:
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率

View File

@@ -3,14 +3,16 @@ from typing import Any, Optional, Dict
from src.common.logger import get_logger
from src.chat.heart_flow.heartFC_chat import HeartFChatting
logger = get_logger("heartflow")
class Heartflow:
"""主心流协调器,负责初始化并协调聊天"""
def __init__(self):
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
"""获取或创建一个新的HeartFChatting实例"""
try:
@@ -18,7 +20,7 @@ class Heartflow:
if chat := self.heartflow_chat_list.get(chat_id):
return chat
else:
new_chat = HeartFChatting(chat_id = chat_id)
new_chat = HeartFChatting(chat_id=chat_id)
await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat
return new_chat
@@ -27,4 +29,5 @@ class Heartflow:
traceback.print_exc()
return None
heartflow = Heartflow()

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度
@@ -34,14 +35,14 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""
if message.is_picid or message.is_emoji:
return 0.0, []
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message)
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
interested_rate = 0.0
with Timer("记忆激活"):
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
max_depth= 4,
max_depth=4,
fast_retrieval=global_config.chat.interest_rate_mode == "fast",
)
message.key_words = keywords
@@ -51,7 +52,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
@@ -75,16 +76,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
message.interest_value = base_interest
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
return base_interest, keywords
@@ -115,14 +115,13 @@ class HeartFCMessageReceiver:
# 2. 兴趣度计算与更新
interested_rate, keywords = await _calculate_interest(message)
await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
if global_config.mood.enable_mood:
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
@@ -132,7 +131,7 @@ class HeartFCMessageReceiver:
# 用这个pattern截取出id部分picid是一个list并替换成对应的图片描述
picid_pattern = r"\[picid:([^\]]+)\]"
picid_list = re.findall(picid_pattern, message.processed_plain_text)
# 创建替换后的文本
processed_text = message.processed_plain_text
if picid_list:
@@ -145,18 +144,20 @@ class HeartFCMessageReceiver:
# 如果没有找到图片描述,则移除[picid:xxxx]标记
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references(
processed_text,
message.message_info.platform, # type: ignore
replace_bot_name=True
message.message_info.platform, # type: ignore
replace_bot_name=True,
)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
_ = Person.register_person(
platform=message.message_info.platform,
user_id=message.message_info.user_info.user_id,
nickname=userinfo.user_nickname,
) # type: ignore
except Exception as e:
logger.error(f"消息处理失败: {e}")

View File

@@ -124,6 +124,7 @@ async def send_typing():
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)
async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
@@ -135,4 +136,4 @@ async def stop_typing():
await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
)
)

View File

@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:

View File

@@ -32,11 +32,11 @@ install(extra_lines=3)
# 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
def __init__(
self,
namespace: str,
dir_path: str,
max_workers: int = DEFAULT_MAX_WORKERS,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
self.namespace = namespace
self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
@@ -103,12 +109,16 @@ class EmbeddingStore:
# 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
# 如果配置值被调整,记录日志
if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
logger.warning(
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
)
if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
logger.warning(
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
)
self.store = {}
@@ -120,23 +130,23 @@ class EmbeddingStore:
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
@@ -147,43 +157,45 @@ class EmbeddingStore:
except Exception:
pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数
progress_callback: 进度回调函数,接收一个参数表示完成的数量
Returns:
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
"""
if not strs:
return []
# 分块
chunks = []
for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size]
chunk = strs[i : i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
try:
# 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 在线程中创建独立的事件循环
@@ -193,25 +205,25 @@ class EmbeddingStore:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
@@ -220,14 +232,14 @@ class EmbeddingStore:
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
return chunk_results
# 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk):
try:
@@ -241,7 +253,7 @@ class EmbeddingStore:
start_idx, chunk_strs = chunk
for i, s in enumerate(chunk_strs):
results[start_idx + i] = (s, [])
# 按原始顺序返回结果
ordered_results = []
for i in range(len(strs)):
@@ -250,7 +262,7 @@ class EmbeddingStore:
else:
# 防止遗漏
ordered_results.append((strs[i], []))
return ordered_results
def get_test_file_path(self):
@@ -259,14 +271,14 @@ class EmbeddingStore:
def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...")
# 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
)
# 构建测试向量字典
test_vectors = {}
for idx, (s, embedding) in enumerate(embedding_results):
@@ -276,10 +288,10 @@ class EmbeddingStore:
logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self):
@@ -297,35 +309,35 @@ class EmbeddingStore:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
# 检查本地向量完整性
for idx in range(len(EMBEDDING_TEST_STRINGS)):
if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
logger.info("开始检验嵌入模型一致性...")
# 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
)
# 检查一致性
for idx, (s, new_emb) in enumerate(embedding_results):
local_emb = local_vectors.get(str(idx))
if not new_emb:
logger.error(f"获取测试字符串嵌入失败: {s}")
return False
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False
logger.info("嵌入模型一致性校验通过。")
return True
@@ -333,22 +345,22 @@ class EmbeddingStore:
"""向库中存入字符串(使用多线程优化)"""
if not strs:
return
total = len(strs)
# 过滤已存在的字符串
new_strs = []
for s in strs:
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash not in self.store:
new_strs.append(s)
if not new_strs:
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
return
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
@@ -362,31 +374,39 @@ class EmbeddingStore:
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
# 首先更新已存在项的进度
already_processed = total - len(new_strs)
if already_processed > 0:
progress.update(task, advance=already_processed)
if new_strs:
# 使用实例配置的参数,智能调整分块和线程数
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
optimal_chunk_size = max(
MIN_CHUNK_SIZE,
min(
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
),
)
optimal_max_workers = min(
self.max_workers,
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
)
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
# 定义进度更新回调函数
def update_progress(count):
progress.update(task, advance=count)
# 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded(
new_strs,
chunk_size=optimal_chunk_size,
new_strs,
chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers,
progress_callback=update_progress
progress_callback=update_progress,
)
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s)
@@ -519,7 +539,7 @@ class EmbeddingManager:
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
"""
初始化EmbeddingManager
Args:
max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小

View File

@@ -426,9 +426,7 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
(node_key, score)
for node_key, score in ppr_res.items()
if node_key.startswith("paragraph")
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
]
del ppr_res

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
from .utils.dyn_topk import dyn_select_top_k
from .lpmmconfig import global_config # noqa
from .embedding_store import EmbeddingManager # noqa
from .llm_client import LLMClient # noqa
from .utils.dyn_topk import dyn_select_top_k # noqa
class MemoryActiveManager:

View File

@@ -8,7 +8,7 @@ def dyn_select_top_k(
# 检查输入列表是否为空
if not score:
return []
# 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)

View File

@@ -18,7 +18,7 @@ class MessageStorage:
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
@@ -85,7 +85,7 @@ class MessageStorage:
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
selected_expressions = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore

View File

@@ -124,4 +124,4 @@ class ActionManager:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")

View File

@@ -103,25 +103,23 @@ class ActionModifier:
self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
# === 第三阶段:激活类型判定 ===
# if chat_content is not None:
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions()
# 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions()
# 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions,
# chat_content,
# )
# 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions,
# chat_content,
# )
# 应用第三阶段的移除
# for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# 应用第三阶段的移除
# for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# === 统一日志记录 ===
all_removals = removals_s1 + removals_s2
@@ -131,9 +129,7 @@ class ActionModifier:
available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug(
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
)
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions: List[Tuple[str, str]] = []

View File

@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
for item_name in stats[period_key][category]:
time_costs = stats[period_key][time_cost_key].get(item_name, [])
if time_costs:
# 计算平均耗时
avg_time_cost = sum(time_costs) / len(time_costs)
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
# 计算标准差
if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5
std_time_cost = variance**0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else:
stats[period_key][std_key][item_name] = 0.0
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
output.append(
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
)
output.append("")
return "\n".join(output)
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_MODEL]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按请求类型分类统计
type_rows = "\n".join(
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_TYPE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按模块分类统计
module_rows = "\n".join(
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[REQ_CNT_BY_MODULE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 聊天消息统计
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
[
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
]
if stat_data[MSG_CNT_BY_CHAT]
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 生成HTML
return f"""

View File

@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
reply_probability = 0.0
is_at = False
is_mentioned = False
# 这部分怎么处理啊啊啊啊
#我觉得可以给消息加一个 reply_probability_boost字段
# 我觉得可以给消息加一个 reply_probability_boost字段
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
@@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [keywords_str] if keywords_str else []
def cut_key_words(concept_name: str) -> list[str]:
"""对概念名称进行jieba分词并过滤掉关键词列表中的关键词"""
concept_name_tokens = list(jieba.cut(concept_name))
# 定义常见连词、停用词与标点
conjunctions = {
"", "", "", "", "以及", "并且", "而且", "", "或者", ""
}
conjunctions = {"", "", "", "", "以及", "并且", "而且", "", "或者", ""}
stop_words = {
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "而且", "或者", "", "以及"
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"而且",
"或者",
"",
"以及",
}
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
@@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]:
left = merged_tokens[-1]
right = cleaned_tokens[i + 1]
# 左右都需要是有效词
if left and right \
and left not in conjunctions and right not in conjunctions \
and left not in stop_words and right not in stop_words \
and not all(ch in chinese_punctuations for ch in left) \
and not all(ch in chinese_punctuations for ch in right):
if (
left
and right
and left not in conjunctions
and right not in conjunctions
and left not in stop_words
and right not in stop_words
and not all(ch in chinese_punctuations for ch in left)
and not all(ch in chinese_punctuations for ch in right)
):
# 合并为一个新词,并替换掉左侧与跳过右侧
combined = f"{left}{tok}{right}"
merged_tokens[-1] = combined
@@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]:
if tok in stop_words:
continue
# if tok in ban_words:
# continue
# continue
if all(ch in chinese_punctuations for ch in tok):
continue
if tok.strip() == "":
@@ -899,4 +932,4 @@ def cut_key_words(concept_name: str) -> list[str]:
result_tokens.append(tok)
filtered_concept_name_tokens = result_tokens
return filtered_concept_name_tokens
return filtered_concept_name_tokens

View File

@@ -91,9 +91,10 @@ class ImageManager:
desc_obj.save()
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
@@ -120,6 +121,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
if tags:

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""

View File

@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
@@ -232,4 +233,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.chat_info_platform = chat_info_platform

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from . import BaseDataModel
if TYPE_CHECKING:
from src.llm_models.payload_content.tool_option import ToolCall
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None

View File

@@ -135,7 +135,7 @@ class Messages(BaseModel):
interest_value = DoubleField(null=True)
key_words = TextField(null=True)
key_words_lite = TextField(null=True)
is_mentioned = BooleanField(null=True)
is_at = BooleanField(null=True)
reply_probability_boost = DoubleField(null=True)
@@ -169,7 +169,7 @@ class Messages(BaseModel):
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
is_notify = BooleanField(default=False)
selected_expressions = TextField(null=True)
class Meta:
@@ -267,13 +267,10 @@ class PersonInfo(BaseModel):
know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间
attitude_to_me = TextField(null=True) # 对bot的态度
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
class Meta:
# database = db # 继承自 BaseModel
table_name = "person_info"
@@ -299,6 +296,7 @@ class GroupInfo(BaseModel):
# database = db # 继承自 BaseModel
table_name = "group_info"
class Expression(BaseModel):
"""
用于存储表达风格的模型。
@@ -315,6 +313,7 @@ class Expression(BaseModel):
class Meta:
table_name = "expression"
class GraphNodes(BaseModel):
"""
用于存储记忆图节点的模型
@@ -374,7 +373,7 @@ def initialize_database(sync_constraints=False):
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则自动添加。
Args:
sync_constraints (bool): 是否同步字段约束。默认为 False。
如果为 True会检查并修复字段的 NULL 约束不一致问题。
@@ -456,13 +455,13 @@ def initialize_database(sync_constraints=False):
logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}")
# 如果启用了约束同步,执行约束检查和修复
if sync_constraints:
logger.debug("开始同步数据库字段约束...")
sync_field_constraints()
logger.debug("数据库字段约束同步完成")
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
@@ -476,7 +475,7 @@ def sync_field_constraints():
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
如果发现不一致,会自动修复字段约束。
"""
models = [
ChatStreams,
LLMUsage,
@@ -501,50 +500,55 @@ def sync_field_constraints():
continue
logger.debug(f"检查表 '{table_name}' 的字段约束...")
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
current_schema = {
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
# 检查每个模型字段的约束
constraints_to_fix = []
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue # 字段不存在,跳过
current_notnull = current_schema[field_name]['notnull']
current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null
# 如果模型允许 null 但数据库字段不允许 null需要修复
if model_allows_null and current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'allow_null',
'current_constraint': 'NOT NULL',
'target_constraint': 'NULL'
})
constraints_to_fix.append(
{
"field_name": field_name,
"field_obj": field_obj,
"action": "allow_null",
"current_constraint": "NOT NULL",
"target_constraint": "NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL")
# 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心
elif not model_allows_null and not current_notnull:
constraints_to_fix.append({
'field_name': field_name,
'field_obj': field_obj,
'action': 'disallow_null',
'current_constraint': 'NULL',
'target_constraint': 'NOT NULL'
})
constraints_to_fix.append(
{
"field_name": field_name,
"field_obj": field_obj,
"action": "disallow_null",
"current_constraint": "NULL",
"target_constraint": "NOT NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL")
# 修复约束不一致的字段
if constraints_to_fix:
logger.info(f"'{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
_fix_table_constraints(table_name, model, constraints_to_fix)
else:
logger.debug(f"'{table_name}' 的字段约束已同步")
except Exception as e:
logger.exception(f"同步字段约束时出错: {e}")
@@ -557,40 +561,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
try:
# 备份表名
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
# 1. 创建备份表
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
logger.info(f"已创建备份表 '{backup_table}'")
# 2. 删除原表
db.execute_sql(f"DROP TABLE {table_name}")
logger.info(f"已删除原表 '{table_name}'")
# 3. 重新创建表(使用当前模型定义)
db.create_tables([model])
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
# 4. 从备份表恢复数据
# 获取字段列表
fields = list(model._meta.fields.keys())
fields_str = ', '.join(fields)
fields_str = ", ".join(fields)
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
# 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [
constraint['field_name'] for constraint in constraints_to_fix
if constraint['action'] == 'disallow_null'
constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
]
if null_to_notnull_fields:
# 需要处理 NULL 值,为这些字段设置默认值
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL需要处理现有的NULL值")
# 构建更复杂的 SELECT 语句来处理 NULL 值
select_fields = []
for field_name in fields:
@@ -607,21 +610,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
default_value = f"'{datetime.datetime.now()}'"
else:
default_value = "''"
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
else:
select_fields.append(field_name)
select_str = ', '.join(select_fields)
select_str = ", ".join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
db.execute_sql(insert_sql)
logger.info(f"已从备份表恢复数据到 '{table_name}'")
# 5. 验证数据完整性
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
if original_count == new_count:
logger.info(f"数据完整性验证通过: {original_count} 行数据")
# 删除备份表
@@ -630,12 +633,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
else:
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count}")
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
# 记录修复的约束
for constraint in constraints_to_fix:
logger.info(f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
logger.info(
f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
)
except Exception as e:
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
# 尝试恢复
@@ -654,7 +659,7 @@ def check_field_constraints():
检查但不修复字段约束,返回不一致的字段信息。
用于在修复前预览需要修复的内容。
"""
models = [
ChatStreams,
LLMUsage,
@@ -669,9 +674,9 @@ def check_field_constraints():
GraphEdges,
ActionRecords,
]
inconsistencies = {}
try:
with db:
for model in models:
@@ -681,49 +686,49 @@ def check_field_constraints():
# 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
for row in cursor.fetchall()}
current_schema = {
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
table_inconsistencies = []
# 检查每个模型字段的约束
for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema:
continue
current_notnull = current_schema[field_name]['notnull']
current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null
if model_allows_null and current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_allows_null_but_db_not_null',
'model_constraint': 'NULL',
'db_constraint': 'NOT NULL',
'recommended_action': 'allow_null'
})
table_inconsistencies.append(
{
"field_name": field_name,
"issue": "model_allows_null_but_db_not_null",
"model_constraint": "NULL",
"db_constraint": "NOT NULL",
"recommended_action": "allow_null",
}
)
elif not model_allows_null and not current_notnull:
table_inconsistencies.append({
'field_name': field_name,
'issue': 'model_not_null_but_db_allows_null',
'model_constraint': 'NOT NULL',
'db_constraint': 'NULL',
'recommended_action': 'disallow_null'
})
table_inconsistencies.append(
{
"field_name": field_name,
"issue": "model_not_null_but_db_allows_null",
"model_constraint": "NOT NULL",
"db_constraint": "NULL",
"recommended_action": "disallow_null",
}
)
if table_inconsistencies:
inconsistencies[table_name] = table_inconsistencies
except Exception as e:
logger.exception(f"检查字段约束时出错: {e}")
return inconsistencies
return inconsistencies
# 模块加载时调用初始化函数
initialize_database(sync_constraints=True)

View File

@@ -339,24 +339,18 @@ MODULE_COLORS = {
# 67 具体的颜色编号0-255这里是较暗的蓝色
"sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志
"send_api": "\033[38;5;24m", # 208号色橙色适合突出显示
# 生成
"replyer": "\033[38;5;208m", # 橙色
"llm_api": "\033[38;5;208m", # 橙色
# 消息处理
"chat": "\033[38;5;82m", # 亮蓝色
"chat_image": "\033[38;5;68m", # 浅蓝色
#emoji
# emoji
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
# 核心模块
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
"memory": "\033[38;5;34m", # 天蓝色
"config": "\033[93m", # 亮黄色
"common": "\033[95m", # 亮紫色
"tools": "\033[96m", # 亮青色
@@ -367,9 +361,6 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m",
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
@@ -379,11 +370,9 @@ MODULE_COLORS = {
"background_tasks": "\033[38;5;240m", # 灰色
"chat_message": "\033[38;5;45m", # 青色
"chat_stream": "\033[38;5;51m", # 亮青色
"message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块
"memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统
"plugins": "\033[31m", # 红色
@@ -412,7 +401,6 @@ MODULE_COLORS = {
# 工具和实用模块
"prompt_build": "\033[38;5;105m", # 紫色
"chat_utils": "\033[38;5;111m", # 蓝色
"maibot_statistic": "\033[38;5;129m", # 紫色
# 特殊功能插件
"mute_plugin": "\033[38;5;240m", # 灰色
@@ -447,10 +435,8 @@ MODULE_ALIASES = {
"llm_api": "生成API",
"emoji": "表情包",
"emoji_api": "表情包API",
"chat": "所见",
"chat_image": "识图",
"action_manager": "动作",
"memory_activator": "记忆",
"tool_use": "工具",
@@ -460,7 +446,6 @@ MODULE_ALIASES = {
"memory": "记忆",
"tool_executor": "工具",
"hfc": "聊天节奏",
"plugin_manager": "插件",
"relationship_builder": "关系",
"llm_models": "模型",

View File

@@ -114,7 +114,7 @@ def set_value_by_path(d, path, value):
if k not in d or not isinstance(d[k], dict):
d[k] = {}
d = d[k]
# 使用 tomlkit.item 来保持 TOML 格式
try:
d[path[-1]] = tomlkit.item(value)
@@ -253,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str):
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
config_updated = True
# 如果配置有更新,立即保存到文件
if config_updated:
with open(old_config_path, "w", encoding="utf-8") as f:

View File

@@ -43,10 +43,11 @@ class PersonalityConfig(ConfigBase):
reply_style: str = ""
"""表达风格"""
interest: str = ""
"""兴趣"""
@dataclass
class RelationshipConfig(ConfigBase):
"""关系配置类"""
@@ -61,31 +62,30 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
interest_rate_mode: Literal["fast", "accurate"] = "fast"
"""兴趣值计算模式fast为快速计算accurate为精确计算"""
mentioned_bot_reply: float = 1
"""提及 bot 必然回复1为100%回复0为不额外增幅"""
planner_size: float = 1.5
"""副规划器大小越小麦麦的动作执行能力越精细但是消耗更多token调大可以缓解429类错误"""
at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅"""
talk_frequency: float = 0.5
"""回复频率阈值"""
# 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多"""
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
"""
统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
@@ -110,7 +110,6 @@ class ChatConfig(ConfigBase):
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
- focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多
"""
@dataclass
@@ -123,6 +122,7 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""

View File

@@ -174,7 +174,7 @@ class ClientRegistry:
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):

View File

@@ -531,7 +531,7 @@ class OpenaiClient(BaseClient):
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__:
if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e
except APIStatusError as e:

View File

@@ -1,3 +1,3 @@
from .tool_option import ToolCall
__all__ = ["ToolCall"]
__all__ = ["ToolCall"]

View File

@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str)
or instance["description"].strip() == ""
not isinstance(instance["description"], str) or instance["description"].strip() == ""
):
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(
f"{path}/{str(i)}", sub_schema[i], defs
)
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(
f"{path}/{key}", value, defs
)
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
return sub_schema
@@ -163,9 +158,7 @@ class RespFormat:
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(
_link_definitions(_remove_title(schema.model_json_schema()))
),
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
"strict": False,
}
if schema.__doc__:

View File

@@ -155,7 +155,13 @@ class LLMUsageRecorder:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
self,
model_info: ModelInfo,
model_usage: UsageRecord,
user_id: str,
request_type: str,
endpoint: str,
time_cost: float = 0.0,
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3),
time_cost=round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -14,31 +14,31 @@ logger = get_logger("context_web")
class ContextMessage:
"""上下文消息类"""
def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text
self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型
self.is_gift = getattr(message, 'is_gift', False)
self.is_superchat = getattr(message, 'is_superchat', False)
self.is_gift = getattr(message, "is_gift", False)
self.is_superchat = getattr(message, "is_superchat", False)
# 添加礼物和SC相关信息
if self.is_gift:
self.gift_name = getattr(message, 'gift_name', '')
self.gift_count = getattr(message, 'gift_count', '1')
self.gift_name = getattr(message, "gift_name", "")
self.gift_count = getattr(message, "gift_count", "1")
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat:
self.superchat_price = getattr(message, 'superchat_price', '0')
self.superchat_message = getattr(message, 'superchat_message_text', '')
self.superchat_price = getattr(message, "superchat_price", "0")
self.superchat_message = getattr(message, "superchat_message_text", "")
if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}"
else:
self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self):
return {
"user_name": self.user_name,
@@ -47,13 +47,13 @@ class ContextMessage:
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name,
"is_gift": self.is_gift,
"is_superchat": self.is_superchat
"is_superchat": self.is_superchat,
}
class ContextWebManager:
"""上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages
self.port = port
@@ -63,53 +63,53 @@ class ContextWebManager:
self.runner = None
self.site = None
self._server_starting = False # 添加启动标志防止并发
async def start_server(self):
"""启动web服务器"""
if self.site is not None:
logger.debug("Web服务器已经启动跳过重复启动")
return
if self._server_starting:
logger.debug("Web服务器正在启动中等待启动完成...")
# 等待启动完成
while self._server_starting and self.site is None:
await asyncio.sleep(0.1)
return
self._server_starting = True
try:
self.app = web.Application()
# 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
allow_methods="*"
)
})
cors = aiohttp_cors.setup(
self.app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
)
},
)
# 添加路由
self.app.router.add_get('/', self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler)
self.app.router.add_get("/", self.index_handler)
self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS
for route in list(self.app.router.routes()):
cors.add(route)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port)
self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源
@@ -121,7 +121,7 @@ class ContextWebManager:
raise
finally:
self._server_starting = False
async def stop_server(self):
"""停止web服务器"""
if self.site:
@@ -132,10 +132,11 @@ class ContextWebManager:
self.runner = None
self.site = None
self._server_starting = False
async def index_handler(self, request):
"""主页处理器"""
html_content = '''
html_content = (
"""
<!DOCTYPE html>
<html>
<head>
@@ -286,7 +287,9 @@ class ContextWebManager:
function connectWebSocket() {
console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() {
console.log('WebSocket连接已建立');
@@ -470,47 +473,48 @@ class ContextWebManager:
</script>
</body>
</html>
'''
return web.Response(text=html_content, content_type='text/html')
"""
)
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request):
"""WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.websockets.append(ws)
logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}")
# 发送初始数据
await self.send_contexts_to_websocket(ws)
async for msg in ws:
if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
logger.error(f"WebSocket错误: {ws.exception()}")
break
# 清理断开的连接
if ws in self.websockets:
self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}")
return ws
async def get_contexts_handler(self, request):
"""获取上下文API"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request):
"""调试信息处理器"""
debug_info = {
@@ -519,7 +523,7 @@ class ContextWebManager:
"total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
}
# 构建聊天详情HTML
chats_html = ""
for chat_id, contexts in self.contexts.items():
@@ -528,15 +532,15 @@ class ContextWebManager:
timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f'''
chats_html += f"""
<div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html}
</div>
'''
html_content = f'''
"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
@@ -578,74 +582,78 @@ class ContextWebManager:
</script>
</body>
</html>
'''
return web.Response(text=html_content, content_type='text/html')
"""
return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文"""
if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg)
# 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接
await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False))
async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新"""
if not self.websockets:
logger.debug("没有WebSocket连接跳过广播")
return
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False)
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本避免在遍历时修改
websockets_copy = self.websockets.copy()
removed_count = 0
for ws in websockets_copy:
if ws.closed:
if ws in self.websockets:
@@ -660,7 +668,7 @@ class ContextWebManager:
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
@@ -681,5 +689,4 @@ async def init_context_web_manager():
"""初始化上下文网页管理器"""
manager = get_context_web_manager()
await manager.start_server()
return manager
return manager

View File

@@ -11,6 +11,7 @@ logger = get_logger("gift_manager")
@dataclass
class PendingGift:
"""等待中的礼物消息"""
message: MessageRecvS4U
total_count: int
timer_task: asyncio.Task
@@ -19,71 +20,68 @@ class PendingGift:
class GiftManager:
"""礼物管理器,提供防抖功能"""
def __init__(self):
"""初始化礼物管理器"""
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
async def handle_gift(
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
) -> bool:
"""处理礼物消息,返回是否应该立即处理
Args:
message: 礼物消息
callback: 防抖完成后的回调函数
Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理
"""
if not message.is_gift:
return True
# 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message)
return False
# 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback)
return False
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息"""
pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器
if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
# 累加礼物数量
try:
new_count = int(new_message.gift_count)
pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变
# 重新创建定时器
pending_gift.timer_task = asyncio.create_task(
self._gift_timeout(gift_key)
)
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift(
self,
gift_key: Tuple[str, str],
message: MessageRecvS4U,
callback: Optional[Callable[[MessageRecvS4U], None]]
self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None:
"""创建新的等待礼物"""
try:
@@ -91,56 +89,51 @@ class GiftManager:
except ValueError:
initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象
pending_gift = PendingGift(
message=message,
total_count=initial_count,
timer_task=timer_task,
callback=callback
)
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
"""礼物防抖超时处理"""
try:
# 等待防抖时间
await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物
if gift_key not in self.pending_gifts:
return
pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调
if pending_gift.callback:
try:
pending_gift.callback(message)
except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError:
# 定时器被取消,不需要处理
pass
except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量"""
return len(self.pending_gifts)
async def flush_all(self) -> None:
"""立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()):
@@ -152,4 +145,3 @@ class GiftManager:
# 创建全局礼物管理器实例
gift_manager = GiftManager()

View File

@@ -1,14 +1,15 @@
class InternalManager:
def __init__(self):
self.now_internal_state = str()
def set_internal_state(self,internal_state:str):
def set_internal_state(self, internal_state: str):
self.now_internal_state = internal_state
def get_internal_state(self):
return self.now_internal_state
def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager()
internal_manager = InternalManager()

View File

@@ -16,7 +16,6 @@ import json
from .s4u_mood_manager import mood_manager
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import get_person_id
from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat")
@@ -33,15 +32,12 @@ class MessageSenderContainer:
self._task: Optional[asyncio.Task] = None
self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
async def add_message(self, chunk: str):
"""向队列中添加一个消息块。"""
@@ -131,7 +127,7 @@ class MessageSenderContainer:
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
@@ -198,12 +194,12 @@ class S4UChat:
self.gpt = S4UStreamGenerator()
self.gpt.chat_stream = self.chat_stream
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.internal_message :List[MessageRecvS4U] = []
self.internal_message: List[MessageRecvS4U] = []
self.msg_id = ""
self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
def _get_priority_info(self, message: MessageRecv) -> dict:
@@ -226,7 +222,7 @@ class S4UChat:
def _get_interest_score(self, user_id: str) -> float:
"""获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0)
def go_processing(self):
if self.voice_done == self.last_msg_id:
return True
@@ -237,14 +233,14 @@ class S4UChat:
为消息计算基础优先级分数。分数越高,优先级越高。
"""
score = 0.0
# 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0)
# 加上用户的固有兴趣分
score += self._get_interest_score(message.message_info.user_info.user_id)
return score
def decay_interest_score(self):
for person_id, score in self.interest_dict.items():
if score > 0:
@@ -252,15 +248,14 @@ class S4UChat:
else:
self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
person_id = get_person_id(platform, user_id)
_person_id = get_person_id(platform, user_id)
# try:
# is_gift = message.is_gift
# is_superchat = message.is_superchat
@@ -276,7 +271,7 @@ class S4UChat:
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# # 添加SuperChat到管理器
# super_chat_manager = get_super_chat_manager()
# await super_chat_manager.add_superchat(message)
@@ -284,16 +279,19 @@ class S4UChat:
# await self.relationship_builder.build_relation(20)
# except Exception:
# traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False
if (s4u_config.enable_message_interruption and
self._current_generation_task and not self._current_generation_task.done()):
if (
s4u_config.enable_message_interruption
and self._current_generation_task
and not self._current_generation_task.done()
):
if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied
@@ -344,39 +342,45 @@ class S4UChat:
"""清理普通队列中不在最近N条消息范围内的消息"""
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return
# 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息
temp_messages = []
removed_count = 0
# 取出所有普通队列中的消息
while not self._normal_queue.empty():
try:
item = self._normal_queue.get_nowait()
neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}")
logger.info(
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
)
if entry_count >= cutoff_counter:
temp_messages.append(item)
else:
removed_count += 1
self._normal_queue.task_done() # 标记被移除的任务为完成
except asyncio.QueueEmpty:
break
# 将保留的消息重新放入队列
for item in temp_messages:
self._normal_queue.put_nowait(item)
if removed_count > 0:
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除")
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
logger.info(
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除"
)
logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
)
async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。"""
@@ -385,7 +389,7 @@ class S4UChat:
# 等待有新消息的信号,避免空转
await self._new_message_event.wait()
self._new_message_event.clear()
# 清理普通队列中的过旧消息
self._cleanup_old_normal_messages()
@@ -396,7 +400,6 @@ class S4UChat:
queue_name = "vip"
# 其次处理普通队列
elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority
# 检查普通消息是否超时
@@ -411,13 +414,15 @@ class S4UChat:
if self.internal_message:
message = self.internal_message[-1]
self.internal_message = []
priority = 0
neg_priority = 0
entry_count = 0
queue_name = "internal"
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...")
logger.info(
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
)
else:
continue # 没有消息了,回去等事件
@@ -457,23 +462,21 @@ class S4UChat:
except Exception as e:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1)
def get_processing_message_id(self):
self.last_msg_id = self.msg_id
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数
self.get_processing_message_id()
# 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
if message.is_internal:
await chat_watching.on_internal_message_start()
else:
@@ -516,16 +519,19 @@ class S4UChat:
total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id)
await yes_or_no_head(
text=total_chars_sent,
emotion=mood.mood_state,
chat_history=message.processed_plain_text,
chat_id=self.stream_id,
)
# 等待所有文本消息发送完成
await sender_container.close()
await sender_container.join()
await chat_watching.on_thinking_finished()
start_time = time.time()
logged = False
while not self.go_processing():
@@ -536,7 +542,7 @@ class S4UChat:
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
logged = True
await asyncio.sleep(0.2)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError:
@@ -548,11 +554,11 @@ class S4UChat:
# 回复生成实时展示:清空内容(出错时)
finally:
self._is_replying = False
# 视线管理:回复结束时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_finished()
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
sender_container.resume()
if not sender_container._task.done():
@@ -576,4 +582,3 @@ class S4UChat:
await self._processing_task
except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
if global_config.memory.enable_memory:
with Timer("记忆激活"):
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)
@@ -49,7 +49,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
@@ -73,7 +73,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
@@ -117,36 +117,32 @@ class S4UMessageProcessor:
user_info=userinfo,
group_info=groupinfo,
)
if await self.handle_internal_message(message):
return
if await self.hadle_if_voice_done(message):
return
# 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message):
return
await self.check_if_fake_gift(message)
# 处理屏幕消息
if await self.handle_screen_message(message):
return
await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message)
await mood_manager.start()
# 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message))
@@ -164,61 +160,56 @@ class S4UMessageProcessor:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
async def handle_internal_message(self, message: MessageRecvS4U):
if message.is_internal:
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心")
chat = await get_chat_manager().get_or_create_stream(
platform = "amaidesu_default",
user_info = message.message_info.user_info,
group_info = group_info
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message)
s4u_chat._new_message_event.set()
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}")
logger.info(
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
)
return True
return False
async def handle_screen_message(self, message: MessageRecvS4U):
if message.is_screen:
screen_manager.set_screen(message.screen_info)
return True
return False
async def hadle_if_voice_done(self, message: MessageRecvS4U):
if message.voice_done:
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done
return True
return False
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物"""
if message.is_gift:
return False
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"]
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True
return True
return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息
Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
"""
@@ -228,37 +219,37 @@ class S4UMessageProcessor:
"""礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task
Args:
chat_id: 聊天ID
message: 消息对象
"""
try:
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
context_manager = get_context_web_manager()
# 只在服务器未启动时启动(避免重复启动)
if context_manager.site is None:
logger.info("🚀 首次启动上下文网页服务器...")
await context_manager.start_server()
# 添加消息到上下文并更新网页
await asyncio.sleep(1.5)
await context_manager.add_message(chat_id, message)
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
except Exception as e:
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)

View File

@@ -176,7 +176,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300,
)
@@ -228,13 +228,17 @@ class PromptBuilder:
last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
)
all_msg_seg_list = []
for msg in core_dialogue_list[1:]:
speaker = msg.user_info.user_id
if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
)
else:
msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str)

View File

@@ -14,11 +14,8 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
# 使用LLMRequest替代AsyncOpenAIClient
self.llm_request = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="s4u_replyer"
)
self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
self.current_model_name = "unknown model"
self.partial_response = ""
@@ -89,16 +86,16 @@ class S4UStreamGenerator:
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
"""使用LLMRequest进行流式响应生成"""
# 构建消息
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
# 选择模型
model_info, api_provider, client = self.llm_request._select_model()
self.current_model_name = model_info.name
# 如果模型支持强制流式模式,使用真正的流式处理
if model_info.force_stream_mode:
# 简化流式处理直接使用LLMRequest的流式功能
@@ -111,14 +108,14 @@ class S4UStreamGenerator:
model_info=model_info,
message_list=messages,
)
# 处理响应内容
content = response.content or ""
if content:
# 将内容按句子分割并输出
async for chunk in self._process_content_streaming(content):
yield chunk
except Exception as e:
logger.error(f"流式请求执行失败: {e}")
# 如果流式请求失败,回退到普通模式
@@ -132,7 +129,7 @@ class S4UStreamGenerator:
content = response.content or ""
async for chunk in self._process_content_streaming(content):
yield chunk
else:
# 如果不支持流式,使用普通方式然后模拟流式输出
response = await self.llm_request._execute_request(
@@ -142,7 +139,7 @@ class S4UStreamGenerator:
model_info=model_info,
message_list=messages,
)
content = response.content or ""
async for chunk in self._process_content_streaming(content):
yield chunk
@@ -163,7 +160,7 @@ class S4UStreamGenerator:
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
buffer = content
punctuation_buffer = ""
# 使用正则表达式匹配句子
last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer):

View File

@@ -1,4 +1,3 @@
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
@@ -47,6 +46,7 @@ HEAD_CODE = {
"看向正前方": "(0,0,0)",
}
class ChatWatching:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
@@ -56,13 +56,13 @@ class ChatWatching:
await send_api.custom_to_stream(
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
)
async def on_reply_finished(self):
"""生成回复完毕时调用"""
await send_api.custom_to_stream(
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
)
async def on_thinking_finished(self):
"""思考完毕时调用"""
await send_api.custom_to_stream(
@@ -74,14 +74,14 @@ class ChatWatching:
await send_api.custom_to_stream(
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
)
async def on_internal_message_start(self):
"""收到消息时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
)
class WatchingManager:
def __init__(self):
self.watching_list: list[ChatWatching] = []
@@ -100,6 +100,7 @@ class WatchingManager:
return new_watching
# 全局视线管理器实例
watching_manager = WatchingManager()
"""全局视线管理器"""

View File

@@ -1,14 +1,15 @@
class ScreenManager:
def __init__(self):
self.now_screen = str()
def set_screen(self,screen_str:str):
def set_screen(self, screen_str: str):
self.now_screen = screen_str
def get_screen(self):
return self.now_screen
def get_screen_str(self):
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
screen_manager = ScreenManager()
screen_manager = ScreenManager()

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecvS4U
# 全局SuperChat管理器实例
from src.mais4u.s4u_config import s4u_config
@@ -13,7 +14,7 @@ logger = get_logger("super_chat_manager")
@dataclass
class SuperChatRecord:
"""SuperChat记录数据类"""
user_id: str
user_nickname: str
platform: str
@@ -23,15 +24,15 @@ class SuperChatRecord:
timestamp: float
expire_time: float
group_name: Optional[str] = None
def is_expired(self) -> bool:
"""检查SuperChat是否已过期"""
return time.time() > self.expire_time
def remaining_time(self) -> float:
"""获取剩余时间(秒)"""
return max(0, self.expire_time - time.time())
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
@@ -44,19 +45,19 @@ class SuperChatRecord:
"timestamp": self.timestamp,
"expire_time": self.expire_time,
"group_name": self.group_name,
"remaining_time": self.remaining_time()
"remaining_time": self.remaining_time(),
}
class SuperChatManager:
"""SuperChat管理器负责管理和跟踪SuperChat消息"""
def __init__(self):
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
self._cleanup_task: Optional[asyncio.Task] = None
self._is_initialized = False
logger.info("SuperChat管理器已初始化")
def _ensure_cleanup_task_started(self):
"""确保清理任务已启动(延迟启动)"""
if self._cleanup_task is None or self._cleanup_task.done():
@@ -68,7 +69,7 @@ class SuperChatManager:
except RuntimeError:
# 没有运行的事件循环,稍后再启动
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
def _start_cleanup_task(self):
"""启动清理任务(已弃用,保留向后兼容)"""
self._ensure_cleanup_task_started()
@@ -78,39 +79,36 @@ class SuperChatManager:
while True:
try:
total_removed = 0
for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat
self.super_chats[chat_id] = [
sc for sc in self.super_chats[chat_id]
if not sc.is_expired()
]
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count
if removed_count > 0:
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
# 如果列表为空,删除该聊天的记录
if not self.super_chats[chat_id]:
del self.super_chats[chat_id]
if total_removed > 0:
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
# 每30秒检查一次
await asyncio.sleep(30)
except Exception as e:
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
await asyncio.sleep(60) # 出错时等待更长时间
def _calculate_expire_time(self, price: float) -> float:
"""根据SuperChat金额计算过期时间"""
current_time = time.time()
# 根据金额阶梯设置不同的存活时间
if price >= 500:
# 500元以上保持4小时
@@ -133,27 +131,27 @@ class SuperChatManager:
else:
# 10元以下保持5分钟
duration = 5 * 60
return current_time + duration
async def add_superchat(self, message: MessageRecvS4U) -> None:
"""添加新的SuperChat记录"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if not message.is_superchat or not message.superchat_price:
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
return
try:
price = float(message.superchat_price)
except (ValueError, TypeError):
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
return
user_info = message.message_info.user_info
group_info = message.message_info.group_info
chat_id = getattr(message, 'chat_stream', None)
chat_id = getattr(message, "chat_stream", None)
if chat_id:
chat_id = chat_id.stream_id
else:
@@ -161,9 +159,9 @@ class SuperChatManager:
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
if group_info:
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
expire_time = self._calculate_expire_time(price)
record = SuperChatRecord(
user_id=user_info.user_id,
user_nickname=user_info.user_nickname,
@@ -173,44 +171,44 @@ class SuperChatManager:
message_text=message.superchat_message_text or "",
timestamp=message.message_info.time,
expire_time=expire_time,
group_name=group_info.group_name if group_info else None
group_name=group_info.group_name if group_info else None,
)
# 添加到对应聊天的SuperChat列表
if chat_id not in self.super_chats:
self.super_chats[chat_id] = []
self.super_chats[chat_id].append(record)
# 按价格降序排序(价格高的在前)
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
"""获取指定聊天的所有有效SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if chat_id not in self.super_chats:
return []
# 过滤掉过期的SuperChat
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
return valid_superchats
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
"""获取所有有效的SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
result = {}
for chat_id, superchats in self.super_chats.items():
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
if valid_superchats:
result[chat_id] = valid_superchats
return result
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
@@ -226,7 +224,9 @@ class SuperChatManager:
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
time_display = (
f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
)
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
@@ -238,7 +238,7 @@ class SuperChatManager:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
@@ -261,30 +261,24 @@ class SuperChatManager:
if lines:
final_str += "\n" + "\n".join(lines)
return final_str
def get_superchat_statistics(self, chat_id: str) -> dict:
"""获取SuperChat统计信息"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return {
"count": 0,
"total_amount": 0,
"average_amount": 0,
"highest_amount": 0,
"lowest_amount": 0
}
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
amounts = [sc.price for sc in superchats]
return {
"count": len(superchats),
"total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts),
"lowest_amount": min(amounts)
"lowest_amount": min(amounts),
}
async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
@@ -296,15 +290,14 @@ class SuperChatManager:
logger.info("SuperChat管理器已关闭")
# sourcery skip: assign-if-exp
if s4u_config.enable_s4u:
super_chat_manager = SuperChatManager()
else:
super_chat_manager = None
def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例"""
return super_chat_manager
return super_chat_manager

View File

@@ -10,10 +10,12 @@ from src.common.logger import get_logger
logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table
def is_dict_like(obj):
return isinstance(obj, (dict, Table))
# 新增递归将Table转为dict
def table_to_dict(obj):
if isinstance(obj, Table):
@@ -25,6 +27,7 @@ def table_to_dict(obj):
else:
return obj
# 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
@@ -190,7 +193,7 @@ class S4UModelConfig(S4UConfigBase):
@dataclass
class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类"""
enable_s4u: bool = False
"""是否启用S4U聊天系统"""
@@ -229,12 +232,12 @@ class S4UConfig(S4UConfigBase):
enable_streaming_output: bool = True
"""是否启用流式输出false时全部生成后一次性发送"""
max_context_message_length: int = 20
"""上下文消息最大长度"""
max_core_message_length: int = 30
"""核心消息最大长度"""
"""核心消息最大长度"""
# 模型配置
models: S4UModelConfig = field(default_factory=S4UModelConfig)
@@ -243,7 +246,6 @@ class S4UConfig(S4UConfigBase):
# 兼容性字段,保持向后兼容
@dataclass
class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类"""
@@ -256,7 +258,7 @@ def update_s4u_config():
"""更新S4U配置文件"""
# 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
@@ -354,13 +356,13 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
logger.critical("S4U配置文件解析失败")
raise e
# 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config()
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u
s4u_config: S4UConfig = s4u_config_main.s4u

View File

@@ -13,7 +13,7 @@ async def migrate_memory_items_to_string():
并根据原始list的项目数量设置weight值
"""
logger.info("开始迁移记忆节点格式...")
migration_stats = {
"total_nodes": 0,
"converted_nodes": 0,
@@ -21,72 +21,74 @@ async def migrate_memory_items_to_string():
"empty_nodes": 0,
"error_nodes": 0,
"weight_updated_nodes": 0,
"truncated_nodes": 0
"truncated_nodes": 0,
}
try:
# 获取所有图节点
all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes:
try:
concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 如果为空,跳过
if not memory_items_raw:
migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}")
continue
try:
# 尝试解析JSON
parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list):
# 如果是list格式需要转换
if parsed_data:
# 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库
node.memory_items = new_memory_items
node.weight = new_weight
node.save()
migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}")
logger.info(
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
)
else:
# 空list设置为空字符串
node.memory_items = ""
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight
current_content = parsed_data
original_length = len(current_content)
content_truncated = False
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
@@ -94,19 +96,21 @@ async def migrate_memory_items_to_string():
migration_stats["truncated_nodes"] += 1
node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新
update_needed = False
if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算
content_parts = current_content.split(" | ") if " | " in current_content else [current_content]
content_parts = (
current_content.split(" | ") if " | " in current_content else [current_content]
)
estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight:
node.weight = estimated_weight
update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed:
node.save()
@@ -118,26 +122,26 @@ async def migrate_memory_items_to_string():
migration_stats["already_string_nodes"] += 1
else:
migration_stats["already_string_nodes"] += 1
else:
# 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串
@@ -145,16 +149,16 @@ async def migrate_memory_items_to_string():
# 去掉引号
clean_content = memory_items_raw[1:-1]
original_length = len(clean_content)
# 检查长度并截断
if len(clean_content) > 100:
clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}")
@@ -162,29 +166,29 @@ async def migrate_memory_items_to_string():
# 已经是纯字符串格式,检查长度
current_content = memory_items_raw
original_length = len(current_content)
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
node.memory_items = current_content
node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else:
migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e:
migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue
except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}")
raise
# 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}")
@@ -194,101 +198,105 @@ async def migrate_memory_items_to_string():
logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0
success_rate = (
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
/ migration_stats["total_nodes"]
* 100
if migration_stats["total_nodes"] > 0
else 0
)
logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats
async def set_all_person_known():
"""
将person_info库中所有记录的is_known字段设置为True
在设置之前先清理掉user_id或platform为空的记录
"""
logger.info("开始设置所有person_info记录为已认识...")
try:
from src.common.database.database_model import PersonInfo
# 获取所有PersonInfo记录
all_persons = PersonInfo.select()
total_count = all_persons.count()
logger.info(f"找到 {total_count} 个人员记录")
if total_count == 0:
logger.info("没有找到任何人员记录")
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
# 删除user_id或platform为空的记录
deleted_count = 0
invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
(PersonInfo.user_id.is_null())
| (PersonInfo.user_id == "")
| (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
# 记录要删除的记录信息
for record in invalid_records:
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}")
logger.debug(
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
)
# 执行删除操作
deleted_count = PersonInfo.delete().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
).execute()
deleted_count = (
PersonInfo.delete()
.where(
(PersonInfo.user_id.is_null())
| (PersonInfo.user_id == "")
| (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
.execute()
)
if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
else:
logger.info("没有发现user_id或platform为空的记录")
# 重新获取剩余记录数量
remaining_count = PersonInfo.select().count()
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
if remaining_count == 0:
logger.info("清理后没有剩余记录")
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
# 批量更新剩余记录的is_known字段为True
updated_count = PersonInfo.update(is_known=True).execute()
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
# 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = {
"total": total_count,
"deleted": deleted_count,
"updated": updated_count,
"known_count": known_count
}
result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}")
logger.info(f"删除记录数: {result['deleted']}")
logger.info(f"更新记录数: {result['updated']}")
logger.info(f"已认识记录数: {result['known_count']}")
return result
except Exception as e:
logger.error(f"更新person_info过程中发生错误: {e}")
raise
async def check_and_run_migrations():
# 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -309,4 +317,3 @@ async def check_and_run_migrations():
# 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f:
f.write("done")

View File

@@ -282,7 +282,7 @@ class Person:
memory_category = parts[0].strip()
memory_text = parts[1].strip()
memory_weight = parts[2].strip()
_memory_weight = parts[2].strip()
# 检查分类是否匹配
if memory_category != category:

View File

@@ -1,11 +1,5 @@
import json
from json_repair import repair_json
from datetime import datetime
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .person_info import Person
from src.chat.utils.prompt_builder import Prompt
logger = get_logger("relation")
@@ -43,4 +37,3 @@ def init_prompt():
""",
"attitude_to_me_prompt",
)

View File

@@ -121,5 +121,5 @@ __all__ = [
"DatabaseChatInfo",
"TargetPersonInfo",
"ActionPlannerInfo",
"LLMGenerationDataModel"
"LLMGenerationDataModel",
]

View File

@@ -7,22 +7,24 @@ logger = get_logger("frequency_api")
def get_current_focus_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value()
def get_current_talk_frequency(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency()
def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust = talk_frequency_adjust
frequency_control_manager.get_or_create_frequency_control(
chat_id
).talk_frequency_external_adjust = talk_frequency_adjust
def get_focus_value_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust
def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust

View File

@@ -159,6 +159,7 @@ async def generate_reply(
logger.error(traceback.format_exc())
return False, None
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,

View File

@@ -72,7 +72,9 @@ async def generate_with_model(
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name
except Exception as e:
@@ -80,6 +82,7 @@ async def generate_with_model(
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call

View File

@@ -435,9 +435,7 @@ def build_readable_messages_to_str(
Returns:
格式化后的可读字符串
"""
return build_readable_messages(
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
)
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
async def build_readable_messages_with_details(
@@ -491,8 +489,6 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
def translate_pid_to_description(pid: str) -> str:
image = Images.get_or_none(Images.image_id == pid)
description = ""
@@ -500,4 +496,4 @@ def translate_pid_to_description(pid: str) -> str:
description = image.description
else:
description = "[图片]"
return description
return description

View File

@@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str:
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""

View File

@@ -2,7 +2,7 @@ from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称
logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):

View File

@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -224,7 +224,7 @@ class PluginManager:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
@@ -401,9 +401,7 @@ class PluginManager:
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]

View File

@@ -149,10 +149,10 @@ class ToolExecutor:
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用
@@ -195,7 +195,9 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用

View File

@@ -63,5 +63,4 @@ class CoreActionsPlugin(BasePlugin):
if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction))
return components

View File

@@ -74,7 +74,9 @@ class BuildMemoryAction(BaseAction):
# 动作基本信息
action_name = "build_memory"
action_description = "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
action_description = (
"了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
)
# 动作参数定义
action_parameters = {
@@ -103,31 +105,34 @@ class BuildMemoryAction(BaseAction):
concept_name = self.action_data.get("concept_name", "")
# 2. 获取目标用户信息
# 对 concept_name 进行jieba分词
concept_name_tokens = cut_key_words(concept_name)
# logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}")
filtered_concept_name_tokens = [
token for token in concept_name_tokens if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
token
for token in concept_name_tokens
if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
]
if not filtered_concept_name_tokens:
logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆")
return False, "过滤后的概念名称列表为空,跳过添加记忆"
similar_topics_dict = hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(filtered_concept_name_tokens)
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(concept_description, similar_topics_dict)
similar_topics_dict = (
hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(
filtered_concept_name_tokens
)
)
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(
concept_description, similar_topics_dict
)
return True, f"成功添加记忆: {concept_name}"
except Exception as e:
logger.error(f"{self.log_prefix} 构建记忆时出错: {e}")
return False, f"构建记忆时出错: {e}"
# 还缺一个关系的太多遗忘和对应的提取

View File

@@ -425,7 +425,7 @@ class ManagementCommand(BaseCommand):
await self._send_message(f"本地禁用组件成功: {component_name}")
else:
await self._send_message(f"本地禁用组件失败: {component_name}")
async def _send_message(self, message: str):
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)

View File

@@ -107,7 +107,7 @@ class BuildRelationAction(BaseAction):
if not person.is_known:
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
return False, f"用户 {person_name} 不存在,跳过添加记忆"
person.last_know = time.time()
person.know_times += 1
person.sync_to_database()
@@ -178,7 +178,9 @@ class BuildRelationAction(BaseAction):
chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
prompt,
model_config=chat_model_config,
request_type="relation.category.update", # type: ignore
)
update_memory_data = json.loads(repair_json(update_memory))
@@ -190,7 +192,7 @@ class BuildRelationAction(BaseAction):
# 新记忆
person.memory_points.append(f"{category}:{new_memory}:1.0")
person.sync_to_database()
logger.info(f"{self.log_prefix}{person.person_name}新增记忆点: {new_memory}")
return True, f"{person.person_name}新增记忆点: {new_memory}"
@@ -207,14 +209,15 @@ class BuildRelationAction(BaseAction):
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
person.sync_to_database()
logger.info(f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}")
logger.info(
f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
)
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
else:
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功"