From 163dbb6b9039946534410c2208aaf49bca70578f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Tue, 9 Sep 2025 19:25:12 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B6=85=E7=BA=A7Ruff?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 3 +- scripts/expression_stats.py | 118 ++++----- scripts/import_openie.py | 3 +- scripts/info_extraction.py | 13 +- scripts/interest_value_analysis.py | 176 +++++++------ scripts/log_viewer_optimized.py | 10 +- scripts/raw_data_preprocessor.py | 5 +- scripts/text_length_analysis.py | 235 +++++++++--------- src/chat/emoji_system/emoji_manager.py | 4 +- src/chat/express/expression_selector.py | 14 +- .../frequency_control/focus_value_control.py | 1 - .../talk_frequency_control.py | 2 - src/chat/frequency_control/utils.py | 2 +- src/chat/heart_flow/heartFC_chat.py | 14 +- src/chat/heart_flow/heartflow.py | 7 +- .../heart_flow/heartflow_message_processor.py | 33 +-- src/chat/heart_flow/hfc_utils.py | 3 +- src/chat/knowledge/__init__.py | 1 + src/chat/knowledge/embedding_store.py | 146 ++++++----- src/chat/knowledge/kg_manager.py | 4 +- src/chat/knowledge/mem_active_manager.py | 8 +- src/chat/knowledge/utils/dyn_topk.py | 2 +- src/chat/message_receive/storage.py | 4 +- src/chat/planner_actions/action_manager.py | 2 +- src/chat/planner_actions/action_modifier.py | 30 +-- src/chat/utils/statistic.py | 28 ++- src/chat/utils/utils.py | 67 +++-- src/chat/utils/utils_image.py | 4 +- src/common/data_models/__init__.py | 1 + src/common/data_models/database_data_model.py | 3 +- src/common/data_models/llm_data_model.py | 4 +- src/common/database/database_model.py | 181 +++++++------- src/common/logger.py | 17 +- src/config/config.py | 4 +- src/config/official_configs.py | 16 +- src/llm_models/model_client/base_client.py | 2 +- src/llm_models/model_client/openai_client.py | 2 +- src/llm_models/payload_content/__init__.py | 2 +- src/llm_models/payload_content/resp_format.py | 15 +- src/llm_models/utils.py | 13 +- src/mais4u/mais4u_chat/context_web_manager.py | 185 +++++++------- src/mais4u/mais4u_chat/gift_manager.py | 78 +++--- src/mais4u/mais4u_chat/internal_manager.py | 11 +- src/mais4u/mais4u_chat/s4u_chat.py | 119 ++++----- src/mais4u/mais4u_chat/s4u_msg_processor.py | 79 +++--- src/mais4u/mais4u_chat/s4u_prompt.py | 10 +- .../mais4u_chat/s4u_stream_generator.py | 23 +- .../mais4u_chat/s4u_watching_manager.py | 11 +- src/mais4u/mais4u_chat/screen_manager.py | 11 +- src/mais4u/mais4u_chat/super_chat_manager.py | 109 ++++---- src/mais4u/s4u_config.py | 20 +- src/migrate_helper/migrate.py | 159 ++++++------ src/person_info/person_info.py | 2 +- src/person_info/relationship_manager.py | 9 +- src/plugin_system/__init__.py | 2 +- src/plugin_system/apis/frequency_api.py | 18 +- src/plugin_system/apis/generator_api.py | 1 + src/plugin_system/apis/llm_api.py | 10 +- src/plugin_system/apis/message_api.py | 8 +- src/plugin_system/apis/plugin_manage_api.py | 2 +- src/plugin_system/apis/plugin_register_api.py | 2 +- .../core/global_announcement_manager.py | 4 +- src/plugin_system/core/plugin_manager.py | 6 +- src/plugin_system/core/tool_use.py | 8 +- src/plugins/built_in/emoji_plugin/plugin.py | 1 - src/plugins/built_in/memory/build_memory.py | 33 +-- .../built_in/plugin_management/plugin.py | 2 +- src/plugins/built_in/relation/relation.py | 13 +- 68 files changed, 1092 insertions(+), 1043 deletions(-) diff --git a/bot.py b/bot.py index ea5244f2..bb7d72ea 100644 --- a/bot.py +++ b/bot.py @@ -62,9 +62,10 @@ def easter_egg(): async def graceful_shutdown(): # sourcery skip: use-named-expression try: logger.info("正在优雅关闭麦麦...") - + from src.plugin_system.core.events_manager import events_manager from src.plugin_system.base.component_types import EventType + # 触发 ON_STOP 事件 await events_manager.handle_mai_events(event_type=EventType.ON_STOP) diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py index 4e761d8d..133f3d73 100644 --- a/scripts/expression_stats.py +++ b/scripts/expression_stats.py @@ -5,12 +5,11 @@ from typing import Dict, List # Add project root to Python path from src.common.database.database_model import Expression, ChatStreams + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) - - def get_chat_name(chat_id: str) -> str: """Get chat name from chat_id by querying ChatStreams table directly""" try: @@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str: chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) if chat_stream is None: return f"未知聊天 ({chat_id})" - + # 如果有群组信息,显示群组名称 if chat_stream.group_name: return f"{chat_stream.group_name} ({chat_id})" @@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]: """Calculate distribution of last active time in days""" now = time.time() distribution = { - '0-1天': 0, - '1-3天': 0, - '3-7天': 0, - '7-14天': 0, - '14-30天': 0, - '30-60天': 0, - '60-90天': 0, - '90+天': 0 + "0-1天": 0, + "1-3天": 0, + "3-7天": 0, + "7-14天": 0, + "14-30天": 0, + "30-60天": 0, + "60-90天": 0, + "90+天": 0, } for expr in expressions: - diff_days = (now - expr.last_active_time) / (24*3600) + diff_days = (now - expr.last_active_time) / (24 * 3600) if diff_days < 1: - distribution['0-1天'] += 1 + distribution["0-1天"] += 1 elif diff_days < 3: - distribution['1-3天'] += 1 + distribution["1-3天"] += 1 elif diff_days < 7: - distribution['3-7天'] += 1 + distribution["3-7天"] += 1 elif diff_days < 14: - distribution['7-14天'] += 1 + distribution["7-14天"] += 1 elif diff_days < 30: - distribution['14-30天'] += 1 + distribution["14-30天"] += 1 elif diff_days < 60: - distribution['30-60天'] += 1 + distribution["30-60天"] += 1 elif diff_days < 90: - distribution['60-90天'] += 1 + distribution["60-90天"] += 1 else: - distribution['90+天'] += 1 + distribution["90+天"] += 1 return distribution def calculate_count_distribution(expressions) -> Dict[str, int]: """Calculate distribution of count values""" - distribution = { - '0-1': 0, - '1-2': 0, - '2-3': 0, - '3-4': 0, - '4-5': 0, - '5-10': 0, - '10+': 0 - } + distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0} for expr in expressions: cnt = expr.count if cnt < 1: - distribution['0-1'] += 1 + distribution["0-1"] += 1 elif cnt < 2: - distribution['1-2'] += 1 + distribution["1-2"] += 1 elif cnt < 3: - distribution['2-3'] += 1 + distribution["2-3"] += 1 elif cnt < 4: - distribution['3-4'] += 1 + distribution["3-4"] += 1 elif cnt < 5: - distribution['4-5'] += 1 + distribution["4-5"] += 1 elif cnt < 10: - distribution['5-10'] += 1 + distribution["5-10"] += 1 else: - distribution['10+'] += 1 + distribution["10+"] += 1 return distribution def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: """Get top N most used expressions for a specific chat_id""" - return (Expression.select() - .where(Expression.chat_id == chat_id) - .order_by(Expression.count.desc()) - .limit(top_n)) + return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n) def show_overall_statistics(expressions, total: int) -> None: """Show overall statistics""" time_dist = calculate_time_distribution(expressions) count_dist = calculate_count_distribution(expressions) - + print("\n=== 总体统计 ===") print(f"总表达式数量: {total}") - + print("\n上次激活时间分布:") for period, count in time_dist.items(): - print(f"{period}: {count} ({count/total*100:.2f}%)") - + print(f"{period}: {count} ({count / total * 100:.2f}%)") + print("\ncount分布:") for range_, count in count_dist.items(): - print(f"{range_}: {count} ({count/total*100:.2f}%)") + print(f"{range_}: {count} ({count / total * 100:.2f}%)") def show_chat_statistics(chat_id: str, chat_name: str) -> None: """Show statistics for a specific chat""" chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id)) chat_total = len(chat_exprs) - + print(f"\n=== {chat_name} ===") print(f"表达式数量: {chat_total}") - + if chat_total == 0: print("该聊天没有表达式数据") return - + # Time distribution for this chat time_dist = calculate_time_distribution(chat_exprs) print("\n上次激活时间分布:") for period, count in time_dist.items(): if count > 0: - print(f"{period}: {count} ({count/chat_total*100:.2f}%)") - + print(f"{period}: {count} ({count / chat_total * 100:.2f}%)") + # Count distribution for this chat count_dist = calculate_count_distribution(chat_exprs) print("\ncount分布:") for range_, count in count_dist.items(): if count > 0: - print(f"{range_}: {count} ({count/chat_total*100:.2f}%)") - + print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)") + # Top expressions print("\nTop 10使用最多的表达式:") top_exprs = get_top_expressions_by_chat(chat_id, 10) @@ -163,32 +151,32 @@ def interactive_menu() -> None: if not expressions: print("数据库中没有找到表达式") return - + total = len(expressions) - + # Get unique chat_ids and their names chat_ids = list(set(expr.chat_id for expr in expressions)) chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids] chat_info.sort(key=lambda x: x[1]) # Sort by chat name - + while True: - print("\n" + "="*50) + print("\n" + "=" * 50) print("表达式统计分析") - print("="*50) + print("=" * 50) print("0. 显示总体统计") - + for i, (chat_id, chat_name) in enumerate(chat_info, 1): chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id) print(f"{i}. {chat_name} ({chat_count}个表达式)") - + print("q. 退出") - + choice = input("\n请选择要查看的统计 (输入序号): ").strip() - - if choice.lower() == 'q': + + if choice.lower() == "q": print("再见!") break - + try: choice_num = int(choice) if choice_num == 0: @@ -200,9 +188,9 @@ def interactive_menu() -> None: print("无效的选择,请重新输入") except ValueError: print("请输入有效的数字") - + input("\n按回车键继续...") if __name__ == "__main__": - interactive_menu() \ No newline at end of file + interactive_menu() diff --git a/scripts/import_openie.py b/scripts/import_openie.py index c4367892..f9405f59 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") + def ensure_openie_dir(): """确保OpenIE数据目录存在""" if not os.path.exists(OPENIE_DIR): @@ -253,7 +254,7 @@ def main(): # 没有运行的事件循环,创建新的 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + try: # 在新的事件循环中运行异步主函数 loop.run_until_complete(main_async()) diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 47ad55a8..391c3470 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_logger + # from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.ie_process import info_extract_from_str from src.chat.knowledge.open_ie import OpenIE @@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") + def ensure_dirs(): """确保临时目录和输出目录存在""" if not os.path.exists(TEMP_DIR): @@ -48,6 +50,7 @@ def ensure_dirs(): os.makedirs(RAW_DATA_PATH) logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") + # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() open_ie_doc_lock = Lock() @@ -56,13 +59,11 @@ open_ie_doc_lock = Lock() shutdown_event = Event() lpmm_entity_extract_llm = LLMRequest( - model_set=model_config.model_task_config.lpmm_entity_extract, - request_type="lpmm.entity_extract" -) -lpmm_rdf_build_llm = LLMRequest( - model_set=model_config.model_task_config.lpmm_rdf_build, - request_type="lpmm.rdf_build" + model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" ) +lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") + + def process_single_text(pg_hash, raw_data): """处理单个文本的函数,用于线程池""" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py index fba1f160..bce37b4a 100644 --- a/scripts/interest_value_analysis.py +++ b/scripts/interest_value_analysis.py @@ -3,12 +3,11 @@ import sys import os from typing import Dict, List, Tuple, Optional from datetime import datetime + # Add project root to Python path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) -from src.common.database.database_model import Messages, ChatStreams #noqa - - +from src.common.database.database_model import Messages, ChatStreams # noqa def get_chat_name(chat_id: str) -> str: @@ -17,7 +16,7 @@ def get_chat_name(chat_id: str) -> str: chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) if chat_stream is None: return f"未知聊天 ({chat_id})" - + if chat_stream.group_name: return f"{chat_stream.group_name} ({chat_id})" elif chat_stream.user_nickname: @@ -39,66 +38,62 @@ def format_timestamp(timestamp: float) -> str: def calculate_interest_value_distribution(messages) -> Dict[str, int]: """Calculate distribution of interest_value""" distribution = { - '0.000-0.010': 0, - '0.010-0.050': 0, - '0.050-0.100': 0, - '0.100-0.500': 0, - '0.500-1.000': 0, - '1.000-2.000': 0, - '2.000-5.000': 0, - '5.000-10.000': 0, - '10.000+': 0 + "0.000-0.010": 0, + "0.010-0.050": 0, + "0.050-0.100": 0, + "0.100-0.500": 0, + "0.500-1.000": 0, + "1.000-2.000": 0, + "2.000-5.000": 0, + "5.000-10.000": 0, + "10.000+": 0, } - + for msg in messages: if msg.interest_value is None or msg.interest_value == 0.0: continue - + value = float(msg.interest_value) if value < 0.010: - distribution['0.000-0.010'] += 1 + distribution["0.000-0.010"] += 1 elif value < 0.050: - distribution['0.010-0.050'] += 1 + distribution["0.010-0.050"] += 1 elif value < 0.100: - distribution['0.050-0.100'] += 1 + distribution["0.050-0.100"] += 1 elif value < 0.500: - distribution['0.100-0.500'] += 1 + distribution["0.100-0.500"] += 1 elif value < 1.000: - distribution['0.500-1.000'] += 1 + distribution["0.500-1.000"] += 1 elif value < 2.000: - distribution['1.000-2.000'] += 1 + distribution["1.000-2.000"] += 1 elif value < 5.000: - distribution['2.000-5.000'] += 1 + distribution["2.000-5.000"] += 1 elif value < 10.000: - distribution['5.000-10.000'] += 1 + distribution["5.000-10.000"] += 1 else: - distribution['10.000+'] += 1 - + distribution["10.000+"] += 1 + return distribution def get_interest_value_stats(messages) -> Dict[str, float]: """Calculate basic statistics for interest_value""" - values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0] - + values = [ + float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0 + ] + if not values: - return { - 'count': 0, - 'min': 0, - 'max': 0, - 'avg': 0, - 'median': 0 - } - + return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0} + values.sort() count = len(values) - + return { - 'count': count, - 'min': min(values), - 'max': max(values), - 'avg': sum(values) / count, - 'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2 + "count": count, + "min": min(values), + "max": max(values), + "avg": sum(values) / count, + "median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2, } @@ -109,20 +104,24 @@ def get_available_chats() -> List[Tuple[str, str, int]]: chat_counts = {} for msg in Messages.select(Messages.chat_id).distinct(): chat_id = msg.chat_id - count = Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.interest_value.is_null(False)) & - (Messages.interest_value != 0.0) - ).count() + count = ( + Messages.select() + .where( + (Messages.chat_id == chat_id) + & (Messages.interest_value.is_null(False)) + & (Messages.interest_value != 0.0) + ) + .count() + ) if count > 0: chat_counts[chat_id] = count - + # 获取聊天名称 result = [] for chat_id, count in chat_counts.items(): chat_name = get_chat_name(chat_id) result.append((chat_id, chat_name, count)) - + # 按消息数量排序 result.sort(key=lambda x: x[2], reverse=True) return result @@ -135,30 +134,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") - print("2. 最近3天") + print("2. 最近3天") print("3. 最近7天") print("4. 最近30天") print("5. 自定义时间范围") print("6. 不限制时间") - + choice = input("请选择时间范围 (1-6): ").strip() - + now = time.time() - + if choice == "1": - return now - 24*3600, now + return now - 24 * 3600, now elif choice == "2": - return now - 3*24*3600, now + return now - 3 * 24 * 3600, now elif choice == "3": - return now - 7*24*3600, now + return now - 7 * 24 * 3600, now elif choice == "4": - return now - 30*24*3600, now + return now - 30 * 24 * 3600, now elif choice == "5": print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") start_str = input().strip() print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") end_str = input().strip() - + try: start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() @@ -170,41 +169,40 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: return None, None -def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: +def analyze_interest_values( + chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None +) -> None: """Analyze interest values with optional filters""" - + # 构建查询条件 - query = Messages.select().where( - (Messages.interest_value.is_null(False)) & - (Messages.interest_value != 0.0) - ) - + query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0)) + if chat_id: query = query.where(Messages.chat_id == chat_id) - + if start_time: query = query.where(Messages.time >= start_time) - + if end_time: query = query.where(Messages.time <= end_time) - + messages = list(query) - + if not messages: print("没有找到符合条件的消息") return - + # 计算统计信息 distribution = calculate_interest_value_distribution(messages) stats = get_interest_value_stats(messages) - + # 显示结果 print("\n=== Interest Value 分析结果 ===") if chat_id: print(f"聊天: {get_chat_name(chat_id)}") else: print("聊天: 全部聊天") - + if start_time and end_time: print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") elif start_time: @@ -213,16 +211,16 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[ print(f"时间范围: {format_timestamp(end_time)} 之前") else: print("时间范围: 不限制") - + print("\n基本统计:") print(f"有效消息数量: {stats['count']} (排除null和0值)") print(f"最小值: {stats['min']:.3f}") print(f"最大值: {stats['max']:.3f}") print(f"平均值: {stats['avg']:.3f}") print(f"中位数: {stats['median']:.3f}") - + print("\nInterest Value 分布:") - total = stats['count'] + total = stats["count"] for range_name, count in distribution.items(): if count > 0: percentage = count / total * 100 @@ -231,34 +229,34 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[ def interactive_menu() -> None: """Interactive menu for interest value analysis""" - + while True: - print("\n" + "="*50) + print("\n" + "=" * 50) print("Interest Value 分析工具") - print("="*50) + print("=" * 50) print("1. 分析全部聊天") print("2. 选择特定聊天分析") print("q. 退出") - + choice = input("\n请选择分析模式 (1-2, q): ").strip() - - if choice.lower() == 'q': + + if choice.lower() == "q": print("再见!") break - + chat_id = None - + if choice == "2": # 显示可用的聊天列表 chats = get_available_chats() if not chats: print("没有找到有interest_value数据的聊天") continue - + print(f"\n可用的聊天 (共{len(chats)}个):") for i, (_cid, name, count) in enumerate(chats, 1): print(f"{i}. {name} ({count}条有效消息)") - + try: chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) if 1 <= chat_choice <= len(chats): @@ -269,19 +267,19 @@ def interactive_menu() -> None: except ValueError: print("请输入有效数字") continue - + elif choice != "1": print("无效选择") continue - + # 获取时间范围 start_time, end_time = get_time_range_input() - + # 执行分析 analyze_interest_values(chat_id, start_time, end_time) - + input("\n按回车键继续...") if __name__ == "__main__": - interactive_menu() \ No newline at end of file + interactive_menu() diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py index b11db1ba..8dd14d35 100644 --- a/scripts/log_viewer_optimized.py +++ b/scripts/log_viewer_optimized.py @@ -828,7 +828,7 @@ class LogViewer: parts, tags = self.formatter.format_log_entry(log_entry) line_text = " ".join(parts) log_lines.append(line_text) - + with open(filename, "w", encoding="utf-8") as f: f.write("\n".join(log_lines)) messagebox.showinfo("导出成功", f"日志已导出到: {filename}") @@ -1188,15 +1188,16 @@ class LogViewer: line_count += 1 except json.JSONDecodeError: continue - + # 如果发现了新模块,在主线程中更新模块集合 if new_modules: + def update_modules(): self.modules.update(new_modules) self.update_module_list() - + self.root.after(0, update_modules) - + return new_entries def append_new_logs(self, new_entries): @@ -1424,4 +1425,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 42a99133..b5762198 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -2,6 +2,7 @@ import os from pathlib import Path import sys # 新增系统模块导入 from src.chat.knowledge.utils.hash import get_sha256 + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_logger @@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") + def _process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: @@ -44,6 +46,7 @@ def _process_multi_files() -> list: all_paragraphs.extend(paragraphs) return all_paragraphs + def load_raw_data() -> tuple[list[str], list[str]]: """加载原始数据文件 @@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]: raw_data.append(item) logger.info(f"共读取到{len(raw_data)}条数据") - return sha256_list, raw_data \ No newline at end of file + return sha256_list, raw_data diff --git a/scripts/text_length_analysis.py b/scripts/text_length_analysis.py index 2ca596e2..5a329b93 100644 --- a/scripts/text_length_analysis.py +++ b/scripts/text_length_analysis.py @@ -4,21 +4,22 @@ import os import re from typing import Dict, List, Tuple, Optional from datetime import datetime + # Add project root to Python path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) -from src.common.database.database_model import Messages, ChatStreams #noqa +from src.common.database.database_model import Messages, ChatStreams # noqa def contains_emoji_or_image_tags(text: str) -> bool: """Check if text contains [表情包xxxxx] or [图片xxxxx] tags""" if not text: return False - + # 检查是否包含 [表情包] 或 [图片] 标记 - emoji_pattern = r'\[表情包[^\]]*\]' - image_pattern = r'\[图片[^\]]*\]' - + emoji_pattern = r"\[表情包[^\]]*\]" + image_pattern = r"\[图片[^\]]*\]" + return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text)) @@ -26,14 +27,14 @@ def clean_reply_text(text: str) -> str: """Remove reply references like [回复 xxxx...] from text""" if not text: return text - + # 匹配 [回复 xxxx...] 格式的内容 # 使用非贪婪匹配,匹配到第一个 ] 就停止 - cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text) - + cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text) + # 去除多余的空白字符 cleaned_text = cleaned_text.strip() - + return cleaned_text @@ -43,7 +44,7 @@ def get_chat_name(chat_id: str) -> str: chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) if chat_stream is None: return f"未知聊天 ({chat_id})" - + if chat_stream.group_name: return f"{chat_stream.group_name} ({chat_id})" elif chat_stream.user_nickname: @@ -65,63 +66,63 @@ def format_timestamp(timestamp: float) -> str: def calculate_text_length_distribution(messages) -> Dict[str, int]: """Calculate distribution of processed_plain_text length""" distribution = { - '0': 0, # 空文本 - '1-5': 0, # 极短文本 - '6-10': 0, # 很短文本 - '11-20': 0, # 短文本 - '21-30': 0, # 较短文本 - '31-50': 0, # 中短文本 - '51-70': 0, # 中等文本 - '71-100': 0, # 较长文本 - '101-150': 0, # 长文本 - '151-200': 0, # 很长文本 - '201-300': 0, # 超长文本 - '301-500': 0, # 极长文本 - '501-1000': 0, # 巨长文本 - '1000+': 0 # 超巨长文本 + "0": 0, # 空文本 + "1-5": 0, # 极短文本 + "6-10": 0, # 很短文本 + "11-20": 0, # 短文本 + "21-30": 0, # 较短文本 + "31-50": 0, # 中短文本 + "51-70": 0, # 中等文本 + "71-100": 0, # 较长文本 + "101-150": 0, # 长文本 + "151-200": 0, # 很长文本 + "201-300": 0, # 超长文本 + "301-500": 0, # 极长文本 + "501-1000": 0, # 巨长文本 + "1000+": 0, # 超巨长文本 } - + for msg in messages: if msg.processed_plain_text is None: continue - + # 排除包含表情包或图片标记的消息 if contains_emoji_or_image_tags(msg.processed_plain_text): continue - + # 清理文本中的回复引用 cleaned_text = clean_reply_text(msg.processed_plain_text) length = len(cleaned_text) - + if length == 0: - distribution['0'] += 1 + distribution["0"] += 1 elif length <= 5: - distribution['1-5'] += 1 + distribution["1-5"] += 1 elif length <= 10: - distribution['6-10'] += 1 + distribution["6-10"] += 1 elif length <= 20: - distribution['11-20'] += 1 + distribution["11-20"] += 1 elif length <= 30: - distribution['21-30'] += 1 + distribution["21-30"] += 1 elif length <= 50: - distribution['31-50'] += 1 + distribution["31-50"] += 1 elif length <= 70: - distribution['51-70'] += 1 + distribution["51-70"] += 1 elif length <= 100: - distribution['71-100'] += 1 + distribution["71-100"] += 1 elif length <= 150: - distribution['101-150'] += 1 + distribution["101-150"] += 1 elif length <= 200: - distribution['151-200'] += 1 + distribution["151-200"] += 1 elif length <= 300: - distribution['201-300'] += 1 + distribution["201-300"] += 1 elif length <= 500: - distribution['301-500'] += 1 + distribution["301-500"] += 1 elif length <= 1000: - distribution['501-1000'] += 1 + distribution["501-1000"] += 1 else: - distribution['1000+'] += 1 - + distribution["1000+"] += 1 + return distribution @@ -130,7 +131,7 @@ def get_text_length_stats(messages) -> Dict[str, float]: lengths = [] null_count = 0 excluded_count = 0 # 被排除的消息数量 - + for msg in messages: if msg.processed_plain_text is None: null_count += 1 @@ -141,29 +142,29 @@ def get_text_length_stats(messages) -> Dict[str, float]: # 清理文本中的回复引用 cleaned_text = clean_reply_text(msg.processed_plain_text) lengths.append(len(cleaned_text)) - + if not lengths: return { - 'count': 0, - 'null_count': null_count, - 'excluded_count': excluded_count, - 'min': 0, - 'max': 0, - 'avg': 0, - 'median': 0 + "count": 0, + "null_count": null_count, + "excluded_count": excluded_count, + "min": 0, + "max": 0, + "avg": 0, + "median": 0, } - + lengths.sort() count = len(lengths) - + return { - 'count': count, - 'null_count': null_count, - 'excluded_count': excluded_count, - 'min': min(lengths), - 'max': max(lengths), - 'avg': sum(lengths) / count, - 'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2 + "count": count, + "null_count": null_count, + "excluded_count": excluded_count, + "min": min(lengths), + "max": max(lengths), + "avg": sum(lengths) / count, + "median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2, } @@ -174,21 +175,25 @@ def get_available_chats() -> List[Tuple[str, str, int]]: chat_counts = {} for msg in Messages.select(Messages.chat_id).distinct(): chat_id = msg.chat_id - count = Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.is_emoji != 1) & - (Messages.is_picid != 1) & - (Messages.is_command != 1) - ).count() + count = ( + Messages.select() + .where( + (Messages.chat_id == chat_id) + & (Messages.is_emoji != 1) + & (Messages.is_picid != 1) + & (Messages.is_command != 1) + ) + .count() + ) if count > 0: chat_counts[chat_id] = count - + # 获取聊天名称 result = [] for chat_id, count in chat_counts.items(): chat_name = get_chat_name(chat_id) result.append((chat_id, chat_name, count)) - + # 按消息数量排序 result.sort(key=lambda x: x[2], reverse=True) return result @@ -201,30 +206,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: """Get time range input from user""" print("\n时间范围选择:") print("1. 最近1天") - print("2. 最近3天") + print("2. 最近3天") print("3. 最近7天") print("4. 最近30天") print("5. 自定义时间范围") print("6. 不限制时间") - + choice = input("请选择时间范围 (1-6): ").strip() - + now = time.time() - + if choice == "1": - return now - 24*3600, now + return now - 24 * 3600, now elif choice == "2": - return now - 3*24*3600, now + return now - 3 * 24 * 3600, now elif choice == "3": - return now - 7*24*3600, now + return now - 7 * 24 * 3600, now elif choice == "4": - return now - 30*24*3600, now + return now - 30 * 24 * 3600, now elif choice == "5": print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") start_str = input().strip() print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") end_str = input().strip() - + try: start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() @@ -239,13 +244,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: """Get top N longest messages""" message_lengths = [] - + for msg in messages: if msg.processed_plain_text is not None: # 排除包含表情包或图片标记的消息 if contains_emoji_or_image_tags(msg.processed_plain_text): continue - + # 清理文本中的回复引用 cleaned_text = clean_reply_text(msg.processed_plain_text) length = len(cleaned_text) @@ -254,42 +259,40 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, # 截取前100个字符作为预览 preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text message_lengths.append((chat_name, length, time_str, preview)) - + # 按长度排序,取前N个 message_lengths.sort(key=lambda x: x[1], reverse=True) return message_lengths[:top_n] -def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: +def analyze_text_lengths( + chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None +) -> None: """Analyze processed_plain_text lengths with optional filters""" - + # 构建查询条件,排除特殊类型的消息 - query = Messages.select().where( - (Messages.is_emoji != 1) & - (Messages.is_picid != 1) & - (Messages.is_command != 1) - ) - + query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1)) + if chat_id: query = query.where(Messages.chat_id == chat_id) - + if start_time: query = query.where(Messages.time >= start_time) - + if end_time: query = query.where(Messages.time <= end_time) - + messages = list(query) - + if not messages: print("没有找到符合条件的消息") return - + # 计算统计信息 distribution = calculate_text_length_distribution(messages) stats = get_text_length_stats(messages) top_longest = get_top_longest_messages(messages, 10) - + # 显示结果 print("\n=== Processed Plain Text 长度分析结果 ===") print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)") @@ -297,7 +300,7 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo print(f"聊天: {get_chat_name(chat_id)}") else: print("聊天: 全部聊天") - + if start_time and end_time: print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") elif start_time: @@ -306,26 +309,26 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo print(f"时间范围: {format_timestamp(end_time)} 之前") else: print("时间范围: 不限制") - + print("\n基本统计:") print(f"总消息数量: {len(messages)}") print(f"有文本消息数量: {stats['count']}") print(f"空文本消息数量: {stats['null_count']}") print(f"被排除的消息数量: {stats['excluded_count']}") - if stats['count'] > 0: + if stats["count"] > 0: print(f"最短长度: {stats['min']} 字符") print(f"最长长度: {stats['max']} 字符") print(f"平均长度: {stats['avg']:.2f} 字符") print(f"中位数长度: {stats['median']:.2f} 字符") - + print("\n文本长度分布:") - total = stats['count'] + total = stats["count"] if total > 0: for range_name, count in distribution.items(): if count > 0: percentage = count / total * 100 print(f"{range_name} 字符: {count} ({percentage:.2f}%)") - + # 显示最长的消息 if top_longest: print(f"\n最长的 {len(top_longest)} 条消息:") @@ -338,34 +341,34 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo def interactive_menu() -> None: """Interactive menu for text length analysis""" - + while True: - print("\n" + "="*50) + print("\n" + "=" * 50) print("Processed Plain Text 长度分析工具") - print("="*50) + print("=" * 50) print("1. 分析全部聊天") print("2. 选择特定聊天分析") print("q. 退出") - + choice = input("\n请选择分析模式 (1-2, q): ").strip() - - if choice.lower() == 'q': + + if choice.lower() == "q": print("再见!") break - + chat_id = None - + if choice == "2": # 显示可用的聊天列表 chats = get_available_chats() if not chats: print("没有找到聊天数据") continue - + print(f"\n可用的聊天 (共{len(chats)}个):") for i, (_cid, name, count) in enumerate(chats, 1): print(f"{i}. {name} ({count}条消息)") - + try: chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) if 1 <= chat_choice <= len(chats): @@ -376,19 +379,19 @@ def interactive_menu() -> None: except ValueError: print("请输入有效数字") continue - + elif choice != "1": print("无效选择") continue - + # 获取时间范围 start_time, end_time = get_time_range_input() - + # 执行分析 analyze_text_lengths(chat_id, start_time, end_time) - + input("\n按回车键继续...") if __name__ == "__main__": - interactive_menu() \ No newline at end of file + interactive_menu() diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 47a50865..b143f0f7 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -708,7 +708,7 @@ class EmojiManager: if not emoji.is_deleted and emoji.hash == emoji_hash: return emoji return None # 如果循环结束还没找到,则返回 None - + async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]: """根据哈希值获取已注册表情包的情感标签列表 @@ -731,7 +731,7 @@ class EmojiManager: emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) if emoji_record and emoji_record.emotion: logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") - return emoji_record.emotion.split(',') + return emoji_record.emotion.split(",") except Exception as e: logger.error(f"从数据库查询表情包情感标签时出错: {e}") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 8716d6bc..5ab5115b 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -77,10 +77,10 @@ class ExpressionSelector: def can_use_expression_for_chat(self, chat_id: str) -> bool: """ 检查指定聊天流是否允许使用表达 - + Args: chat_id: 聊天流ID - + Returns: bool: 是否允许使用表达 """ @@ -123,9 +123,7 @@ class ExpressionSelector: return group_chat_ids return [chat_id] - def get_random_expressions( - self, chat_id: str, total_num: int - ) -> List[Dict[str, Any]]: + def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) @@ -200,7 +198,7 @@ class ExpressionSelector: ) -> Tuple[List[Dict[str, Any]], List[int]]: # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" - + # 检查是否允许在此聊天流中使用表达 if not self.can_use_expression_for_chat(chat_id): logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") @@ -208,7 +206,7 @@ class ExpressionSelector: # 1. 获取20个随机表达方式(现在按权重抽取) style_exprs = self.get_random_expressions(chat_id, 10) - + if len(style_exprs) < 10: logger.info(f"聊天流 {chat_id} 表达方式正在积累中") return [], [] @@ -248,7 +246,6 @@ class ExpressionSelector: # 4. 调用LLM try: - # start_time = time.time() content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) # logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}") @@ -295,7 +292,6 @@ class ExpressionSelector: except Exception as e: logger.error(f"LLM处理表达方式选择时出错: {e}") return [], [] - init_prompt() diff --git a/src/chat/frequency_control/focus_value_control.py b/src/chat/frequency_control/focus_value_control.py index be820760..e6fdf1f3 100644 --- a/src/chat/frequency_control/focus_value_control.py +++ b/src/chat/frequency_control/focus_value_control.py @@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]: return get_time_based_focus_value(config_item[1:]) return None - diff --git a/src/chat/frequency_control/talk_frequency_control.py b/src/chat/frequency_control/talk_frequency_control.py index 11728e26..ccf7ed66 100644 --- a/src/chat/frequency_control/talk_frequency_control.py +++ b/src/chat/frequency_control/talk_frequency_control.py @@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]: return get_time_based_frequency(config_item[1:]) return None - - diff --git a/src/chat/frequency_control/utils.py b/src/chat/frequency_control/utils.py index 4cbd7979..037e02a5 100644 --- a/src/chat/frequency_control/utils.py +++ b/src/chat/frequency_control/utils.py @@ -34,4 +34,4 @@ def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: return hashlib.md5(key.encode()).hexdigest() except (ValueError, IndexError): - return None \ No newline at end of file + return None diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index f3d9becb..86add908 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -155,21 +155,21 @@ class HeartFChatting: timer_strings.append(f"{name}: {formatted_time}") # 获取动作类型,兼容新旧格式 - action_type = "未知动作" + _action_type = "未知动作" if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail: loop_plan_info = self._current_cycle_detail.loop_plan_info if isinstance(loop_plan_info, dict): action_result = loop_plan_info.get("action_result", {}) if isinstance(action_result, dict): # 旧格式:action_result是字典 - action_type = action_result.get("action_type", "未知动作") + _action_type = action_result.get("action_type", "未知动作") elif isinstance(action_result, list) and action_result: # 新格式:action_result是actions列表 # TODO: 把这里写明白 - action_type = action_result[0].action_type or "未知动作" + _action_type = action_result[0].action_type or "未知动作" elif isinstance(loop_plan_info, list) and loop_plan_info: # 直接是actions列表的情况 - action_type = loop_plan_info[0].get("action_type", "未知动作") + _action_type = loop_plan_info[0].get("action_type", "未知动作") logger.info( f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," @@ -258,7 +258,11 @@ class HeartFChatting: return loop_info, reply_text, cycle_timers - async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool: + async def _observe( + self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None + ) -> bool: + if recent_messages_list is None: + recent_messages_list = [] reply_text = "" # 初始化reply_text变量,避免UnboundLocalError # 使用sigmoid函数将interest_value转换为概率 diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index 9f5c0423..7354b9ac 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -3,14 +3,16 @@ from typing import Any, Optional, Dict from src.common.logger import get_logger from src.chat.heart_flow.heartFC_chat import HeartFChatting + logger = get_logger("heartflow") + class Heartflow: """主心流协调器,负责初始化并协调聊天""" def __init__(self): self.heartflow_chat_list: Dict[Any, HeartFChatting] = {} - + async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]: """获取或创建一个新的HeartFChatting实例""" try: @@ -18,7 +20,7 @@ class Heartflow: if chat := self.heartflow_chat_list.get(chat_id): return chat else: - new_chat = HeartFChatting(chat_id = chat_id) + new_chat = HeartFChatting(chat_id=chat_id) await new_chat.start() self.heartflow_chat_list[chat_id] = new_chat return new_chat @@ -27,4 +29,5 @@ class Heartflow: traceback.print_exc() return None + heartflow = Heartflow() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index ac424c66..326d7516 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: logger = get_logger("chat") + async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: """计算消息的兴趣度 @@ -34,14 +35,14 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: """ if message.is_picid or message.is_emoji: return 0.0, [] - - is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message) + + is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) interested_rate = 0.0 with Timer("记忆激活"): - interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text( + interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text( message.processed_plain_text, - max_depth= 4, + max_depth=4, fast_retrieval=global_config.chat.interest_rate_mode == "fast", ) message.key_words = keywords @@ -51,7 +52,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: text_len = len(message.processed_plain_text) # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) - + if text_len == 0: base_interest = 0.01 # 空消息最低兴趣度 elif text_len <= 5: @@ -75,16 +76,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: else: # 100+字符:对数增长 0.26 -> 0.3,增长率递减 base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 - + # 确保在范围内 base_interest = min(max(base_interest, 0.01), 0.3) - message.interest_value = base_interest message.is_mentioned = is_mentioned message.is_at = is_at message.reply_probability_boost = reply_probability_boost - + return base_interest, keywords @@ -115,14 +115,13 @@ class HeartFCMessageReceiver: # 2. 兴趣度计算与更新 interested_rate, keywords = await _calculate_interest(message) - await self.storage.store_message(message, chat) heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) - if global_config.mood.enable_mood: + if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id) asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) @@ -132,7 +131,7 @@ class HeartFCMessageReceiver: # 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述 picid_pattern = r"\[picid:([^\]]+)\]" picid_list = re.findall(picid_pattern, message.processed_plain_text) - + # 创建替换后的文本 processed_text = message.processed_plain_text if picid_list: @@ -145,18 +144,20 @@ class HeartFCMessageReceiver: # 如果没有找到图片描述,则移除[picid:xxxx]标记 processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]") - # 应用用户引用格式替换,将回复和@格式转换为可读格式 processed_plain_text = replace_user_references( processed_text, - message.message_info.platform, # type: ignore - replace_bot_name=True + message.message_info.platform, # type: ignore + replace_bot_name=True, ) - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore - _ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore + _ = Person.register_person( + platform=message.message_info.platform, + user_id=message.message_info.user_info.user_id, + nickname=userinfo.user_nickname, + ) # type: ignore except Exception as e: logger.error(f"消息处理失败: {e}") diff --git a/src/chat/heart_flow/hfc_utils.py b/src/chat/heart_flow/hfc_utils.py index 973c4f94..9a715a2d 100644 --- a/src/chat/heart_flow/hfc_utils.py +++ b/src/chat/heart_flow/hfc_utils.py @@ -124,6 +124,7 @@ async def send_typing(): message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False ) + async def stop_typing(): group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") @@ -135,4 +136,4 @@ async def stop_typing(): await send_api.custom_to_stream( message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False - ) \ No newline at end of file + ) diff --git a/src/chat/knowledge/__init__.py b/src/chat/knowledge/__init__.py index 38f88e10..324320f2 100644 --- a/src/chat/knowledge/__init__.py +++ b/src/chat/knowledge/__init__.py @@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data") qa_manager = None inspire_manager = None + def lpmm_start_up(): # sourcery skip: extract-duplicate-method # 检查LPMM知识库是否启用 if global_config.lpmm_knowledge.enable: diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 7e3695fe..768373cf 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -32,11 +32,11 @@ install(extra_lines=3) # 多线程embedding配置常量 DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 -DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 -MIN_CHUNK_SIZE = 1 # 最小分块大小 -MAX_CHUNK_SIZE = 50 # 最大分块大小 -MIN_WORKERS = 1 # 最小线程数 -MAX_WORKERS = 20 # 最大线程数 +DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 +MIN_CHUNK_SIZE = 1 # 最小分块大小 +MAX_CHUNK_SIZE = 50 # 最大分块大小 +MIN_WORKERS = 1 # 最小线程数 +MAX_WORKERS = 20 # 最大线程数 ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") @@ -93,7 +93,13 @@ class EmbeddingStoreItem: class EmbeddingStore: - def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): + def __init__( + self, + namespace: str, + dir_path: str, + max_workers: int = DEFAULT_MAX_WORKERS, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ): self.namespace = namespace self.dir = dir_path self.embedding_file_path = f"{dir_path}/{namespace}.parquet" @@ -103,12 +109,16 @@ class EmbeddingStore: # 多线程配置参数验证和设置 self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers)) self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size)) - + # 如果配置值被调整,记录日志 if self.max_workers != max_workers: - logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})") + logger.warning( + f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})" + ) if self.chunk_size != chunk_size: - logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})") + logger.warning( + f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})" + ) self.store = {} @@ -120,23 +130,23 @@ class EmbeddingStore: # 创建新的事件循环并在完成后立即关闭 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + try: # 创建新的LLMRequest实例 from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - + # 使用新的事件循环运行异步方法 embedding, _ = loop.run_until_complete(llm.get_embedding(s)) - + if embedding and len(embedding) > 0: return embedding else: logger.error(f"获取嵌入失败: {s}") return [] - + except Exception as e: logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") return [] @@ -147,43 +157,45 @@ class EmbeddingStore: except Exception: pass - def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: + def _get_embeddings_batch_threaded( + self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + ) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 - + Args: strs: 要获取嵌入的字符串列表 chunk_size: 每个线程处理的数据块大小 max_workers: 最大线程数 progress_callback: 进度回调函数,接收一个参数表示完成的数量 - + Returns: 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 """ if not strs: return [] - + # 分块 chunks = [] for i in range(0, len(strs), chunk_size): - chunk = strs[i:i + chunk_size] + chunk = strs[i : i + chunk_size] chunks.append((i, chunk)) # 保存起始索引以维持顺序 - + # 结果存储,使用字典按索引存储以保证顺序 results = {} - + def process_chunk(chunk_data): """处理单个数据块的函数""" start_idx, chunk_strs = chunk_data chunk_results = [] - + # 为每个线程创建独立的LLMRequest实例 from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + try: # 创建线程专用的LLM实例 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - + for i, s in enumerate(chunk_strs): try: # 在线程中创建独立的事件循环 @@ -193,25 +205,25 @@ class EmbeddingStore: embedding = loop.run_until_complete(llm.get_embedding(s)) finally: loop.close() - + if embedding and len(embedding) > 0: chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 else: logger.error(f"获取嵌入失败: {s}") chunk_results.append((start_idx + i, s, [])) - + # 每完成一个嵌入立即更新进度 if progress_callback: progress_callback(1) - + except Exception as e: logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") chunk_results.append((start_idx + i, s, [])) - + # 即使失败也要更新进度 if progress_callback: progress_callback(1) - + except Exception as e: logger.error(f"创建LLM实例失败: {e}") # 如果创建LLM实例失败,返回空结果 @@ -220,14 +232,14 @@ class EmbeddingStore: # 即使失败也要更新进度 if progress_callback: progress_callback(1) - + return chunk_results - + # 使用线程池处理 with ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交所有任务 future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} - + # 收集结果(进度已在process_chunk中实时更新) for future in as_completed(future_to_chunk): try: @@ -241,7 +253,7 @@ class EmbeddingStore: start_idx, chunk_strs = chunk for i, s in enumerate(chunk_strs): results[start_idx + i] = (s, []) - + # 按原始顺序返回结果 ordered_results = [] for i in range(len(strs)): @@ -250,7 +262,7 @@ class EmbeddingStore: else: # 防止遗漏 ordered_results.append((strs[i], [])) - + return ordered_results def get_test_file_path(self): @@ -259,14 +271,14 @@ class EmbeddingStore: def save_embedding_test_vectors(self): """保存测试字符串的嵌入到本地(使用多线程优化)""" logger.info("开始保存测试字符串的嵌入向量...") - + # 使用多线程批量获取测试字符串的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), - max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) - + # 构建测试向量字典 test_vectors = {} for idx, (s, embedding) in enumerate(embedding_results): @@ -276,10 +288,10 @@ class EmbeddingStore: logger.error(f"获取测试字符串嵌入失败: {s}") # 使用原始单线程方法作为后备 test_vectors[str(idx)] = self._get_embedding(s) - + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: json.dump(test_vectors, f, ensure_ascii=False, indent=2) - + logger.info("测试字符串嵌入向量保存完成") def load_embedding_test_vectors(self): @@ -297,35 +309,35 @@ class EmbeddingStore: logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") self.save_embedding_test_vectors() return True - + # 检查本地向量完整性 for idx in range(len(EMBEDDING_TEST_STRINGS)): if local_vectors.get(str(idx)) is None: logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") self.save_embedding_test_vectors() return True - + logger.info("开始检验嵌入模型一致性...") - + # 使用多线程批量获取当前模型的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), - max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) - + # 检查一致性 for idx, (s, new_emb) in enumerate(embedding_results): local_emb = local_vectors.get(str(idx)) if not new_emb: logger.error(f"获取测试字符串嵌入失败: {s}") return False - + sim = cosine_similarity(local_emb, new_emb) if sim < EMBEDDING_SIM_THRESHOLD: logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}") return False - + logger.info("嵌入模型一致性校验通过。") return True @@ -333,22 +345,22 @@ class EmbeddingStore: """向库中存入字符串(使用多线程优化)""" if not strs: return - + total = len(strs) - + # 过滤已存在的字符串 new_strs = [] for s in strs: item_hash = self.namespace + "-" + get_sha256(s) if item_hash not in self.store: new_strs.append(s) - + if not new_strs: logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理") return - + logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串") - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -362,31 +374,39 @@ class EmbeddingStore: transient=False, ) as progress: task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) - + # 首先更新已存在项的进度 already_processed = total - len(new_strs) if already_processed > 0: progress.update(task, advance=already_processed) - + if new_strs: # 使用实例配置的参数,智能调整分块和线程数 - optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size)) - optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1)) - + optimal_chunk_size = max( + MIN_CHUNK_SIZE, + min( + self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size + ), + ) + optimal_max_workers = min( + self.max_workers, + max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1), + ) + logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") - + # 定义进度更新回调函数 def update_progress(count): progress.update(task, advance=count) - + # 批量获取嵌入,并实时更新进度 embedding_results = self._get_embeddings_batch_threaded( - new_strs, - chunk_size=optimal_chunk_size, + new_strs, + chunk_size=optimal_chunk_size, max_workers=optimal_max_workers, - progress_callback=update_progress + progress_callback=update_progress, ) - + # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) for s, embedding in embedding_results: item_hash = self.namespace + "-" + get_sha256(s) @@ -519,7 +539,7 @@ class EmbeddingManager: def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): """ 初始化EmbeddingManager - + Args: max_workers: 最大线程数 chunk_size: 每个线程处理的数据块大小 diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index da082e39..ac86fa20 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -426,9 +426,7 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) - for node_key, score in ppr_res.items() - if node_key.startswith("paragraph") + (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph") ] del ppr_res diff --git a/src/chat/knowledge/mem_active_manager.py b/src/chat/knowledge/mem_active_manager.py index a55b929f..2f294139 100644 --- a/src/chat/knowledge/mem_active_manager.py +++ b/src/chat/knowledge/mem_active_manager.py @@ -1,8 +1,8 @@ raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it") -from .lpmmconfig import global_config -from .embedding_store import EmbeddingManager -from .llm_client import LLMClient -from .utils.dyn_topk import dyn_select_top_k +from .lpmmconfig import global_config # noqa +from .embedding_store import EmbeddingManager # noqa +from .llm_client import LLMClient # noqa +from .utils.dyn_topk import dyn_select_top_k # noqa class MemoryActiveManager: diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index 5304934f..df9e470d 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -8,7 +8,7 @@ def dyn_select_top_k( # 检查输入列表是否为空 if not score: return [] - + # 按照分数排序(降序) sorted_score = sorted(score, key=lambda x: x[1], reverse=True) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 3d84f270..37c9d188 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -18,7 +18,7 @@ class MessageStorage: if isinstance(keywords, list): return json.dumps(keywords, ensure_ascii=False) return "[]" - + @staticmethod def _deserialize_keywords(keywords_str: str) -> list: """将JSON字符串反序列化为关键词列表""" @@ -85,7 +85,7 @@ class MessageStorage: key_words = MessageStorage._serialize_keywords(message.key_words) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) selected_expressions = "" - + chat_info_dict = chat_stream.to_dict() user_info_dict = message.message_info.user_info.to_dict() # type: ignore diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 1de033bf..013d78e1 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -124,4 +124,4 @@ class ActionManager: """恢复到默认动作集""" actions_to_restore = list(self._using_actions.keys()) self._using_actions = component_registry.get_default_actions() - logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") \ No newline at end of file + logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 024d7011..def8322a 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -103,25 +103,23 @@ class ActionModifier: self.action_manager.remove_action_from_using(action_name) logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") - - # === 第三阶段:激活类型判定 === # if chat_content is not None: - # logger.debug(f"{self.log_prefix}开始激活类型判定阶段") + # logger.debug(f"{self.log_prefix}开始激活类型判定阶段") - # 获取当前使用的动作集(经过第一阶段处理) - # current_using_actions = self.action_manager.get_using_actions() + # 获取当前使用的动作集(经过第一阶段处理) + # current_using_actions = self.action_manager.get_using_actions() - # 获取因激活类型判定而需要移除的动作 - # removals_s3 = await self._get_deactivated_actions_by_type( - # current_using_actions, - # chat_content, - # ) + # 获取因激活类型判定而需要移除的动作 + # removals_s3 = await self._get_deactivated_actions_by_type( + # current_using_actions, + # chat_content, + # ) - # 应用第三阶段的移除 - # for action_name, reason in removals_s3: - # self.action_manager.remove_action_from_using(action_name) - # logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") + # 应用第三阶段的移除 + # for action_name, reason in removals_s3: + # self.action_manager.remove_action_from_using(action_name) + # logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") # === 统一日志记录 === all_removals = removals_s1 + removals_s2 @@ -131,9 +129,7 @@ class ActionModifier: available_actions = list(self.action_manager.get_using_actions().keys()) available_actions_text = "、".join(available_actions) if available_actions else "无" - logger.debug( - f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}" - ) + logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): type_mismatched_actions: List[Tuple[str, str]] = [] diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 1aaa9461..97ef1cc0 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask): time_cost_key = f"time_costs_by_{category.split('_')[-1]}" avg_key = f"avg_time_costs_by_{category.split('_')[-1]}" std_key = f"std_time_costs_by_{category.split('_')[-1]}" - + for item_name in stats[period_key][category]: time_costs = stats[period_key][time_cost_key].get(item_name, []) if time_costs: # 计算平均耗时 avg_time_cost = sum(time_costs) / len(time_costs) stats[period_key][avg_key][item_name] = round(avg_time_cost, 3) - + # 计算标准差 if len(time_costs) > 1: variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs) - std_time_cost = variance ** 0.5 + std_time_cost = variance**0.5 stats[period_key][std_key][item_name] = round(std_time_cost, 3) else: stats[period_key][std_key][item_name] = 0.0 @@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask): break return stats - - def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: """ 收集各时间段的统计数据 @@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask): cost = stats[COST_BY_MODEL][model_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name] std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name] - output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)) + output.append( + data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost) + ) output.append("") return "\n".join(output) @@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask): f"{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒" f"" for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) - ] if stat_data[REQ_CNT_BY_MODEL] else ["暂无数据"] + ] + if stat_data[REQ_CNT_BY_MODEL] + else ["暂无数据"] ) # 按请求类型分类统计 type_rows = "\n".join( @@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask): f"{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒" f"" for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) - ] if stat_data[REQ_CNT_BY_TYPE] else ["暂无数据"] + ] + if stat_data[REQ_CNT_BY_TYPE] + else ["暂无数据"] ) # 按模块分类统计 module_rows = "\n".join( @@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask): f"{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒" f"" for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items()) - ] if stat_data[REQ_CNT_BY_MODULE] else ["暂无数据"] + ] + if stat_data[REQ_CNT_BY_MODULE] + else ["暂无数据"] ) # 聊天消息统计 @@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask): [ f"{self.name_mapping[chat_id][0]}{count}" for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) - ] if stat_data[MSG_CNT_BY_CHAT] else ["暂无数据"] + ] + if stat_data[MSG_CNT_BY_CHAT] + else ["暂无数据"] ) # 生成HTML return f""" diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 549ee421..2fb24245 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float reply_probability = 0.0 is_at = False is_mentioned = False - + # 这部分怎么处理啊啊啊啊 - #我觉得可以给消息加一个 reply_probability_boost字段 + # 我觉得可以给消息加一个 reply_probability_boost字段 if ( message.message_info.additional_config is not None and message.message_info.additional_config.get("is_mentioned") is not None @@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]: return [keywords_str] if keywords_str else [] - - def cut_key_words(concept_name: str) -> list[str]: """对概念名称进行jieba分词,并过滤掉关键词列表中的关键词""" concept_name_tokens = list(jieba.cut(concept_name)) # 定义常见连词、停用词与标点 - conjunctions = { - "和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并" - } + conjunctions = {"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"} stop_words = { - "的", "了", "呢", "吗", "吧", "啊", "哦", "恩", "嗯", "呀", "嘛", "哇", - "在", "是", "很", "也", "又", "就", "都", "还", "更", "最", "被", "把", - "给", "对", "和", "与", "及", "跟", "并", "而且", "或者", "或", "以及" + "的", + "了", + "呢", + "吗", + "吧", + "啊", + "哦", + "恩", + "嗯", + "呀", + "嘛", + "哇", + "在", + "是", + "很", + "也", + "又", + "就", + "都", + "还", + "更", + "最", + "被", + "把", + "给", + "对", + "和", + "与", + "及", + "跟", + "并", + "而且", + "或者", + "或", + "以及", } chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\") @@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]: left = merged_tokens[-1] right = cleaned_tokens[i + 1] # 左右都需要是有效词 - if left and right \ - and left not in conjunctions and right not in conjunctions \ - and left not in stop_words and right not in stop_words \ - and not all(ch in chinese_punctuations for ch in left) \ - and not all(ch in chinese_punctuations for ch in right): + if ( + left + and right + and left not in conjunctions + and right not in conjunctions + and left not in stop_words + and right not in stop_words + and not all(ch in chinese_punctuations for ch in left) + and not all(ch in chinese_punctuations for ch in right) + ): # 合并为一个新词,并替换掉左侧与跳过右侧 combined = f"{left}{tok}{right}" merged_tokens[-1] = combined @@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]: if tok in stop_words: continue # if tok in ban_words: - # continue + # continue if all(ch in chinese_punctuations for ch in tok): continue if tok.strip() == "": @@ -899,4 +932,4 @@ def cut_key_words(concept_name: str) -> list[str]: result_tokens.append(tok) filtered_concept_name_tokens = result_tokens - return filtered_concept_name_tokens \ No newline at end of file + return filtered_concept_name_tokens diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 3c9c51e9..2b1b5366 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -91,9 +91,10 @@ class ImageManager: desc_obj.save() except Exception as e: logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") - + async def get_emoji_tag(self, image_base64: str) -> str: from src.chat.emoji_system.emoji_manager import get_emoji_manager + emoji_manager = get_emoji_manager() if isinstance(image_base64, str): image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") @@ -120,6 +121,7 @@ class ImageManager: # 优先使用EmojiManager查询已注册表情包的描述 try: from src.chat.emoji_system.emoji_manager import get_emoji_manager + emoji_manager = get_emoji_manager() tags = await emoji_manager.get_emoji_tag_by_hash(image_hash) if tags: diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index 222ff59c..d104eec9 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -6,6 +6,7 @@ class BaseDataModel: def deepcopy(self): return copy.deepcopy(self) + def temporarily_transform_class_to_dict(obj: Any) -> Any: # sourcery skip: assign-if-exp, reintroduce-else """ diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index bf4a5f52..18465b00 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } + @dataclass(init=False) class DatabaseActionRecords(BaseDataModel): def __init__( @@ -232,4 +233,4 @@ class DatabaseActionRecords(BaseDataModel): self.action_prompt_display = action_prompt_display self.chat_id = chat_id self.chat_info_stream_id = chat_info_stream_id - self.chat_info_platform = chat_info_platform \ No newline at end of file + self.chat_info_platform = chat_info_platform diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 1d5b75e0..d862e9b5 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -2,9 +2,11 @@ from dataclasses import dataclass from typing import Optional, List, Tuple, TYPE_CHECKING, Any from . import BaseDataModel + if TYPE_CHECKING: from src.llm_models.payload_content.tool_option import ToolCall + @dataclass class LLMGenerationDataModel(BaseDataModel): content: Optional[str] = None @@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel): tool_calls: Optional[List["ToolCall"]] = None prompt: Optional[str] = None selected_expressions: Optional[List[int]] = None - reply_set: Optional[List[Tuple[str, Any]]] = None \ No newline at end of file + reply_set: Optional[List[Tuple[str, Any]]] = None diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 14ce741d..72d5252e 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -135,7 +135,7 @@ class Messages(BaseModel): interest_value = DoubleField(null=True) key_words = TextField(null=True) key_words_lite = TextField(null=True) - + is_mentioned = BooleanField(null=True) is_at = BooleanField(null=True) reply_probability_boost = DoubleField(null=True) @@ -169,7 +169,7 @@ class Messages(BaseModel): is_picid = BooleanField(default=False) is_command = BooleanField(default=False) is_notify = BooleanField(default=False) - + selected_expressions = TextField(null=True) class Meta: @@ -267,13 +267,10 @@ class PersonInfo(BaseModel): know_times = FloatField(null=True) # 认识时间 (时间戳) know_since = FloatField(null=True) # 首次印象总结时间 last_know = FloatField(null=True) # 最后一次印象总结时间 - + attitude_to_me = TextField(null=True) # 对bot的态度 attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度 - - - class Meta: # database = db # 继承自 BaseModel table_name = "person_info" @@ -299,6 +296,7 @@ class GroupInfo(BaseModel): # database = db # 继承自 BaseModel table_name = "group_info" + class Expression(BaseModel): """ 用于存储表达风格的模型。 @@ -315,6 +313,7 @@ class Expression(BaseModel): class Meta: table_name = "expression" + class GraphNodes(BaseModel): """ 用于存储记忆图节点的模型 @@ -374,7 +373,7 @@ def initialize_database(sync_constraints=False): """ 检查所有定义的表是否存在,如果不存在则创建它们。 检查所有表的所有字段是否存在,如果缺失则自动添加。 - + Args: sync_constraints (bool): 是否同步字段约束。默认为 False。 如果为 True,会检查并修复字段的 NULL 约束不一致问题。 @@ -456,13 +455,13 @@ def initialize_database(sync_constraints=False): logger.info(f"字段 '{field_name}' 删除成功") except Exception as e: logger.error(f"删除字段 '{field_name}' 失败: {e}") - + # 如果启用了约束同步,执行约束检查和修复 if sync_constraints: logger.debug("开始同步数据库字段约束...") sync_field_constraints() logger.debug("数据库字段约束同步完成") - + except Exception as e: logger.exception(f"检查表或字段是否存在时出错: {e}") # 如果检查失败(例如数据库不可用),则退出 @@ -476,7 +475,7 @@ def sync_field_constraints(): 同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。 如果发现不一致,会自动修复字段约束。 """ - + models = [ ChatStreams, LLMUsage, @@ -501,50 +500,55 @@ def sync_field_constraints(): continue logger.debug(f"检查表 '{table_name}' 的字段约束...") - + # 获取当前表结构信息 cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") - current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} - for row in cursor.fetchall()} - + current_schema = { + row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall() + } + # 检查每个模型字段的约束 constraints_to_fix = [] for field_name, field_obj in model._meta.fields.items(): if field_name not in current_schema: continue # 字段不存在,跳过 - - current_notnull = current_schema[field_name]['notnull'] + + current_notnull = current_schema[field_name]["notnull"] model_allows_null = field_obj.null - + # 如果模型允许 null 但数据库字段不允许 null,需要修复 if model_allows_null and current_notnull: - constraints_to_fix.append({ - 'field_name': field_name, - 'field_obj': field_obj, - 'action': 'allow_null', - 'current_constraint': 'NOT NULL', - 'target_constraint': 'NULL' - }) + constraints_to_fix.append( + { + "field_name": field_name, + "field_obj": field_obj, + "action": "allow_null", + "current_constraint": "NOT NULL", + "target_constraint": "NULL", + } + ) logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL,但数据库为NOT NULL") - + # 如果模型不允许 null 但数据库字段允许 null,也需要修复(但要小心) elif not model_allows_null and not current_notnull: - constraints_to_fix.append({ - 'field_name': field_name, - 'field_obj': field_obj, - 'action': 'disallow_null', - 'current_constraint': 'NULL', - 'target_constraint': 'NOT NULL' - }) + constraints_to_fix.append( + { + "field_name": field_name, + "field_obj": field_obj, + "action": "disallow_null", + "current_constraint": "NULL", + "target_constraint": "NOT NULL", + } + ) logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL,但数据库允许NULL") - + # 修复约束不一致的字段 if constraints_to_fix: logger.info(f"表 '{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束") _fix_table_constraints(table_name, model, constraints_to_fix) else: logger.debug(f"表 '{table_name}' 的字段约束已同步") - + except Exception as e: logger.exception(f"同步字段约束时出错: {e}") @@ -557,40 +561,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix): try: # 备份表名 backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}" - + logger.info(f"开始修复表 '{table_name}' 的字段约束...") - + # 1. 创建备份表 db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}") logger.info(f"已创建备份表 '{backup_table}'") - + # 2. 删除原表 db.execute_sql(f"DROP TABLE {table_name}") logger.info(f"已删除原表 '{table_name}'") - + # 3. 重新创建表(使用当前模型定义) db.create_tables([model]) logger.info(f"已重新创建表 '{table_name}' 使用新的约束") - + # 4. 从备份表恢复数据 # 获取字段列表 fields = list(model._meta.fields.keys()) - fields_str = ', '.join(fields) - + fields_str = ", ".join(fields) + # 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据 # 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值 insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}" - + # 检查是否有字段需要从 NULL 改为 NOT NULL null_to_notnull_fields = [ - constraint['field_name'] for constraint in constraints_to_fix - if constraint['action'] == 'disallow_null' + constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null" ] - + if null_to_notnull_fields: # 需要处理 NULL 值,为这些字段设置默认值 logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL,需要处理现有的NULL值") - + # 构建更复杂的 SELECT 语句来处理 NULL 值 select_fields = [] for field_name in fields: @@ -607,21 +610,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix): default_value = f"'{datetime.datetime.now()}'" else: default_value = "''" - + select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}") else: select_fields.append(field_name) - - select_str = ', '.join(select_fields) + + select_str = ", ".join(select_fields) insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}" - + db.execute_sql(insert_sql) logger.info(f"已从备份表恢复数据到 '{table_name}'") - + # 5. 验证数据完整性 original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0] new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0] - + if original_count == new_count: logger.info(f"数据完整性验证通过: {original_count} 行数据") # 删除备份表 @@ -630,12 +633,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix): else: logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count} 行") logger.error(f"备份表 '{backup_table}' 已保留,请手动检查") - + # 记录修复的约束 for constraint in constraints_to_fix: - logger.info(f"已修复字段 '{constraint['field_name']}': " - f"{constraint['current_constraint']} -> {constraint['target_constraint']}") - + logger.info( + f"已修复字段 '{constraint['field_name']}': " + f"{constraint['current_constraint']} -> {constraint['target_constraint']}" + ) + except Exception as e: logger.exception(f"修复表 '{table_name}' 约束时出错: {e}") # 尝试恢复 @@ -654,7 +659,7 @@ def check_field_constraints(): 检查但不修复字段约束,返回不一致的字段信息。 用于在修复前预览需要修复的内容。 """ - + models = [ ChatStreams, LLMUsage, @@ -669,9 +674,9 @@ def check_field_constraints(): GraphEdges, ActionRecords, ] - + inconsistencies = {} - + try: with db: for model in models: @@ -681,49 +686,49 @@ def check_field_constraints(): # 获取当前表结构信息 cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") - current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} - for row in cursor.fetchall()} - + current_schema = { + row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall() + } + table_inconsistencies = [] - + # 检查每个模型字段的约束 for field_name, field_obj in model._meta.fields.items(): if field_name not in current_schema: continue - - current_notnull = current_schema[field_name]['notnull'] + + current_notnull = current_schema[field_name]["notnull"] model_allows_null = field_obj.null - + if model_allows_null and current_notnull: - table_inconsistencies.append({ - 'field_name': field_name, - 'issue': 'model_allows_null_but_db_not_null', - 'model_constraint': 'NULL', - 'db_constraint': 'NOT NULL', - 'recommended_action': 'allow_null' - }) + table_inconsistencies.append( + { + "field_name": field_name, + "issue": "model_allows_null_but_db_not_null", + "model_constraint": "NULL", + "db_constraint": "NOT NULL", + "recommended_action": "allow_null", + } + ) elif not model_allows_null and not current_notnull: - table_inconsistencies.append({ - 'field_name': field_name, - 'issue': 'model_not_null_but_db_allows_null', - 'model_constraint': 'NOT NULL', - 'db_constraint': 'NULL', - 'recommended_action': 'disallow_null' - }) - + table_inconsistencies.append( + { + "field_name": field_name, + "issue": "model_not_null_but_db_allows_null", + "model_constraint": "NOT NULL", + "db_constraint": "NULL", + "recommended_action": "disallow_null", + } + ) + if table_inconsistencies: inconsistencies[table_name] = table_inconsistencies - + except Exception as e: logger.exception(f"检查字段约束时出错: {e}") - - return inconsistencies + return inconsistencies # 模块加载时调用初始化函数 initialize_database(sync_constraints=True) - - - - diff --git a/src/common/logger.py b/src/common/logger.py index ab0fd849..f980064f 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -339,24 +339,18 @@ MODULE_COLORS = { # 67 :具体的颜色编号(0-255),这里是较暗的蓝色 "sender": "\033[38;5;24m", # 67号色,较暗的蓝色,适合不显眼的日志 "send_api": "\033[38;5;24m", # 208号色,橙色,适合突出显示 - # 生成 "replyer": "\033[38;5;208m", # 橙色 "llm_api": "\033[38;5;208m", # 橙色 - # 消息处理 "chat": "\033[38;5;82m", # 亮蓝色 "chat_image": "\033[38;5;68m", # 浅蓝色 - - #emoji + # emoji "emoji": "\033[38;5;214m", # 橙黄色,偏向橙色 "emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色 - # 核心模块 "main": "\033[1;97m", # 亮白色+粗体 (主程序) - "memory": "\033[38;5;34m", # 天蓝色 - "config": "\033[93m", # 亮黄色 "common": "\033[95m", # 亮紫色 "tools": "\033[96m", # 亮青色 @@ -367,9 +361,6 @@ MODULE_COLORS = { "llm_models": "\033[36m", # 青色 "remote": "\033[38;5;242m", # 深灰色,更不显眼 "planner": "\033[36m", - - - "relation": "\033[38;5;139m", # 柔和的紫色,不刺眼 # 聊天相关模块 "normal_chat": "\033[38;5;81m", # 亮蓝绿色 @@ -379,11 +370,9 @@ MODULE_COLORS = { "background_tasks": "\033[38;5;240m", # 灰色 "chat_message": "\033[38;5;45m", # 青色 "chat_stream": "\033[38;5;51m", # 亮青色 - "message_storage": "\033[38;5;33m", # 深蓝色 "expressor": "\033[38;5;166m", # 橙色 # 专注聊天模块 - "memory_activator": "\033[38;5;117m", # 天蓝色 # 插件系统 "plugins": "\033[31m", # 红色 @@ -412,7 +401,6 @@ MODULE_COLORS = { # 工具和实用模块 "prompt_build": "\033[38;5;105m", # 紫色 "chat_utils": "\033[38;5;111m", # 蓝色 - "maibot_statistic": "\033[38;5;129m", # 紫色 # 特殊功能插件 "mute_plugin": "\033[38;5;240m", # 灰色 @@ -447,10 +435,8 @@ MODULE_ALIASES = { "llm_api": "生成API", "emoji": "表情包", "emoji_api": "表情包API", - "chat": "所见", "chat_image": "识图", - "action_manager": "动作", "memory_activator": "记忆", "tool_use": "工具", @@ -460,7 +446,6 @@ MODULE_ALIASES = { "memory": "记忆", "tool_executor": "工具", "hfc": "聊天节奏", - "plugin_manager": "插件", "relationship_builder": "关系", "llm_models": "模型", diff --git a/src/config/config.py b/src/config/config.py index f3827855..a35ba7b7 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -114,7 +114,7 @@ def set_value_by_path(d, path, value): if k not in d or not isinstance(d[k], dict): d[k] = {} d = d[k] - + # 使用 tomlkit.item 来保持 TOML 格式 try: d[path[-1]] = tomlkit.item(value) @@ -253,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str): f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" ) config_updated = True - + # 如果配置有更新,立即保存到文件 if config_updated: with open(old_config_path, "w", encoding="utf-8") as f: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 61eba986..3da87dee 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -43,10 +43,11 @@ class PersonalityConfig(ConfigBase): reply_style: str = "" """表达风格""" - + interest: str = "" """兴趣""" + @dataclass class RelationshipConfig(ConfigBase): """关系配置类""" @@ -61,31 +62,30 @@ class ChatConfig(ConfigBase): max_context_size: int = 18 """上下文长度""" - + interest_rate_mode: Literal["fast", "accurate"] = "fast" """兴趣值计算模式,fast为快速计算,accurate为精确计算""" mentioned_bot_reply: float = 1 """提及 bot 必然回复,1为100%回复,0为不额外增幅""" - + planner_size: float = 1.5 """副规划器大小,越小,麦麦的动作执行能力越精细,但是消耗更多token,调大可以缓解429类错误""" at_bot_inevitable_reply: float = 1 """@bot 必然回复,1为100%回复,0为不额外增幅""" - + talk_frequency: float = 0.5 """回复频率阈值""" # 合并后的时段频率配置 talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: []) - focus_value: float = 0.5 """麦麦的专注思考能力,越低越容易专注,消耗token也越多""" - + focus_value_adjust: list[list[str]] = field(default_factory=lambda: []) - + """ 统一的活跃度和专注度配置 格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...] @@ -110,7 +110,6 @@ class ChatConfig(ConfigBase): - talk_frequency_adjust 控制回复频率,数值越高回复越频繁 - focus_value_adjust 控制专注思考能力,数值越低越容易专注,消耗token也越多 """ - @dataclass @@ -123,6 +122,7 @@ class MessageReceiveConfig(ConfigBase): ban_msgs_regex: set[str] = field(default_factory=lambda: set()) """过滤正则表达式列表""" + @dataclass class ExpressionConfig(ConfigBase): """表达配置类""" diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 807f6484..eb74b0df 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -174,7 +174,7 @@ class ClientRegistry: return client_class(api_provider) else: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") - + # 正常的缓存逻辑 if api_provider.name not in self.client_instance_cache: if client_class := self.client_registry.get(api_provider.client_type): diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 51bb692f..1287dbec 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -531,7 +531,7 @@ class OpenaiClient(BaseClient): # 添加详细的错误信息以便调试 logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") logger.error(f"错误类型: {type(e)}") - if hasattr(e, '__cause__') and e.__cause__: + if hasattr(e, "__cause__") and e.__cause__: logger.error(f"底层错误: {str(e.__cause__)}") raise NetworkConnectionError() from e except APIStatusError as e: diff --git a/src/llm_models/payload_content/__init__.py b/src/llm_models/payload_content/__init__.py index 33e43c5e..f33921f6 100644 --- a/src/llm_models/payload_content/__init__.py +++ b/src/llm_models/payload_content/__init__.py @@ -1,3 +1,3 @@ from .tool_option import ToolCall -__all__ = ["ToolCall"] \ No newline at end of file +__all__ = ["ToolCall"] diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py index ab2e2edf..e1baa374 100644 --- a/src/llm_models/payload_content/resp_format.py +++ b/src/llm_models/payload_content/resp_format.py @@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None: elif not isinstance(instance["name"], str) or instance["name"].strip() == "": return "schema的'name'字段必须是非空字符串" if "description" in instance and ( - not isinstance(instance["description"], str) - or instance["description"].strip() == "" + not isinstance(instance["description"], str) or instance["description"].strip() == "" ): return "schema的'description'字段只能填入非空字符串" if "schema" not in instance: @@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: # 如果当前Schema是列表,则遍历每个元素 for i in range(len(sub_schema)): if isinstance(sub_schema[i], dict): - sub_schema[i] = link_definitions_recursive( - f"{path}/{str(i)}", sub_schema[i], defs - ) + sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs) else: # 否则为字典 if "$defs" in sub_schema: @@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: for key, value in sub_schema.items(): if isinstance(value, (dict, list)): # 如果当前值是字典或列表,则递归调用 - sub_schema[key] = link_definitions_recursive( - f"{path}/{key}", value, defs - ) + sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs) return sub_schema @@ -163,9 +158,7 @@ class RespFormat: def _generate_schema_from_model(schema): json_schema = { "name": schema.__name__, - "schema": _remove_defs( - _link_definitions(_remove_title(schema.model_json_schema())) - ), + "schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))), "strict": False, } if schema.__doc__: diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index cf047654..5c760252 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -155,7 +155,13 @@ class LLMUsageRecorder: logger.error(f"创建 LLMUsage 表失败: {str(e)}") def record_usage_to_database( - self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0 + self, + model_info: ModelInfo, + model_usage: UsageRecord, + user_id: str, + request_type: str, + endpoint: str, + time_cost: float = 0.0, ): input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out @@ -173,7 +179,7 @@ class LLMUsageRecorder: completion_tokens=model_usage.completion_tokens or 0, total_tokens=model_usage.total_tokens or 0, cost=total_cost or 0.0, - time_cost = round(time_cost or 0.0, 3), + time_cost=round(time_cost or 0.0, 3), status="success", timestamp=datetime.now(), # Peewee 会处理 DateTimeField ) @@ -186,4 +192,5 @@ class LLMUsageRecorder: except Exception as e: logger.error(f"记录token使用情况失败: {str(e)}") -llm_usage_recorder = LLMUsageRecorder() \ No newline at end of file + +llm_usage_recorder = LLMUsageRecorder() diff --git a/src/mais4u/mais4u_chat/context_web_manager.py b/src/mais4u/mais4u_chat/context_web_manager.py index 8c6cde2c..1e11f725 100644 --- a/src/mais4u/mais4u_chat/context_web_manager.py +++ b/src/mais4u/mais4u_chat/context_web_manager.py @@ -14,31 +14,31 @@ logger = get_logger("context_web") class ContextMessage: """上下文消息类""" - + def __init__(self, message: MessageRecv): self.user_name = message.message_info.user_info.user_nickname self.user_id = message.message_info.user_info.user_id self.content = message.processed_plain_text self.timestamp = datetime.now() self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊" - + # 识别消息类型 - self.is_gift = getattr(message, 'is_gift', False) - self.is_superchat = getattr(message, 'is_superchat', False) - + self.is_gift = getattr(message, "is_gift", False) + self.is_superchat = getattr(message, "is_superchat", False) + # 添加礼物和SC相关信息 if self.is_gift: - self.gift_name = getattr(message, 'gift_name', '') - self.gift_count = getattr(message, 'gift_count', '1') + self.gift_name = getattr(message, "gift_name", "") + self.gift_count = getattr(message, "gift_count", "1") self.content = f"送出了 {self.gift_name} x{self.gift_count}" elif self.is_superchat: - self.superchat_price = getattr(message, 'superchat_price', '0') - self.superchat_message = getattr(message, 'superchat_message_text', '') + self.superchat_price = getattr(message, "superchat_price", "0") + self.superchat_message = getattr(message, "superchat_message_text", "") if self.superchat_message: self.content = f"[¥{self.superchat_price}] {self.superchat_message}" else: self.content = f"[¥{self.superchat_price}] {self.content}" - + def to_dict(self): return { "user_name": self.user_name, @@ -47,13 +47,13 @@ class ContextMessage: "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), "group_name": self.group_name, "is_gift": self.is_gift, - "is_superchat": self.is_superchat + "is_superchat": self.is_superchat, } class ContextWebManager: """上下文网页管理器""" - + def __init__(self, max_messages: int = 10, port: int = 8765): self.max_messages = max_messages self.port = port @@ -63,53 +63,53 @@ class ContextWebManager: self.runner = None self.site = None self._server_starting = False # 添加启动标志防止并发 - + async def start_server(self): """启动web服务器""" if self.site is not None: logger.debug("Web服务器已经启动,跳过重复启动") return - + if self._server_starting: logger.debug("Web服务器正在启动中,等待启动完成...") # 等待启动完成 while self._server_starting and self.site is None: await asyncio.sleep(0.1) return - + self._server_starting = True - + try: self.app = web.Application() - + # 设置CORS - cors = aiohttp_cors.setup(self.app, defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - allow_methods="*" - ) - }) - + cors = aiohttp_cors.setup( + self.app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*" + ) + }, + ) + # 添加路由 - self.app.router.add_get('/', self.index_handler) - self.app.router.add_get('/ws', self.websocket_handler) - self.app.router.add_get('/api/contexts', self.get_contexts_handler) - self.app.router.add_get('/debug', self.debug_handler) - + self.app.router.add_get("/", self.index_handler) + self.app.router.add_get("/ws", self.websocket_handler) + self.app.router.add_get("/api/contexts", self.get_contexts_handler) + self.app.router.add_get("/debug", self.debug_handler) + # 为所有路由添加CORS for route in list(self.app.router.routes()): cors.add(route) - + self.runner = web.AppRunner(self.app) await self.runner.setup() - - self.site = web.TCPSite(self.runner, 'localhost', self.port) + + self.site = web.TCPSite(self.runner, "localhost", self.port) await self.site.start() - + logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}") - + except Exception as e: logger.error(f"❌ 启动Web服务器失败: {e}") # 清理部分启动的资源 @@ -121,7 +121,7 @@ class ContextWebManager: raise finally: self._server_starting = False - + async def stop_server(self): """停止web服务器""" if self.site: @@ -132,10 +132,11 @@ class ContextWebManager: self.runner = None self.site = None self._server_starting = False - + async def index_handler(self, request): """主页处理器""" - html_content = ''' + html_content = ( + """ @@ -286,7 +287,9 @@ class ContextWebManager: function connectWebSocket() { console.log('正在连接WebSocket...'); - ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws'); + ws = new WebSocket('ws://localhost:""" + + str(self.port) + + """/ws'); ws.onopen = function() { console.log('WebSocket连接已建立'); @@ -470,47 +473,48 @@ class ContextWebManager: - ''' - return web.Response(text=html_content, content_type='text/html') - + """ + ) + return web.Response(text=html_content, content_type="text/html") + async def websocket_handler(self, request): """WebSocket处理器""" ws = web.WebSocketResponse() await ws.prepare(request) - + self.websockets.append(ws) logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}") - + # 发送初始数据 await self.send_contexts_to_websocket(ws) - + async for msg in ws: if msg.type == WSMsgType.ERROR: - logger.error(f'WebSocket错误: {ws.exception()}') + logger.error(f"WebSocket错误: {ws.exception()}") break - + # 清理断开的连接 if ws in self.websockets: self.websockets.remove(ws) logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}") - + return ws - + async def get_contexts_handler(self, request): """获取上下文API""" all_context_msgs = [] for _chat_id, contexts in self.contexts.items(): all_context_msgs.extend(list(contexts)) - + # 按时间排序,最新的在最后 all_context_msgs.sort(key=lambda x: x.timestamp) - + # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] - + contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] + logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息") return web.json_response({"contexts": contexts_data}) - + async def debug_handler(self, request): """调试信息处理器""" debug_info = { @@ -519,7 +523,7 @@ class ContextWebManager: "total_chats": len(self.contexts), "total_messages": sum(len(contexts) for contexts in self.contexts.values()), } - + # 构建聊天详情HTML chats_html = "" for chat_id, contexts in self.contexts.items(): @@ -528,15 +532,15 @@ class ContextWebManager: timestamp = msg.timestamp.strftime("%H:%M:%S") content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content messages_html += f'
[{timestamp}] {msg.user_name}: {content}
' - - chats_html += f''' + + chats_html += f"""

聊天 {chat_id} ({len(contexts)} 条消息)

{messages_html}
- ''' - - html_content = f''' + """ + + html_content = f""" @@ -578,74 +582,78 @@ class ContextWebManager: - ''' - - return web.Response(text=html_content, content_type='text/html') - + """ + + return web.Response(text=html_content, content_type="text/html") + async def add_message(self, chat_id: str, message: MessageRecv): """添加新消息到上下文""" if chat_id not in self.contexts: self.contexts[chat_id] = deque(maxlen=self.max_messages) logger.debug(f"为聊天 {chat_id} 创建新的上下文队列") - + context_msg = ContextMessage(message) self.contexts[chat_id].append(context_msg) - + # 统计当前总消息数 total_messages = sum(len(contexts) for contexts in self.contexts.values()) - - logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}") - + + logger.info( + f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}" + ) + # 调试:打印当前所有消息 logger.info("📝 当前上下文中的所有消息:") for cid, contexts in self.contexts.items(): logger.info(f" 聊天 {cid}: {len(contexts)} 条消息") for i, msg in enumerate(contexts): - logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...") - + logger.info( + f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..." + ) + # 广播更新给所有WebSocket连接 await self.broadcast_contexts() - + async def send_contexts_to_websocket(self, ws: web.WebSocketResponse): """向单个WebSocket发送上下文数据""" all_context_msgs = [] for _chat_id, contexts in self.contexts.items(): all_context_msgs.extend(list(contexts)) - + # 按时间排序,最新的在最后 all_context_msgs.sort(key=lambda x: x.timestamp) - + # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] - + contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] + data = {"contexts": contexts_data} await ws.send_str(json.dumps(data, ensure_ascii=False)) - + async def broadcast_contexts(self): """向所有WebSocket连接广播上下文更新""" if not self.websockets: logger.debug("没有WebSocket连接,跳过广播") return - + all_context_msgs = [] for _chat_id, contexts in self.contexts.items(): all_context_msgs.extend(list(contexts)) - + # 按时间排序,最新的在最后 all_context_msgs.sort(key=lambda x: x.timestamp) - + # 转换为字典格式 - contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] - + contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]] + data = {"contexts": contexts_data} message = json.dumps(data, ensure_ascii=False) - + logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接") - + # 创建WebSocket列表的副本,避免在遍历时修改 websockets_copy = self.websockets.copy() removed_count = 0 - + for ws in websockets_copy: if ws.closed: if ws in self.websockets: @@ -660,7 +668,7 @@ class ContextWebManager: if ws in self.websockets: self.websockets.remove(ws) removed_count += 1 - + if removed_count > 0: logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接") @@ -681,5 +689,4 @@ async def init_context_web_manager(): """初始化上下文网页管理器""" manager = get_context_web_manager() await manager.start_server() - return manager - + return manager diff --git a/src/mais4u/mais4u_chat/gift_manager.py b/src/mais4u/mais4u_chat/gift_manager.py index b75882dc..d489550c 100644 --- a/src/mais4u/mais4u_chat/gift_manager.py +++ b/src/mais4u/mais4u_chat/gift_manager.py @@ -11,6 +11,7 @@ logger = get_logger("gift_manager") @dataclass class PendingGift: """等待中的礼物消息""" + message: MessageRecvS4U total_count: int timer_task: asyncio.Task @@ -19,71 +20,68 @@ class PendingGift: class GiftManager: """礼物管理器,提供防抖功能""" - + def __init__(self): """初始化礼物管理器""" self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {} self.debounce_timeout = 5.0 # 3秒防抖时间 - - async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool: + + async def handle_gift( + self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None + ) -> bool: """处理礼物消息,返回是否应该立即处理 - + Args: message: 礼物消息 callback: 防抖完成后的回调函数 - + Returns: bool: False表示消息被暂存等待防抖,True表示应该立即处理 """ if not message.is_gift: return True - + # 构建礼物的唯一键:(发送人ID, 礼物名称) gift_key = (message.message_info.user_info.user_id, message.gift_name) - + # 如果已经有相同的礼物在等待中,则合并 if gift_key in self.pending_gifts: await self._merge_gift(gift_key, message) return False - + # 创建新的等待礼物 await self._create_pending_gift(gift_key, message, callback) return False - + async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None: """合并礼物消息""" pending_gift = self.pending_gifts[gift_key] - + # 取消之前的定时器 if not pending_gift.timer_task.cancelled(): pending_gift.timer_task.cancel() - + # 累加礼物数量 try: new_count = int(new_message.gift_count) pending_gift.total_count += new_count - + # 更新消息为最新的(保留最新的消息,但累加数量) pending_gift.message = new_message pending_gift.message.gift_count = str(pending_gift.total_count) pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}" - + except ValueError: logger.warning(f"无法解析礼物数量: {new_message.gift_count}") # 如果无法解析数量,保持原有数量不变 - + # 重新创建定时器 - pending_gift.timer_task = asyncio.create_task( - self._gift_timeout(gift_key) - ) - + pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key)) + logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") - + async def _create_pending_gift( - self, - gift_key: Tuple[str, str], - message: MessageRecvS4U, - callback: Optional[Callable[[MessageRecvS4U], None]] + self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] ) -> None: """创建新的等待礼物""" try: @@ -91,56 +89,51 @@ class GiftManager: except ValueError: initial_count = 1 logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1") - + # 创建定时器任务 timer_task = asyncio.create_task(self._gift_timeout(gift_key)) - + # 创建等待礼物对象 - pending_gift = PendingGift( - message=message, - total_count=initial_count, - timer_task=timer_task, - callback=callback - ) - + pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback) + self.pending_gifts[gift_key] = pending_gift - + logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}") - + async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None: """礼物防抖超时处理""" try: # 等待防抖时间 await asyncio.sleep(self.debounce_timeout) - + # 获取等待中的礼物 if gift_key not in self.pending_gifts: return - + pending_gift = self.pending_gifts.pop(gift_key) - + logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}") - + message = pending_gift.message message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}" - + # 执行回调 if pending_gift.callback: try: pending_gift.callback(message) except Exception as e: logger.error(f"礼物回调执行失败: {e}", exc_info=True) - + except asyncio.CancelledError: # 定时器被取消,不需要处理 pass except Exception as e: logger.error(f"礼物防抖处理异常: {e}", exc_info=True) - + def get_pending_count(self) -> int: """获取当前等待中的礼物数量""" return len(self.pending_gifts) - + async def flush_all(self) -> None: """立即处理所有等待中的礼物""" for gift_key in list(self.pending_gifts.keys()): @@ -152,4 +145,3 @@ class GiftManager: # 创建全局礼物管理器实例 gift_manager = GiftManager() - \ No newline at end of file diff --git a/src/mais4u/mais4u_chat/internal_manager.py b/src/mais4u/mais4u_chat/internal_manager.py index 695b0772..4b3db326 100644 --- a/src/mais4u/mais4u_chat/internal_manager.py +++ b/src/mais4u/mais4u_chat/internal_manager.py @@ -1,14 +1,15 @@ class InternalManager: def __init__(self): self.now_internal_state = str() - - def set_internal_state(self,internal_state:str): + + def set_internal_state(self, internal_state: str): self.now_internal_state = internal_state - + def get_internal_state(self): return self.now_internal_state - + def get_internal_state_str(self): return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}" -internal_manager = InternalManager() \ No newline at end of file + +internal_manager = InternalManager() diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index f98c6fdb..8d749697 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -16,7 +16,6 @@ import json from .s4u_mood_manager import mood_manager from src.mais4u.s4u_config import s4u_config from src.person_info.person_info import get_person_id -from .super_chat_manager import get_super_chat_manager from .yes_or_no import yes_or_no_head logger = get_logger("S4U_chat") @@ -33,15 +32,12 @@ class MessageSenderContainer: self._task: Optional[asyncio.Task] = None self._paused_event = asyncio.Event() self._paused_event.set() # 默认设置为非暂停状态 - - self.msg_id = "" - - self.last_msg_id = "" - - self.voice_done = "" - - + self.msg_id = "" + + self.last_msg_id = "" + + self.voice_done = "" async def add_message(self, chunk: str): """向队列中添加一个消息块。""" @@ -131,7 +127,7 @@ class MessageSenderContainer: reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}", ) await bot_message.process() - + await self.storage.store_message(bot_message, self.chat_stream) except Exception as e: @@ -198,12 +194,12 @@ class S4UChat: self.gpt = S4UStreamGenerator() self.gpt.chat_stream = self.chat_stream self.interest_dict: Dict[str, float] = {} # 用户兴趣分 - - self.internal_message :List[MessageRecvS4U] = [] - + + self.internal_message: List[MessageRecvS4U] = [] + self.msg_id = "" self.voice_done = "" - + logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.") def _get_priority_info(self, message: MessageRecv) -> dict: @@ -226,7 +222,7 @@ class S4UChat: def _get_interest_score(self, user_id: str) -> float: """获取用户的兴趣分,默认为1.0""" return self.interest_dict.get(user_id, 1.0) - + def go_processing(self): if self.voice_done == self.last_msg_id: return True @@ -237,14 +233,14 @@ class S4UChat: 为消息计算基础优先级分数。分数越高,优先级越高。 """ score = 0.0 - + # 加上消息自带的优先级 score += priority_info.get("message_priority", 0.0) # 加上用户的固有兴趣分 score += self._get_interest_score(message.message_info.user_info.user_id) return score - + def decay_interest_score(self): for person_id, score in self.interest_dict.items(): if score > 0: @@ -252,15 +248,14 @@ class S4UChat: else: self.interest_dict[person_id] = 0 - async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None: - + async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None: self.decay_interest_score() - + """根据VIP状态和中断逻辑将消息放入相应队列。""" user_id = message.message_info.user_info.user_id platform = message.message_info.platform - person_id = get_person_id(platform, user_id) - + _person_id = get_person_id(platform, user_id) + # try: # is_gift = message.is_gift # is_superchat = message.is_superchat @@ -276,7 +271,7 @@ class S4UChat: # # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0 # current_score = self.interest_dict.get(person_id, 1.0) # self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price) - + # # 添加SuperChat到管理器 # super_chat_manager = get_super_chat_manager() # await super_chat_manager.add_superchat(message) @@ -284,16 +279,19 @@ class S4UChat: # await self.relationship_builder.build_relation(20) # except Exception: # traceback.print_exc() - + logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}") - + priority_info = self._get_priority_info(message) is_vip = self._is_vip(priority_info) new_priority_score = self._calculate_base_priority_score(message, priority_info) should_interrupt = False - if (s4u_config.enable_message_interruption and - self._current_generation_task and not self._current_generation_task.done()): + if ( + s4u_config.enable_message_interruption + and self._current_generation_task + and not self._current_generation_task.done() + ): if self._current_message_being_replied: current_queue, current_priority, _, current_msg = self._current_message_being_replied @@ -344,39 +342,45 @@ class S4UChat: """清理普通队列中不在最近N条消息范围内的消息""" if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty(): return - + # 计算阈值:保留最近 recent_message_keep_count 条消息 cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count) - + # 临时存储需要保留的消息 temp_messages = [] removed_count = 0 - + # 取出所有普通队列中的消息 while not self._normal_queue.empty(): try: item = self._normal_queue.get_nowait() neg_priority, entry_count, timestamp, message = item - + # 如果消息在最近N条消息范围内,保留它 - logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}") - + logger.info( + f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}" + ) + if entry_count >= cutoff_counter: temp_messages.append(item) else: removed_count += 1 self._normal_queue.task_done() # 标记被移除的任务为完成 - + except asyncio.QueueEmpty: break - + # 将保留的消息重新放入队列 for item in temp_messages: self._normal_queue.put_nowait(item) - + if removed_count > 0: - logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除") - logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.") + logger.info( + f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除" + ) + logger.info( + f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range." + ) async def _message_processor(self): """调度器:优先处理VIP队列,然后处理普通队列。""" @@ -385,7 +389,7 @@ class S4UChat: # 等待有新消息的信号,避免空转 await self._new_message_event.wait() self._new_message_event.clear() - + # 清理普通队列中的过旧消息 self._cleanup_old_normal_messages() @@ -396,7 +400,6 @@ class S4UChat: queue_name = "vip" # 其次处理普通队列 elif not self._normal_queue.empty(): - neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait() priority = -neg_priority # 检查普通消息是否超时 @@ -411,13 +414,15 @@ class S4UChat: if self.internal_message: message = self.internal_message[-1] self.internal_message = [] - + priority = 0 neg_priority = 0 entry_count = 0 queue_name = "internal" - logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...") + logger.info( + f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..." + ) else: continue # 没有消息了,回去等事件 @@ -457,23 +462,21 @@ class S4UChat: except Exception as e: logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True) await asyncio.sleep(1) - - + def get_processing_message_id(self): self.last_msg_id = self.msg_id self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}" - async def _generate_and_send(self, message: MessageRecv): """为单个消息生成文本回复。整个过程可以被中断。""" self._is_replying = True total_chars_sent = 0 # 跟踪发送的总字符数 - + self.get_processing_message_id() - + # 视线管理:开始生成回复时切换视线状态 chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) - + if message.is_internal: await chat_watching.on_internal_message_start() else: @@ -516,16 +519,19 @@ class S4UChat: total_chars_sent = len("麦麦不知道哦") mood = mood_manager.get_mood_by_chat_id(self.stream_id) - await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id) + await yes_or_no_head( + text=total_chars_sent, + emotion=mood.mood_state, + chat_history=message.processed_plain_text, + chat_id=self.stream_id, + ) # 等待所有文本消息发送完成 await sender_container.close() await sender_container.join() - + await chat_watching.on_thinking_finished() - - - + start_time = time.time() logged = False while not self.go_processing(): @@ -536,7 +542,7 @@ class S4UChat: logger.info(f"[{self.stream_name}] 等待消息发送完成...") logged = True await asyncio.sleep(0.2) - + logger.info(f"[{self.stream_name}] 所有文本块处理完毕。") except asyncio.CancelledError: @@ -548,11 +554,11 @@ class S4UChat: # 回复生成实时展示:清空内容(出错时) finally: self._is_replying = False - + # 视线管理:回复结束时切换视线状态 chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) await chat_watching.on_reply_finished() - + # 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的) sender_container.resume() if not sender_container._task.done(): @@ -576,4 +582,3 @@ class S4UChat: await self._processing_task except asyncio.CancelledError: logger.info(f"处理任务已成功取消: {self.stream_name}") - diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index 315d0500..4263194b 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: if global_config.memory.enable_memory: with Timer("记忆激活"): - interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text( + interested_rate, _, _ = await hippocampus_manager.get_activate_from_text( message.processed_plain_text, fast_retrieval=True, ) @@ -49,7 +49,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: text_len = len(message.processed_plain_text) # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) - + if text_len == 0: base_interest = 0.01 # 空消息最低兴趣度 elif text_len <= 5: @@ -73,7 +73,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: else: # 100+字符:对数增长 0.26 -> 0.3,增长率递减 base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 - + # 确保在范围内 base_interest = min(max(base_interest, 0.01), 0.3) @@ -117,36 +117,32 @@ class S4UMessageProcessor: user_info=userinfo, group_info=groupinfo, ) - + if await self.handle_internal_message(message): return - + if await self.hadle_if_voice_done(message): return - + # 处理礼物消息,如果消息被暂存则停止当前处理流程 if not skip_gift_debounce and not await self.handle_if_gift(message): return await self.check_if_fake_gift(message) - + # 处理屏幕消息 if await self.handle_screen_message(message): return - await self.storage.store_message(message, chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) - await s4u_chat.add_message(message) _interested_rate, _ = await _calculate_interest(message) - + await mood_manager.start() - - # 一系列llm驱动的前处理 chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) asyncio.create_task(chat_mood.update_mood_by_message(message)) @@ -164,61 +160,56 @@ class S4UMessageProcessor: logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}") else: logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") - + async def handle_internal_message(self, message: MessageRecvS4U): if message.is_internal: - - group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心") - - chat = await get_chat_manager().get_or_create_stream( - platform = "amaidesu_default", - user_info = message.message_info.user_info, - group_info = group_info + group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心") + + chat = await get_chat_manager().get_or_create_stream( + platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info ) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) message.message_info.group_info = s4u_chat.chat_stream.group_info message.message_info.platform = s4u_chat.chat_stream.platform - - + s4u_chat.internal_message.append(message) s4u_chat._new_message_event.set() - - - logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}") - - + + logger.info( + f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}" + ) + return True return False - - + async def handle_screen_message(self, message: MessageRecvS4U): if message.is_screen: screen_manager.set_screen(message.screen_info) return True return False - + async def hadle_if_voice_done(self, message: MessageRecvS4U): if message.voice_done: s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream) s4u_chat.voice_done = message.voice_done return True return False - + async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool: """检查消息是否为假礼物""" if message.is_gift: return False - - gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"] + + gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"] if any(keyword in message.processed_plain_text for keyword in gift_keywords): message.is_fake_gift = True return True return False - + async def handle_if_gift(self, message: MessageRecvS4U) -> bool: """处理礼物消息 - + Returns: bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理 """ @@ -228,37 +219,37 @@ class S4UMessageProcessor: """礼物防抖完成后的回调""" # 创建异步任务来处理合并后的礼物消息,跳过防抖处理 asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True)) - + # 交给礼物管理器处理,并传入回调函数 # 对于礼物消息,handle_gift 总是返回 False(消息被暂存) await gift_manager.handle_gift(message, gift_callback) return False # 消息被暂存,不继续处理 - + return True # 非礼物消息,继续正常处理 async def _handle_context_web_update(self, chat_id: str, message: MessageRecv): """处理上下文网页更新的独立task - + Args: chat_id: 聊天ID message: 消息对象 """ try: logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}") - + context_manager = get_context_web_manager() - + # 只在服务器未启动时启动(避免重复启动) if context_manager.site is None: logger.info("🚀 首次启动上下文网页服务器...") await context_manager.start_server() - + # 添加消息到上下文并更新网页 await asyncio.sleep(1.5) - + await context_manager.add_message(chat_id, message) - + logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}") - + except Exception as e: logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True) diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 86447e27..15e4d729 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -176,7 +176,7 @@ class PromptBuilder: message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), - # sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if + # sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if limit=300, ) @@ -228,13 +228,17 @@ class PromptBuilder: last_speaking_user_id = start_speaking_user_id msg_seg_str = "对方的发言:\n" - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n" + msg_seg_str += ( + f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n" + ) all_msg_seg_list = [] for msg in core_dialogue_list[1:]: speaker = msg.user_info.user_id if speaker == last_speaking_user_id: - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n" + msg_seg_str += ( + f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n" + ) else: msg_seg_str = f"{msg_seg_str}\n" all_msg_seg_list.append(msg_seg_str) diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 607470cd..3d7db3f3 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -14,11 +14,8 @@ logger = get_logger("s4u_stream_generator") class S4UStreamGenerator: def __init__(self): # 使用LLMRequest替代AsyncOpenAIClient - self.llm_request = LLMRequest( - model_set=model_config.model_task_config.replyer, - request_type="s4u_replyer" - ) - + self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer") + self.current_model_name = "unknown model" self.partial_response = "" @@ -89,16 +86,16 @@ class S4UStreamGenerator: async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]: """使用LLMRequest进行流式响应生成""" - + # 构建消息 message_builder = MessageBuilder() message_builder.add_text_content(prompt) messages = [message_builder.build()] - + # 选择模型 model_info, api_provider, client = self.llm_request._select_model() self.current_model_name = model_info.name - + # 如果模型支持强制流式模式,使用真正的流式处理 if model_info.force_stream_mode: # 简化流式处理:直接使用LLMRequest的流式功能 @@ -111,14 +108,14 @@ class S4UStreamGenerator: model_info=model_info, message_list=messages, ) - + # 处理响应内容 content = response.content or "" if content: # 将内容按句子分割并输出 async for chunk in self._process_content_streaming(content): yield chunk - + except Exception as e: logger.error(f"流式请求执行失败: {e}") # 如果流式请求失败,回退到普通模式 @@ -132,7 +129,7 @@ class S4UStreamGenerator: content = response.content or "" async for chunk in self._process_content_streaming(content): yield chunk - + else: # 如果不支持流式,使用普通方式然后模拟流式输出 response = await self.llm_request._execute_request( @@ -142,7 +139,7 @@ class S4UStreamGenerator: model_info=model_info, message_list=messages, ) - + content = response.content or "" async for chunk in self._process_content_streaming(content): yield chunk @@ -163,7 +160,7 @@ class S4UStreamGenerator: """处理内容进行流式输出(用于非流式模型的模拟流式输出)""" buffer = content punctuation_buffer = "" - + # 使用正则表达式匹配句子 last_match_end = 0 for match in self.sentence_split_pattern.finditer(buffer): diff --git a/src/mais4u/mais4u_chat/s4u_watching_manager.py b/src/mais4u/mais4u_chat/s4u_watching_manager.py index 62ef6d86..f079501c 100644 --- a/src/mais4u/mais4u_chat/s4u_watching_manager.py +++ b/src/mais4u/mais4u_chat/s4u_watching_manager.py @@ -1,4 +1,3 @@ - from src.common.logger import get_logger from src.plugin_system.apis import send_api @@ -47,6 +46,7 @@ HEAD_CODE = { "看向正前方": "(0,0,0)", } + class ChatWatching: def __init__(self, chat_id: str): self.chat_id: str = chat_id @@ -56,13 +56,13 @@ class ChatWatching: await send_api.custom_to_stream( message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False ) - + async def on_reply_finished(self): """生成回复完毕时调用""" await send_api.custom_to_stream( message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False ) - + async def on_thinking_finished(self): """思考完毕时调用""" await send_api.custom_to_stream( @@ -74,14 +74,14 @@ class ChatWatching: await send_api.custom_to_stream( message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False ) - - + async def on_internal_message_start(self): """收到消息时调用""" await send_api.custom_to_stream( message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False ) + class WatchingManager: def __init__(self): self.watching_list: list[ChatWatching] = [] @@ -100,6 +100,7 @@ class WatchingManager: return new_watching + # 全局视线管理器实例 watching_manager = WatchingManager() """全局视线管理器""" diff --git a/src/mais4u/mais4u_chat/screen_manager.py b/src/mais4u/mais4u_chat/screen_manager.py index 63ed06c2..996e6399 100644 --- a/src/mais4u/mais4u_chat/screen_manager.py +++ b/src/mais4u/mais4u_chat/screen_manager.py @@ -1,14 +1,15 @@ class ScreenManager: def __init__(self): self.now_screen = str() - - def set_screen(self,screen_str:str): + + def set_screen(self, screen_str: str): self.now_screen = screen_str - + def get_screen(self): return self.now_screen - + def get_screen_str(self): return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}" -screen_manager = ScreenManager() \ No newline at end of file + +screen_manager = ScreenManager() diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index 0fd9b231..ef86a6ba 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional from src.common.logger import get_logger from src.chat.message_receive.message import MessageRecvS4U + # 全局SuperChat管理器实例 from src.mais4u.s4u_config import s4u_config @@ -13,7 +14,7 @@ logger = get_logger("super_chat_manager") @dataclass class SuperChatRecord: """SuperChat记录数据类""" - + user_id: str user_nickname: str platform: str @@ -23,15 +24,15 @@ class SuperChatRecord: timestamp: float expire_time: float group_name: Optional[str] = None - + def is_expired(self) -> bool: """检查SuperChat是否已过期""" return time.time() > self.expire_time - + def remaining_time(self) -> float: """获取剩余时间(秒)""" return max(0, self.expire_time - time.time()) - + def to_dict(self) -> dict: """转换为字典格式""" return { @@ -44,19 +45,19 @@ class SuperChatRecord: "timestamp": self.timestamp, "expire_time": self.expire_time, "group_name": self.group_name, - "remaining_time": self.remaining_time() + "remaining_time": self.remaining_time(), } class SuperChatManager: """SuperChat管理器,负责管理和跟踪SuperChat消息""" - + def __init__(self): self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表 self._cleanup_task: Optional[asyncio.Task] = None self._is_initialized = False logger.info("SuperChat管理器已初始化") - + def _ensure_cleanup_task_started(self): """确保清理任务已启动(延迟启动)""" if self._cleanup_task is None or self._cleanup_task.done(): @@ -68,7 +69,7 @@ class SuperChatManager: except RuntimeError: # 没有运行的事件循环,稍后再启动 logger.debug("当前没有运行的事件循环,将在需要时启动清理任务") - + def _start_cleanup_task(self): """启动清理任务(已弃用,保留向后兼容)""" self._ensure_cleanup_task_started() @@ -78,39 +79,36 @@ class SuperChatManager: while True: try: total_removed = 0 - + for chat_id in list(self.super_chats.keys()): original_count = len(self.super_chats[chat_id]) # 移除过期的SuperChat - self.super_chats[chat_id] = [ - sc for sc in self.super_chats[chat_id] - if not sc.is_expired() - ] - + self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] + removed_count = original_count - len(self.super_chats[chat_id]) total_removed += removed_count - + if removed_count > 0: logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat") - + # 如果列表为空,删除该聊天的记录 if not self.super_chats[chat_id]: del self.super_chats[chat_id] - + if total_removed > 0: logger.info(f"总共清理了 {total_removed} 个过期的SuperChat") - + # 每30秒检查一次 await asyncio.sleep(30) - + except Exception as e: logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True) await asyncio.sleep(60) # 出错时等待更长时间 - + def _calculate_expire_time(self, price: float) -> float: """根据SuperChat金额计算过期时间""" current_time = time.time() - + # 根据金额阶梯设置不同的存活时间 if price >= 500: # 500元以上:保持4小时 @@ -133,27 +131,27 @@ class SuperChatManager: else: # 10元以下:保持5分钟 duration = 5 * 60 - + return current_time + duration - + async def add_superchat(self, message: MessageRecvS4U) -> None: """添加新的SuperChat记录""" # 确保清理任务已启动 self._ensure_cleanup_task_started() - + if not message.is_superchat or not message.superchat_price: logger.warning("尝试添加非SuperChat消息到SuperChat管理器") return - + try: price = float(message.superchat_price) except (ValueError, TypeError): logger.error(f"无效的SuperChat价格: {message.superchat_price}") return - + user_info = message.message_info.user_info group_info = message.message_info.group_info - chat_id = getattr(message, 'chat_stream', None) + chat_id = getattr(message, "chat_stream", None) if chat_id: chat_id = chat_id.stream_id else: @@ -161,9 +159,9 @@ class SuperChatManager: chat_id = f"{message.message_info.platform}_{user_info.user_id}" if group_info: chat_id = f"{message.message_info.platform}_{group_info.group_id}" - + expire_time = self._calculate_expire_time(price) - + record = SuperChatRecord( user_id=user_info.user_id, user_nickname=user_info.user_nickname, @@ -173,44 +171,44 @@ class SuperChatManager: message_text=message.superchat_message_text or "", timestamp=message.message_info.time, expire_time=expire_time, - group_name=group_info.group_name if group_info else None + group_name=group_info.group_name if group_info else None, ) - + # 添加到对应聊天的SuperChat列表 if chat_id not in self.super_chats: self.super_chats[chat_id] = [] - + self.super_chats[chat_id].append(record) - + # 按价格降序排序(价格高的在前) self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True) - + logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}") - + def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]: """获取指定聊天的所有有效SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() - + if chat_id not in self.super_chats: return [] - + # 过滤掉过期的SuperChat valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()] return valid_superchats - + def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]: """获取所有有效的SuperChat""" # 确保清理任务已启动 self._ensure_cleanup_task_started() - + result = {} for chat_id, superchats in self.super_chats.items(): valid_superchats = [sc for sc in superchats if not sc.is_expired()] if valid_superchats: result[chat_id] = valid_superchats return result - + def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: """构建SuperChat显示字符串""" superchats = self.get_superchats_by_chat(chat_id) @@ -226,7 +224,9 @@ class SuperChatManager: remaining_minutes = int(sc.remaining_time() / 60) remaining_seconds = int(sc.remaining_time() % 60) - time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒" + time_display = ( + f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒" + ) line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" if len(line) > 100: # 限制单行长度 @@ -238,7 +238,7 @@ class SuperChatManager: lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") return "\n".join(lines) - + def build_superchat_summary_string(self, chat_id: str) -> str: """构建SuperChat摘要字符串""" superchats = self.get_superchats_by_chat(chat_id) @@ -261,30 +261,24 @@ class SuperChatManager: if lines: final_str += "\n" + "\n".join(lines) return final_str - + def get_superchat_statistics(self, chat_id: str) -> dict: """获取SuperChat统计信息""" superchats = self.get_superchats_by_chat(chat_id) - + if not superchats: - return { - "count": 0, - "total_amount": 0, - "average_amount": 0, - "highest_amount": 0, - "lowest_amount": 0 - } - + return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0} + amounts = [sc.price for sc in superchats] - + return { "count": len(superchats), "total_amount": sum(amounts), "average_amount": sum(amounts) / len(amounts), "highest_amount": max(amounts), - "lowest_amount": min(amounts) + "lowest_amount": min(amounts), } - + async def shutdown(self): # sourcery skip: use-contextlib-suppress """关闭管理器,清理资源""" if self._cleanup_task and not self._cleanup_task.done(): @@ -296,15 +290,14 @@ class SuperChatManager: logger.info("SuperChat管理器已关闭") - - # sourcery skip: assign-if-exp if s4u_config.enable_s4u: super_chat_manager = SuperChatManager() else: super_chat_manager = None + def get_super_chat_manager() -> SuperChatManager: """获取全局SuperChat管理器实例""" - return super_chat_manager \ No newline at end of file + return super_chat_manager diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py index f6a153c5..cbb686a4 100644 --- a/src/mais4u/s4u_config.py +++ b/src/mais4u/s4u_config.py @@ -10,10 +10,12 @@ from src.common.logger import get_logger logger = get_logger("s4u_config") + # 新增:兼容dict和tomlkit Table def is_dict_like(obj): return isinstance(obj, (dict, Table)) + # 新增:递归将Table转为dict def table_to_dict(obj): if isinstance(obj, Table): @@ -25,6 +27,7 @@ def table_to_dict(obj): else: return obj + # 获取mais4u模块目录 MAIS4U_ROOT = os.path.dirname(__file__) CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config") @@ -190,7 +193,7 @@ class S4UModelConfig(S4UConfigBase): @dataclass class S4UConfig(S4UConfigBase): """S4U聊天系统配置类""" - + enable_s4u: bool = False """是否启用S4U聊天系统""" @@ -229,12 +232,12 @@ class S4UConfig(S4UConfigBase): enable_streaming_output: bool = True """是否启用流式输出,false时全部生成后一次性发送""" - + max_context_message_length: int = 20 """上下文消息最大长度""" - + max_core_message_length: int = 30 - """核心消息最大长度""" + """核心消息最大长度""" # 模型配置 models: S4UModelConfig = field(default_factory=S4UModelConfig) @@ -243,7 +246,6 @@ class S4UConfig(S4UConfigBase): # 兼容性字段,保持向后兼容 - @dataclass class S4UGlobalConfig(S4UConfigBase): """S4U总配置类""" @@ -256,7 +258,7 @@ def update_s4u_config(): """更新S4U配置文件""" # 创建配置目录(如果不存在) os.makedirs(CONFIG_DIR, exist_ok=True) - + # 检查模板文件是否存在 if not os.path.exists(TEMPLATE_PATH): logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}") @@ -354,13 +356,13 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig: logger.critical("S4U配置文件解析失败") raise e - - # 初始化S4U配置 + + logger.info(f"S4U当前版本: {S4U_VERSION}") update_s4u_config() s4u_config_main = load_s4u_config(config_path=CONFIG_PATH) logger.info("S4U配置文件加载完成!") -s4u_config: S4UConfig = s4u_config_main.s4u \ No newline at end of file +s4u_config: S4UConfig = s4u_config_main.s4u diff --git a/src/migrate_helper/migrate.py b/src/migrate_helper/migrate.py index 6d60dae0..5a565cae 100644 --- a/src/migrate_helper/migrate.py +++ b/src/migrate_helper/migrate.py @@ -13,7 +13,7 @@ async def migrate_memory_items_to_string(): 并根据原始list的项目数量设置weight值 """ logger.info("开始迁移记忆节点格式...") - + migration_stats = { "total_nodes": 0, "converted_nodes": 0, @@ -21,72 +21,74 @@ async def migrate_memory_items_to_string(): "empty_nodes": 0, "error_nodes": 0, "weight_updated_nodes": 0, - "truncated_nodes": 0 + "truncated_nodes": 0, } - + try: # 获取所有图节点 all_nodes = GraphNodes.select() migration_stats["total_nodes"] = all_nodes.count() - + logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点") - + for node in all_nodes: try: concept = node.concept memory_items_raw = node.memory_items.strip() if node.memory_items else "" - original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0 - + original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0 + # 如果为空,跳过 if not memory_items_raw: migration_stats["empty_nodes"] += 1 logger.debug(f"跳过空节点: {concept}") continue - + try: # 尝试解析JSON parsed_data = json.loads(memory_items_raw) - + if isinstance(parsed_data, list): # 如果是list格式,需要转换 if parsed_data: # 转换为字符串格式 new_memory_items = " | ".join(str(item) for item in parsed_data) original_length = len(new_memory_items) - + # 检查长度并截断 if len(new_memory_items) > 100: new_memory_items = new_memory_items[:100] migration_stats["truncated_nodes"] += 1 logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符") - + new_weight = float(len(parsed_data)) # weight = list项目数量 - + # 更新数据库 node.memory_items = new_memory_items node.weight = new_weight node.save() - + migration_stats["converted_nodes"] += 1 migration_stats["weight_updated_nodes"] += 1 - + length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" - logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}") + logger.info( + f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}" + ) else: # 空list,设置为空字符串 node.memory_items = "" node.weight = 1.0 node.save() - + migration_stats["converted_nodes"] += 1 logger.debug(f"转换空list节点: {concept}") - + elif isinstance(parsed_data, str): # 已经是字符串格式,检查长度和weight current_content = parsed_data original_length = len(current_content) content_truncated = False - + # 检查长度并截断 if len(current_content) > 100: current_content = current_content[:100] @@ -94,19 +96,21 @@ async def migrate_memory_items_to_string(): migration_stats["truncated_nodes"] += 1 node.memory_items = current_content logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符") - + # 检查weight是否需要更新 update_needed = False if original_weight == 1.0: # 如果weight还是默认值,可以根据内容复杂度估算 - content_parts = current_content.split(" | ") if " | " in current_content else [current_content] + content_parts = ( + current_content.split(" | ") if " | " in current_content else [current_content] + ) estimated_weight = max(1.0, float(len(content_parts))) - + if estimated_weight != original_weight: node.weight = estimated_weight update_needed = True logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}") - + # 如果内容被截断或权重需要更新,保存到数据库 if content_truncated or update_needed: node.save() @@ -118,26 +122,26 @@ async def migrate_memory_items_to_string(): migration_stats["already_string_nodes"] += 1 else: migration_stats["already_string_nodes"] += 1 - + else: # 其他JSON类型,转换为字符串 new_memory_items = str(parsed_data) if parsed_data else "" original_length = len(new_memory_items) - + # 检查长度并截断 if len(new_memory_items) > 100: new_memory_items = new_memory_items[:100] migration_stats["truncated_nodes"] += 1 logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符") - + node.memory_items = new_memory_items node.weight = 1.0 node.save() - + migration_stats["converted_nodes"] += 1 length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" logger.debug(f"转换其他类型节点: {concept}{length_info}") - + except json.JSONDecodeError: # 不是JSON格式,假设已经是纯字符串 # 检查是否是带引号的字符串 @@ -145,16 +149,16 @@ async def migrate_memory_items_to_string(): # 去掉引号 clean_content = memory_items_raw[1:-1] original_length = len(clean_content) - + # 检查长度并截断 if len(clean_content) > 100: clean_content = clean_content[:100] migration_stats["truncated_nodes"] += 1 logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符") - + node.memory_items = clean_content node.save() - + migration_stats["converted_nodes"] += 1 length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" logger.debug(f"去除引号节点: {concept}{length_info}") @@ -162,29 +166,29 @@ async def migrate_memory_items_to_string(): # 已经是纯字符串格式,检查长度 current_content = memory_items_raw original_length = len(current_content) - + # 检查长度并截断 if len(current_content) > 100: current_content = current_content[:100] node.memory_items = current_content node.save() - + migration_stats["converted_nodes"] += 1 # 算作转换节点 migration_stats["truncated_nodes"] += 1 logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符") else: migration_stats["already_string_nodes"] += 1 logger.debug(f"已是字符串格式节点: {concept}") - + except Exception as e: migration_stats["error_nodes"] += 1 logger.error(f"处理节点 {concept} 时发生错误: {e}") continue - + except Exception as e: logger.error(f"迁移过程中发生严重错误: {e}") raise - + # 输出迁移统计 logger.info("=== 记忆节点迁移完成 ===") logger.info(f"总节点数: {migration_stats['total_nodes']}") @@ -194,101 +198,105 @@ async def migrate_memory_items_to_string(): logger.info(f"错误节点: {migration_stats['error_nodes']}") logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}") logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}") - - success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0 + + success_rate = ( + (migration_stats["converted_nodes"] + migration_stats["already_string_nodes"]) + / migration_stats["total_nodes"] + * 100 + if migration_stats["total_nodes"] > 0 + else 0 + ) logger.info(f"迁移成功率: {success_rate:.1f}%") - + return migration_stats - - async def set_all_person_known(): """ 将person_info库中所有记录的is_known字段设置为True 在设置之前,先清理掉user_id或platform为空的记录 """ logger.info("开始设置所有person_info记录为已认识...") - + try: from src.common.database.database_model import PersonInfo - + # 获取所有PersonInfo记录 all_persons = PersonInfo.select() total_count = all_persons.count() - + logger.info(f"找到 {total_count} 个人员记录") - + if total_count == 0: logger.info("没有找到任何人员记录") return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0} - + # 删除user_id或platform为空的记录 deleted_count = 0 invalid_records = PersonInfo.select().where( - (PersonInfo.user_id.is_null()) | - (PersonInfo.user_id == '') | - (PersonInfo.platform.is_null()) | - (PersonInfo.platform == '') + (PersonInfo.user_id.is_null()) + | (PersonInfo.user_id == "") + | (PersonInfo.platform.is_null()) + | (PersonInfo.platform == "") ) - + # 记录要删除的记录信息 for record in invalid_records: user_id_info = f"'{record.user_id}'" if record.user_id else "NULL" platform_info = f"'{record.platform}'" if record.platform else "NULL" person_name_info = f"'{record.person_name}'" if record.person_name else "无名称" - logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}") - + logger.debug( + f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}" + ) + # 执行删除操作 - deleted_count = PersonInfo.delete().where( - (PersonInfo.user_id.is_null()) | - (PersonInfo.user_id == '') | - (PersonInfo.platform.is_null()) | - (PersonInfo.platform == '') - ).execute() - + deleted_count = ( + PersonInfo.delete() + .where( + (PersonInfo.user_id.is_null()) + | (PersonInfo.user_id == "") + | (PersonInfo.platform.is_null()) + | (PersonInfo.platform == "") + ) + .execute() + ) + if deleted_count > 0: logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录") else: logger.info("没有发现user_id或platform为空的记录") - + # 重新获取剩余记录数量 remaining_count = PersonInfo.select().count() logger.info(f"清理后剩余 {remaining_count} 个有效记录") - + if remaining_count == 0: logger.info("清理后没有剩余记录") return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0} - + # 批量更新剩余记录的is_known字段为True updated_count = PersonInfo.update(is_known=True).execute() - + logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True") - + # 验证更新结果 known_count = PersonInfo.select().where(PersonInfo.is_known).count() - - result = { - "total": total_count, - "deleted": deleted_count, - "updated": updated_count, - "known_count": known_count - } - + + result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count} + logger.info("=== person_info更新完成 ===") logger.info(f"原始记录数: {result['total']}") logger.info(f"删除记录数: {result['deleted']}") logger.info(f"更新记录数: {result['updated']}") logger.info(f"已认识记录数: {result['known_count']}") - + return result - + except Exception as e: logger.error(f"更新person_info过程中发生错误: {e}") raise - async def check_and_run_migrations(): # 获取根目录 project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) @@ -309,4 +317,3 @@ async def check_and_run_migrations(): # 创建done.mem文件 with open(done_file, "w", encoding="utf-8") as f: f.write("done") - \ No newline at end of file diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 584af8b8..f8a1e463 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -282,7 +282,7 @@ class Person: memory_category = parts[0].strip() memory_text = parts[1].strip() - memory_weight = parts[2].strip() + _memory_weight = parts[2].strip() # 检查分类是否匹配 if memory_category != category: diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 15b65ed0..3fb2e6d2 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,11 +1,5 @@ -import json -from json_repair import repair_json -from datetime import datetime from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from .person_info import Person +from src.chat.utils.prompt_builder import Prompt logger = get_logger("relation") @@ -43,4 +37,3 @@ def init_prompt(): """, "attitude_to_me_prompt", ) - diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 70faaba6..e20fc0af 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -121,5 +121,5 @@ __all__ = [ "DatabaseChatInfo", "TargetPersonInfo", "ActionPlannerInfo", - "LLMGenerationDataModel" + "LLMGenerationDataModel", ] diff --git a/src/plugin_system/apis/frequency_api.py b/src/plugin_system/apis/frequency_api.py index 448050b9..a688f43f 100644 --- a/src/plugin_system/apis/frequency_api.py +++ b/src/plugin_system/apis/frequency_api.py @@ -7,22 +7,24 @@ logger = get_logger("frequency_api") def get_current_focus_value(chat_id: str) -> float: return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value() + def get_current_talk_frequency(chat_id: str) -> float: return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency() + def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None: frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust - + + def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: - frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust = talk_frequency_adjust + frequency_control_manager.get_or_create_frequency_control( + chat_id + ).talk_frequency_external_adjust = talk_frequency_adjust + def get_focus_value_adjust(chat_id: str) -> float: return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust - + + def get_talk_frequency_adjust(chat_id: str) -> float: return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust - - - - - diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 257c60fa..f251cede 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -159,6 +159,7 @@ async def generate_reply( logger.error(traceback.format_exc()) return False, None + async def rewrite_reply( chat_stream: Optional[ChatStream] = None, reply_data: Optional[Dict[str, Any]] = None, diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index 1c65d099..debb67d7 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -72,7 +72,9 @@ async def generate_with_model( llm_request = LLMRequest(model_set=model_config, request_type=request_type) - response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens) + response, (reasoning_content, model_name, _) = await llm_request.generate_response_async( + prompt, temperature=temperature, max_tokens=max_tokens + ) return True, response, reasoning_content, model_name except Exception as e: @@ -80,6 +82,7 @@ async def generate_with_model( logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "" + async def generate_with_model_with_tools( prompt: str, model_config: TaskConfig, @@ -109,10 +112,7 @@ async def generate_with_model_with_tools( llm_request = LLMRequest(model_set=model_config, request_type=request_type) response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( - prompt, - tools=tool_options, - temperature=temperature, - max_tokens=max_tokens + prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens ) return True, response, reasoning_content, model_name, tool_call diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index c5a6a101..f4ba0b71 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -435,9 +435,7 @@ def build_readable_messages_to_str( Returns: 格式化后的可读字符串 """ - return build_readable_messages( - messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions - ) + return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) async def build_readable_messages_with_details( @@ -491,8 +489,6 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)] - - def translate_pid_to_description(pid: str) -> str: image = Images.get_or_none(Images.image_id == pid) description = "" @@ -500,4 +496,4 @@ def translate_pid_to_description(pid: str) -> str: description = image.description else: description = "[图片]" - return description \ No newline at end of file + return description diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index 693e42b4..d428eb28 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str: Returns: str: 插件目录的绝对路径。 - + Raises: ValueError: 如果插件不存在。 """ diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index e4ba2ee4..2e14b0c8 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -2,7 +2,7 @@ from pathlib import Path from src.common.logger import get_logger -logger = get_logger("plugin_manager") # 复用plugin_manager名称 +logger = get_logger("plugin_manager") # 复用plugin_manager名称 def register_plugin(cls): diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py index bb6f06b4..05abf0b7 100644 --- a/src/plugin_system/core/global_announcement_manager.py +++ b/src/plugin_system/core/global_announcement_manager.py @@ -88,7 +88,7 @@ class GlobalAnnouncementManager: return False self._user_disabled_tools[chat_id].append(tool_name) return True - + def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: """启用特定聊天的某个工具""" if chat_id in self._user_disabled_tools: @@ -111,7 +111,7 @@ class GlobalAnnouncementManager: def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有事件处理器""" return self._user_disabled_event_handlers.get(chat_id, []).copy() - + def get_disabled_chat_tools(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有工具""" return self._user_disabled_tools.get(chat_id, []).copy() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 014b7a0c..122a9ea2 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -224,7 +224,7 @@ class PluginManager: list: 已注册的插件类名称列表。 """ return list(self.plugin_classes.keys()) - + def get_plugin_path(self, plugin_name: str) -> Optional[str]: """ 获取指定插件的路径。 @@ -401,9 +401,7 @@ class PluginManager: command_components = [ c for c in plugin_info.components if c.component_type == ComponentType.COMMAND ] - tool_components = [ - c for c in plugin_info.components if c.component_type == ComponentType.TOOL - ] + tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL] event_handler_components = [ c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER ] diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 17e23685..131adc6a 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -149,10 +149,10 @@ class ToolExecutor: if not tool_calls: logger.debug(f"{self.log_prefix}无需执行工具") return [], [] - + # 提取tool_calls中的函数名称 func_names = [call.func_name for call in tool_calls if call.func_name] - + logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}") # 执行每个工具调用 @@ -195,7 +195,9 @@ class ToolExecutor: return tool_results, used_tools - async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: + async def execute_tool_call( + self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None + ) -> Optional[Dict[str, Any]]: # sourcery skip: use-assigned-variable """执行单个工具调用 diff --git a/src/plugins/built_in/emoji_plugin/plugin.py b/src/plugins/built_in/emoji_plugin/plugin.py index 94a8b7d1..b7afc522 100644 --- a/src/plugins/built_in/emoji_plugin/plugin.py +++ b/src/plugins/built_in/emoji_plugin/plugin.py @@ -63,5 +63,4 @@ class CoreActionsPlugin(BasePlugin): if self.get_config("components.enable_emoji", True): components.append((EmojiAction.get_action_info(), EmojiAction)) - return components diff --git a/src/plugins/built_in/memory/build_memory.py b/src/plugins/built_in/memory/build_memory.py index 939f6c23..e53b57fe 100644 --- a/src/plugins/built_in/memory/build_memory.py +++ b/src/plugins/built_in/memory/build_memory.py @@ -74,7 +74,9 @@ class BuildMemoryAction(BaseAction): # 动作基本信息 action_name = "build_memory" - action_description = "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息" + action_description = ( + "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息" + ) # 动作参数定义 action_parameters = { @@ -103,31 +105,34 @@ class BuildMemoryAction(BaseAction): concept_name = self.action_data.get("concept_name", "") # 2. 获取目标用户信息 - - # 对 concept_name 进行jieba分词 concept_name_tokens = cut_key_words(concept_name) # logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}") - + filtered_concept_name_tokens = [ - token for token in concept_name_tokens if all(keyword not in token for keyword in global_config.memory.memory_ban_words) + token + for token in concept_name_tokens + if all(keyword not in token for keyword in global_config.memory.memory_ban_words) ] - + if not filtered_concept_name_tokens: logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆") return False, "过滤后的概念名称列表为空,跳过添加记忆" - - similar_topics_dict = hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(filtered_concept_name_tokens) - await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(concept_description, similar_topics_dict) - - - + + similar_topics_dict = ( + hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords( + filtered_concept_name_tokens + ) + ) + await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar( + concept_description, similar_topics_dict + ) + return True, f"成功添加记忆: {concept_name}" - + except Exception as e: logger.error(f"{self.log_prefix} 构建记忆时出错: {e}") return False, f"构建记忆时出错: {e}" - # 还缺一个关系的太多遗忘和对应的提取 diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index c2489a38..ba60f451 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -425,7 +425,7 @@ class ManagementCommand(BaseCommand): await self._send_message(f"本地禁用组件成功: {component_name}") else: await self._send_message(f"本地禁用组件失败: {component_name}") - + async def _send_message(self, message: str): await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False) diff --git a/src/plugins/built_in/relation/relation.py b/src/plugins/built_in/relation/relation.py index 1f6f0d0f..bc65b1aa 100644 --- a/src/plugins/built_in/relation/relation.py +++ b/src/plugins/built_in/relation/relation.py @@ -107,7 +107,7 @@ class BuildRelationAction(BaseAction): if not person.is_known: logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆") return False, f"用户 {person_name} 不存在,跳过添加记忆" - + person.last_know = time.time() person.know_times += 1 person.sync_to_database() @@ -178,7 +178,9 @@ class BuildRelationAction(BaseAction): chat_model_config = models.get("utils") success, update_memory, _, _ = await llm_api.generate_with_model( - prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore + prompt, + model_config=chat_model_config, + request_type="relation.category.update", # type: ignore ) update_memory_data = json.loads(repair_json(update_memory)) @@ -190,7 +192,7 @@ class BuildRelationAction(BaseAction): # 新记忆 person.memory_points.append(f"{category}:{new_memory}:1.0") person.sync_to_database() - + logger.info(f"{self.log_prefix} 为{person.person_name}新增记忆点: {new_memory}") return True, f"为{person.person_name}新增记忆点: {new_memory}" @@ -207,14 +209,15 @@ class BuildRelationAction(BaseAction): person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}") person.sync_to_database() - logger.info(f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}") + logger.info( + f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" + ) return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" else: logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") return False, f"删除{person.person_name}的记忆点失败: {memory_content}" - return True, "关系动作执行成功"