Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev

This commit is contained in:
UnCLAS-Prommer
2025-09-09 22:36:09 +08:00
66 changed files with 1085 additions and 1038 deletions

3
bot.py
View File

@@ -62,9 +62,10 @@ def easter_egg():
async def graceful_shutdown(): # sourcery skip: use-named-expression async def graceful_shutdown(): # sourcery skip: use-named-expression
try: try:
logger.info("正在优雅关闭麦麦...") logger.info("正在优雅关闭麦麦...")
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType
# 触发 ON_STOP 事件 # 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP) await events_manager.handle_mai_events(event_type=EventType.ON_STOP)

View File

@@ -5,12 +5,11 @@ from typing import Dict, List
# Add project root to Python path # Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly""" """Get chat name from chat_id by querying ChatStreams table directly"""
try: try:
@@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None: if chat_stream is None:
return f"未知聊天 ({chat_id})" return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称 # 如果有群组信息,显示群组名称
if chat_stream.group_name: if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})" return f"{chat_stream.group_name} ({chat_id})"
@@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days""" """Calculate distribution of last active time in days"""
now = time.time() now = time.time()
distribution = { distribution = {
'0-1天': 0, "0-1天": 0,
'1-3天': 0, "1-3天": 0,
'3-7天': 0, "3-7天": 0,
'7-14天': 0, "7-14天": 0,
'14-30天': 0, "14-30天": 0,
'30-60天': 0, "30-60天": 0,
'60-90天': 0, "60-90天": 0,
'90+天': 0 "90+天": 0,
} }
for expr in expressions: for expr in expressions:
diff_days = (now - expr.last_active_time) / (24*3600) diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1: if diff_days < 1:
distribution['0-1天'] += 1 distribution["0-1天"] += 1
elif diff_days < 3: elif diff_days < 3:
distribution['1-3天'] += 1 distribution["1-3天"] += 1
elif diff_days < 7: elif diff_days < 7:
distribution['3-7天'] += 1 distribution["3-7天"] += 1
elif diff_days < 14: elif diff_days < 14:
distribution['7-14天'] += 1 distribution["7-14天"] += 1
elif diff_days < 30: elif diff_days < 30:
distribution['14-30天'] += 1 distribution["14-30天"] += 1
elif diff_days < 60: elif diff_days < 60:
distribution['30-60天'] += 1 distribution["30-60天"] += 1
elif diff_days < 90: elif diff_days < 90:
distribution['60-90天'] += 1 distribution["60-90天"] += 1
else: else:
distribution['90+天'] += 1 distribution["90+天"] += 1
return distribution return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]: def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values""" """Calculate distribution of count values"""
distribution = { distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
'0-1': 0,
'1-2': 0,
'2-3': 0,
'3-4': 0,
'4-5': 0,
'5-10': 0,
'10+': 0
}
for expr in expressions: for expr in expressions:
cnt = expr.count cnt = expr.count
if cnt < 1: if cnt < 1:
distribution['0-1'] += 1 distribution["0-1"] += 1
elif cnt < 2: elif cnt < 2:
distribution['1-2'] += 1 distribution["1-2"] += 1
elif cnt < 3: elif cnt < 3:
distribution['2-3'] += 1 distribution["2-3"] += 1
elif cnt < 4: elif cnt < 4:
distribution['3-4'] += 1 distribution["3-4"] += 1
elif cnt < 5: elif cnt < 5:
distribution['4-5'] += 1 distribution["4-5"] += 1
elif cnt < 10: elif cnt < 10:
distribution['5-10'] += 1 distribution["5-10"] += 1
else: else:
distribution['10+'] += 1 distribution["10+"] += 1
return distribution return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id""" """Get top N most used expressions for a specific chat_id"""
return (Expression.select() return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
.where(Expression.chat_id == chat_id)
.order_by(Expression.count.desc())
.limit(top_n))
def show_overall_statistics(expressions, total: int) -> None: def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics""" """Show overall statistics"""
time_dist = calculate_time_distribution(expressions) time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions) count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===") print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}") print(f"总表达式数量: {total}")
print("\n上次激活时间分布:") print("\n上次激活时间分布:")
for period, count in time_dist.items(): for period, count in time_dist.items():
print(f"{period}: {count} ({count/total*100:.2f}%)") print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:") print("\ncount分布:")
for range_, count in count_dist.items(): for range_, count in count_dist.items():
print(f"{range_}: {count} ({count/total*100:.2f}%)") print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None: def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat""" """Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id)) chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs) chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===") print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}") print(f"表达式数量: {chat_total}")
if chat_total == 0: if chat_total == 0:
print("该聊天没有表达式数据") print("该聊天没有表达式数据")
return return
# Time distribution for this chat # Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs) time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:") print("\n上次激活时间分布:")
for period, count in time_dist.items(): for period, count in time_dist.items():
if count > 0: if count > 0:
print(f"{period}: {count} ({count/chat_total*100:.2f}%)") print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat # Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs) count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:") print("\ncount分布:")
for range_, count in count_dist.items(): for range_, count in count_dist.items():
if count > 0: if count > 0:
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)") print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions # Top expressions
print("\nTop 10使用最多的表达式:") print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10) top_exprs = get_top_expressions_by_chat(chat_id, 10)
@@ -163,32 +151,32 @@ def interactive_menu() -> None:
if not expressions: if not expressions:
print("数据库中没有找到表达式") print("数据库中没有找到表达式")
return return
total = len(expressions) total = len(expressions)
# Get unique chat_ids and their names # Get unique chat_ids and their names
chat_ids = list(set(expr.chat_id for expr in expressions)) chat_ids = list(set(expr.chat_id for expr in expressions))
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids] chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("表达式统计分析") print("表达式统计分析")
print("="*50) print("=" * 50)
print("0. 显示总体统计") print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1): for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id) chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)") print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出") print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip() choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break
try: try:
choice_num = int(choice) choice_num = int(choice)
if choice_num == 0: if choice_num == 0:
@@ -200,9 +188,9 @@ def interactive_menu() -> None:
print("无效的选择,请重新输入") print("无效的选择,请重新输入")
except ValueError: except ValueError:
print("请输入有效的数字") print("请输入有效的数字")
input("\n按回车键继续...") input("\n按回车键继续...")
if __name__ == "__main__": if __name__ == "__main__":
interactive_menu() interactive_menu()

View File

@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入") logger = get_logger("OpenIE导入")
def ensure_openie_dir(): def ensure_openie_dir():
"""确保OpenIE数据目录存在""" """确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR): if not os.path.exists(OPENIE_DIR):
@@ -253,7 +254,7 @@ def main():
# 没有运行的事件循环,创建新的 # 没有运行的事件循环,创建新的
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
# 在新的事件循环中运行异步主函数 # 在新的事件循环中运行异步主函数
loop.run_until_complete(main_async()) loop.run_until_complete(main_async())

View File

@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条 from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config # from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.open_ie import OpenIE
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs(): def ensure_dirs():
"""确保临时目录和输出目录存在""" """确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR): if not os.path.exists(TEMP_DIR):
@@ -48,6 +50,7 @@ def ensure_dirs():
os.makedirs(RAW_DATA_PATH) os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据 # 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock() file_lock = Lock()
open_ie_doc_lock = Lock() open_ie_doc_lock = Lock()
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event() shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest( lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
) )
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
def process_single_text(pg_hash, raw_data): def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池""" """处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"

View File

@@ -3,12 +3,11 @@ import sys
import os import os
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
from datetime import datetime from datetime import datetime
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
@@ -17,7 +16,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None: if chat_stream is None:
return f"未知聊天 ({chat_id})" return f"未知聊天 ({chat_id})"
if chat_stream.group_name: if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})" return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname: elif chat_stream.user_nickname:
@@ -39,66 +38,62 @@ def format_timestamp(timestamp: float) -> str:
def calculate_interest_value_distribution(messages) -> Dict[str, int]: def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value""" """Calculate distribution of interest_value"""
distribution = { distribution = {
'0.000-0.010': 0, "0.000-0.010": 0,
'0.010-0.050': 0, "0.010-0.050": 0,
'0.050-0.100': 0, "0.050-0.100": 0,
'0.100-0.500': 0, "0.100-0.500": 0,
'0.500-1.000': 0, "0.500-1.000": 0,
'1.000-2.000': 0, "1.000-2.000": 0,
'2.000-5.000': 0, "2.000-5.000": 0,
'5.000-10.000': 0, "5.000-10.000": 0,
'10.000+': 0 "10.000+": 0,
} }
for msg in messages: for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0: if msg.interest_value is None or msg.interest_value == 0.0:
continue continue
value = float(msg.interest_value) value = float(msg.interest_value)
if value < 0.010: if value < 0.010:
distribution['0.000-0.010'] += 1 distribution["0.000-0.010"] += 1
elif value < 0.050: elif value < 0.050:
distribution['0.010-0.050'] += 1 distribution["0.010-0.050"] += 1
elif value < 0.100: elif value < 0.100:
distribution['0.050-0.100'] += 1 distribution["0.050-0.100"] += 1
elif value < 0.500: elif value < 0.500:
distribution['0.100-0.500'] += 1 distribution["0.100-0.500"] += 1
elif value < 1.000: elif value < 1.000:
distribution['0.500-1.000'] += 1 distribution["0.500-1.000"] += 1
elif value < 2.000: elif value < 2.000:
distribution['1.000-2.000'] += 1 distribution["1.000-2.000"] += 1
elif value < 5.000: elif value < 5.000:
distribution['2.000-5.000'] += 1 distribution["2.000-5.000"] += 1
elif value < 10.000: elif value < 10.000:
distribution['5.000-10.000'] += 1 distribution["5.000-10.000"] += 1
else: else:
distribution['10.000+'] += 1 distribution["10.000+"] += 1
return distribution return distribution
def get_interest_value_stats(messages) -> Dict[str, float]: def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value""" """Calculate basic statistics for interest_value"""
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0] values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values: if not values:
return { return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
'count': 0,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
values.sort() values.sort()
count = len(values) count = len(values)
return { return {
'count': count, "count": count,
'min': min(values), "min": min(values),
'max': max(values), "max": max(values),
'avg': sum(values) / count, "avg": sum(values) / count,
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2 "median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
} }
@@ -109,20 +104,24 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {} chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct(): for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id chat_id = msg.chat_id
count = Messages.select().where( count = (
(Messages.chat_id == chat_id) & Messages.select()
(Messages.interest_value.is_null(False)) & .where(
(Messages.interest_value != 0.0) (Messages.chat_id == chat_id)
).count() & (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0: if count > 0:
chat_counts[chat_id] = count chat_counts[chat_id] = count
# 获取聊天名称 # 获取聊天名称
result = [] result = []
for chat_id, count in chat_counts.items(): for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count)) result.append((chat_id, chat_name, count))
# 按消息数量排序 # 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True) result.sort(key=lambda x: x[2], reverse=True)
return result return result
@@ -135,30 +134,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user""" """Get time range input from user"""
print("\n时间范围选择:") print("\n时间范围选择:")
print("1. 最近1天") print("1. 最近1天")
print("2. 最近3天") print("2. 最近3天")
print("3. 最近7天") print("3. 最近7天")
print("4. 最近30天") print("4. 最近30天")
print("5. 自定义时间范围") print("5. 自定义时间范围")
print("6. 不限制时间") print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip() choice = input("请选择时间范围 (1-6): ").strip()
now = time.time() now = time.time()
if choice == "1": if choice == "1":
return now - 24*3600, now return now - 24 * 3600, now
elif choice == "2": elif choice == "2":
return now - 3*24*3600, now return now - 3 * 24 * 3600, now
elif choice == "3": elif choice == "3":
return now - 7*24*3600, now return now - 7 * 24 * 3600, now
elif choice == "4": elif choice == "4":
return now - 30*24*3600, now return now - 30 * 24 * 3600, now
elif choice == "5": elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip() start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip() end_str = input().strip()
try: try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
@@ -170,41 +169,40 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
return None, None return None, None
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze interest values with optional filters""" """Analyze interest values with optional filters"""
# 构建查询条件 # 构建查询条件
query = Messages.select().where( query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
)
if chat_id: if chat_id:
query = query.where(Messages.chat_id == chat_id) query = query.where(Messages.chat_id == chat_id)
if start_time: if start_time:
query = query.where(Messages.time >= start_time) query = query.where(Messages.time >= start_time)
if end_time: if end_time:
query = query.where(Messages.time <= end_time) query = query.where(Messages.time <= end_time)
messages = list(query) messages = list(query)
if not messages: if not messages:
print("没有找到符合条件的消息") print("没有找到符合条件的消息")
return return
# 计算统计信息 # 计算统计信息
distribution = calculate_interest_value_distribution(messages) distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages) stats = get_interest_value_stats(messages)
# 显示结果 # 显示结果
print("\n=== Interest Value 分析结果 ===") print("\n=== Interest Value 分析结果 ===")
if chat_id: if chat_id:
print(f"聊天: {get_chat_name(chat_id)}") print(f"聊天: {get_chat_name(chat_id)}")
else: else:
print("聊天: 全部聊天") print("聊天: 全部聊天")
if start_time and end_time: if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}") print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time: elif start_time:
@@ -213,16 +211,16 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
print(f"时间范围: {format_timestamp(end_time)} 之前") print(f"时间范围: {format_timestamp(end_time)} 之前")
else: else:
print("时间范围: 不限制") print("时间范围: 不限制")
print("\n基本统计:") print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)") print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}") print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}") print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}") print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}") print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:") print("\nInterest Value 分布:")
total = stats['count'] total = stats["count"]
for range_name, count in distribution.items(): for range_name, count in distribution.items():
if count > 0: if count > 0:
percentage = count / total * 100 percentage = count / total * 100
@@ -231,34 +229,34 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
def interactive_menu() -> None: def interactive_menu() -> None:
"""Interactive menu for interest value analysis""" """Interactive menu for interest value analysis"""
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("Interest Value 分析工具") print("Interest Value 分析工具")
print("="*50) print("=" * 50)
print("1. 分析全部聊天") print("1. 分析全部聊天")
print("2. 选择特定聊天分析") print("2. 选择特定聊天分析")
print("q. 退出") print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip() choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break
chat_id = None chat_id = None
if choice == "2": if choice == "2":
# 显示可用的聊天列表 # 显示可用的聊天列表
chats = get_available_chats() chats = get_available_chats()
if not chats: if not chats:
print("没有找到有interest_value数据的聊天") print("没有找到有interest_value数据的聊天")
continue continue
print(f"\n可用的聊天 (共{len(chats)}个):") print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1): for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)") print(f"{i}. {name} ({count}条有效消息)")
try: try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats): if 1 <= chat_choice <= len(chats):
@@ -269,19 +267,19 @@ def interactive_menu() -> None:
except ValueError: except ValueError:
print("请输入有效数字") print("请输入有效数字")
continue continue
elif choice != "1": elif choice != "1":
print("无效选择") print("无效选择")
continue continue
# 获取时间范围 # 获取时间范围
start_time, end_time = get_time_range_input() start_time, end_time = get_time_range_input()
# 执行分析 # 执行分析
analyze_interest_values(chat_id, start_time, end_time) analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...") input("\n按回车键继续...")
if __name__ == "__main__": if __name__ == "__main__":
interactive_menu() interactive_menu()

View File

@@ -828,7 +828,7 @@ class LogViewer:
parts, tags = self.formatter.format_log_entry(log_entry) parts, tags = self.formatter.format_log_entry(log_entry)
line_text = " ".join(parts) line_text = " ".join(parts)
log_lines.append(line_text) log_lines.append(line_text)
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
f.write("\n".join(log_lines)) f.write("\n".join(log_lines))
messagebox.showinfo("导出成功", f"日志已导出到: {filename}") messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
@@ -1188,15 +1188,16 @@ class LogViewer:
line_count += 1 line_count += 1
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
# 如果发现了新模块,在主线程中更新模块集合 # 如果发现了新模块,在主线程中更新模块集合
if new_modules: if new_modules:
def update_modules(): def update_modules():
self.modules.update(new_modules) self.modules.update(new_modules)
self.update_module_list() self.update_module_list()
self.root.after(0, update_modules) self.root.after(0, update_modules)
return new_entries return new_entries
def append_new_logs(self, new_entries): def append_new_logs(self, new_entries):
@@ -1424,4 +1425,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,6 +2,7 @@ import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
from src.chat.knowledge.utils.hash import get_sha256 from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
def _process_text_file(file_path): def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表""" """处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
all_paragraphs.extend(paragraphs) all_paragraphs.extend(paragraphs)
return all_paragraphs return all_paragraphs
def load_raw_data() -> tuple[list[str], list[str]]: def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件 """加载原始数据文件
@@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]:
raw_data.append(item) raw_data.append(item)
logger.info(f"共读取到{len(raw_data)}条数据") logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data return sha256_list, raw_data

View File

@@ -4,21 +4,22 @@ import os
import re import re
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
from datetime import datetime from datetime import datetime
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa from src.common.database.database_model import Messages, ChatStreams # noqa
def contains_emoji_or_image_tags(text: str) -> bool: def contains_emoji_or_image_tags(text: str) -> bool:
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags""" """Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
if not text: if not text:
return False return False
# 检查是否包含 [表情包] 或 [图片] 标记 # 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r'\[表情包[^\]]*\]' emoji_pattern = r"\[表情包[^\]]*\]"
image_pattern = r'\[图片[^\]]*\]' image_pattern = r"\[图片[^\]]*\]"
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text)) return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
@@ -26,14 +27,14 @@ def clean_reply_text(text: str) -> str:
"""Remove reply references like [回复 xxxx...] from text""" """Remove reply references like [回复 xxxx...] from text"""
if not text: if not text:
return text return text
# 匹配 [回复 xxxx...] 格式的内容 # 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止 # 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text) cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
# 去除多余的空白字符 # 去除多余的空白字符
cleaned_text = cleaned_text.strip() cleaned_text = cleaned_text.strip()
return cleaned_text return cleaned_text
@@ -43,7 +44,7 @@ def get_chat_name(chat_id: str) -> str:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None: if chat_stream is None:
return f"未知聊天 ({chat_id})" return f"未知聊天 ({chat_id})"
if chat_stream.group_name: if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})" return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname: elif chat_stream.user_nickname:
@@ -65,63 +66,63 @@ def format_timestamp(timestamp: float) -> str:
def calculate_text_length_distribution(messages) -> Dict[str, int]: def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length""" """Calculate distribution of processed_plain_text length"""
distribution = { distribution = {
'0': 0, # 空文本 "0": 0, # 空文本
'1-5': 0, # 极短文本 "1-5": 0, # 极短文本
'6-10': 0, # 很短文本 "6-10": 0, # 很短文本
'11-20': 0, # 短文本 "11-20": 0, # 短文本
'21-30': 0, # 较短文本 "21-30": 0, # 较短文本
'31-50': 0, # 中短文本 "31-50": 0, # 中短文本
'51-70': 0, # 中等文本 "51-70": 0, # 中等文本
'71-100': 0, # 较长文本 "71-100": 0, # 较长文本
'101-150': 0, # 长文本 "101-150": 0, # 长文本
'151-200': 0, # 很长文本 "151-200": 0, # 很长文本
'201-300': 0, # 超长文本 "201-300": 0, # 超长文本
'301-500': 0, # 极长文本 "301-500": 0, # 极长文本
'501-1000': 0, # 巨长文本 "501-1000": 0, # 巨长文本
'1000+': 0 # 超巨长文本 "1000+": 0, # 超巨长文本
} }
for msg in messages: for msg in messages:
if msg.processed_plain_text is None: if msg.processed_plain_text is None:
continue continue
# 排除包含表情包或图片标记的消息 # 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text): if contains_emoji_or_image_tags(msg.processed_plain_text):
continue continue
# 清理文本中的回复引用 # 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text) cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text) length = len(cleaned_text)
if length == 0: if length == 0:
distribution['0'] += 1 distribution["0"] += 1
elif length <= 5: elif length <= 5:
distribution['1-5'] += 1 distribution["1-5"] += 1
elif length <= 10: elif length <= 10:
distribution['6-10'] += 1 distribution["6-10"] += 1
elif length <= 20: elif length <= 20:
distribution['11-20'] += 1 distribution["11-20"] += 1
elif length <= 30: elif length <= 30:
distribution['21-30'] += 1 distribution["21-30"] += 1
elif length <= 50: elif length <= 50:
distribution['31-50'] += 1 distribution["31-50"] += 1
elif length <= 70: elif length <= 70:
distribution['51-70'] += 1 distribution["51-70"] += 1
elif length <= 100: elif length <= 100:
distribution['71-100'] += 1 distribution["71-100"] += 1
elif length <= 150: elif length <= 150:
distribution['101-150'] += 1 distribution["101-150"] += 1
elif length <= 200: elif length <= 200:
distribution['151-200'] += 1 distribution["151-200"] += 1
elif length <= 300: elif length <= 300:
distribution['201-300'] += 1 distribution["201-300"] += 1
elif length <= 500: elif length <= 500:
distribution['301-500'] += 1 distribution["301-500"] += 1
elif length <= 1000: elif length <= 1000:
distribution['501-1000'] += 1 distribution["501-1000"] += 1
else: else:
distribution['1000+'] += 1 distribution["1000+"] += 1
return distribution return distribution
@@ -130,7 +131,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
lengths = [] lengths = []
null_count = 0 null_count = 0
excluded_count = 0 # 被排除的消息数量 excluded_count = 0 # 被排除的消息数量
for msg in messages: for msg in messages:
if msg.processed_plain_text is None: if msg.processed_plain_text is None:
null_count += 1 null_count += 1
@@ -141,29 +142,29 @@ def get_text_length_stats(messages) -> Dict[str, float]:
# 清理文本中的回复引用 # 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text) cleaned_text = clean_reply_text(msg.processed_plain_text)
lengths.append(len(cleaned_text)) lengths.append(len(cleaned_text))
if not lengths: if not lengths:
return { return {
'count': 0, "count": 0,
'null_count': null_count, "null_count": null_count,
'excluded_count': excluded_count, "excluded_count": excluded_count,
'min': 0, "min": 0,
'max': 0, "max": 0,
'avg': 0, "avg": 0,
'median': 0 "median": 0,
} }
lengths.sort() lengths.sort()
count = len(lengths) count = len(lengths)
return { return {
'count': count, "count": count,
'null_count': null_count, "null_count": null_count,
'excluded_count': excluded_count, "excluded_count": excluded_count,
'min': min(lengths), "min": min(lengths),
'max': max(lengths), "max": max(lengths),
'avg': sum(lengths) / count, "avg": sum(lengths) / count,
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2 "median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
} }
@@ -174,21 +175,25 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {} chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct(): for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id chat_id = msg.chat_id
count = Messages.select().where( count = (
(Messages.chat_id == chat_id) & Messages.select()
(Messages.is_emoji != 1) & .where(
(Messages.is_picid != 1) & (Messages.chat_id == chat_id)
(Messages.is_command != 1) & (Messages.is_emoji != 1)
).count() & (Messages.is_picid != 1)
& (Messages.is_command != 1)
)
.count()
)
if count > 0: if count > 0:
chat_counts[chat_id] = count chat_counts[chat_id] = count
# 获取聊天名称 # 获取聊天名称
result = [] result = []
for chat_id, count in chat_counts.items(): for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count)) result.append((chat_id, chat_name, count))
# 按消息数量排序 # 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True) result.sort(key=lambda x: x[2], reverse=True)
return result return result
@@ -201,30 +206,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user""" """Get time range input from user"""
print("\n时间范围选择:") print("\n时间范围选择:")
print("1. 最近1天") print("1. 最近1天")
print("2. 最近3天") print("2. 最近3天")
print("3. 最近7天") print("3. 最近7天")
print("4. 最近30天") print("4. 最近30天")
print("5. 自定义时间范围") print("5. 自定义时间范围")
print("6. 不限制时间") print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip() choice = input("请选择时间范围 (1-6): ").strip()
now = time.time() now = time.time()
if choice == "1": if choice == "1":
return now - 24*3600, now return now - 24 * 3600, now
elif choice == "2": elif choice == "2":
return now - 3*24*3600, now return now - 3 * 24 * 3600, now
elif choice == "3": elif choice == "3":
return now - 7*24*3600, now return now - 7 * 24 * 3600, now
elif choice == "4": elif choice == "4":
return now - 30*24*3600, now return now - 30 * 24 * 3600, now
elif choice == "5": elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip() start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip() end_str = input().strip()
try: try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
@@ -239,13 +244,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
"""Get top N longest messages""" """Get top N longest messages"""
message_lengths = [] message_lengths = []
for msg in messages: for msg in messages:
if msg.processed_plain_text is not None: if msg.processed_plain_text is not None:
# 排除包含表情包或图片标记的消息 # 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text): if contains_emoji_or_image_tags(msg.processed_plain_text):
continue continue
# 清理文本中的回复引用 # 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text) cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text) length = len(cleaned_text)
@@ -254,42 +259,40 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
# 截取前100个字符作为预览 # 截取前100个字符作为预览
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
message_lengths.append((chat_name, length, time_str, preview)) message_lengths.append((chat_name, length, time_str, preview))
# 按长度排序取前N个 # 按长度排序取前N个
message_lengths.sort(key=lambda x: x[1], reverse=True) message_lengths.sort(key=lambda x: x[1], reverse=True)
return message_lengths[:top_n] return message_lengths[:top_n]
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters""" """Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息 # 构建查询条件,排除特殊类型的消息
query = Messages.select().where( query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
)
if chat_id: if chat_id:
query = query.where(Messages.chat_id == chat_id) query = query.where(Messages.chat_id == chat_id)
if start_time: if start_time:
query = query.where(Messages.time >= start_time) query = query.where(Messages.time >= start_time)
if end_time: if end_time:
query = query.where(Messages.time <= end_time) query = query.where(Messages.time <= end_time)
messages = list(query) messages = list(query)
if not messages: if not messages:
print("没有找到符合条件的消息") print("没有找到符合条件的消息")
return return
# 计算统计信息 # 计算统计信息
distribution = calculate_text_length_distribution(messages) distribution = calculate_text_length_distribution(messages)
stats = get_text_length_stats(messages) stats = get_text_length_stats(messages)
top_longest = get_top_longest_messages(messages, 10) top_longest = get_top_longest_messages(messages, 10)
# 显示结果 # 显示结果
print("\n=== Processed Plain Text 长度分析结果 ===") print("\n=== Processed Plain Text 长度分析结果 ===")
print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)") print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)")
@@ -297,7 +300,7 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
print(f"聊天: {get_chat_name(chat_id)}") print(f"聊天: {get_chat_name(chat_id)}")
else: else:
print("聊天: 全部聊天") print("聊天: 全部聊天")
if start_time and end_time: if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}") print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time: elif start_time:
@@ -306,26 +309,26 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
print(f"时间范围: {format_timestamp(end_time)} 之前") print(f"时间范围: {format_timestamp(end_time)} 之前")
else: else:
print("时间范围: 不限制") print("时间范围: 不限制")
print("\n基本统计:") print("\n基本统计:")
print(f"总消息数量: {len(messages)}") print(f"总消息数量: {len(messages)}")
print(f"有文本消息数量: {stats['count']}") print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}") print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}") print(f"被排除的消息数量: {stats['excluded_count']}")
if stats['count'] > 0: if stats["count"] > 0:
print(f"最短长度: {stats['min']} 字符") print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符") print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符") print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符") print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:") print("\n文本长度分布:")
total = stats['count'] total = stats["count"]
if total > 0: if total > 0:
for range_name, count in distribution.items(): for range_name, count in distribution.items():
if count > 0: if count > 0:
percentage = count / total * 100 percentage = count / total * 100
print(f"{range_name} 字符: {count} ({percentage:.2f}%)") print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
# 显示最长的消息 # 显示最长的消息
if top_longest: if top_longest:
print(f"\n最长的 {len(top_longest)} 条消息:") print(f"\n最长的 {len(top_longest)} 条消息:")
@@ -338,34 +341,34 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
def interactive_menu() -> None: def interactive_menu() -> None:
"""Interactive menu for text length analysis""" """Interactive menu for text length analysis"""
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("Processed Plain Text 长度分析工具") print("Processed Plain Text 长度分析工具")
print("="*50) print("=" * 50)
print("1. 分析全部聊天") print("1. 分析全部聊天")
print("2. 选择特定聊天分析") print("2. 选择特定聊天分析")
print("q. 退出") print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip() choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break
chat_id = None chat_id = None
if choice == "2": if choice == "2":
# 显示可用的聊天列表 # 显示可用的聊天列表
chats = get_available_chats() chats = get_available_chats()
if not chats: if not chats:
print("没有找到聊天数据") print("没有找到聊天数据")
continue continue
print(f"\n可用的聊天 (共{len(chats)}个):") print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1): for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条消息)") print(f"{i}. {name} ({count}条消息)")
try: try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats): if 1 <= chat_choice <= len(chats):
@@ -376,19 +379,19 @@ def interactive_menu() -> None:
except ValueError: except ValueError:
print("请输入有效数字") print("请输入有效数字")
continue continue
elif choice != "1": elif choice != "1":
print("无效选择") print("无效选择")
continue continue
# 获取时间范围 # 获取时间范围
start_time, end_time = get_time_range_input() start_time, end_time = get_time_range_input()
# 执行分析 # 执行分析
analyze_text_lengths(chat_id, start_time, end_time) analyze_text_lengths(chat_id, start_time, end_time)
input("\n按回车键继续...") input("\n按回车键继续...")
if __name__ == "__main__": if __name__ == "__main__":
interactive_menu() interactive_menu()

View File

@@ -708,7 +708,7 @@ class EmojiManager:
if not emoji.is_deleted and emoji.hash == emoji_hash: if not emoji.is_deleted and emoji.hash == emoji_hash:
return emoji return emoji
return None # 如果循环结束还没找到,则返回 None return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]: async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
"""根据哈希值获取已注册表情包的情感标签列表 """根据哈希值获取已注册表情包的情感标签列表
@@ -731,7 +731,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion: if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',') return emoji_record.emotion.split(",")
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}") logger.error(f"从数据库查询表情包情感标签时出错: {e}")

View File

@@ -77,10 +77,10 @@ class ExpressionSelector:
def can_use_expression_for_chat(self, chat_id: str) -> bool: def can_use_expression_for_chat(self, chat_id: str) -> bool:
""" """
检查指定聊天流是否允许使用表达 检查指定聊天流是否允许使用表达
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID
Returns: Returns:
bool: 是否允许使用表达 bool: 是否允许使用表达
""" """
@@ -123,9 +123,7 @@ class ExpressionSelector:
return group_chat_ids return group_chat_ids
return [chat_id] return [chat_id]
def get_random_expressions( def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
self, chat_id: str, total_num: int
) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign # sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选 # 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -200,7 +198,7 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]: ) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension # sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式""" """使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达 # 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id): if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
@@ -208,7 +206,7 @@ class ExpressionSelector:
# 1. 获取20个随机表达方式现在按权重抽取 # 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 10) style_exprs = self.get_random_expressions(chat_id, 10)
if len(style_exprs) < 10: if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中") logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], [] return [], []
@@ -248,7 +246,6 @@ class ExpressionSelector:
# 4. 调用LLM # 4. 调用LLM
try: try:
# start_time = time.time() # start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}") # logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
@@ -295,7 +292,6 @@ class ExpressionSelector:
except Exception as e: except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}") logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], [] return [], []
init_prompt() init_prompt()

View File

@@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]:
return get_time_based_focus_value(config_item[1:]) return get_time_based_focus_value(config_item[1:])
return None return None

View File

@@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]:
return get_time_based_frequency(config_item[1:]) return get_time_based_frequency(config_item[1:])
return None return None

View File

@@ -34,4 +34,4 @@ def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError): except (ValueError, IndexError):
return None return None

View File

@@ -261,7 +261,11 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers return loop_info, reply_text, cycle_timers
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool: async def _observe(
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None
) -> bool:
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率 # 使用sigmoid函数将interest_value转换为概率

View File

@@ -3,14 +3,16 @@ from typing import Any, Optional, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.heart_flow.heartFC_chat import HeartFChatting from src.chat.heart_flow.heartFC_chat import HeartFChatting
logger = get_logger("heartflow") logger = get_logger("heartflow")
class Heartflow: class Heartflow:
"""主心流协调器,负责初始化并协调聊天""" """主心流协调器,负责初始化并协调聊天"""
def __init__(self): def __init__(self):
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {} self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]: async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
"""获取或创建一个新的HeartFChatting实例""" """获取或创建一个新的HeartFChatting实例"""
try: try:
@@ -18,7 +20,7 @@ class Heartflow:
if chat := self.heartflow_chat_list.get(chat_id): if chat := self.heartflow_chat_list.get(chat_id):
return chat return chat
else: else:
new_chat = HeartFChatting(chat_id = chat_id) new_chat = HeartFChatting(chat_id=chat_id)
await new_chat.start() await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat self.heartflow_chat_list[chat_id] = new_chat
return new_chat return new_chat
@@ -27,4 +29,5 @@ class Heartflow:
traceback.print_exc() traceback.print_exc()
return None return None
heartflow = Heartflow() heartflow = Heartflow()

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
logger = get_logger("chat") logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度 """计算消息的兴趣度
@@ -34,14 +35,14 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
""" """
if message.is_picid or message.is_emoji: if message.is_picid or message.is_emoji:
return 0.0, [] return 0.0, []
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message) is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
interested_rate = 0.0 interested_rate = 0.0
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text( interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
max_depth= 4, max_depth=4,
fast_retrieval=global_config.chat.interest_rate_mode == "fast", fast_retrieval=global_config.chat.interest_rate_mode == "fast",
) )
message.key_words = keywords message.key_words = keywords
@@ -51,7 +52,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
text_len = len(message.processed_plain_text) text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) # 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0: if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度 base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5: elif text_len <= 5:
@@ -75,16 +76,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
else: else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减 # 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内 # 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3) base_interest = min(max(base_interest, 0.01), 0.3)
message.interest_value = base_interest message.interest_value = base_interest
message.is_mentioned = is_mentioned message.is_mentioned = is_mentioned
message.is_at = is_at message.is_at = is_at
message.reply_probability_boost = reply_probability_boost message.reply_probability_boost = reply_probability_boost
return base_interest, keywords return base_interest, keywords
@@ -115,14 +115,13 @@ class HeartFCMessageReceiver:
# 2. 兴趣度计算与更新 # 2. 兴趣度计算与更新
interested_rate, keywords = await _calculate_interest(message) interested_rate, keywords = await _calculate_interest(message)
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
if global_config.mood.enable_mood: if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id) chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
@@ -132,7 +131,7 @@ class HeartFCMessageReceiver:
# 用这个pattern截取出id部分picid是一个list并替换成对应的图片描述 # 用这个pattern截取出id部分picid是一个list并替换成对应的图片描述
picid_pattern = r"\[picid:([^\]]+)\]" picid_pattern = r"\[picid:([^\]]+)\]"
picid_list = re.findall(picid_pattern, message.processed_plain_text) picid_list = re.findall(picid_pattern, message.processed_plain_text)
# 创建替换后的文本 # 创建替换后的文本
processed_text = message.processed_plain_text processed_text = message.processed_plain_text
if picid_list: if picid_list:
@@ -145,18 +144,20 @@ class HeartFCMessageReceiver:
# 如果没有找到图片描述,则移除[picid:xxxx]标记 # 如果没有找到图片描述,则移除[picid:xxxx]标记
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]") processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式 # 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references( processed_plain_text = replace_user_references(
processed_text, processed_text,
message.message_info.platform, # type: ignore message.message_info.platform, # type: ignore
replace_bot_name=True replace_bot_name=True,
) )
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore _ = Person.register_person(
platform=message.message_info.platform,
user_id=message.message_info.user_info.user_id,
nickname=userinfo.user_nickname,
) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"消息处理失败: {e}") logger.error(f"消息处理失败: {e}")

View File

@@ -124,6 +124,7 @@ async def send_typing():
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
) )
async def stop_typing(): async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
@@ -135,4 +136,4 @@ async def stop_typing():
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
) )

View File

@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None qa_manager = None
inspire_manager = None inspire_manager = None
def lpmm_start_up(): # sourcery skip: extract-duplicate-method def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable: if global_config.lpmm_knowledge.enable:

View File

@@ -32,11 +32,11 @@ install(extra_lines=3)
# 多线程embedding配置常量 # 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小 MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小 MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数 MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数 MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
class EmbeddingStore: class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): def __init__(
self,
namespace: str,
dir_path: str,
max_workers: int = DEFAULT_MAX_WORKERS,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
self.namespace = namespace self.namespace = namespace
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
@@ -103,12 +109,16 @@ class EmbeddingStore:
# 多线程配置参数验证和设置 # 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers)) self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size)) self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
# 如果配置值被调整,记录日志 # 如果配置值被调整,记录日志
if self.max_workers != max_workers: if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})") logger.warning(
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
)
if self.chunk_size != chunk_size: if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})") logger.warning(
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
)
self.store = {} self.store = {}
@@ -120,23 +130,23 @@ class EmbeddingStore:
# 创建新的事件循环并在完成后立即关闭 # 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
# 创建新的LLMRequest实例 # 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法 # 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s)) embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0: if embedding and len(embedding) > 0:
return embedding return embedding
else: else:
logger.error(f"获取嵌入失败: {s}") logger.error(f"获取嵌入失败: {s}")
return [] return []
except Exception as e: except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return [] return []
@@ -147,43 +157,45 @@ class EmbeddingStore:
except Exception: except Exception:
pass pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量 """使用多线程批量获取嵌入向量
Args: Args:
strs: 要获取嵌入的字符串列表 strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小 chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数 max_workers: 最大线程数
progress_callback: 进度回调函数,接收一个参数表示完成的数量 progress_callback: 进度回调函数,接收一个参数表示完成的数量
Returns: Returns:
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
""" """
if not strs: if not strs:
return [] return []
# 分块 # 分块
chunks = [] chunks = []
for i in range(0, len(strs), chunk_size): for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size] chunk = strs[i : i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序 chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序 # 结果存储,使用字典按索引存储以保证顺序
results = {} results = {}
def process_chunk(chunk_data): def process_chunk(chunk_data):
"""处理单个数据块的函数""" """处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data start_idx, chunk_strs = chunk_data
chunk_results = [] chunk_results = []
# 为每个线程创建独立的LLMRequest实例 # 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
try: try:
# 创建线程专用的LLM实例 # 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
for i, s in enumerate(chunk_strs): for i, s in enumerate(chunk_strs):
try: try:
# 在线程中创建独立的事件循环 # 在线程中创建独立的事件循环
@@ -193,25 +205,25 @@ class EmbeddingStore:
embedding = loop.run_until_complete(llm.get_embedding(s)) embedding = loop.run_until_complete(llm.get_embedding(s))
finally: finally:
loop.close() loop.close()
if embedding and len(embedding) > 0: if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else: else:
logger.error(f"获取嵌入失败: {s}") logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, [])) chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度 # 每完成一个嵌入立即更新进度
if progress_callback: if progress_callback:
progress_callback(1) progress_callback(1)
except Exception as e: except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, [])) chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度 # 即使失败也要更新进度
if progress_callback: if progress_callback:
progress_callback(1) progress_callback(1)
except Exception as e: except Exception as e:
logger.error(f"创建LLM实例失败: {e}") logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果 # 如果创建LLM实例失败返回空结果
@@ -220,14 +232,14 @@ class EmbeddingStore:
# 即使失败也要更新进度 # 即使失败也要更新进度
if progress_callback: if progress_callback:
progress_callback(1) progress_callback(1)
return chunk_results return chunk_results
# 使用线程池处理 # 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务 # 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新 # 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk): for future in as_completed(future_to_chunk):
try: try:
@@ -241,7 +253,7 @@ class EmbeddingStore:
start_idx, chunk_strs = chunk start_idx, chunk_strs = chunk
for i, s in enumerate(chunk_strs): for i, s in enumerate(chunk_strs):
results[start_idx + i] = (s, []) results[start_idx + i] = (s, [])
# 按原始顺序返回结果 # 按原始顺序返回结果
ordered_results = [] ordered_results = []
for i in range(len(strs)): for i in range(len(strs)):
@@ -250,7 +262,7 @@ class EmbeddingStore:
else: else:
# 防止遗漏 # 防止遗漏
ordered_results.append((strs[i], [])) ordered_results.append((strs[i], []))
return ordered_results return ordered_results
def get_test_file_path(self): def get_test_file_path(self):
@@ -259,14 +271,14 @@ class EmbeddingStore:
def save_embedding_test_vectors(self): def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地(使用多线程优化)""" """保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...") logger.info("开始保存测试字符串的嵌入向量...")
# 使用多线程批量获取测试字符串的嵌入 # 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
# 构建测试向量字典 # 构建测试向量字典
test_vectors = {} test_vectors = {}
for idx, (s, embedding) in enumerate(embedding_results): for idx, (s, embedding) in enumerate(embedding_results):
@@ -276,10 +288,10 @@ class EmbeddingStore:
logger.error(f"获取测试字符串嵌入失败: {s}") logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备 # 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s) test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f: with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2) json.dump(test_vectors, f, ensure_ascii=False, indent=2)
logger.info("测试字符串嵌入向量保存完成") logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self): def load_embedding_test_vectors(self):
@@ -297,35 +309,35 @@ class EmbeddingStore:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
# 检查本地向量完整性 # 检查本地向量完整性
for idx in range(len(EMBEDDING_TEST_STRINGS)): for idx in range(len(EMBEDDING_TEST_STRINGS)):
if local_vectors.get(str(idx)) is None: if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
logger.info("开始检验嵌入模型一致性...") logger.info("开始检验嵌入模型一致性...")
# 使用多线程批量获取当前模型的嵌入 # 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
# 检查一致性 # 检查一致性
for idx, (s, new_emb) in enumerate(embedding_results): for idx, (s, new_emb) in enumerate(embedding_results):
local_emb = local_vectors.get(str(idx)) local_emb = local_vectors.get(str(idx))
if not new_emb: if not new_emb:
logger.error(f"获取测试字符串嵌入失败: {s}") logger.error(f"获取测试字符串嵌入失败: {s}")
return False return False
sim = cosine_similarity(local_emb, new_emb) sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD: if sim < EMBEDDING_SIM_THRESHOLD:
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}") logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False return False
logger.info("嵌入模型一致性校验通过。") logger.info("嵌入模型一致性校验通过。")
return True return True
@@ -333,22 +345,22 @@ class EmbeddingStore:
"""向库中存入字符串(使用多线程优化)""" """向库中存入字符串(使用多线程优化)"""
if not strs: if not strs:
return return
total = len(strs) total = len(strs)
# 过滤已存在的字符串 # 过滤已存在的字符串
new_strs = [] new_strs = []
for s in strs: for s in strs:
item_hash = self.namespace + "-" + get_sha256(s) item_hash = self.namespace + "-" + get_sha256(s)
if item_hash not in self.store: if item_hash not in self.store:
new_strs.append(s) new_strs.append(s)
if not new_strs: if not new_strs:
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理") logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
return return
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串") logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
@@ -362,31 +374,39 @@ class EmbeddingStore:
transient=False, transient=False,
) as progress: ) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
# 首先更新已存在项的进度 # 首先更新已存在项的进度
already_processed = total - len(new_strs) already_processed = total - len(new_strs)
if already_processed > 0: if already_processed > 0:
progress.update(task, advance=already_processed) progress.update(task, advance=already_processed)
if new_strs: if new_strs:
# 使用实例配置的参数,智能调整分块和线程数 # 使用实例配置的参数,智能调整分块和线程数
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size)) optimal_chunk_size = max(
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1)) MIN_CHUNK_SIZE,
min(
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
),
)
optimal_max_workers = min(
self.max_workers,
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
)
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
# 定义进度更新回调函数 # 定义进度更新回调函数
def update_progress(count): def update_progress(count):
progress.update(task, advance=count) progress.update(task, advance=count)
# 批量获取嵌入,并实时更新进度 # 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
new_strs, new_strs,
chunk_size=optimal_chunk_size, chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers, max_workers=optimal_max_workers,
progress_callback=update_progress progress_callback=update_progress,
) )
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results: for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s) item_hash = self.namespace + "-" + get_sha256(s)
@@ -519,7 +539,7 @@ class EmbeddingManager:
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
""" """
初始化EmbeddingManager 初始化EmbeddingManager
Args: Args:
max_workers: 最大线程数 max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小 chunk_size: 每个线程处理的数据块大小

View File

@@ -426,9 +426,7 @@ class KGManager:
# 获取最终结果 # 获取最终结果
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
for node_key, score in ppr_res.items()
if node_key.startswith("paragraph")
] ]
del ppr_res del ppr_res

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it") raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config from .lpmmconfig import global_config # noqa
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager # noqa
from .llm_client import LLMClient from .llm_client import LLMClient # noqa
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k # noqa
class MemoryActiveManager: class MemoryActiveManager:

View File

@@ -8,7 +8,7 @@ def dyn_select_top_k(
# 检查输入列表是否为空 # 检查输入列表是否为空
if not score: if not score:
return [] return []
# 按照分数排序(降序) # 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True) sorted_score = sorted(score, key=lambda x: x[1], reverse=True)

View File

@@ -18,7 +18,7 @@ class MessageStorage:
if isinstance(keywords, list): if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False) return json.dumps(keywords, ensure_ascii=False)
return "[]" return "[]"
@staticmethod @staticmethod
def _deserialize_keywords(keywords_str: str) -> list: def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表""" """将JSON字符串反序列化为关键词列表"""
@@ -85,7 +85,7 @@ class MessageStorage:
key_words = MessageStorage._serialize_keywords(message.key_words) key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
selected_expressions = "" selected_expressions = ""
chat_info_dict = chat_stream.to_dict() chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore user_info_dict = message.message_info.user_info.to_dict() # type: ignore

View File

@@ -124,4 +124,4 @@ class ActionManager:
"""恢复到默认动作集""" """恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys()) actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions() self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")

View File

@@ -103,25 +103,23 @@ class ActionModifier:
self.action_manager.remove_action_from_using(action_name) self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
# === 第三阶段:激活类型判定 === # === 第三阶段:激活类型判定 ===
# if chat_content is not None: # if chat_content is not None:
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段") # logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# 获取当前使用的动作集(经过第一阶段处理) # 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions() # current_using_actions = self.action_manager.get_using_actions()
# 获取因激活类型判定而需要移除的动作 # 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type( # removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions, # current_using_actions,
# chat_content, # chat_content,
# ) # )
# 应用第三阶段的移除 # 应用第三阶段的移除
# for action_name, reason in removals_s3: # for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name) # self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") # logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# === 统一日志记录 === # === 统一日志记录 ===
all_removals = removals_s1 + removals_s2 all_removals = removals_s1 + removals_s2
@@ -131,9 +129,7 @@ class ActionModifier:
available_actions = list(self.action_manager.get_using_actions().keys()) available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else "" available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug( logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
)
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions: List[Tuple[str, str]] = [] type_mismatched_actions: List[Tuple[str, str]] = []

View File

@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
time_cost_key = f"time_costs_by_{category.split('_')[-1]}" time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}" avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
std_key = f"std_time_costs_by_{category.split('_')[-1]}" std_key = f"std_time_costs_by_{category.split('_')[-1]}"
for item_name in stats[period_key][category]: for item_name in stats[period_key][category]:
time_costs = stats[period_key][time_cost_key].get(item_name, []) time_costs = stats[period_key][time_cost_key].get(item_name, [])
if time_costs: if time_costs:
# 计算平均耗时 # 计算平均耗时
avg_time_cost = sum(time_costs) / len(time_costs) avg_time_cost = sum(time_costs) / len(time_costs)
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3) stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
# 计算标准差 # 计算标准差
if len(time_costs) > 1: if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs) variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5 std_time_cost = variance**0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3) stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else: else:
stats[period_key][std_key][item_name] = 0.0 stats[period_key][std_key][item_name] = 0.0
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
break break
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
""" """
收集各时间段的统计数据 收集各时间段的统计数据
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name] cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name] std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)) output.append(
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
)
output.append("") output.append("")
return "\n".join(output) return "\n".join(output)
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_MODEL]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 按请求类型分类统计 # 按请求类型分类统计
type_rows = "\n".join( type_rows = "\n".join(
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_TYPE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 按模块分类统计 # 按模块分类统计
module_rows = "\n".join( module_rows = "\n".join(
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items()) for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_MODULE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 聊天消息统计 # 聊天消息统计
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
[ [
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>" f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[MSG_CNT_BY_CHAT]
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 生成HTML # 生成HTML
return f""" return f"""

View File

@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
reply_probability = 0.0 reply_probability = 0.0
is_at = False is_at = False
is_mentioned = False is_mentioned = False
# 这部分怎么处理啊啊啊啊 # 这部分怎么处理啊啊啊啊
#我觉得可以给消息加一个 reply_probability_boost字段 # 我觉得可以给消息加一个 reply_probability_boost字段
if ( if (
message.message_info.additional_config is not None message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None and message.message_info.additional_config.get("is_mentioned") is not None
@@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [keywords_str] if keywords_str else [] return [keywords_str] if keywords_str else []
def cut_key_words(concept_name: str) -> list[str]: def cut_key_words(concept_name: str) -> list[str]:
"""对概念名称进行jieba分词并过滤掉关键词列表中的关键词""" """对概念名称进行jieba分词并过滤掉关键词列表中的关键词"""
concept_name_tokens = list(jieba.cut(concept_name)) concept_name_tokens = list(jieba.cut(concept_name))
# 定义常见连词、停用词与标点 # 定义常见连词、停用词与标点
conjunctions = { conjunctions = {"", "", "", "", "以及", "并且", "而且", "", "或者", ""}
"", "", "", "", "以及", "并且", "而且", "", "或者", ""
}
stop_words = { stop_words = {
"", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "而且", "或者", "", "以及" "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"而且",
"或者",
"",
"以及",
} }
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\") chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
@@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]:
left = merged_tokens[-1] left = merged_tokens[-1]
right = cleaned_tokens[i + 1] right = cleaned_tokens[i + 1]
# 左右都需要是有效词 # 左右都需要是有效词
if left and right \ if (
and left not in conjunctions and right not in conjunctions \ left
and left not in stop_words and right not in stop_words \ and right
and not all(ch in chinese_punctuations for ch in left) \ and left not in conjunctions
and not all(ch in chinese_punctuations for ch in right): and right not in conjunctions
and left not in stop_words
and right not in stop_words
and not all(ch in chinese_punctuations for ch in left)
and not all(ch in chinese_punctuations for ch in right)
):
# 合并为一个新词,并替换掉左侧与跳过右侧 # 合并为一个新词,并替换掉左侧与跳过右侧
combined = f"{left}{tok}{right}" combined = f"{left}{tok}{right}"
merged_tokens[-1] = combined merged_tokens[-1] = combined
@@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]:
if tok in stop_words: if tok in stop_words:
continue continue
# if tok in ban_words: # if tok in ban_words:
# continue # continue
if all(ch in chinese_punctuations for ch in tok): if all(ch in chinese_punctuations for ch in tok):
continue continue
if tok.strip() == "": if tok.strip() == "":
@@ -899,4 +932,4 @@ def cut_key_words(concept_name: str) -> list[str]:
result_tokens.append(tok) result_tokens.append(tok)
filtered_concept_name_tokens = result_tokens filtered_concept_name_tokens = result_tokens
return filtered_concept_name_tokens return filtered_concept_name_tokens

View File

@@ -91,9 +91,10 @@ class ImageManager:
desc_obj.save() desc_obj.save()
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str: async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
if isinstance(image_base64, str): if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
@@ -120,6 +121,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述 # 优先使用EmojiManager查询已注册表情包的描述
try: try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash) tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
if tags: if tags:

View File

@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname,
} }
@dataclass(init=False) @dataclass(init=False)
class DatabaseActionRecords(BaseDataModel): class DatabaseActionRecords(BaseDataModel):
def __init__( def __init__(
@@ -232,4 +233,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display self.action_prompt_display = action_prompt_display
self.chat_id = chat_id self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform self.chat_info_platform = chat_info_platform

View File

@@ -2,10 +2,12 @@ from dataclasses import dataclass
from typing import Optional, List, TYPE_CHECKING from typing import Optional, List, TYPE_CHECKING
from . import BaseDataModel from . import BaseDataModel
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.message_data_model import ReplySetModel from src.common.data_models.message_data_model import ReplySetModel
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
@dataclass @dataclass
class LLMGenerationDataModel(BaseDataModel): class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None content: Optional[str] = None

View File

@@ -135,7 +135,7 @@ class Messages(BaseModel):
interest_value = DoubleField(null=True) interest_value = DoubleField(null=True)
key_words = TextField(null=True) key_words = TextField(null=True)
key_words_lite = TextField(null=True) key_words_lite = TextField(null=True)
is_mentioned = BooleanField(null=True) is_mentioned = BooleanField(null=True)
is_at = BooleanField(null=True) is_at = BooleanField(null=True)
reply_probability_boost = DoubleField(null=True) reply_probability_boost = DoubleField(null=True)
@@ -169,7 +169,7 @@ class Messages(BaseModel):
is_picid = BooleanField(default=False) is_picid = BooleanField(default=False)
is_command = BooleanField(default=False) is_command = BooleanField(default=False)
is_notify = BooleanField(default=False) is_notify = BooleanField(default=False)
selected_expressions = TextField(null=True) selected_expressions = TextField(null=True)
class Meta: class Meta:
@@ -267,13 +267,10 @@ class PersonInfo(BaseModel):
know_times = FloatField(null=True) # 认识时间 (时间戳) know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间 know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间 last_know = FloatField(null=True) # 最后一次印象总结时间
attitude_to_me = TextField(null=True) # 对bot的态度 attitude_to_me = TextField(null=True) # 对bot的态度
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度 attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
class Meta: class Meta:
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "person_info" table_name = "person_info"
@@ -299,6 +296,7 @@ class GroupInfo(BaseModel):
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "group_info" table_name = "group_info"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型。 用于存储表达风格的模型。
@@ -315,6 +313,7 @@ class Expression(BaseModel):
class Meta: class Meta:
table_name = "expression" table_name = "expression"
class GraphNodes(BaseModel): class GraphNodes(BaseModel):
""" """
用于存储记忆图节点的模型 用于存储记忆图节点的模型
@@ -374,7 +373,7 @@ def initialize_database(sync_constraints=False):
""" """
检查所有定义的表是否存在,如果不存在则创建它们。 检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则自动添加。 检查所有表的所有字段是否存在,如果缺失则自动添加。
Args: Args:
sync_constraints (bool): 是否同步字段约束。默认为 False。 sync_constraints (bool): 是否同步字段约束。默认为 False。
如果为 True会检查并修复字段的 NULL 约束不一致问题。 如果为 True会检查并修复字段的 NULL 约束不一致问题。
@@ -456,13 +455,13 @@ def initialize_database(sync_constraints=False):
logger.info(f"字段 '{field_name}' 删除成功") logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e: except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}") logger.error(f"删除字段 '{field_name}' 失败: {e}")
# 如果启用了约束同步,执行约束检查和修复 # 如果启用了约束同步,执行约束检查和修复
if sync_constraints: if sync_constraints:
logger.debug("开始同步数据库字段约束...") logger.debug("开始同步数据库字段约束...")
sync_field_constraints() sync_field_constraints()
logger.debug("数据库字段约束同步完成") logger.debug("数据库字段约束同步完成")
except Exception as e: except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}") logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出 # 如果检查失败(例如数据库不可用),则退出
@@ -476,7 +475,7 @@ def sync_field_constraints():
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。 同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
如果发现不一致,会自动修复字段约束。 如果发现不一致,会自动修复字段约束。
""" """
models = [ models = [
ChatStreams, ChatStreams,
LLMUsage, LLMUsage,
@@ -501,50 +500,55 @@ def sync_field_constraints():
continue continue
logger.debug(f"检查表 '{table_name}' 的字段约束...") logger.debug(f"检查表 '{table_name}' 的字段约束...")
# 获取当前表结构信息 # 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} current_schema = {
for row in cursor.fetchall()} row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
# 检查每个模型字段的约束 # 检查每个模型字段的约束
constraints_to_fix = [] constraints_to_fix = []
for field_name, field_obj in model._meta.fields.items(): for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema: if field_name not in current_schema:
continue # 字段不存在,跳过 continue # 字段不存在,跳过
current_notnull = current_schema[field_name]['notnull'] current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null model_allows_null = field_obj.null
# 如果模型允许 null 但数据库字段不允许 null需要修复 # 如果模型允许 null 但数据库字段不允许 null需要修复
if model_allows_null and current_notnull: if model_allows_null and current_notnull:
constraints_to_fix.append({ constraints_to_fix.append(
'field_name': field_name, {
'field_obj': field_obj, "field_name": field_name,
'action': 'allow_null', "field_obj": field_obj,
'current_constraint': 'NOT NULL', "action": "allow_null",
'target_constraint': 'NULL' "current_constraint": "NOT NULL",
}) "target_constraint": "NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL") logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL")
# 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心 # 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心
elif not model_allows_null and not current_notnull: elif not model_allows_null and not current_notnull:
constraints_to_fix.append({ constraints_to_fix.append(
'field_name': field_name, {
'field_obj': field_obj, "field_name": field_name,
'action': 'disallow_null', "field_obj": field_obj,
'current_constraint': 'NULL', "action": "disallow_null",
'target_constraint': 'NOT NULL' "current_constraint": "NULL",
}) "target_constraint": "NOT NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL") logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL")
# 修复约束不一致的字段 # 修复约束不一致的字段
if constraints_to_fix: if constraints_to_fix:
logger.info(f"'{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束") logger.info(f"'{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
_fix_table_constraints(table_name, model, constraints_to_fix) _fix_table_constraints(table_name, model, constraints_to_fix)
else: else:
logger.debug(f"'{table_name}' 的字段约束已同步") logger.debug(f"'{table_name}' 的字段约束已同步")
except Exception as e: except Exception as e:
logger.exception(f"同步字段约束时出错: {e}") logger.exception(f"同步字段约束时出错: {e}")
@@ -557,40 +561,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
try: try:
# 备份表名 # 备份表名
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}" backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
logger.info(f"开始修复表 '{table_name}' 的字段约束...") logger.info(f"开始修复表 '{table_name}' 的字段约束...")
# 1. 创建备份表 # 1. 创建备份表
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}") db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
logger.info(f"已创建备份表 '{backup_table}'") logger.info(f"已创建备份表 '{backup_table}'")
# 2. 删除原表 # 2. 删除原表
db.execute_sql(f"DROP TABLE {table_name}") db.execute_sql(f"DROP TABLE {table_name}")
logger.info(f"已删除原表 '{table_name}'") logger.info(f"已删除原表 '{table_name}'")
# 3. 重新创建表(使用当前模型定义) # 3. 重新创建表(使用当前模型定义)
db.create_tables([model]) db.create_tables([model])
logger.info(f"已重新创建表 '{table_name}' 使用新的约束") logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
# 4. 从备份表恢复数据 # 4. 从备份表恢复数据
# 获取字段列表 # 获取字段列表
fields = list(model._meta.fields.keys()) fields = list(model._meta.fields.keys())
fields_str = ', '.join(fields) fields_str = ", ".join(fields)
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据 # 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值 # 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}" insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
# 检查是否有字段需要从 NULL 改为 NOT NULL # 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [ null_to_notnull_fields = [
constraint['field_name'] for constraint in constraints_to_fix constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
if constraint['action'] == 'disallow_null'
] ]
if null_to_notnull_fields: if null_to_notnull_fields:
# 需要处理 NULL 值,为这些字段设置默认值 # 需要处理 NULL 值,为这些字段设置默认值
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL需要处理现有的NULL值") logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL需要处理现有的NULL值")
# 构建更复杂的 SELECT 语句来处理 NULL 值 # 构建更复杂的 SELECT 语句来处理 NULL 值
select_fields = [] select_fields = []
for field_name in fields: for field_name in fields:
@@ -607,21 +610,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
default_value = f"'{datetime.datetime.now()}'" default_value = f"'{datetime.datetime.now()}'"
else: else:
default_value = "''" default_value = "''"
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}") select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
else: else:
select_fields.append(field_name) select_fields.append(field_name)
select_str = ', '.join(select_fields) select_str = ", ".join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}" insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
db.execute_sql(insert_sql) db.execute_sql(insert_sql)
logger.info(f"已从备份表恢复数据到 '{table_name}'") logger.info(f"已从备份表恢复数据到 '{table_name}'")
# 5. 验证数据完整性 # 5. 验证数据完整性
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0] original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0] new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
if original_count == new_count: if original_count == new_count:
logger.info(f"数据完整性验证通过: {original_count} 行数据") logger.info(f"数据完整性验证通过: {original_count} 行数据")
# 删除备份表 # 删除备份表
@@ -630,12 +633,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
else: else:
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count}") logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count}")
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查") logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
# 记录修复的约束 # 记录修复的约束
for constraint in constraints_to_fix: for constraint in constraints_to_fix:
logger.info(f"已修复字段 '{constraint['field_name']}': " logger.info(
f"{constraint['current_constraint']} -> {constraint['target_constraint']}") f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
)
except Exception as e: except Exception as e:
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}") logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
# 尝试恢复 # 尝试恢复
@@ -654,7 +659,7 @@ def check_field_constraints():
检查但不修复字段约束,返回不一致的字段信息。 检查但不修复字段约束,返回不一致的字段信息。
用于在修复前预览需要修复的内容。 用于在修复前预览需要修复的内容。
""" """
models = [ models = [
ChatStreams, ChatStreams,
LLMUsage, LLMUsage,
@@ -669,9 +674,9 @@ def check_field_constraints():
GraphEdges, GraphEdges,
ActionRecords, ActionRecords,
] ]
inconsistencies = {} inconsistencies = {}
try: try:
with db: with db:
for model in models: for model in models:
@@ -681,49 +686,49 @@ def check_field_constraints():
# 获取当前表结构信息 # 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} current_schema = {
for row in cursor.fetchall()} row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
table_inconsistencies = [] table_inconsistencies = []
# 检查每个模型字段的约束 # 检查每个模型字段的约束
for field_name, field_obj in model._meta.fields.items(): for field_name, field_obj in model._meta.fields.items():
if field_name not in current_schema: if field_name not in current_schema:
continue continue
current_notnull = current_schema[field_name]['notnull'] current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null model_allows_null = field_obj.null
if model_allows_null and current_notnull: if model_allows_null and current_notnull:
table_inconsistencies.append({ table_inconsistencies.append(
'field_name': field_name, {
'issue': 'model_allows_null_but_db_not_null', "field_name": field_name,
'model_constraint': 'NULL', "issue": "model_allows_null_but_db_not_null",
'db_constraint': 'NOT NULL', "model_constraint": "NULL",
'recommended_action': 'allow_null' "db_constraint": "NOT NULL",
}) "recommended_action": "allow_null",
}
)
elif not model_allows_null and not current_notnull: elif not model_allows_null and not current_notnull:
table_inconsistencies.append({ table_inconsistencies.append(
'field_name': field_name, {
'issue': 'model_not_null_but_db_allows_null', "field_name": field_name,
'model_constraint': 'NOT NULL', "issue": "model_not_null_but_db_allows_null",
'db_constraint': 'NULL', "model_constraint": "NOT NULL",
'recommended_action': 'disallow_null' "db_constraint": "NULL",
}) "recommended_action": "disallow_null",
}
)
if table_inconsistencies: if table_inconsistencies:
inconsistencies[table_name] = table_inconsistencies inconsistencies[table_name] = table_inconsistencies
except Exception as e: except Exception as e:
logger.exception(f"检查字段约束时出错: {e}") logger.exception(f"检查字段约束时出错: {e}")
return inconsistencies
return inconsistencies
# 模块加载时调用初始化函数 # 模块加载时调用初始化函数
initialize_database(sync_constraints=True) initialize_database(sync_constraints=True)

View File

@@ -339,24 +339,18 @@ MODULE_COLORS = {
# 67 具体的颜色编号0-255这里是较暗的蓝色 # 67 具体的颜色编号0-255这里是较暗的蓝色
"sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志 "sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志
"send_api": "\033[38;5;24m", # 208号色橙色适合突出显示 "send_api": "\033[38;5;24m", # 208号色橙色适合突出显示
# 生成 # 生成
"replyer": "\033[38;5;208m", # 橙色 "replyer": "\033[38;5;208m", # 橙色
"llm_api": "\033[38;5;208m", # 橙色 "llm_api": "\033[38;5;208m", # 橙色
# 消息处理 # 消息处理
"chat": "\033[38;5;82m", # 亮蓝色 "chat": "\033[38;5;82m", # 亮蓝色
"chat_image": "\033[38;5;68m", # 浅蓝色 "chat_image": "\033[38;5;68m", # 浅蓝色
# emoji
#emoji
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色 "emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色 "emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
# 核心模块 # 核心模块
"main": "\033[1;97m", # 亮白色+粗体 (主程序) "main": "\033[1;97m", # 亮白色+粗体 (主程序)
"memory": "\033[38;5;34m", # 天蓝色 "memory": "\033[38;5;34m", # 天蓝色
"config": "\033[93m", # 亮黄色 "config": "\033[93m", # 亮黄色
"common": "\033[95m", # 亮紫色 "common": "\033[95m", # 亮紫色
"tools": "\033[96m", # 亮青色 "tools": "\033[96m", # 亮青色
@@ -367,9 +361,6 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色 "llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼 "remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m", "planner": "\033[36m",
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼 "relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块 # 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色 "normal_chat": "\033[38;5;81m", # 亮蓝绿色
@@ -379,11 +370,9 @@ MODULE_COLORS = {
"background_tasks": "\033[38;5;240m", # 灰色 "background_tasks": "\033[38;5;240m", # 灰色
"chat_message": "\033[38;5;45m", # 青色 "chat_message": "\033[38;5;45m", # 青色
"chat_stream": "\033[38;5;51m", # 亮青色 "chat_stream": "\033[38;5;51m", # 亮青色
"message_storage": "\033[38;5;33m", # 深蓝色 "message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色 "expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块 # 专注聊天模块
"memory_activator": "\033[38;5;117m", # 天蓝色 "memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统 # 插件系统
"plugins": "\033[31m", # 红色 "plugins": "\033[31m", # 红色
@@ -412,7 +401,6 @@ MODULE_COLORS = {
# 工具和实用模块 # 工具和实用模块
"prompt_build": "\033[38;5;105m", # 紫色 "prompt_build": "\033[38;5;105m", # 紫色
"chat_utils": "\033[38;5;111m", # 蓝色 "chat_utils": "\033[38;5;111m", # 蓝色
"maibot_statistic": "\033[38;5;129m", # 紫色 "maibot_statistic": "\033[38;5;129m", # 紫色
# 特殊功能插件 # 特殊功能插件
"mute_plugin": "\033[38;5;240m", # 灰色 "mute_plugin": "\033[38;5;240m", # 灰色
@@ -447,10 +435,8 @@ MODULE_ALIASES = {
"llm_api": "生成API", "llm_api": "生成API",
"emoji": "表情包", "emoji": "表情包",
"emoji_api": "表情包API", "emoji_api": "表情包API",
"chat": "所见", "chat": "所见",
"chat_image": "识图", "chat_image": "识图",
"action_manager": "动作", "action_manager": "动作",
"memory_activator": "记忆", "memory_activator": "记忆",
"tool_use": "工具", "tool_use": "工具",
@@ -460,7 +446,6 @@ MODULE_ALIASES = {
"memory": "记忆", "memory": "记忆",
"tool_executor": "工具", "tool_executor": "工具",
"hfc": "聊天节奏", "hfc": "聊天节奏",
"plugin_manager": "插件", "plugin_manager": "插件",
"relationship_builder": "关系", "relationship_builder": "关系",
"llm_models": "模型", "llm_models": "模型",

View File

@@ -114,7 +114,7 @@ def set_value_by_path(d, path, value):
if k not in d or not isinstance(d[k], dict): if k not in d or not isinstance(d[k], dict):
d[k] = {} d[k] = {}
d = d[k] d = d[k]
# 使用 tomlkit.item 来保持 TOML 格式 # 使用 tomlkit.item 来保持 TOML 格式
try: try:
d[path[-1]] = tomlkit.item(value) d[path[-1]] = tomlkit.item(value)
@@ -253,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str):
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
) )
config_updated = True config_updated = True
# 如果配置有更新,立即保存到文件 # 如果配置有更新,立即保存到文件
if config_updated: if config_updated:
with open(old_config_path, "w", encoding="utf-8") as f: with open(old_config_path, "w", encoding="utf-8") as f:

View File

@@ -43,10 +43,11 @@ class PersonalityConfig(ConfigBase):
reply_style: str = "" reply_style: str = ""
"""表达风格""" """表达风格"""
interest: str = "" interest: str = ""
"""兴趣""" """兴趣"""
@dataclass @dataclass
class RelationshipConfig(ConfigBase): class RelationshipConfig(ConfigBase):
"""关系配置类""" """关系配置类"""
@@ -61,31 +62,30 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18 max_context_size: int = 18
"""上下文长度""" """上下文长度"""
interest_rate_mode: Literal["fast", "accurate"] = "fast" interest_rate_mode: Literal["fast", "accurate"] = "fast"
"""兴趣值计算模式fast为快速计算accurate为精确计算""" """兴趣值计算模式fast为快速计算accurate为精确计算"""
mentioned_bot_reply: float = 1 mentioned_bot_reply: float = 1
"""提及 bot 必然回复1为100%回复0为不额外增幅""" """提及 bot 必然回复1为100%回复0为不额外增幅"""
planner_size: float = 1.5 planner_size: float = 1.5
"""副规划器大小越小麦麦的动作执行能力越精细但是消耗更多token调大可以缓解429类错误""" """副规划器大小越小麦麦的动作执行能力越精细但是消耗更多token调大可以缓解429类错误"""
at_bot_inevitable_reply: float = 1 at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅""" """@bot 必然回复1为100%回复0为不额外增幅"""
talk_frequency: float = 0.5 talk_frequency: float = 0.5
"""回复频率阈值""" """回复频率阈值"""
# 合并后的时段频率配置 # 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: []) talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
focus_value: float = 0.5 focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多""" """麦麦的专注思考能力越低越容易专注消耗token也越多"""
focus_value_adjust: list[list[str]] = field(default_factory=lambda: []) focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
""" """
统一的活跃度和专注度配置 统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...] 格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
@@ -110,7 +110,6 @@ class ChatConfig(ConfigBase):
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁 - talk_frequency_adjust 控制回复频率,数值越高回复越频繁
- focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多 - focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多
""" """
@dataclass @dataclass
@@ -123,6 +122,7 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set()) ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表""" """过滤正则表达式列表"""
@dataclass @dataclass
class ExpressionConfig(ConfigBase): class ExpressionConfig(ConfigBase):
"""表达配置类""" """表达配置类"""

View File

@@ -174,7 +174,7 @@ class ClientRegistry:
return client_class(api_provider) return client_class(api_provider)
else: else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑 # 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache: if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type): if client_class := self.client_registry.get(api_provider.client_type):

View File

@@ -531,7 +531,7 @@ class OpenaiClient(BaseClient):
# 添加详细的错误信息以便调试 # 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}") logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}") logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__: if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}") logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e raise NetworkConnectionError() from e
except APIStatusError as e: except APIStatusError as e:

View File

@@ -1,3 +1,3 @@
from .tool_option import ToolCall from .tool_option import ToolCall
__all__ = ["ToolCall"] __all__ = ["ToolCall"]

View File

@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
elif not isinstance(instance["name"], str) or instance["name"].strip() == "": elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串" return "schema的'name'字段必须是非空字符串"
if "description" in instance and ( if "description" in instance and (
not isinstance(instance["description"], str) not isinstance(instance["description"], str) or instance["description"].strip() == ""
or instance["description"].strip() == ""
): ):
return "schema的'description'字段只能填入非空字符串" return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance: if "schema" not in instance:
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素 # 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)): for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict): if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive( sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
f"{path}/{str(i)}", sub_schema[i], defs
)
else: else:
# 否则为字典 # 否则为字典
if "$defs" in sub_schema: if "$defs" in sub_schema:
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
for key, value in sub_schema.items(): for key, value in sub_schema.items():
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用 # 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive( sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
f"{path}/{key}", value, defs
)
return sub_schema return sub_schema
@@ -163,9 +158,7 @@ class RespFormat:
def _generate_schema_from_model(schema): def _generate_schema_from_model(schema):
json_schema = { json_schema = {
"name": schema.__name__, "name": schema.__name__,
"schema": _remove_defs( "schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
_link_definitions(_remove_title(schema.model_json_schema()))
),
"strict": False, "strict": False,
} }
if schema.__doc__: if schema.__doc__:

View File

@@ -155,7 +155,13 @@ class LLMUsageRecorder:
logger.error(f"创建 LLMUsage 表失败: {str(e)}") logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database( def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0 self,
model_info: ModelInfo,
model_usage: UsageRecord,
user_id: str,
request_type: str,
endpoint: str,
time_cost: float = 0.0,
): ):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
completion_tokens=model_usage.completion_tokens or 0, completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0, total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0, cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3), time_cost=round(time_cost or 0.0, 3),
status="success", status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField timestamp=datetime.now(), # Peewee 会处理 DateTimeField
) )
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
except Exception as e: except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}") logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -14,31 +14,31 @@ logger = get_logger("context_web")
class ContextMessage: class ContextMessage:
"""上下文消息类""" """上下文消息类"""
def __init__(self, message: MessageRecv): def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text self.content = message.processed_plain_text
self.timestamp = datetime.now() self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊" self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型 # 识别消息类型
self.is_gift = getattr(message, 'is_gift', False) self.is_gift = getattr(message, "is_gift", False)
self.is_superchat = getattr(message, 'is_superchat', False) self.is_superchat = getattr(message, "is_superchat", False)
# 添加礼物和SC相关信息 # 添加礼物和SC相关信息
if self.is_gift: if self.is_gift:
self.gift_name = getattr(message, 'gift_name', '') self.gift_name = getattr(message, "gift_name", "")
self.gift_count = getattr(message, 'gift_count', '1') self.gift_count = getattr(message, "gift_count", "1")
self.content = f"送出了 {self.gift_name} x{self.gift_count}" self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat: elif self.is_superchat:
self.superchat_price = getattr(message, 'superchat_price', '0') self.superchat_price = getattr(message, "superchat_price", "0")
self.superchat_message = getattr(message, 'superchat_message_text', '') self.superchat_message = getattr(message, "superchat_message_text", "")
if self.superchat_message: if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}" self.content = f"{self.superchat_price}] {self.superchat_message}"
else: else:
self.content = f"{self.superchat_price}] {self.content}" self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self): def to_dict(self):
return { return {
"user_name": self.user_name, "user_name": self.user_name,
@@ -47,13 +47,13 @@ class ContextMessage:
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name, "group_name": self.group_name,
"is_gift": self.is_gift, "is_gift": self.is_gift,
"is_superchat": self.is_superchat "is_superchat": self.is_superchat,
} }
class ContextWebManager: class ContextWebManager:
"""上下文网页管理器""" """上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765): def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages self.max_messages = max_messages
self.port = port self.port = port
@@ -63,53 +63,53 @@ class ContextWebManager:
self.runner = None self.runner = None
self.site = None self.site = None
self._server_starting = False # 添加启动标志防止并发 self._server_starting = False # 添加启动标志防止并发
async def start_server(self): async def start_server(self):
"""启动web服务器""" """启动web服务器"""
if self.site is not None: if self.site is not None:
logger.debug("Web服务器已经启动跳过重复启动") logger.debug("Web服务器已经启动跳过重复启动")
return return
if self._server_starting: if self._server_starting:
logger.debug("Web服务器正在启动中等待启动完成...") logger.debug("Web服务器正在启动中等待启动完成...")
# 等待启动完成 # 等待启动完成
while self._server_starting and self.site is None: while self._server_starting and self.site is None:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return return
self._server_starting = True self._server_starting = True
try: try:
self.app = web.Application() self.app = web.Application()
# 设置CORS # 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={ cors = aiohttp_cors.setup(
"*": aiohttp_cors.ResourceOptions( self.app,
allow_credentials=True, defaults={
expose_headers="*", "*": aiohttp_cors.ResourceOptions(
allow_headers="*", allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
allow_methods="*" )
) },
}) )
# 添加路由 # 添加路由
self.app.router.add_get('/', self.index_handler) self.app.router.add_get("/", self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler) self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler) self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler) self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS # 为所有路由添加CORS
for route in list(self.app.router.routes()): for route in list(self.app.router.routes()):
cors.add(route) cors.add(route)
self.runner = web.AppRunner(self.app) self.runner = web.AppRunner(self.app)
await self.runner.setup() await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port) self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start() await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}") logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e: except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}") logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源 # 清理部分启动的资源
@@ -121,7 +121,7 @@ class ContextWebManager:
raise raise
finally: finally:
self._server_starting = False self._server_starting = False
async def stop_server(self): async def stop_server(self):
"""停止web服务器""" """停止web服务器"""
if self.site: if self.site:
@@ -132,10 +132,11 @@ class ContextWebManager:
self.runner = None self.runner = None
self.site = None self.site = None
self._server_starting = False self._server_starting = False
async def index_handler(self, request): async def index_handler(self, request):
"""主页处理器""" """主页处理器"""
html_content = ''' html_content = (
"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -286,7 +287,9 @@ class ContextWebManager:
function connectWebSocket() { function connectWebSocket() {
console.log('正在连接WebSocket...'); console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws'); ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() { ws.onopen = function() {
console.log('WebSocket连接已建立'); console.log('WebSocket连接已建立');
@@ -470,47 +473,48 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') )
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request): async def websocket_handler(self, request):
"""WebSocket处理器""" """WebSocket处理器"""
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
self.websockets.append(ws) self.websockets.append(ws)
logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}") logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}")
# 发送初始数据 # 发送初始数据
await self.send_contexts_to_websocket(ws) await self.send_contexts_to_websocket(ws)
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.ERROR: if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}') logger.error(f"WebSocket错误: {ws.exception()}")
break break
# 清理断开的连接 # 清理断开的连接
if ws in self.websockets: if ws in self.websockets:
self.websockets.remove(ws) self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}") logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}")
return ws return ws
async def get_contexts_handler(self, request): async def get_contexts_handler(self, request):
"""获取上下文API""" """获取上下文API"""
all_context_msgs = [] all_context_msgs = []
for _chat_id, contexts in self.contexts.items(): for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts)) all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后 # 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息") logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data}) return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request): async def debug_handler(self, request):
"""调试信息处理器""" """调试信息处理器"""
debug_info = { debug_info = {
@@ -519,7 +523,7 @@ class ContextWebManager:
"total_chats": len(self.contexts), "total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()), "total_messages": sum(len(contexts) for contexts in self.contexts.values()),
} }
# 构建聊天详情HTML # 构建聊天详情HTML
chats_html = "" chats_html = ""
for chat_id, contexts in self.contexts.items(): for chat_id, contexts in self.contexts.items():
@@ -528,15 +532,15 @@ class ContextWebManager:
timestamp = msg.timestamp.strftime("%H:%M:%S") timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>' messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f''' chats_html += f"""
<div class="chat"> <div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3> <h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html} {messages_html}
</div> </div>
''' """
html_content = f''' html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -578,74 +582,78 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv): async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文""" """添加新消息到上下文"""
if chat_id not in self.contexts: if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages) self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列") logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message) context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg) self.contexts[chat_id].append(context_msg)
# 统计当前总消息数 # 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values()) total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}") logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息 # 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:") logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items(): for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息") logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts): for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...") logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接 # 广播更新给所有WebSocket连接
await self.broadcast_contexts() await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse): async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据""" """向单个WebSocket发送上下文数据"""
all_context_msgs = [] all_context_msgs = []
for _chat_id, contexts in self.contexts.items(): for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts)) all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后 # 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False)) await ws.send_str(json.dumps(data, ensure_ascii=False))
async def broadcast_contexts(self): async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新""" """向所有WebSocket连接广播上下文更新"""
if not self.websockets: if not self.websockets:
logger.debug("没有WebSocket连接跳过广播") logger.debug("没有WebSocket连接跳过广播")
return return
all_context_msgs = [] all_context_msgs = []
for _chat_id, contexts in self.contexts.items(): for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts)) all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后 # 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False) message = json.dumps(data, ensure_ascii=False)
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接") logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本避免在遍历时修改 # 创建WebSocket列表的副本避免在遍历时修改
websockets_copy = self.websockets.copy() websockets_copy = self.websockets.copy()
removed_count = 0 removed_count = 0
for ws in websockets_copy: for ws in websockets_copy:
if ws.closed: if ws.closed:
if ws in self.websockets: if ws in self.websockets:
@@ -660,7 +668,7 @@ class ContextWebManager:
if ws in self.websockets: if ws in self.websockets:
self.websockets.remove(ws) self.websockets.remove(ws)
removed_count += 1 removed_count += 1
if removed_count > 0: if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接") logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
@@ -681,5 +689,4 @@ async def init_context_web_manager():
"""初始化上下文网页管理器""" """初始化上下文网页管理器"""
manager = get_context_web_manager() manager = get_context_web_manager()
await manager.start_server() await manager.start_server()
return manager return manager

View File

@@ -11,6 +11,7 @@ logger = get_logger("gift_manager")
@dataclass @dataclass
class PendingGift: class PendingGift:
"""等待中的礼物消息""" """等待中的礼物消息"""
message: MessageRecvS4U message: MessageRecvS4U
total_count: int total_count: int
timer_task: asyncio.Task timer_task: asyncio.Task
@@ -19,71 +20,68 @@ class PendingGift:
class GiftManager: class GiftManager:
"""礼物管理器,提供防抖功能""" """礼物管理器,提供防抖功能"""
def __init__(self): def __init__(self):
"""初始化礼物管理器""" """初始化礼物管理器"""
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {} self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间 self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool: async def handle_gift(
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
) -> bool:
"""处理礼物消息,返回是否应该立即处理 """处理礼物消息,返回是否应该立即处理
Args: Args:
message: 礼物消息 message: 礼物消息
callback: 防抖完成后的回调函数 callback: 防抖完成后的回调函数
Returns: Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理 bool: False表示消息被暂存等待防抖True表示应该立即处理
""" """
if not message.is_gift: if not message.is_gift:
return True return True
# 构建礼物的唯一键:(发送人ID, 礼物名称) # 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name) gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并 # 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts: if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message) await self._merge_gift(gift_key, message)
return False return False
# 创建新的等待礼物 # 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback) await self._create_pending_gift(gift_key, message, callback)
return False return False
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None: async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息""" """合并礼物消息"""
pending_gift = self.pending_gifts[gift_key] pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器 # 取消之前的定时器
if not pending_gift.timer_task.cancelled(): if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel() pending_gift.timer_task.cancel()
# 累加礼物数量 # 累加礼物数量
try: try:
new_count = int(new_message.gift_count) new_count = int(new_message.gift_count)
pending_gift.total_count += new_count pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量) # 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count) pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}" pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError: except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}") logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变 # 如果无法解析数量,保持原有数量不变
# 重新创建定时器 # 重新创建定时器
pending_gift.timer_task = asyncio.create_task( pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
self._gift_timeout(gift_key)
)
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift( async def _create_pending_gift(
self, self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
gift_key: Tuple[str, str],
message: MessageRecvS4U,
callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None: ) -> None:
"""创建新的等待礼物""" """创建新的等待礼物"""
try: try:
@@ -91,56 +89,51 @@ class GiftManager:
except ValueError: except ValueError:
initial_count = 1 initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1") logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务 # 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key)) timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象 # 创建等待礼物对象
pending_gift = PendingGift( pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
message=message,
total_count=initial_count,
timer_task=timer_task,
callback=callback
)
self.pending_gifts[gift_key] = pending_gift self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}") logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None: async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
"""礼物防抖超时处理""" """礼物防抖超时处理"""
try: try:
# 等待防抖时间 # 等待防抖时间
await asyncio.sleep(self.debounce_timeout) await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物 # 获取等待中的礼物
if gift_key not in self.pending_gifts: if gift_key not in self.pending_gifts:
return return
pending_gift = self.pending_gifts.pop(gift_key) pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}") logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}" message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调 # 执行回调
if pending_gift.callback: if pending_gift.callback:
try: try:
pending_gift.callback(message) pending_gift.callback(message)
except Exception as e: except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True) logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError: except asyncio.CancelledError:
# 定时器被取消,不需要处理 # 定时器被取消,不需要处理
pass pass
except Exception as e: except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True) logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int: def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量""" """获取当前等待中的礼物数量"""
return len(self.pending_gifts) return len(self.pending_gifts)
async def flush_all(self) -> None: async def flush_all(self) -> None:
"""立即处理所有等待中的礼物""" """立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()): for gift_key in list(self.pending_gifts.keys()):
@@ -152,4 +145,3 @@ class GiftManager:
# 创建全局礼物管理器实例 # 创建全局礼物管理器实例
gift_manager = GiftManager() gift_manager = GiftManager()

View File

@@ -1,14 +1,15 @@
class InternalManager: class InternalManager:
def __init__(self): def __init__(self):
self.now_internal_state = str() self.now_internal_state = str()
def set_internal_state(self,internal_state:str): def set_internal_state(self, internal_state: str):
self.now_internal_state = internal_state self.now_internal_state = internal_state
def get_internal_state(self): def get_internal_state(self):
return self.now_internal_state return self.now_internal_state
def get_internal_state_str(self): def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}" return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager()
internal_manager = InternalManager()

View File

@@ -16,7 +16,6 @@ import json
from .s4u_mood_manager import mood_manager from .s4u_mood_manager import mood_manager
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import get_person_id from src.person_info.person_info import get_person_id
from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -33,15 +32,12 @@ class MessageSenderContainer:
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
self._paused_event = asyncio.Event() self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态 self._paused_event.set() # 默认设置为非暂停状态
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
async def add_message(self, chunk: str): async def add_message(self, chunk: str):
"""向队列中添加一个消息块。""" """向队列中添加一个消息块。"""
@@ -131,7 +127,7 @@ class MessageSenderContainer:
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}", reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
) )
await bot_message.process() await bot_message.process()
await self.storage.store_message(bot_message, self.chat_stream) await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e: except Exception as e:
@@ -198,12 +194,12 @@ class S4UChat:
self.gpt = S4UStreamGenerator() self.gpt = S4UStreamGenerator()
self.gpt.chat_stream = self.chat_stream self.gpt.chat_stream = self.chat_stream
self.interest_dict: Dict[str, float] = {} # 用户兴趣分 self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.internal_message :List[MessageRecvS4U] = [] self.internal_message: List[MessageRecvS4U] = []
self.msg_id = "" self.msg_id = ""
self.voice_done = "" self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.") logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
def _get_priority_info(self, message: MessageRecv) -> dict: def _get_priority_info(self, message: MessageRecv) -> dict:
@@ -226,7 +222,7 @@ class S4UChat:
def _get_interest_score(self, user_id: str) -> float: def _get_interest_score(self, user_id: str) -> float:
"""获取用户的兴趣分默认为1.0""" """获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0) return self.interest_dict.get(user_id, 1.0)
def go_processing(self): def go_processing(self):
if self.voice_done == self.last_msg_id: if self.voice_done == self.last_msg_id:
return True return True
@@ -237,14 +233,14 @@ class S4UChat:
为消息计算基础优先级分数。分数越高,优先级越高。 为消息计算基础优先级分数。分数越高,优先级越高。
""" """
score = 0.0 score = 0.0
# 加上消息自带的优先级 # 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0) score += priority_info.get("message_priority", 0.0)
# 加上用户的固有兴趣分 # 加上用户的固有兴趣分
score += self._get_interest_score(message.message_info.user_info.user_id) score += self._get_interest_score(message.message_info.user_info.user_id)
return score return score
def decay_interest_score(self): def decay_interest_score(self):
for person_id, score in self.interest_dict.items(): for person_id, score in self.interest_dict.items():
if score > 0: if score > 0:
@@ -252,15 +248,14 @@ class S4UChat:
else: else:
self.interest_dict[person_id] = 0 self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None: async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
self.decay_interest_score() self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。""" """根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id
platform = message.message_info.platform platform = message.message_info.platform
person_id = get_person_id(platform, user_id) _person_id = get_person_id(platform, user_id)
# try: # try:
# is_gift = message.is_gift # is_gift = message.is_gift
# is_superchat = message.is_superchat # is_superchat = message.is_superchat
@@ -276,7 +271,7 @@ class S4UChat:
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0 # # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0) # current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price) # self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# # 添加SuperChat到管理器 # # 添加SuperChat到管理器
# super_chat_manager = get_super_chat_manager() # super_chat_manager = get_super_chat_manager()
# await super_chat_manager.add_superchat(message) # await super_chat_manager.add_superchat(message)
@@ -284,16 +279,19 @@ class S4UChat:
# await self.relationship_builder.build_relation(20) # await self.relationship_builder.build_relation(20)
# except Exception: # except Exception:
# traceback.print_exc() # traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}") logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message) priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info) is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info) new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False should_interrupt = False
if (s4u_config.enable_message_interruption and if (
self._current_generation_task and not self._current_generation_task.done()): s4u_config.enable_message_interruption
and self._current_generation_task
and not self._current_generation_task.done()
):
if self._current_message_being_replied: if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied current_queue, current_priority, _, current_msg = self._current_message_being_replied
@@ -344,39 +342,45 @@ class S4UChat:
"""清理普通队列中不在最近N条消息范围内的消息""" """清理普通队列中不在最近N条消息范围内的消息"""
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty(): if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return return
# 计算阈值:保留最近 recent_message_keep_count 条消息 # 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count) cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息 # 临时存储需要保留的消息
temp_messages = [] temp_messages = []
removed_count = 0 removed_count = 0
# 取出所有普通队列中的消息 # 取出所有普通队列中的消息
while not self._normal_queue.empty(): while not self._normal_queue.empty():
try: try:
item = self._normal_queue.get_nowait() item = self._normal_queue.get_nowait()
neg_priority, entry_count, timestamp, message = item neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它 # 如果消息在最近N条消息范围内保留它
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}") logger.info(
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
)
if entry_count >= cutoff_counter: if entry_count >= cutoff_counter:
temp_messages.append(item) temp_messages.append(item)
else: else:
removed_count += 1 removed_count += 1
self._normal_queue.task_done() # 标记被移除的任务为完成 self._normal_queue.task_done() # 标记被移除的任务为完成
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
break break
# 将保留的消息重新放入队列 # 将保留的消息重新放入队列
for item in temp_messages: for item in temp_messages:
self._normal_queue.put_nowait(item) self._normal_queue.put_nowait(item)
if removed_count > 0: if removed_count > 0:
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除") logger.info(
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.") f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除"
)
logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
)
async def _message_processor(self): async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。""" """调度器优先处理VIP队列然后处理普通队列。"""
@@ -385,7 +389,7 @@ class S4UChat:
# 等待有新消息的信号,避免空转 # 等待有新消息的信号,避免空转
await self._new_message_event.wait() await self._new_message_event.wait()
self._new_message_event.clear() self._new_message_event.clear()
# 清理普通队列中的过旧消息 # 清理普通队列中的过旧消息
self._cleanup_old_normal_messages() self._cleanup_old_normal_messages()
@@ -396,7 +400,6 @@ class S4UChat:
queue_name = "vip" queue_name = "vip"
# 其次处理普通队列 # 其次处理普通队列
elif not self._normal_queue.empty(): elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait() neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority priority = -neg_priority
# 检查普通消息是否超时 # 检查普通消息是否超时
@@ -411,13 +414,15 @@ class S4UChat:
if self.internal_message: if self.internal_message:
message = self.internal_message[-1] message = self.internal_message[-1]
self.internal_message = [] self.internal_message = []
priority = 0 priority = 0
neg_priority = 0 neg_priority = 0
entry_count = 0 entry_count = 0
queue_name = "internal" queue_name = "internal"
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...") logger.info(
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
)
else: else:
continue # 没有消息了,回去等事件 continue # 没有消息了,回去等事件
@@ -457,23 +462,21 @@ class S4UChat:
except Exception as e: except Exception as e:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True) logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1) await asyncio.sleep(1)
def get_processing_message_id(self): def get_processing_message_id(self):
self.last_msg_id = self.msg_id self.last_msg_id = self.msg_id
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}" self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv): async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。""" """为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数 total_chars_sent = 0 # 跟踪发送的总字符数
self.get_processing_message_id() self.get_processing_message_id()
# 视线管理:开始生成回复时切换视线状态 # 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
if message.is_internal: if message.is_internal:
await chat_watching.on_internal_message_start() await chat_watching.on_internal_message_start()
else: else:
@@ -516,16 +519,19 @@ class S4UChat:
total_chars_sent = len("麦麦不知道哦") total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id) mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id) await yes_or_no_head(
text=total_chars_sent,
emotion=mood.mood_state,
chat_history=message.processed_plain_text,
chat_id=self.stream_id,
)
# 等待所有文本消息发送完成 # 等待所有文本消息发送完成
await sender_container.close() await sender_container.close()
await sender_container.join() await sender_container.join()
await chat_watching.on_thinking_finished() await chat_watching.on_thinking_finished()
start_time = time.time() start_time = time.time()
logged = False logged = False
while not self.go_processing(): while not self.go_processing():
@@ -536,7 +542,7 @@ class S4UChat:
logger.info(f"[{self.stream_name}] 等待消息发送完成...") logger.info(f"[{self.stream_name}] 等待消息发送完成...")
logged = True logged = True
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。") logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError: except asyncio.CancelledError:
@@ -548,11 +554,11 @@ class S4UChat:
# 回复生成实时展示:清空内容(出错时) # 回复生成实时展示:清空内容(出错时)
finally: finally:
self._is_replying = False self._is_replying = False
# 视线管理:回复结束时切换视线状态 # 视线管理:回复结束时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_finished() await chat_watching.on_reply_finished()
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的) # 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
sender_container.resume() sender_container.resume()
if not sender_container._task.done(): if not sender_container._task.done():
@@ -576,4 +582,3 @@ class S4UChat:
await self._processing_task await self._processing_task
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}") logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
if global_config.memory.enable_memory: if global_config.memory.enable_memory:
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text( interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
fast_retrieval=True, fast_retrieval=True,
) )
@@ -49,7 +49,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
text_len = len(message.processed_plain_text) text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) # 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0: if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度 base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5: elif text_len <= 5:
@@ -73,7 +73,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
else: else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减 # 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内 # 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3) base_interest = min(max(base_interest, 0.01), 0.3)
@@ -117,36 +117,32 @@ class S4UMessageProcessor:
user_info=userinfo, user_info=userinfo,
group_info=groupinfo, group_info=groupinfo,
) )
if await self.handle_internal_message(message): if await self.handle_internal_message(message):
return return
if await self.hadle_if_voice_done(message): if await self.hadle_if_voice_done(message):
return return
# 处理礼物消息,如果消息被暂存则停止当前处理流程 # 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message): if not skip_gift_debounce and not await self.handle_if_gift(message):
return return
await self.check_if_fake_gift(message) await self.check_if_fake_gift(message)
# 处理屏幕消息 # 处理屏幕消息
if await self.handle_screen_message(message): if await self.handle_screen_message(message):
return return
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message) await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message) _interested_rate, _ = await _calculate_interest(message)
await mood_manager.start() await mood_manager.start()
# 一系列llm驱动的前处理 # 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message)) asyncio.create_task(chat_mood.update_mood_by_message(message))
@@ -164,61 +160,56 @@ class S4UMessageProcessor:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}") logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else: else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
async def handle_internal_message(self, message: MessageRecvS4U): async def handle_internal_message(self, message: MessageRecvS4U):
if message.is_internal: if message.is_internal:
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心")
chat = await get_chat_manager().get_or_create_stream(
chat = await get_chat_manager().get_or_create_stream( platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
platform = "amaidesu_default",
user_info = message.message_info.user_info,
group_info = group_info
) )
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message) s4u_chat.internal_message.append(message)
s4u_chat._new_message_event.set() s4u_chat._new_message_event.set()
logger.info(
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}") f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
)
return True return True
return False return False
async def handle_screen_message(self, message: MessageRecvS4U): async def handle_screen_message(self, message: MessageRecvS4U):
if message.is_screen: if message.is_screen:
screen_manager.set_screen(message.screen_info) screen_manager.set_screen(message.screen_info)
return True return True
return False return False
async def hadle_if_voice_done(self, message: MessageRecvS4U): async def hadle_if_voice_done(self, message: MessageRecvS4U):
if message.voice_done: if message.voice_done:
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream) s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done s4u_chat.voice_done = message.voice_done
return True return True
return False return False
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool: async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物""" """检查消息是否为假礼物"""
if message.is_gift: if message.is_gift:
return False return False
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"] gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords): if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True message.is_fake_gift = True
return True return True
return False return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool: async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息 """处理礼物消息
Returns: Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理 bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
""" """
@@ -228,37 +219,37 @@ class S4UMessageProcessor:
"""礼物防抖完成后的回调""" """礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理 # 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True)) asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数 # 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存 # 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback) await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理 return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理 return True # 非礼物消息,继续正常处理
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv): async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task """处理上下文网页更新的独立task
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
message: 消息对象 message: 消息对象
""" """
try: try:
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}") logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
context_manager = get_context_web_manager() context_manager = get_context_web_manager()
# 只在服务器未启动时启动(避免重复启动) # 只在服务器未启动时启动(避免重复启动)
if context_manager.site is None: if context_manager.site is None:
logger.info("🚀 首次启动上下文网页服务器...") logger.info("🚀 首次启动上下文网页服务器...")
await context_manager.start_server() await context_manager.start_server()
# 添加消息到上下文并更新网页 # 添加消息到上下文并更新网页
await asyncio.sleep(1.5) await asyncio.sleep(1.5)
await context_manager.add_message(chat_id, message) await context_manager.add_message(chat_id, message)
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}") logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
except Exception as e: except Exception as e:
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True) logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)

View File

@@ -176,7 +176,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if # sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300, limit=300,
) )
@@ -228,13 +228,17 @@ class PromptBuilder:
last_speaking_user_id = start_speaking_user_id last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
)
all_msg_seg_list = [] all_msg_seg_list = []
for msg in core_dialogue_list[1:]: for msg in core_dialogue_list[1:]:
speaker = msg.user_info.user_id speaker = msg.user_info.user_id
if speaker == last_speaking_user_id: if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
)
else: else:
msg_seg_str = f"{msg_seg_str}\n" msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)

View File

@@ -14,11 +14,8 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator: class S4UStreamGenerator:
def __init__(self): def __init__(self):
# 使用LLMRequest替代AsyncOpenAIClient # 使用LLMRequest替代AsyncOpenAIClient
self.llm_request = LLMRequest( self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
model_set=model_config.model_task_config.replyer,
request_type="s4u_replyer"
)
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
self.partial_response = "" self.partial_response = ""
@@ -89,16 +86,16 @@ class S4UStreamGenerator:
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]: async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
"""使用LLMRequest进行流式响应生成""" """使用LLMRequest进行流式响应生成"""
# 构建消息 # 构建消息
message_builder = MessageBuilder() message_builder = MessageBuilder()
message_builder.add_text_content(prompt) message_builder.add_text_content(prompt)
messages = [message_builder.build()] messages = [message_builder.build()]
# 选择模型 # 选择模型
model_info, api_provider, client = self.llm_request._select_model() model_info, api_provider, client = self.llm_request._select_model()
self.current_model_name = model_info.name self.current_model_name = model_info.name
# 如果模型支持强制流式模式,使用真正的流式处理 # 如果模型支持强制流式模式,使用真正的流式处理
if model_info.force_stream_mode: if model_info.force_stream_mode:
# 简化流式处理直接使用LLMRequest的流式功能 # 简化流式处理直接使用LLMRequest的流式功能
@@ -111,14 +108,14 @@ class S4UStreamGenerator:
model_info=model_info, model_info=model_info,
message_list=messages, message_list=messages,
) )
# 处理响应内容 # 处理响应内容
content = response.content or "" content = response.content or ""
if content: if content:
# 将内容按句子分割并输出 # 将内容按句子分割并输出
async for chunk in self._process_content_streaming(content): async for chunk in self._process_content_streaming(content):
yield chunk yield chunk
except Exception as e: except Exception as e:
logger.error(f"流式请求执行失败: {e}") logger.error(f"流式请求执行失败: {e}")
# 如果流式请求失败,回退到普通模式 # 如果流式请求失败,回退到普通模式
@@ -132,7 +129,7 @@ class S4UStreamGenerator:
content = response.content or "" content = response.content or ""
async for chunk in self._process_content_streaming(content): async for chunk in self._process_content_streaming(content):
yield chunk yield chunk
else: else:
# 如果不支持流式,使用普通方式然后模拟流式输出 # 如果不支持流式,使用普通方式然后模拟流式输出
response = await self.llm_request._execute_request( response = await self.llm_request._execute_request(
@@ -142,7 +139,7 @@ class S4UStreamGenerator:
model_info=model_info, model_info=model_info,
message_list=messages, message_list=messages,
) )
content = response.content or "" content = response.content or ""
async for chunk in self._process_content_streaming(content): async for chunk in self._process_content_streaming(content):
yield chunk yield chunk
@@ -163,7 +160,7 @@ class S4UStreamGenerator:
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)""" """处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
buffer = content buffer = content
punctuation_buffer = "" punctuation_buffer = ""
# 使用正则表达式匹配句子 # 使用正则表达式匹配句子
last_match_end = 0 last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer): for match in self.sentence_split_pattern.finditer(buffer):

View File

@@ -1,4 +1,3 @@
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@@ -47,6 +46,7 @@ HEAD_CODE = {
"看向正前方": "(0,0,0)", "看向正前方": "(0,0,0)",
} }
class ChatWatching: class ChatWatching:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.chat_id: str = chat_id self.chat_id: str = chat_id
@@ -56,13 +56,13 @@ class ChatWatching:
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
) )
async def on_reply_finished(self): async def on_reply_finished(self):
"""生成回复完毕时调用""" """生成回复完毕时调用"""
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
) )
async def on_thinking_finished(self): async def on_thinking_finished(self):
"""思考完毕时调用""" """思考完毕时调用"""
await send_api.custom_to_stream( await send_api.custom_to_stream(
@@ -74,14 +74,14 @@ class ChatWatching:
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
) )
async def on_internal_message_start(self): async def on_internal_message_start(self):
"""收到消息时调用""" """收到消息时调用"""
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
) )
class WatchingManager: class WatchingManager:
def __init__(self): def __init__(self):
self.watching_list: list[ChatWatching] = [] self.watching_list: list[ChatWatching] = []
@@ -100,6 +100,7 @@ class WatchingManager:
return new_watching return new_watching
# 全局视线管理器实例 # 全局视线管理器实例
watching_manager = WatchingManager() watching_manager = WatchingManager()
"""全局视线管理器""" """全局视线管理器"""

View File

@@ -1,14 +1,15 @@
class ScreenManager: class ScreenManager:
def __init__(self): def __init__(self):
self.now_screen = str() self.now_screen = str()
def set_screen(self,screen_str:str): def set_screen(self, screen_str: str):
self.now_screen = screen_str self.now_screen = screen_str
def get_screen(self): def get_screen(self):
return self.now_screen return self.now_screen
def get_screen_str(self): def get_screen_str(self):
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}" return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
screen_manager = ScreenManager()
screen_manager = ScreenManager()

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecvS4U from src.chat.message_receive.message import MessageRecvS4U
# 全局SuperChat管理器实例 # 全局SuperChat管理器实例
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
@@ -13,7 +14,7 @@ logger = get_logger("super_chat_manager")
@dataclass @dataclass
class SuperChatRecord: class SuperChatRecord:
"""SuperChat记录数据类""" """SuperChat记录数据类"""
user_id: str user_id: str
user_nickname: str user_nickname: str
platform: str platform: str
@@ -23,15 +24,15 @@ class SuperChatRecord:
timestamp: float timestamp: float
expire_time: float expire_time: float
group_name: Optional[str] = None group_name: Optional[str] = None
def is_expired(self) -> bool: def is_expired(self) -> bool:
"""检查SuperChat是否已过期""" """检查SuperChat是否已过期"""
return time.time() > self.expire_time return time.time() > self.expire_time
def remaining_time(self) -> float: def remaining_time(self) -> float:
"""获取剩余时间(秒)""" """获取剩余时间(秒)"""
return max(0, self.expire_time - time.time()) return max(0, self.expire_time - time.time())
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
return { return {
@@ -44,19 +45,19 @@ class SuperChatRecord:
"timestamp": self.timestamp, "timestamp": self.timestamp,
"expire_time": self.expire_time, "expire_time": self.expire_time,
"group_name": self.group_name, "group_name": self.group_name,
"remaining_time": self.remaining_time() "remaining_time": self.remaining_time(),
} }
class SuperChatManager: class SuperChatManager:
"""SuperChat管理器负责管理和跟踪SuperChat消息""" """SuperChat管理器负责管理和跟踪SuperChat消息"""
def __init__(self): def __init__(self):
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表 self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_task: Optional[asyncio.Task] = None
self._is_initialized = False self._is_initialized = False
logger.info("SuperChat管理器已初始化") logger.info("SuperChat管理器已初始化")
def _ensure_cleanup_task_started(self): def _ensure_cleanup_task_started(self):
"""确保清理任务已启动(延迟启动)""" """确保清理任务已启动(延迟启动)"""
if self._cleanup_task is None or self._cleanup_task.done(): if self._cleanup_task is None or self._cleanup_task.done():
@@ -68,7 +69,7 @@ class SuperChatManager:
except RuntimeError: except RuntimeError:
# 没有运行的事件循环,稍后再启动 # 没有运行的事件循环,稍后再启动
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务") logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
def _start_cleanup_task(self): def _start_cleanup_task(self):
"""启动清理任务(已弃用,保留向后兼容)""" """启动清理任务(已弃用,保留向后兼容)"""
self._ensure_cleanup_task_started() self._ensure_cleanup_task_started()
@@ -78,39 +79,36 @@ class SuperChatManager:
while True: while True:
try: try:
total_removed = 0 total_removed = 0
for chat_id in list(self.super_chats.keys()): for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id]) original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat # 移除过期的SuperChat
self.super_chats[chat_id] = [ self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
sc for sc in self.super_chats[chat_id]
if not sc.is_expired()
]
removed_count = original_count - len(self.super_chats[chat_id]) removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count total_removed += removed_count
if removed_count > 0: if removed_count > 0:
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat") logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
# 如果列表为空,删除该聊天的记录 # 如果列表为空,删除该聊天的记录
if not self.super_chats[chat_id]: if not self.super_chats[chat_id]:
del self.super_chats[chat_id] del self.super_chats[chat_id]
if total_removed > 0: if total_removed > 0:
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat") logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
# 每30秒检查一次 # 每30秒检查一次
await asyncio.sleep(30) await asyncio.sleep(30)
except Exception as e: except Exception as e:
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True) logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
await asyncio.sleep(60) # 出错时等待更长时间 await asyncio.sleep(60) # 出错时等待更长时间
def _calculate_expire_time(self, price: float) -> float: def _calculate_expire_time(self, price: float) -> float:
"""根据SuperChat金额计算过期时间""" """根据SuperChat金额计算过期时间"""
current_time = time.time() current_time = time.time()
# 根据金额阶梯设置不同的存活时间 # 根据金额阶梯设置不同的存活时间
if price >= 500: if price >= 500:
# 500元以上保持4小时 # 500元以上保持4小时
@@ -133,27 +131,27 @@ class SuperChatManager:
else: else:
# 10元以下保持5分钟 # 10元以下保持5分钟
duration = 5 * 60 duration = 5 * 60
return current_time + duration return current_time + duration
async def add_superchat(self, message: MessageRecvS4U) -> None: async def add_superchat(self, message: MessageRecvS4U) -> None:
"""添加新的SuperChat记录""" """添加新的SuperChat记录"""
# 确保清理任务已启动 # 确保清理任务已启动
self._ensure_cleanup_task_started() self._ensure_cleanup_task_started()
if not message.is_superchat or not message.superchat_price: if not message.is_superchat or not message.superchat_price:
logger.warning("尝试添加非SuperChat消息到SuperChat管理器") logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
return return
try: try:
price = float(message.superchat_price) price = float(message.superchat_price)
except (ValueError, TypeError): except (ValueError, TypeError):
logger.error(f"无效的SuperChat价格: {message.superchat_price}") logger.error(f"无效的SuperChat价格: {message.superchat_price}")
return return
user_info = message.message_info.user_info user_info = message.message_info.user_info
group_info = message.message_info.group_info group_info = message.message_info.group_info
chat_id = getattr(message, 'chat_stream', None) chat_id = getattr(message, "chat_stream", None)
if chat_id: if chat_id:
chat_id = chat_id.stream_id chat_id = chat_id.stream_id
else: else:
@@ -161,9 +159,9 @@ class SuperChatManager:
chat_id = f"{message.message_info.platform}_{user_info.user_id}" chat_id = f"{message.message_info.platform}_{user_info.user_id}"
if group_info: if group_info:
chat_id = f"{message.message_info.platform}_{group_info.group_id}" chat_id = f"{message.message_info.platform}_{group_info.group_id}"
expire_time = self._calculate_expire_time(price) expire_time = self._calculate_expire_time(price)
record = SuperChatRecord( record = SuperChatRecord(
user_id=user_info.user_id, user_id=user_info.user_id,
user_nickname=user_info.user_nickname, user_nickname=user_info.user_nickname,
@@ -173,44 +171,44 @@ class SuperChatManager:
message_text=message.superchat_message_text or "", message_text=message.superchat_message_text or "",
timestamp=message.message_info.time, timestamp=message.message_info.time,
expire_time=expire_time, expire_time=expire_time,
group_name=group_info.group_name if group_info else None group_name=group_info.group_name if group_info else None,
) )
# 添加到对应聊天的SuperChat列表 # 添加到对应聊天的SuperChat列表
if chat_id not in self.super_chats: if chat_id not in self.super_chats:
self.super_chats[chat_id] = [] self.super_chats[chat_id] = []
self.super_chats[chat_id].append(record) self.super_chats[chat_id].append(record)
# 按价格降序排序(价格高的在前) # 按价格降序排序(价格高的在前)
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True) self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}") logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]: def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
"""获取指定聊天的所有有效SuperChat""" """获取指定聊天的所有有效SuperChat"""
# 确保清理任务已启动 # 确保清理任务已启动
self._ensure_cleanup_task_started() self._ensure_cleanup_task_started()
if chat_id not in self.super_chats: if chat_id not in self.super_chats:
return [] return []
# 过滤掉过期的SuperChat # 过滤掉过期的SuperChat
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
return valid_superchats return valid_superchats
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]: def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
"""获取所有有效的SuperChat""" """获取所有有效的SuperChat"""
# 确保清理任务已启动 # 确保清理任务已启动
self._ensure_cleanup_task_started() self._ensure_cleanup_task_started()
result = {} result = {}
for chat_id, superchats in self.super_chats.items(): for chat_id, superchats in self.super_chats.items():
valid_superchats = [sc for sc in superchats if not sc.is_expired()] valid_superchats = [sc for sc in superchats if not sc.is_expired()]
if valid_superchats: if valid_superchats:
result[chat_id] = valid_superchats result[chat_id] = valid_superchats
return result return result
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串""" """构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
@@ -226,7 +224,9 @@ class SuperChatManager:
remaining_minutes = int(sc.remaining_time() / 60) remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60) remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}" time_display = (
f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
)
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度 if len(line) > 100: # 限制单行长度
@@ -238,7 +238,7 @@ class SuperChatManager:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines) return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str: def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串""" """构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
@@ -261,30 +261,24 @@ class SuperChatManager:
if lines: if lines:
final_str += "\n" + "\n".join(lines) final_str += "\n" + "\n".join(lines)
return final_str return final_str
def get_superchat_statistics(self, chat_id: str) -> dict: def get_superchat_statistics(self, chat_id: str) -> dict:
"""获取SuperChat统计信息""" """获取SuperChat统计信息"""
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
if not superchats: if not superchats:
return { return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
"count": 0,
"total_amount": 0,
"average_amount": 0,
"highest_amount": 0,
"lowest_amount": 0
}
amounts = [sc.price for sc in superchats] amounts = [sc.price for sc in superchats]
return { return {
"count": len(superchats), "count": len(superchats),
"total_amount": sum(amounts), "total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts), "average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts), "highest_amount": max(amounts),
"lowest_amount": min(amounts) "lowest_amount": min(amounts),
} }
async def shutdown(self): # sourcery skip: use-contextlib-suppress async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源""" """关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done(): if self._cleanup_task and not self._cleanup_task.done():
@@ -296,15 +290,14 @@ class SuperChatManager:
logger.info("SuperChat管理器已关闭") logger.info("SuperChat管理器已关闭")
# sourcery skip: assign-if-exp # sourcery skip: assign-if-exp
if s4u_config.enable_s4u: if s4u_config.enable_s4u:
super_chat_manager = SuperChatManager() super_chat_manager = SuperChatManager()
else: else:
super_chat_manager = None super_chat_manager = None
def get_super_chat_manager() -> SuperChatManager: def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例""" """获取全局SuperChat管理器实例"""
return super_chat_manager return super_chat_manager

View File

@@ -10,10 +10,12 @@ from src.common.logger import get_logger
logger = get_logger("s4u_config") logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table # 新增兼容dict和tomlkit Table
def is_dict_like(obj): def is_dict_like(obj):
return isinstance(obj, (dict, Table)) return isinstance(obj, (dict, Table))
# 新增递归将Table转为dict # 新增递归将Table转为dict
def table_to_dict(obj): def table_to_dict(obj):
if isinstance(obj, Table): if isinstance(obj, Table):
@@ -25,6 +27,7 @@ def table_to_dict(obj):
else: else:
return obj return obj
# 获取mais4u模块目录 # 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__) MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config") CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
@@ -190,7 +193,7 @@ class S4UModelConfig(S4UConfigBase):
@dataclass @dataclass
class S4UConfig(S4UConfigBase): class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类""" """S4U聊天系统配置类"""
enable_s4u: bool = False enable_s4u: bool = False
"""是否启用S4U聊天系统""" """是否启用S4U聊天系统"""
@@ -229,12 +232,12 @@ class S4UConfig(S4UConfigBase):
enable_streaming_output: bool = True enable_streaming_output: bool = True
"""是否启用流式输出false时全部生成后一次性发送""" """是否启用流式输出false时全部生成后一次性发送"""
max_context_message_length: int = 20 max_context_message_length: int = 20
"""上下文消息最大长度""" """上下文消息最大长度"""
max_core_message_length: int = 30 max_core_message_length: int = 30
"""核心消息最大长度""" """核心消息最大长度"""
# 模型配置 # 模型配置
models: S4UModelConfig = field(default_factory=S4UModelConfig) models: S4UModelConfig = field(default_factory=S4UModelConfig)
@@ -243,7 +246,6 @@ class S4UConfig(S4UConfigBase):
# 兼容性字段,保持向后兼容 # 兼容性字段,保持向后兼容
@dataclass @dataclass
class S4UGlobalConfig(S4UConfigBase): class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类""" """S4U总配置类"""
@@ -256,7 +258,7 @@ def update_s4u_config():
"""更新S4U配置文件""" """更新S4U配置文件"""
# 创建配置目录(如果不存在) # 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True) os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在 # 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH): if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}") logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
@@ -354,13 +356,13 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
logger.critical("S4U配置文件解析失败") logger.critical("S4U配置文件解析失败")
raise e raise e
# 初始化S4U配置 # 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}") logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config() update_s4u_config()
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH) s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成") logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u s4u_config: S4UConfig = s4u_config_main.s4u

View File

@@ -13,7 +13,7 @@ async def migrate_memory_items_to_string():
并根据原始list的项目数量设置weight值 并根据原始list的项目数量设置weight值
""" """
logger.info("开始迁移记忆节点格式...") logger.info("开始迁移记忆节点格式...")
migration_stats = { migration_stats = {
"total_nodes": 0, "total_nodes": 0,
"converted_nodes": 0, "converted_nodes": 0,
@@ -21,72 +21,74 @@ async def migrate_memory_items_to_string():
"empty_nodes": 0, "empty_nodes": 0,
"error_nodes": 0, "error_nodes": 0,
"weight_updated_nodes": 0, "weight_updated_nodes": 0,
"truncated_nodes": 0 "truncated_nodes": 0,
} }
try: try:
# 获取所有图节点 # 获取所有图节点
all_nodes = GraphNodes.select() all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count() migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点") logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes: for node in all_nodes:
try: try:
concept = node.concept concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else "" memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0 original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 如果为空,跳过 # 如果为空,跳过
if not memory_items_raw: if not memory_items_raw:
migration_stats["empty_nodes"] += 1 migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}") logger.debug(f"跳过空节点: {concept}")
continue continue
try: try:
# 尝试解析JSON # 尝试解析JSON
parsed_data = json.loads(memory_items_raw) parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list): if isinstance(parsed_data, list):
# 如果是list格式需要转换 # 如果是list格式需要转换
if parsed_data: if parsed_data:
# 转换为字符串格式 # 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data) new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items) original_length = len(new_memory_items)
# 检查长度并截断 # 检查长度并截断
if len(new_memory_items) > 100: if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100] new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1 migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符") logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量 new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库 # 更新数据库
node.memory_items = new_memory_items node.memory_items = new_memory_items
node.weight = new_weight node.weight = new_weight
node.save() node.save()
migration_stats["converted_nodes"] += 1 migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1 migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}") logger.info(
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
)
else: else:
# 空list设置为空字符串 # 空list设置为空字符串
node.memory_items = "" node.memory_items = ""
node.weight = 1.0 node.weight = 1.0
node.save() node.save()
migration_stats["converted_nodes"] += 1 migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}") logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str): elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight # 已经是字符串格式检查长度和weight
current_content = parsed_data current_content = parsed_data
original_length = len(current_content) original_length = len(current_content)
content_truncated = False content_truncated = False
# 检查长度并截断 # 检查长度并截断
if len(current_content) > 100: if len(current_content) > 100:
current_content = current_content[:100] current_content = current_content[:100]
@@ -94,19 +96,21 @@ async def migrate_memory_items_to_string():
migration_stats["truncated_nodes"] += 1 migration_stats["truncated_nodes"] += 1
node.memory_items = current_content node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符") logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新 # 检查weight是否需要更新
update_needed = False update_needed = False
if original_weight == 1.0: if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算 # 如果weight还是默认值可以根据内容复杂度估算
content_parts = current_content.split(" | ") if " | " in current_content else [current_content] content_parts = (
current_content.split(" | ") if " | " in current_content else [current_content]
)
estimated_weight = max(1.0, float(len(content_parts))) estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight: if estimated_weight != original_weight:
node.weight = estimated_weight node.weight = estimated_weight
update_needed = True update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}") logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库 # 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed: if content_truncated or update_needed:
node.save() node.save()
@@ -118,26 +122,26 @@ async def migrate_memory_items_to_string():
migration_stats["already_string_nodes"] += 1 migration_stats["already_string_nodes"] += 1
else: else:
migration_stats["already_string_nodes"] += 1 migration_stats["already_string_nodes"] += 1
else: else:
# 其他JSON类型转换为字符串 # 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else "" new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items) original_length = len(new_memory_items)
# 检查长度并截断 # 检查长度并截断
if len(new_memory_items) > 100: if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100] new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1 migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符") logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items node.memory_items = new_memory_items
node.weight = 1.0 node.weight = 1.0
node.save() node.save()
migration_stats["converted_nodes"] += 1 migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}") logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError: except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串 # 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串 # 检查是否是带引号的字符串
@@ -145,16 +149,16 @@ async def migrate_memory_items_to_string():
# 去掉引号 # 去掉引号
clean_content = memory_items_raw[1:-1] clean_content = memory_items_raw[1:-1]
original_length = len(clean_content) original_length = len(clean_content)
# 检查长度并截断 # 检查长度并截断
if len(clean_content) > 100: if len(clean_content) > 100:
clean_content = clean_content[:100] clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1 migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符") logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content node.memory_items = clean_content
node.save() node.save()
migration_stats["converted_nodes"] += 1 migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}") logger.debug(f"去除引号节点: {concept}{length_info}")
@@ -162,29 +166,29 @@ async def migrate_memory_items_to_string():
# 已经是纯字符串格式,检查长度 # 已经是纯字符串格式,检查长度
current_content = memory_items_raw current_content = memory_items_raw
original_length = len(current_content) original_length = len(current_content)
# 检查长度并截断 # 检查长度并截断
if len(current_content) > 100: if len(current_content) > 100:
current_content = current_content[:100] current_content = current_content[:100]
node.memory_items = current_content node.memory_items = current_content
node.save() node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点 migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1 migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符") logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else: else:
migration_stats["already_string_nodes"] += 1 migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}") logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e: except Exception as e:
migration_stats["error_nodes"] += 1 migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}") logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue continue
except Exception as e: except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}") logger.error(f"迁移过程中发生严重错误: {e}")
raise raise
# 输出迁移统计 # 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===") logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}") logger.info(f"总节点数: {migration_stats['total_nodes']}")
@@ -194,101 +198,105 @@ async def migrate_memory_items_to_string():
logger.info(f"错误节点: {migration_stats['error_nodes']}") logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}") logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}") logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0 success_rate = (
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
/ migration_stats["total_nodes"]
* 100
if migration_stats["total_nodes"] > 0
else 0
)
logger.info(f"迁移成功率: {success_rate:.1f}%") logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats return migration_stats
async def set_all_person_known(): async def set_all_person_known():
""" """
将person_info库中所有记录的is_known字段设置为True 将person_info库中所有记录的is_known字段设置为True
在设置之前先清理掉user_id或platform为空的记录 在设置之前先清理掉user_id或platform为空的记录
""" """
logger.info("开始设置所有person_info记录为已认识...") logger.info("开始设置所有person_info记录为已认识...")
try: try:
from src.common.database.database_model import PersonInfo from src.common.database.database_model import PersonInfo
# 获取所有PersonInfo记录 # 获取所有PersonInfo记录
all_persons = PersonInfo.select() all_persons = PersonInfo.select()
total_count = all_persons.count() total_count = all_persons.count()
logger.info(f"找到 {total_count} 个人员记录") logger.info(f"找到 {total_count} 个人员记录")
if total_count == 0: if total_count == 0:
logger.info("没有找到任何人员记录") logger.info("没有找到任何人员记录")
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0} return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
# 删除user_id或platform为空的记录 # 删除user_id或platform为空的记录
deleted_count = 0 deleted_count = 0
invalid_records = PersonInfo.select().where( invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null()) | (PersonInfo.user_id.is_null())
(PersonInfo.user_id == '') | | (PersonInfo.user_id == "")
(PersonInfo.platform.is_null()) | | (PersonInfo.platform.is_null())
(PersonInfo.platform == '') | (PersonInfo.platform == "")
) )
# 记录要删除的记录信息 # 记录要删除的记录信息
for record in invalid_records: for record in invalid_records:
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL" user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL" platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称" person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}") logger.debug(
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
)
# 执行删除操作 # 执行删除操作
deleted_count = PersonInfo.delete().where( deleted_count = (
(PersonInfo.user_id.is_null()) | PersonInfo.delete()
(PersonInfo.user_id == '') | .where(
(PersonInfo.platform.is_null()) | (PersonInfo.user_id.is_null())
(PersonInfo.platform == '') | (PersonInfo.user_id == "")
).execute() | (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
.execute()
)
if deleted_count > 0: if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录") logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
else: else:
logger.info("没有发现user_id或platform为空的记录") logger.info("没有发现user_id或platform为空的记录")
# 重新获取剩余记录数量 # 重新获取剩余记录数量
remaining_count = PersonInfo.select().count() remaining_count = PersonInfo.select().count()
logger.info(f"清理后剩余 {remaining_count} 个有效记录") logger.info(f"清理后剩余 {remaining_count} 个有效记录")
if remaining_count == 0: if remaining_count == 0:
logger.info("清理后没有剩余记录") logger.info("清理后没有剩余记录")
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0} return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
# 批量更新剩余记录的is_known字段为True # 批量更新剩余记录的is_known字段为True
updated_count = PersonInfo.update(is_known=True).execute() updated_count = PersonInfo.update(is_known=True).execute()
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True") logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
# 验证更新结果 # 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count() known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = { result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
"total": total_count,
"deleted": deleted_count,
"updated": updated_count,
"known_count": known_count
}
logger.info("=== person_info更新完成 ===") logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}") logger.info(f"原始记录数: {result['total']}")
logger.info(f"删除记录数: {result['deleted']}") logger.info(f"删除记录数: {result['deleted']}")
logger.info(f"更新记录数: {result['updated']}") logger.info(f"更新记录数: {result['updated']}")
logger.info(f"已认识记录数: {result['known_count']}") logger.info(f"已认识记录数: {result['known_count']}")
return result return result
except Exception as e: except Exception as e:
logger.error(f"更新person_info过程中发生错误: {e}") logger.error(f"更新person_info过程中发生错误: {e}")
raise raise
async def check_and_run_migrations(): async def check_and_run_migrations():
# 获取根目录 # 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -309,4 +317,3 @@ async def check_and_run_migrations():
# 创建done.mem文件 # 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f: with open(done_file, "w", encoding="utf-8") as f:
f.write("done") f.write("done")

View File

@@ -282,7 +282,7 @@ class Person:
memory_category = parts[0].strip() memory_category = parts[0].strip()
memory_text = parts[1].strip() memory_text = parts[1].strip()
memory_weight = parts[2].strip() _memory_weight = parts[2].strip()
# 检查分类是否匹配 # 检查分类是否匹配
if memory_category != category: if memory_category != category:

View File

@@ -1,11 +1,5 @@
import json
from json_repair import repair_json
from datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.chat.utils.prompt_builder import Prompt
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .person_info import Person
logger = get_logger("relation") logger = get_logger("relation")
@@ -43,4 +37,3 @@ def init_prompt():
""", """,
"attitude_to_me_prompt", "attitude_to_me_prompt",
) )

View File

@@ -121,5 +121,5 @@ __all__ = [
"DatabaseChatInfo", "DatabaseChatInfo",
"TargetPersonInfo", "TargetPersonInfo",
"ActionPlannerInfo", "ActionPlannerInfo",
"LLMGenerationDataModel" "LLMGenerationDataModel",
] ]

View File

@@ -7,22 +7,24 @@ logger = get_logger("frequency_api")
def get_current_focus_value(chat_id: str) -> float: def get_current_focus_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value() return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value()
def get_current_talk_frequency(chat_id: str) -> float: def get_current_talk_frequency(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency() return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency()
def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None: def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust = talk_frequency_adjust frequency_control_manager.get_or_create_frequency_control(
chat_id
).talk_frequency_external_adjust = talk_frequency_adjust
def get_focus_value_adjust(chat_id: str) -> float: def get_focus_value_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust
def get_talk_frequency_adjust(chat_id: str) -> float: def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust

View File

@@ -72,7 +72,9 @@ async def generate_with_model(
llm_request = LLMRequest(model_set=model_config, request_type=request_type) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens) response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name return True, response, reasoning_content, model_name
except Exception as e: except Exception as e:
@@ -80,6 +82,7 @@ async def generate_with_model(
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "" return False, error_msg, "", ""
async def generate_with_model_with_tools( async def generate_with_model_with_tools(
prompt: str, prompt: str,
model_config: TaskConfig, model_config: TaskConfig,
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
llm_request = LLMRequest(model_set=model_config, request_type=request_type) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt, prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
) )
return True, response, reasoning_content, model_name, tool_call return True, response, reasoning_content, model_name, tool_call

View File

@@ -435,9 +435,7 @@ def build_readable_messages_to_str(
Returns: Returns:
格式化后的可读字符串 格式化后的可读字符串
""" """
return build_readable_messages( return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
)
async def build_readable_messages_with_details( async def build_readable_messages_with_details(
@@ -491,8 +489,6 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)] return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
def translate_pid_to_description(pid: str) -> str: def translate_pid_to_description(pid: str) -> str:
image = Images.get_or_none(Images.image_id == pid) image = Images.get_or_none(Images.image_id == pid)
description = "" description = ""
@@ -500,4 +496,4 @@ def translate_pid_to_description(pid: str) -> str:
description = image.description description = image.description
else: else:
description = "[图片]" description = "[图片]"
return description return description

View File

@@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str:
Returns: Returns:
str: 插件目录的绝对路径。 str: 插件目录的绝对路径。
Raises: Raises:
ValueError: 如果插件不存在。 ValueError: 如果插件不存在。
""" """

View File

@@ -2,7 +2,7 @@ from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称 logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls): def register_plugin(cls):

View File

@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
return False return False
self._user_disabled_tools[chat_id].append(tool_name) self._user_disabled_tools[chat_id].append(tool_name)
return True return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具""" """启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools: if chat_id in self._user_disabled_tools:
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器""" """获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy() return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]: def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具""" """获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy() return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -224,7 +224,7 @@ class PluginManager:
list: 已注册的插件类名称列表。 list: 已注册的插件类名称列表。
""" """
return list(self.plugin_classes.keys()) return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]: def get_plugin_path(self, plugin_name: str) -> Optional[str]:
""" """
获取指定插件的路径。 获取指定插件的路径。
@@ -401,9 +401,7 @@ class PluginManager:
command_components = [ command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
] ]
tool_components = [ tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [ event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
] ]

View File

@@ -149,10 +149,10 @@ class ToolExecutor:
if not tool_calls: if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具") logger.debug(f"{self.log_prefix}无需执行工具")
return [], [] return [], []
# 提取tool_calls中的函数名称 # 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name] func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}") logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用 # 执行每个工具调用
@@ -195,7 +195,9 @@ class ToolExecutor:
return tool_results, used_tools return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable # sourcery skip: use-assigned-variable
"""执行单个工具调用 """执行单个工具调用

View File

@@ -63,5 +63,4 @@ class CoreActionsPlugin(BasePlugin):
if self.get_config("components.enable_emoji", True): if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction)) components.append((EmojiAction.get_action_info(), EmojiAction))
return components return components

View File

@@ -74,7 +74,9 @@ class BuildMemoryAction(BaseAction):
# 动作基本信息 # 动作基本信息
action_name = "build_memory" action_name = "build_memory"
action_description = "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息" action_description = (
"了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
)
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters = {
@@ -103,31 +105,34 @@ class BuildMemoryAction(BaseAction):
concept_name = self.action_data.get("concept_name", "") concept_name = self.action_data.get("concept_name", "")
# 2. 获取目标用户信息 # 2. 获取目标用户信息
# 对 concept_name 进行jieba分词 # 对 concept_name 进行jieba分词
concept_name_tokens = cut_key_words(concept_name) concept_name_tokens = cut_key_words(concept_name)
# logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}") # logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}")
filtered_concept_name_tokens = [ filtered_concept_name_tokens = [
token for token in concept_name_tokens if all(keyword not in token for keyword in global_config.memory.memory_ban_words) token
for token in concept_name_tokens
if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
] ]
if not filtered_concept_name_tokens: if not filtered_concept_name_tokens:
logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆") logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆")
return False, "过滤后的概念名称列表为空,跳过添加记忆" return False, "过滤后的概念名称列表为空,跳过添加记忆"
similar_topics_dict = hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(filtered_concept_name_tokens) similar_topics_dict = (
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(concept_description, similar_topics_dict) hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(
filtered_concept_name_tokens
)
)
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(
concept_description, similar_topics_dict
)
return True, f"成功添加记忆: {concept_name}" return True, f"成功添加记忆: {concept_name}"
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 构建记忆时出错: {e}") logger.error(f"{self.log_prefix} 构建记忆时出错: {e}")
return False, f"构建记忆时出错: {e}" return False, f"构建记忆时出错: {e}"
# 还缺一个关系的太多遗忘和对应的提取 # 还缺一个关系的太多遗忘和对应的提取

View File

@@ -425,7 +425,7 @@ class ManagementCommand(BaseCommand):
await self._send_message(f"本地禁用组件成功: {component_name}") await self._send_message(f"本地禁用组件成功: {component_name}")
else: else:
await self._send_message(f"本地禁用组件失败: {component_name}") await self._send_message(f"本地禁用组件失败: {component_name}")
async def _send_message(self, message: str): async def _send_message(self, message: str):
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False) await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)

View File

@@ -107,7 +107,7 @@ class BuildRelationAction(BaseAction):
if not person.is_known: if not person.is_known:
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆") logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
return False, f"用户 {person_name} 不存在,跳过添加记忆" return False, f"用户 {person_name} 不存在,跳过添加记忆"
person.last_know = time.time() person.last_know = time.time()
person.know_times += 1 person.know_times += 1
person.sync_to_database() person.sync_to_database()
@@ -178,7 +178,9 @@ class BuildRelationAction(BaseAction):
chat_model_config = models.get("utils") chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model( success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore prompt,
model_config=chat_model_config,
request_type="relation.category.update", # type: ignore
) )
update_memory_data = json.loads(repair_json(update_memory)) update_memory_data = json.loads(repair_json(update_memory))
@@ -190,7 +192,7 @@ class BuildRelationAction(BaseAction):
# 新记忆 # 新记忆
person.memory_points.append(f"{category}:{new_memory}:1.0") person.memory_points.append(f"{category}:{new_memory}:1.0")
person.sync_to_database() person.sync_to_database()
logger.info(f"{self.log_prefix}{person.person_name}新增记忆点: {new_memory}") logger.info(f"{self.log_prefix}{person.person_name}新增记忆点: {new_memory}")
return True, f"{person.person_name}新增记忆点: {new_memory}" return True, f"{person.person_name}新增记忆点: {new_memory}"
@@ -207,14 +209,15 @@ class BuildRelationAction(BaseAction):
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}") person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
person.sync_to_database() person.sync_to_database()
logger.info(f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}") logger.info(
f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
)
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
else: else:
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}" return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功" return True, "关系动作执行成功"