diff --git a/.gitignore b/.gitignore index 1a5ca0a7..6d8b249f 100644 --- a/.gitignore +++ b/.gitignore @@ -278,6 +278,8 @@ logs .vscode /config/* +config/mcp_config.json +!config/mcp_config.json.template config/old/bot_config_20250405_212257.toml temp/ diff --git a/code_scripts/migrate_expression_jargon_db.py b/code_scripts/migrate_expression_jargon_db.py new file mode 100644 index 00000000..8816df6c --- /dev/null +++ b/code_scripts/migrate_expression_jargon_db.py @@ -0,0 +1,535 @@ +from argparse import ArgumentParser, Namespace +from collections.abc import Iterable +from datetime import datetime +from pathlib import Path +from sys import path as sys_path +from typing import Any, Optional + +import json +import sqlite3 + +from sqlalchemy import text +from sqlmodel import Session, SQLModel, create_engine, delete + +ROOT_PATH = Path(__file__).resolve().parent.parent +if str(ROOT_PATH) not in sys_path: + sys_path.insert(0, str(ROOT_PATH)) + +from src.common.database.database_model import Expression, Jargon, ModifiedBy + + +def build_argument_parser() -> ArgumentParser: + """构建命令行参数解析器。""" + parser = ArgumentParser( + description="将旧版 expression/jargon 数据迁移到新版 expressions/jargons 数据库。" + ) + parser.add_argument("--source-db", dest="source_db", help="旧版 SQLite 数据库路径") + parser.add_argument("--target-db", dest="target_db", help="新版 SQLite 数据库路径") + parser.add_argument( + "--clear-target", + dest="clear_target", + action="store_true", + help="迁移前清空目标库中的 expressions 和 jargons 表", + ) + return parser + + +def prompt_path(prompt_text: str, current_value: Optional[str] = None) -> Path: + """读取数据库路径输入。""" + while True: + suffix = f" [{current_value}]" if current_value else "" + raw_text = input(f"{prompt_text}{suffix}: ").strip() + value = raw_text or current_value or "" + if not value: + print("路径不能为空,请重新输入。") + continue + return Path(value).expanduser().resolve() + + +def prompt_yes_no(prompt_text: str, default: bool = False) -> bool: + """读取是否确认输入。""" + default_hint = "Y/n" if default else "y/N" + raw_text = input(f"{prompt_text} [{default_hint}]: ").strip().lower() + if not raw_text: + return default + return raw_text in {"y", "yes"} + + +def ensure_sqlite_file(path: Path, should_exist: bool) -> None: + """校验 SQLite 文件路径。""" + if should_exist and not path.is_file(): + raise FileNotFoundError(f"数据库文件不存在:{path}") + if not should_exist: + path.parent.mkdir(parents=True, exist_ok=True) + + +def connect_sqlite(path: Path) -> sqlite3.Connection: + """创建 SQLite 连接。""" + connection = sqlite3.connect(path) + connection.row_factory = sqlite3.Row + return connection + + +def table_exists(connection: sqlite3.Connection, table_name: str) -> bool: + """检查表是否存在。""" + result = connection.execute( + "SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? LIMIT 1", + (table_name,), + ).fetchone() + return result is not None + + +def resolve_source_table_name(connection: sqlite3.Connection, candidates: list[str]) -> str: + """从候选表名中解析实际存在的表名。""" + for table_name in candidates: + if table_exists(connection, table_name): + return table_name + raise ValueError(f"未找到候选表:{', '.join(candidates)}") + + +def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[str]: + """获取表字段名集合。""" + rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall() + return {str(row["name"]) for row in rows} + + +def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]: + """获取表字段是否允许 NULL 的映射。""" + rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall() + return {str(row["name"]): not bool(row["notnull"]) for row in rows} + + +def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]: + """读取整张表的数据。""" + return connection.execute(f"SELECT * FROM {table_name}").fetchall() + + +def normalize_optional_text(raw_value: Any) -> Optional[str]: + """标准化可空文本字段。""" + if raw_value is None: + return None + return str(raw_value) + + +def ensure_nullable_compatibility( + table_name: str, + column_name: str, + row_id: Any, + value: Any, + nullable_map: dict[str, bool], +) -> None: + """检查待迁移值是否与目标表可空约束兼容。""" + if value is None and not nullable_map.get(column_name, True): + raise ValueError( + f"目标表 {table_name}.{column_name} 不允许 NULL,但源记录 id={row_id} 的该字段为 NULL。" + ) + + +def normalize_string_list(raw_value: Any) -> list[str]: + """将旧库中的 JSON/文本字段标准化为字符串列表。""" + if raw_value is None: + return [] + if isinstance(raw_value, list): + return [str(item).strip() for item in raw_value if str(item).strip()] + if isinstance(raw_value, str): + raw_text = raw_value.strip() + if not raw_text: + return [] + try: + parsed = json.loads(raw_text) + except json.JSONDecodeError: + return [raw_text] + if isinstance(parsed, list): + return [str(item).strip() for item in parsed if str(item).strip()] + if isinstance(parsed, str): + parsed_text = parsed.strip() + return [parsed_text] if parsed_text else [] + if parsed is None: + return [] + return [str(parsed).strip()] + return [str(raw_value).strip()] + + +def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]: + """标准化审核来源字段。""" + if raw_value is None: + return None + + normalized_raw_value = raw_value + if isinstance(raw_value, str): + raw_text = raw_value.strip() + if raw_text.startswith('"') and raw_text.endswith('"'): + try: + normalized_raw_value = json.loads(raw_text) + except json.JSONDecodeError: + normalized_raw_value = raw_text + else: + normalized_raw_value = raw_text + + value = str(normalized_raw_value).strip().lower() + if value in {"", "none", "null"}: + return None + if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}: + return ModifiedBy.AI + if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}: + return ModifiedBy.USER + return None + + +def parse_optional_bool(raw_value: Any) -> Optional[bool]: + """解析可空布尔值,兼容整数和字符串。""" + if raw_value is None: + return None + if isinstance(raw_value, bool): + return raw_value + if isinstance(raw_value, int): + return bool(raw_value) + if isinstance(raw_value, float): + return bool(int(raw_value)) + + value = str(raw_value).strip().lower() + if value in {"", "none", "null"}: + return None + if value in {"1", "true", "t", "yes", "y"}: + return True + if value in {"0", "false", "f", "no", "n"}: + return False + raise ValueError(f"无法解析布尔值:{raw_value}") + + +def parse_bool(raw_value: Any, default: bool = False) -> bool: + """解析非空布尔值。""" + parsed = parse_optional_bool(raw_value) + return default if parsed is None else parsed + + +def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]: + """将旧库中的 Unix 时间戳转换为 datetime。""" + if raw_value is None or raw_value == "": + return datetime.now() if fallback_now else None + if isinstance(raw_value, datetime): + return raw_value + try: + return datetime.fromtimestamp(float(raw_value)) + except (TypeError, ValueError, OSError, OverflowError): + return datetime.now() if fallback_now else None + + +def build_session_id_dict(raw_chat_id: Any, fallback_count: int) -> str: + """将旧版 jargon.chat_id 转换为新版 session_id_dict。""" + if raw_chat_id is None: + return json.dumps({}, ensure_ascii=False) + + if isinstance(raw_chat_id, str): + raw_text = raw_chat_id.strip() + else: + raw_text = str(raw_chat_id).strip() + + if not raw_text: + return json.dumps({}, ensure_ascii=False) + + try: + parsed = json.loads(raw_text) + except json.JSONDecodeError: + return json.dumps({raw_text: max(fallback_count, 1)}, ensure_ascii=False) + + if isinstance(parsed, str): + parsed_text = parsed.strip() + session_counts = {parsed_text: max(fallback_count, 1)} if parsed_text else {} + return json.dumps(session_counts, ensure_ascii=False) + + if not isinstance(parsed, list): + return json.dumps({}, ensure_ascii=False) + + session_counts: dict[str, int] = {} + for item in parsed: + if not isinstance(item, list) or not item: + continue + session_id = str(item[0]).strip() + if not session_id: + continue + item_count = 1 + if len(item) > 1: + try: + item_count = int(item[1]) + except (TypeError, ValueError): + item_count = 1 + session_counts[session_id] = max(item_count, 1) + + return json.dumps(session_counts, ensure_ascii=False) + + +def create_target_engine(target_db_path: Path): + """创建目标数据库引擎。""" + return create_engine( + f"sqlite:///{target_db_path.as_posix()}", + echo=False, + connect_args={"check_same_thread": False}, + ) + + +def clear_target_tables(session: Session) -> None: + """清空目标表。""" + session.exec(delete(Expression)) + session.exec(delete(Jargon)) + + +def migrate_expressions( + old_rows: Iterable[sqlite3.Row], + target_session: Session, + expression_columns: set[str], +) -> int: + """迁移 expression 数据。""" + migrated_count = 0 + modified_by_ai_count = 0 + modified_by_user_count = 0 + modified_by_null_count = 0 + unknown_modified_by_values: dict[str, int] = {} + for row in old_rows: + create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True) + last_active_time = timestamp_to_datetime( + row["last_active_time"] if "last_active_time" in expression_columns else None, + True, + ) + content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None) + raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None + modified_by = normalize_modified_by(raw_modified_by) + if modified_by == ModifiedBy.AI: + modified_by_ai_count += 1 + elif modified_by == ModifiedBy.USER: + modified_by_user_count += 1 + else: + modified_by_null_count += 1 + if raw_modified_by not in (None, "", "null", "NULL", "None"): + unknown_key = str(raw_modified_by) + unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1 + + target_session.execute( + text( + """ + INSERT INTO expressions ( + id, + situation, + style, + content_list, + count, + last_active_time, + create_time, + session_id, + checked, + rejected, + modified_by + ) VALUES ( + :id, + :situation, + :style, + :content_list, + :count, + :last_active_time, + :create_time, + :session_id, + :checked, + :rejected, + :modified_by + ) + """ + ), + { + "id": int(row["id"]) if row["id"] is not None else None, + "situation": str(row["situation"]).strip(), + "style": str(row["style"]).strip(), + "content_list": json.dumps(content_list, ensure_ascii=False), + "count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1, + "last_active_time": last_active_time or datetime.now(), + "create_time": create_time or datetime.now(), + "session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None, + "checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False), + "rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False), + "modified_by": modified_by.name if modified_by is not None else None, + }, + ) + migrated_count += 1 + + print( + "Expression modified_by 迁移统计:" + f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}" + ) + if unknown_modified_by_values: + preview_items = list(unknown_modified_by_values.items())[:10] + preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items) + print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}") + return migrated_count + + +def migrate_jargons( + old_rows: Iterable[sqlite3.Row], + target_session: Session, + jargon_columns: set[str], + jargon_nullable_map: dict[str, bool], +) -> int: + """迁移 jargon 数据。""" + migrated_count = 0 + coerced_meaning_null_count = 0 + for row in old_rows: + count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0 + raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None + raw_content_list = normalize_string_list(raw_content_value) + meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None) + is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None) + inference_content_key = ( + "inference_content_only" + if "inference_content_only" in jargon_columns + else "inference_with_content_only" + if "inference_with_content_only" in jargon_columns + else None + ) + + ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map) + + if meaning_value is None and not jargon_nullable_map.get("meaning", True): + meaning_value = "" + coerced_meaning_null_count += 1 + + # 显式执行 SQL,避免 ORM 在 None 场景下回填模型默认值。 + target_session.execute( + text( + """ + INSERT INTO jargons ( + id, + content, + raw_content, + meaning, + session_id_dict, + count, + is_jargon, + is_complete, + is_global, + last_inference_count, + inference_with_context, + inference_with_content_only + ) VALUES ( + :id, + :content, + :raw_content, + :meaning, + :session_id_dict, + :count, + :is_jargon, + :is_complete, + :is_global, + :last_inference_count, + :inference_with_context, + :inference_with_content_only + ) + """ + ), + { + "id": int(row["id"]) if row["id"] is not None else None, + "content": str(row["content"]).strip(), + "raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None, + "meaning": meaning_value, + "session_id_dict": build_session_id_dict( + row["chat_id"] if "chat_id" in jargon_columns else None, + fallback_count=count, + ), + "count": count, + "is_jargon": is_jargon_value, + "is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False), + "is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False), + "last_inference_count": ( + int(row["last_inference_count"]) + if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None + else 0 + ), + "inference_with_context": ( + str(row["inference_with_context"]) + if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None + else None + ), + "inference_with_content_only": ( + str(row[inference_content_key]) + if inference_content_key and row[inference_content_key] is not None + else None + ), + }, + ) + migrated_count += 1 + + if coerced_meaning_null_count > 0: + print( + f"警告:目标表 jargons.meaning 不允许 NULL,已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。" + ) + return migrated_count + + +def confirm_target_replacement(target_db_path: Path, clear_target: bool) -> bool: + """确认是否写入目标数据库。""" + if clear_target: + return prompt_yes_no(f"将清空目标库中的 expressions/jargons 后再迁移,确认继续吗?\n目标库:{target_db_path}") + return prompt_yes_no(f"将写入目标库,若主键冲突会导致迁移失败,确认继续吗?\n目标库:{target_db_path}") + + +def parse_arguments() -> Namespace: + """解析参数。""" + return build_argument_parser().parse_args() + + +def main() -> None: + """脚本入口。""" + args = parse_arguments() + + print("旧版 expression/jargon -> 新版 expressions/jargons 迁移工具") + source_db_path = prompt_path("请输入旧版数据库路径", args.source_db) + target_db_path = prompt_path("请输入新版数据库路径", args.target_db) + clear_target = args.clear_target or prompt_yes_no("迁移前是否清空目标库中的 expressions 和 jargons 表?", False) + + if source_db_path == target_db_path: + raise ValueError("旧版数据库路径和新版数据库路径不能相同。") + + ensure_sqlite_file(source_db_path, should_exist=True) + ensure_sqlite_file(target_db_path, should_exist=False) + + print(f"旧库:{source_db_path}") + print(f"新库:{target_db_path}") + print(f"清空目标表:{'是' if clear_target else '否'}") + + if not confirm_target_replacement(target_db_path, clear_target): + print("已取消迁移。") + return + + source_connection = connect_sqlite(source_db_path) + try: + expression_table_name = resolve_source_table_name(source_connection, ["expression", "expressions"]) + jargon_table_name = resolve_source_table_name(source_connection, ["jargon", "jargons"]) + expression_columns = get_table_columns(source_connection, expression_table_name) + jargon_columns = get_table_columns(source_connection, jargon_table_name) + expression_rows = load_rows(source_connection, expression_table_name) + jargon_rows = load_rows(source_connection, jargon_table_name) + finally: + source_connection.close() + + target_engine = create_target_engine(target_db_path) + SQLModel.metadata.create_all(target_engine) + + target_sqlite_connection = connect_sqlite(target_db_path) + try: + jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons") + finally: + target_sqlite_connection.close() + + with Session(target_engine) as target_session: + if clear_target: + clear_target_tables(target_session) + target_session.commit() + + expression_count = migrate_expressions(expression_rows, target_session, expression_columns) + jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map) + target_session.commit() + + print("迁移完成。") + print(f"已迁移 expression 记录:{expression_count}") + print(f"已迁移 jargon 记录:{jargon_count}") + + +if __name__ == "__main__": + main() diff --git a/locales/en-US/startup.json b/locales/en-US/startup.json index 8ad1a8c4..7482666b 100644 --- a/locales/en-US/startup.json +++ b/locales/en-US/startup.json @@ -27,7 +27,7 @@ "startup.main_error": "Main process encountered an exception: {error}", "startup.opensource_free_notice": " This project is completely free and open-source software, released under the GPL-3.0 license", "startup.opensource_group": " Official group chat: ", - "startup.opensource_group_value": "1006149251", + "startup.opensource_group_value": "766798517", "startup.opensource_repo": " Official repository: ", "startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot", "startup.opensource_resale_warning": " Reselling this software as a \"product\" or concealing its open-source nature violates the license!", diff --git a/locales/ja-JP/startup.json b/locales/ja-JP/startup.json index 94ec95ec..6a855dc6 100644 --- a/locales/ja-JP/startup.json +++ b/locales/ja-JP/startup.json @@ -27,7 +27,7 @@ "startup.main_error": "メインプロセスで例外が発生しました: {error}", "startup.opensource_free_notice": " 本プロジェクトは完全無料のオープンソースソフトウェアであり、GPL-3.0 ライセンスのもとで公開されています", "startup.opensource_group": " 公式グループ: ", - "startup.opensource_group_value": "1006149251", + "startup.opensource_group_value": "766798517", "startup.opensource_repo": " 公式リポジトリ: ", "startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot", "startup.opensource_resale_warning": " 本ソフトウェアを「商品」として転売したり、オープンソースであることを隠すことはライセンス違反です!", diff --git a/locales/ko/startup.json b/locales/ko/startup.json index 1a31a17d..2f7ee595 100644 --- a/locales/ko/startup.json +++ b/locales/ko/startup.json @@ -27,7 +27,7 @@ "startup.main_error": "메인 프로세스에서 예외 발생: {error}", "startup.opensource_free_notice": " 본 프로젝트는 완전 무료 오픈소스 소프트웨어이며, GPL-3.0 라이선스로 배포됩니다", "startup.opensource_group": " 공식 그룹: ", - "startup.opensource_group_value": "1006149251", + "startup.opensource_group_value": "766798517", "startup.opensource_repo": " 공식 저장소: ", "startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot", "startup.opensource_resale_warning": " 본 소프트웨어를 '상품'으로 재판매하거나 오픈소스임을 숨기는 행위는 라이선스 위반입니다!", diff --git a/locales/zh-CN/startup.json b/locales/zh-CN/startup.json index c70441df..2290b652 100644 --- a/locales/zh-CN/startup.json +++ b/locales/zh-CN/startup.json @@ -27,7 +27,7 @@ "startup.main_error": "主程序发生异常: {error}", "startup.opensource_free_notice": " 本项目是完全免费的开源软件,基于 GPL-3.0 协议发布", "startup.opensource_group": " 官方群聊: ", - "startup.opensource_group_value": "1006149251", + "startup.opensource_group_value": "766798517", "startup.opensource_repo": " 官方仓库: ", "startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot", "startup.opensource_resale_warning": " 将本软件作为「商品」倒卖、隐瞒开源性质均违反协议!", diff --git a/mai_knowledge/knowledge.json b/mai_knowledge/knowledge.json new file mode 100644 index 00000000..feae33c6 --- /dev/null +++ b/mai_knowledge/knowledge.json @@ -0,0 +1,887 @@ +{ + "1": [ + { + "id": "know_1_1774770946.623486", + "content": "备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:55:46.623486" + }, + { + "id": "know_1_1774771765.051286", + "content": "性别为女性", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:09:25.051286" + }, + { + "id": "know_1_1774771851.333504", + "content": "用户是I人(内向型人格)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:10:51.333504" + }, + { + "id": "know_1_1774771894.517183", + "content": "用户名为小千,被他人称为“宝宝”,结合语境推测为女性", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:11:34.517183" + }, + { + "id": "know_1_1774771923.859455", + "content": "小千是I人(内向型人格)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:03.859455" + }, + { + "id": "know_1_1774771993.479732", + "content": "小千是女性", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:13:13.479732" + }, + { + "id": "know_1_1774772079.496335", + "content": "用户名为小千,被他人称为“宝宝”,推测为女性或处于亲密社交语境中(注:性别非明确陈述,但基于昵称高频使用及语境,高置信度归纳为女性或女性化称呼偏好,若严格遵循“明确表达”则此项存疑。鉴于指令要求“高置信度可归纳”,且群内互动模式符合典型女性向昵称习惯,此处提取为倾向性事实)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:14:39.496335" + }, + { + "id": "know_1_1774773435.68612", + "content": "用户名为小千", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:37:15.686120" + }, + { + "id": "know_1_1774773676.69252", + "content": "用户自称猫娘(二次元人设)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:41:16.692520" + } + ], + "2": [ + { + "id": "know_2_1774768612.298128", + "content": "性格自信,常以“真理在我这边”自居", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:16:52.298128" + }, + { + "id": "know_2_1774768645.029561", + "content": "性格自信且带有自嘲精神,喜欢用轻松调侃的方式应对他人评价", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:17:25.029561" + }, + { + "id": "know_2_1774771068.355999", + "content": "喜欢用夸张、幽默或古风修辞表达观点", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:57:48.355999" + }, + { + "id": "know_2_1774771397.764996", + "content": "性格幽默,喜欢使用夸张比喻和古风表达", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:03:17.764996" + }, + { + "id": "know_2_1774771471.03367", + "content": "幽默风趣,喜欢使用夸张比喻和玩梗", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:04:31.033670" + }, + { + "id": "know_2_1774771765.052285", + "content": "性格不孤僻,社交圈较广", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:09:25.052285" + }, + { + "id": "know_2_1774771851.33601", + "content": "用户表现出社恐倾向,喜欢回避社交互动", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:10:51.336010" + }, + { + "id": "know_2_1774771894.520185", + "content": "性格偏向内向(I人),有社恐倾向,喜欢回避社交压力", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:11:34.520185" + }, + { + "id": "know_2_1774771958.585244", + "content": "小千是内向型人格(I人)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:38.585244" + }, + { + "id": "know_2_1774771993.481732", + "content": "小千性格内向(I人)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:13:13.481732" + } + ], + "3": [ + { + "id": "know_3_1774773676.695521", + "content": "喜欢冰淇淋", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:41:16.695521" + } + ], + "4": [], + "5": [], + "6": [ + { + "id": "know_6_1774768486.451792", + "content": "正在搭建 RAG 测试集", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:14:46.451792" + }, + { + "id": "know_6_1774768517.122405", + "content": "熟悉 NapCat、RAG 等技术工具及互联网梗文化", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:15:17.122405" + }, + { + "id": "know_6_1774769406.247087", + "content": "喜欢动漫风格插画", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:30:06.247087" + }, + { + "id": "know_6_1774770487.207364", + "content": "关注显卡硬件参数(如显存、型号)及深度学习/炼丹应用", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:48:07.207364" + }, + { + "id": "know_6_1774770487.209372", + "content": "对游戏光影效果感兴趣", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:48:07.209372" + }, + { + "id": "know_6_1774770603.063873", + "content": "喜欢玩《我的世界》和VRChat", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:50:03.063873" + }, + { + "id": "know_6_1774770654.654349", + "content": "关注显卡硬件参数(如4090、48G显存、5090)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:50:54.654349" + }, + { + "id": "know_6_1774770654.655356", + "content": "使用VRChat进行社交娱乐", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:50:54.655356" + }, + { + "id": "know_6_1774770734.287947", + "content": "关注显卡硬件(如4090、3050)及AI炼丹技术", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:52:14.287947" + }, + { + "id": "know_6_1774770734.289944", + "content": "玩《我的世界》并配置光影效果", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:52:14.289944" + }, + { + "id": "know_6_1774770734.291944", + "content": "计划游玩VRChat", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:52:14.291944" + }, + { + "id": "know_6_1774771033.111011", + "content": "喜欢玩VRChat", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:57:13.111011" + }, + { + "id": "know_6_1774771068.358999", + "content": "关注VRChat等虚拟现实游戏及硬件性能话题", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:57:48.358999" + }, + { + "id": "know_6_1774771233.980219", + "content": "使用VRChat(VRC)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:00:33.980219" + }, + { + "id": "know_6_1774771397.766996", + "content": "对VRChat(VRC)及虚拟形象社交感兴趣", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:03:17.766996" + }, + { + "id": "know_6_1774771471.03567", + "content": "对VRChat等虚拟社交游戏感兴趣", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:04:31.035670" + }, + { + "id": "know_6_1774771894.521183", + "content": "熟悉二次元文化、动漫角色及互联网流行梗(Meme)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:11:34.521183" + }, + { + "id": "know_6_1774771923.861534", + "content": "小千玩CS:GO游戏", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:03.861534" + }, + { + "id": "know_6_1774771958.587243", + "content": "回声者_Echoderd喜欢玩CS:GO游戏", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:38.587243" + }, + { + "id": "know_6_1774771993.483732", + "content": "小千喜欢二次元文化及动漫游戏圈梗", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:13:13.483732" + }, + { + "id": "know_6_1774772079.499335", + "content": "熟悉并喜爱二次元文化、动漫角色及互联网梗图(如阴间美学、病娇系、黑长直萌妹等风格)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:14:39.499335" + }, + { + "id": "know_6_1774772112.716455", + "content": "小千关注CS:GO游戏及中考备考话题", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:15:12.716455" + }, + { + "id": "know_6_1774772154.873237", + "content": "用户玩CS:GO游戏", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:15:54.873237" + }, + { + "id": "know_6_1774772186.438797", + "content": "玩CS:GO游戏", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:16:26.438797" + }, + { + "id": "know_6_1774772730.867535", + "content": "熟悉《我的青春恋爱物语果然有问题》及二次元表情包文化", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:25:30.867535" + }, + { + "id": "know_6_1774773338.849271", + "content": "熟悉《原神》等二次元游戏及网络梗文化", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:35:38.849271" + }, + { + "id": "know_6_1774773371.406209", + "content": "关注高分屏字体显示效果", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:36:11.406209" + }, + { + "id": "know_6_1774773401.48921", + "content": "熟悉电脑显示技术(如高分屏字体选择)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:36:41.489210" + }, + { + "id": "know_6_1774773435.688119", + "content": "关注高分屏显示效果与字体选择(无衬线/衬线体)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:37:15.688119" + }, + { + "id": "know_6_1774773608.256103", + "content": "关注屏幕字体与分辨率(无衬线/有衬线)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:40:08.256103" + }, + { + "id": "know_6_1774773645.671546", + "content": "关注屏幕分辨率与字体显示效果(高分屏/无衬线体)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:40:45.671546" + }, + { + "id": "know_6_1774773676.698035", + "content": "关注字体设计(无衬线体)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:41:16.698035" + }, + { + "id": "know_6_1774773740.83822", + "content": "喜欢二次元文化及 VTuber 风格内容", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:42:20.838220" + } + ], + "7": [ + { + "id": "know_7_1774768517.120403", + "content": "从事 RAG 测试集搭建或相关技术工作", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:15:17.120403" + }, + { + "id": "know_7_1774768573.741823", + "content": "从事 RAG(检索增强生成)测试集搭建相关工作", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:16:13.741823" + }, + { + "id": "know_7_1774770603.062873", + "content": "备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:50:03.062873" + }, + { + "id": "know_7_1774771471.036668", + "content": "正在备战中考的学生", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:04:31.036668" + }, + { + "id": "know_7_1774771923.862535", + "content": "小千正在备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:03.862535" + }, + { + "id": "know_7_1774771958.588749", + "content": "回声者_Echoderd正在备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:12:38.588749" + }, + { + "id": "know_7_1774772112.714455", + "content": "小千使用AI模型进行对话", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:15:12.714455" + }, + { + "id": "know_7_1774772154.870238", + "content": "用户正在备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:15:54.870238" + }, + { + "id": "know_7_1774773185.194069", + "content": "使用 NapCat 框架", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:33:05.194069" + }, + { + "id": "know_7_1774773338.851275", + "content": "使用 NapCat 框架,具备技术平台认知能力", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:35:38.851275" + }, + { + "id": "know_7_1774773371.403696", + "content": "熟悉 NapCat 框架", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:36:11.403696" + } + ], + "8": [ + { + "id": "know_8_1774770946.624486", + "content": "日常逛游戏地图", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:55:46.624486" + }, + { + "id": "know_8_1774771397.769034", + "content": "备考中考期间仍保持日常游戏娱乐习惯", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:03:17.769034" + }, + { + "id": "know_8_1774771851.338018", + "content": "用户有备考中考的学习任务", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:10:51.338018" + }, + { + "id": "know_8_1774771894.523189", + "content": "备考中(备战中考)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:11:34.523189" + }, + { + "id": "know_8_1774771993.484733", + "content": "小千有打CS:GO的游戏习惯", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:13:13.484733" + }, + { + "id": "know_8_1774772079.501334", + "content": "有在高压环境下(如中考前)进行游戏娱乐(CS:GO)的习惯,自称或认同“摆烂”的生活态度", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:14:39.501334" + }, + { + "id": "know_8_1774772154.875743", + "content": "用户在备考期间有打游戏摸鱼的习惯", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:15:54.875743" + }, + { + "id": "know_8_1774773435.690121", + "content": "习惯使用表情包表达情绪或进行网络互动", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:37:15.690121" + }, + { + "id": "know_8_1774773676.701034", + "content": "备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:41:16.701034" + } + ], + "9": [], + "10": [ + { + "id": "know_10_1774768486.452792", + "content": "沟通风格带有调侃和自信,习惯用反问句表达观点", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:14:46.452792" + }, + { + "id": "know_10_1774768517.121403", + "content": "沟通风格带有较强的好胜心和防御性,习惯用反问和调侃回应质疑", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:15:17.121403" + }, + { + "id": "know_10_1774768573.742824", + "content": "沟通风格幽默,擅长使用逻辑闭环和反问句式进行辩论或调侃", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:16:13.742824" + }, + { + "id": "know_10_1774768612.299126", + "content": "沟通风格幽默风趣,擅长使用网络梗和表情包互动", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:16:52.299126" + }, + { + "id": "know_10_1774768612.299845", + "content": "偶尔会文绉绉地表达(自称“文青病犯了”),但能迅速切换回口语化", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:16:52.299845" + }, + { + "id": "know_10_1774768645.028561", + "content": "沟通风格幽默风趣,偶尔会文青病发作使用古风表达", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:17:25.028561" + }, + { + "id": "know_10_1774769406.249584", + "content": "沟通中常使用文言文或半文言表达", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:30:06.249584" + }, + { + "id": "know_10_1774769406.251097", + "content": "习惯用反问句和夸张语气进行互动", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:30:06.251097" + }, + { + "id": "know_10_1774770487.211056", + "content": "沟通风格幽默,常使用网络梗和夸张表达", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:48:07.211056" + }, + { + "id": "know_10_1774771471.038677", + "content": "沟通风格轻松随意,善于接话和调侃", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:04:31.038677" + }, + { + "id": "know_10_1774771765.053285", + "content": "沟通风格活泼,喜欢使用语气词和表情符号撒娇", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:09:25.053285" + }, + { + "id": "know_10_1774772079.503333", + "content": "沟通风格幽默调侃,擅长用反话(如“烦到了”)和夸张修辞(如“耳朵起茧子”、“要报警了”)表达情绪", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:14:39.503333" + }, + { + "id": "know_10_1774773338.853274", + "content": "沟通风格幽默风趣,擅长玩梗与自嘲", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:35:38.853274" + }, + { + "id": "know_10_1774773371.408719", + "content": "喜欢用幽默调侃的方式回应他人", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:36:11.408719" + }, + { + "id": "know_10_1774773401.491209", + "content": "沟通风格幽默风趣,擅长玩梗和角色扮演", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:36:41.491209" + }, + { + "id": "know_10_1774773435.693121", + "content": "沟通风格幽默、喜欢玩梗和自嘲,擅长接话茬", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:37:15.693121" + }, + { + "id": "know_10_1774773532.488374", + "content": "沟通风格幽默,喜欢使用网络梗和表情包活跃气氛", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:38:52.488374" + }, + { + "id": "know_10_1774773532.490959", + "content": "在争论中倾向于据理力争,并自嘲或调侃对方阅读理解能力", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:38:52.490959" + }, + { + "id": "know_10_1774773569.709356", + "content": "喜欢用幽默、夸张和自嘲的方式活跃气氛", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:39:29.709356" + } + ], + "11": [ + { + "id": "know_11_1774771068.360999", + "content": "乐于接受并学习新的技术技巧(如加速器用法)", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:57:48.360999" + } + ], + "12": [ + { + "id": "know_12_1774770654.657355", + "content": "面对网络延迟问题倾向于寻找加速器解决方案", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T15:50:54.657355" + }, + { + "id": "know_12_1774773185.196068", + "content": "备战中考", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:33:05.196068" + }, + { + "id": "know_12_1774773740.836223", + "content": "面对压力或冲突时,倾向于通过撒娇、耍赖和寻求盟友支持来应对", + "metadata": { + "session_id": "628336b082552269377e9d0648e26c60", + "source": "maisaka_learning" + }, + "created_at": "2026-03-29T16:42:20.836223" + } + ] +} \ No newline at end of file diff --git a/plugins/MaiBot_MCPBridgePlugin/.gitignore b/plugins/MaiBot_MCPBridgePlugin/.gitignore deleted file mode 100644 index ebef83b0..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/.gitignore +++ /dev/null @@ -1,30 +0,0 @@ -# 运行时配置(包含用户敏感信息) -config.toml - -# 备份文件 -*.backup.* -*.bak - -# 日志 -logs/ -*.log -*.jsonl - -# Python 缓存 -__pycache__/ -*.py[cod] -*$py.class -*.so - -# 本地测试脚本(仓库不提交) -test_*.py - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# 系统文件 -.DS_Store -Thumbs.db diff --git a/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md b/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md deleted file mode 100644 index 0c3feb46..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md +++ /dev/null @@ -1,24 +0,0 @@ -# Changelog - -本文件记录 `MaiBot_MCPBridgePlugin` 的用户可感知变更。 - -## 2.0.0 - -- 配置入口统一:MCP 服务器仅使用 Claude Desktop `mcpServers` JSON(`servers.claude_config_json`) -- 兼容迁移:自动识别旧版 `servers.list` 并迁移为 `mcpServers`(需在 WebUI 保存一次固化) -- 保持功能不变:保留 Workflow(硬流程/工具链)与 ReAct(软流程)双轨制能力 -- 精简实现:移除旧的 WebUI 导入导出/快速添加服务器实现与 `tomlkit` 依赖 -- 易用性:完善 Workflow 变量替换(支持数组下标与 bracket 写法),并优化 WebUI 配置区顺序 - -## 1.9.0 - -- 双轨制架构:ReAct(软流程)+ Workflow(硬流程/工具链) - -## 1.8.0 - -- Workflow(工具链):多工具顺序执行、变量替换、自定义 Workflow 并注册为组合工具 - -## 1.7.0 - -- 断路器模式、状态刷新、工具搜索等易用性增强 - diff --git a/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md b/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md deleted file mode 100644 index 7299fe13..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md +++ /dev/null @@ -1,356 +0,0 @@ -# MCP 桥接插件开发文档 - -本文档面向开发者,介绍插件的架构设计、核心模块和扩展方式。 - -## 架构概览 - -``` -MaiBot_MCPBridgePlugin/ -├── plugin.py # 主插件文件,包含所有核心逻辑 -├── mcp_client.py # MCP 客户端封装 -├── tool_chain.py # 工具链(Workflow)模块 -├── core/ -│ └── claude_config.py # Claude Desktop mcpServers 解析/迁移 -├── config.toml # 运行时配置 -└── _manifest.json # 插件元数据 -``` - -## 核心模块 - -### 1. MCP 客户端 (`mcp_client.py`) - -封装了与 MCP 服务器的通信逻辑。 - -```python -from .mcp_client import mcp_manager, MCPServerConfig, TransportType - -# 添加服务器 -config = MCPServerConfig( - name="my-server", - transport=TransportType.STREAMABLE_HTTP, - url="https://mcp.example.com/mcp" -) -await mcp_manager.add_server(config) - -# 调用工具 -result = await mcp_manager.call_tool("server_tool_name", {"param": "value"}) -if result.success: - print(result.content) -``` - -**支持的传输类型:** -- `STDIO`: 本地进程通信 -- `SSE`: Server-Sent Events -- `HTTP`: HTTP 请求 -- `STREAMABLE_HTTP`: 流式 HTTP(推荐) - -### 2. 工具注册系统 - -MCP 工具通过动态类创建注册到 MaiBot: - -```python -# 创建工具代理类 -class MCPToolProxy(BaseTool): - name = "mcp_server_tool" - description = "工具描述" - parameters = [("param", ToolParamType.STRING, "参数描述", True, None)] - available_for_llm = True - - async def execute(self, function_args): - result = await mcp_manager.call_tool(self._mcp_tool_key, function_args) - return {"name": self.name, "content": result.content} -``` - -### 3. 工具链模块 (`tool_chain.py`) - -实现 Workflow 硬流程,支持多工具顺序执行。 - -```python -from .tool_chain import ToolChainDefinition, ToolChainStep, tool_chain_manager - -# 定义工具链 -chain = ToolChainDefinition( - name="search_and_detail", - description="搜索并获取详情", - input_params={"query": "搜索关键词"}, - steps=[ - ToolChainStep( - tool_name="mcp_server_search", - args_template={"keyword": "${input.query}"}, - output_key="search_result" - ), - ToolChainStep( - tool_name="mcp_server_detail", - args_template={"id": "${prev}"} - ) - ] -) - -# 注册并执行 -tool_chain_manager.add_chain(chain) -result = await tool_chain_manager.execute_chain("search_and_detail", {"query": "test"}) -``` - -**变量替换语法:** -- `${input.参数名}`: 用户输入 -- `${step.输出键}`: 指定步骤的输出 -- `${prev}`: 上一步输出 -- `${prev.字段}`: 上一步输出(JSON)的字段 -- `${step.geo.return.0.location}` / `${step.geo.return[0].location}`: 数组下标访问 -- `${step.geo['return'][0]['location']}`: bracket 写法(最通用) - -## 双轨制架构 - -### ReAct 软流程 - -将 MCP 工具注册到 MaiBot 的记忆检索 ReAct 系统,LLM 自主决策调用。 - -```python -def _register_tools_to_react(self) -> int: - from src.memory_system.retrieval_tools import register_memory_retrieval_tool - - def make_execute_func(tool_key: str): - async def execute_func(**kwargs) -> str: - result = await mcp_manager.call_tool(tool_key, kwargs) - return result.content if result.success else f"失败: {result.error}" - return execute_func - - register_memory_retrieval_tool( - name="mcp_tool_name", - description="工具描述", - parameters=[{"name": "param", "type": "string", "required": True}], - execute_func=make_execute_func("tool_key") - ) -``` - -### Workflow 硬流程 - -用户预定义的固定执行流程,注册为组合工具。 - -```python -def _register_tool_chains(self) -> None: - from src.plugin_system.core.component_registry import component_registry - - for chain_name, chain in tool_chain_manager.get_enabled_chains().items(): - info, tool_class = tool_chain_registry.register_chain(chain) - info.plugin_name = self.plugin_name - component_registry.register_component(info, tool_class) -``` - -## 配置系统 - -### MCP 服务器配置(Claude Desktop 规范) - -插件只接受 Claude Desktop 的 `mcpServers` JSON(见 `core/claude_config.py`)。配置入口统一为: - -- WebUI/配置文件:`[servers].claude_config_json` -- 命令:`/mcp import`(合并 `mcpServers`)与 `/mcp export`(导出当前 `mcpServers`) - -兼容迁移: -- 若检测到旧版 `servers.list`,会自动迁移为 `servers.claude_config_json`(仅迁移到内存配置,需 WebUI 保存一次固化)。 - -### WebUI 配置 Schema - -使用 `ConfigField` 定义 WebUI 配置项: - -```python -config_schema = { - "section_name": { - "field_name": ConfigField( - type=str, # 类型: str, bool, int, float - default="default_value", # 默认值 - description="字段描述", - label="显示标签", - input_type="textarea", # 输入类型: text, textarea, password - rows=5, # textarea 行数 - disabled=True, # 只读 - choices=["a", "b"], # 下拉选项 - hint="提示信息", - order=1, # 排序 - ), - }, -} -``` - -### 配置读取 - -```python -# 在组件中读取配置 -value = self.get_config("section.key", default="fallback") - -# 在插件类中读取 -value = self.config.get("section", {}).get("key", "default") -``` - -## 事件处理 - -### 启动事件 - -```python -class MCPStartupHandler(BaseEventHandler): - event_type = EventType.ON_START - handler_name = "mcp_startup" - - async def execute(self, message): - global _plugin_instance - if _plugin_instance: - await _plugin_instance._async_connect_servers() - return (True, True, None, None, None) -``` - -### 停止事件 - -```python -class MCPStopHandler(BaseEventHandler): - event_type = EventType.ON_STOP - handler_name = "mcp_stop" - - async def execute(self, message): - await mcp_manager.shutdown() - return (True, True, None, None, None) -``` - -## 命令系统 - -```python -class MCPStatusCommand(BaseCommand): - command_name = "mcp_status" - command_pattern = r"^/mcp(?:\s+(?P\S+))?(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, str, bool]: - action = self.matched_groups.get("action", "") - arg = self.matched_groups.get("arg", "") - - if action == "tools": - await self.send_text("工具列表...") - elif action == "reconnect": - await self._handle_reconnect(arg) - - return (True, None, True) # (成功, 消息, 拦截) -``` - -## 高级功能 - -### 调用追踪 - -```python -from plugin import tool_call_tracer, ToolCallRecord - -# 记录调用 -record = ToolCallRecord( - call_id="xxx", - timestamp=time.time(), - tool_name="tool", - server_name="server", - arguments={"key": "value"}, - success=True, - duration_ms=100.0 -) -tool_call_tracer.record(record) - -# 查询记录 -recent = tool_call_tracer.get_recent(10) -by_tool = tool_call_tracer.get_by_tool("tool_name") -``` - -### 调用缓存 - -```python -from plugin import tool_call_cache - -# 配置缓存 -tool_call_cache.configure( - enabled=True, - ttl=300, # 秒 - max_entries=200, - exclude_tools="mcp_*_time_*" # 排除模式 -) - -# 使用缓存 -cached = tool_call_cache.get("tool_name", {"param": "value"}) -if cached is None: - result = await call_tool(...) - tool_call_cache.set("tool_name", {"param": "value"}, result) -``` - -### 权限控制 - -```python -from plugin import permission_checker - -# 配置权限 -permission_checker.configure( - enabled=True, - default_mode="allow_all", # 或 "deny_all" - rules_json='[{"tool": "mcp_*_delete_*", "denied": ["qq:123:group"]}]', - quick_deny_groups="123456789", - quick_allow_users="111111111" -) - -# 检查权限 -allowed = permission_checker.check( - tool_name="mcp_server_delete", - chat_id="123456", - user_id="789", - is_group=True -) -``` - -### 断路器模式 - -MCP 客户端内置断路器,故障服务器快速失败: - -- 连续失败 N 次后熔断 -- 熔断期间直接返回错误 -- 定期尝试恢复 - -## 扩展开发 - -### 添加新的传输类型 - -1. 在 `mcp_client.py` 中添加 `TransportType` 枚举值 -2. 实现对应的连接逻辑 -3. 更新 `_create_transport()` 方法 - -### 添加新的工具类型 - -1. 继承 `BaseTool` 创建新类 -2. 在 `get_plugin_components()` 中注册 -3. 实现 `execute()` 方法 - -### 添加新的命令 - -1. 在 `MCPStatusCommand.execute()` 中添加新的 action 分支 -2. 或创建新的 `BaseCommand` 子类 - -## 调试技巧 - -### 日志级别 - -```python -from src.common.logger import get_logger -logger = get_logger("mcp_bridge_plugin") - -logger.debug("详细调试信息") -logger.info("一般信息") -logger.warning("警告") -logger.error("错误") -``` - -### 常用调试命令 - -```bash -/mcp # 查看状态 -/mcp tools # 查看工具列表 -/mcp trace # 查看调用记录 -/mcp cache # 查看缓存状态 -/mcp chain # 查看工具链 -``` - -## 更新日志 - -见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md` - -## 开发约定 - -- 本仓库不提交测试脚本/临时复现文件;如需本地验证,可自行在工作区创建未跟踪文件(建议放到 `.local/` 并加入 `.gitignore`)。 diff --git a/plugins/MaiBot_MCPBridgePlugin/README.md b/plugins/MaiBot_MCPBridgePlugin/README.md deleted file mode 100644 index 61aca8f5..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/README.md +++ /dev/null @@ -1,357 +0,0 @@ -# MCP 桥接插件 - -将 [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。 - -image - -## 🚀 快速开始 - -### 1. 安装 - -```bash -# 克隆到 MaiBot 插件目录 -cd /path/to/MaiBot/plugins -git clone https://github.com/CharTyr/MaiBot_MCPBridgePlugin.git MCPBridgePlugin - -# 安装依赖 -pip install mcp - -# 复制配置文件 -cd MCPBridgePlugin -cp config.example.toml config.toml -``` - -### 2. 添加服务器 - -编辑 `config.toml`,在 `[servers]` 的 `claude_config_json` 中填写 Claude Desktop 的 `mcpServers` JSON: - -```toml -[servers] -claude_config_json = ''' -{ - "mcpServers": { - "time": { "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" }, - "my-server": { "transport": "streamable_http", "url": "https://mcp.xxx.com/mcp", "headers": { "Authorization": "Bearer 你的密钥" } }, - "fetch": { "command": "uvx", "args": ["mcp-server-fetch"] } - } -} -''' -``` - -### 3. 启动 - -重启 MaiBot,或发送 `/mcp reconnect` - ---- - -## 📚 去哪找 MCP 服务器? - -| 平台 | 说明 | -|------|------| -| [mcp.modelscope.cn](https://mcp.modelscope.cn/) | 魔搭 ModelScope,免费推荐 | -| [smithery.ai](https://smithery.ai/) | MCP 服务器注册中心 | -| [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) | 官方服务器列表 | - ---- - -## 💡 常用命令 - -| 命令 | 说明 | -|------|------| -| `/mcp` | 查看连接状态 | -| `/mcp tools` | 查看可用工具 | -| `/mcp reconnect` | 重连服务器 | -| `/mcp trace` | 查看调用记录 | -| `/mcp cache` | 查看缓存状态 | -| `/mcp perm` | 查看权限配置 | -| `/mcp import ` | 🆕 导入 Claude Desktop 配置 | -| `/mcp export` | 🆕 导出配置 | -| `/mcp search <关键词>` | 🆕 搜索工具 | -| `/mcp chain` | 🆕 查看工具链 | -| `/mcp chain <名称>` | 🆕 查看工具链详情 | -| `/mcp chain test <名称> <参数>` | 🆕 测试执行工具链 | - ---- - -## ✨ 功能特性 - -### 核心功能 -- 🔌 多服务器同时连接 -- 📡 支持 stdio / SSE / HTTP / Streamable HTTP -- 🔄 自动重试、心跳检测、断线重连 -- 🖥️ WebUI 完整配置支持 - -### 双轨制架构 -- 🔄 **ReAct(软流程)**:LLM 自主决策,多轮动态调用 MCP 工具(适合探索式场景) -- 🔗 **Workflow(硬流程/工具链)**:用户预定义步骤顺序与参数传递(适合可控可复用场景) - -### 高级功能 -- 📦 Resources 支持(实验性) -- 📝 Prompts 支持(实验性) -- 🔄 结果后处理(LLM 摘要提炼) -- 🔍 调用追踪 / 🗄️ 调用缓存 / 🔐 权限控制 / 🚫 工具禁用 - -### 更新日志 -- 见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md` - ---- - -## ⚙️ 配置说明 - -### 服务器配置 - -```json -{ - "mcpServers": { - "server_name": { - "transport": "streamable_http", - "url": "https://..." - } - } -} -``` - -| 字段 | 说明 | -|------|------| -| `mcpServers.` | 服务器名称(唯一) | -| `enabled` | 是否启用(可选,默认 true) | -| `transport` | `stdio` / `sse` / `http` / `streamable_http` | -| `url` | 远程服务器地址 | -| `headers` | 🆕 鉴权头(如 `{"Authorization": "Bearer xxx"}`) | -| `command` / `args` | 本地服务器启动命令 | - -### 权限控制 - -**快捷配置(推荐):** -```toml -[permissions] -perm_enabled = true -quick_deny_groups = "123456789" # 禁用的群号 -quick_allow_users = "111111111" # 管理员白名单 -``` - -**高级规则:** -```json -[{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}] -``` - -### 工具禁用 - -```toml -[tools] -disabled_tools = ''' -mcp_filesystem_delete_file -mcp_filesystem_write_file -''' -``` - -### 调用缓存 - -```toml -[settings] -cache_enabled = true -cache_ttl = 300 -cache_exclude_tools = "mcp_*_time_*" -``` - ---- - -## ❓ 常见问题 - -**Q: 工具没有注册?** -- 检查 `enabled = true` -- 检查 MaiBot 日志错误信息 -- 确认 `pip install mcp` - -**Q: JSON 格式报错?** -- 多行 JSON 用 `'''` 三引号包裹 -- 使用英文双引号 `"` - -**Q: 如何手动重连?** -- `/mcp reconnect` 或 `/mcp reconnect 服务器名` - ---- - -## 📥 配置导入导出(Claude mcpServers) - -### 从 Claude Desktop 导入 - -如果你已有 Claude Desktop 的 MCP 配置,可以直接导入: - -``` -/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]},"fetch":{"command":"uvx","args":["mcp-server-fetch"]}}} -``` - -支持的格式: -- Claude Desktop 格式(`mcpServers` 对象) -- 兼容旧版:MaiBot servers 列表数组(将自动迁移为 `mcpServers`) - -### 导出配置 - -``` -/mcp export # 导出为 Claude Desktop 格式(默认) -/mcp export claude # 导出为 Claude Desktop 格式 -``` - -### 注意事项 -- 导入时会自动跳过同名服务器 -- 导入后需要发送 `/mcp reconnect` 使配置生效 -- 支持 stdio、sse、http、streamable_http 全部传输类型 - ---- - -## 🔗 Workflow(硬流程/工具链) - -工具链允许你将多个 MCP 工具按顺序执行,后续工具可以使用前序工具的输出作为输入。 - -### 1 分钟上手(推荐 WebUI) -1. 先完成 MCP 服务器配置并 `/mcp reconnect` -2. 发送 `/mcp tools`,复制你要用的工具名 -3. 打开 WebUI → 「Workflow(硬流程/工具链)」→ 用“快速添加”表单填入: - - 名称/描述 - - 输入参数(每行 `参数名=描述`) - - 执行步骤(每行 `工具名|参数JSON|输出键`) -4. 在“确认添加”中输入 `ADD` 并保存 - -### 快速添加工具链(推荐) - -在 WebUI 的「工具链」配置区,使用表单快速添加: - -1. **名称**: 填写工具链名称(英文,如 `search_and_detail`) -2. **描述**: 填写工具链用途(供 LLM 理解何时使用) -3. **输入参数**: 每行一个,格式 `参数名=描述` - ``` - query=搜索关键词 - max_results=最大结果数 - ``` -4. **执行步骤**: 每行一个,格式 `工具名|参数JSON|输出键` - ``` - mcp_server_search|{"keyword":"${input.query}"}|search_result - mcp_server_detail|{"id":"${prev}"}| - ``` -5. **确认添加**: 输入 `ADD` 并保存 - -### JSON 配置方式 - -也可以直接在「工具链列表」中编写 JSON: - -```json -[ - { - "name": "search_and_detail", - "description": "先搜索模组,再获取详情", - "input_params": { - "query": "搜索关键词" - }, - "steps": [ - { - "tool_name": "mcp_mcmod_search_mod", - "args_template": {"keyword": "${input.query}", "limit": 1}, - "output_key": "search_result", - "description": "搜索模组" - }, - { - "tool_name": "mcp_mcmod_get_mod_detail", - "args_template": {"mod_id": "${prev}"}, - "description": "获取详情" - } - ] - } -] -``` - -### 变量替换 - -| 变量格式 | 说明 | -|---------|------| -| `${input.参数名}` | 用户输入的参数 | -| `${step.输出键}` | 某个步骤的输出(通过 `output_key` 指定) | -| `${prev}` | 上一步的输出 | -| `${prev.字段}` | 上一步输出(JSON)的某个字段 | -| `${step.geo.return.0.location}` | 数组下标访问(dot) | -| `${step.geo.return[0].location}` | 数组下标访问([]) | -| `${step.geo['return'][0]['location']}` | bracket 写法(最通用) | - -### 工具链字段说明 - -| 字段 | 说明 | -|------|------| -| `name` | 工具链名称,将生成 `chain_xxx` 工具 | -| `description` | 描述,供 LLM 理解何时使用 | -| `input_params` | 输入参数定义 `{参数名: 描述}` | -| `steps` | 执行步骤数组 | -| `steps[].tool_name` | 要调用的工具名 | -| `steps[].args_template` | 参数模板,支持变量替换 | -| `steps[].output_key` | 输出存储键名(可选) | -| `steps[].optional` | 是否可选,失败时继续执行(默认 false) | - -### 命令 - -```bash -/mcp chain # 查看所有工具链 -/mcp chain list # 列出工具链 -/mcp chain <名称> # 查看详情 -/mcp chain test <名称> {"query": "JEI"} # 测试执行 -/mcp chain reload # 重新加载配置 -``` - ---- - -## 🔄 双轨制架构 - -MCP 桥接插件支持两种工具调用模式,可根据场景选择: - -### ReAct 软流程 - -LLM 自主决策的多轮工具调用模式,适合复杂、不确定的场景。 - -**工作原理:** -1. 用户提问 → LLM 分析需要什么信息 -2. LLM 选择调用工具 → 获取结果 -3. LLM 观察结果 → 决定是否需要更多信息 -4. 重复 2-3 直到信息足够 → 生成最终回答 - -**启用方式:** -在 WebUI「ReAct (软流程)」配置区启用,MCP 工具将自动注册到 MaiBot 的记忆检索 ReAct 系统。 - -**适用场景:** -- 复杂问题需要多步推理 -- 不确定需要调用哪些工具 -- 需要根据中间结果动态调整 - -### Workflow 硬流程 - -用户预定义的工作流,固定执行顺序,适合可靠、可控的场景。 - -**工作原理:** -1. 用户定义步骤顺序和参数传递 -2. 按顺序执行每个步骤 -3. 后续步骤可使用前序步骤的输出 -4. 返回最终结果 - -**适用场景:** -- 流程固定、可预测 -- 需要可靠、可重复的执行 -- 希望精确控制工具调用顺序 - -### 对比 - -| 特性 | ReAct 软流程 | Workflow 硬流程 | -|------|-------------|----------------| -| 决策者 | LLM 自主决策 | 用户预定义 | -| 灵活性 | 高,动态调整 | 低,固定流程 | -| 可预测性 | 低 | 高 | -| 适用场景 | 复杂、探索性任务 | 固定、重复性任务 | -| 配置方式 | 启用即可 | 需要定义步骤 | - ---- - -## 📋 依赖 - -- MaiBot >= 0.11.6 -- Python >= 3.10 -- mcp >= 1.0.0 - -## 📄 许可证 - -AGPL-3.0 diff --git a/plugins/MaiBot_MCPBridgePlugin/__init__.py b/plugins/MaiBot_MCPBridgePlugin/__init__.py deleted file mode 100644 index 80e2ae47..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -MCP 桥接插件 -将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot - -v1.1.0 新增功能: -- 心跳检测和自动重连 -- 调用统计(次数、成功率、耗时) -- 更好的错误处理 - -v1.2.0 新增功能: -- Resources 支持(资源读取) -- Prompts 支持(提示模板) -""" - -from .plugin import MCPBridgePlugin, mcp_tool_registry, MCPStartupHandler, MCPStopHandler -from .mcp_client import ( - mcp_manager, - MCPClientManager, - MCPServerConfig, - TransportType, - MCPCallResult, - MCPToolInfo, - MCPResourceInfo, - MCPPromptInfo, - ToolCallStats, - ServerStats, -) - -__all__ = [ - "MCPBridgePlugin", - "mcp_tool_registry", - "mcp_manager", - "MCPClientManager", - "MCPServerConfig", - "TransportType", - "MCPCallResult", - "MCPToolInfo", - "MCPResourceInfo", - "MCPPromptInfo", - "ToolCallStats", - "ServerStats", - "MCPStartupHandler", - "MCPStopHandler", -] diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json deleted file mode 100644 index d2e08ab4..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "manifest_version": 2, - "version": "2.0.0", - "name": "MCP桥接插件", - "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。", - "author": { - "name": "CharTyr", - "url": "https://github.com/CharTyr" - }, - "license": "AGPL-3.0", - "urls": { - "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues" - }, - "host_application": { - "min_version": "0.11.6", - "max_version": "1.0.0" - }, - "sdk": { - "min_version": "2.0.0", - "max_version": "2.99.99" - }, - "dependencies": [ - { - "type": "python_package", - "name": "mcp", - "version_spec": ">=0.0.0" - } - ], - "capabilities": [ - "send.text" - ], - "i18n": { - "default_locale": "zh-CN", - "supported_locales": [ - "zh-CN" - ] - }, - "id": "chartyr.mcpbridge-plugin" -} diff --git a/plugins/MaiBot_MCPBridgePlugin/config.example.toml b/plugins/MaiBot_MCPBridgePlugin/config.example.toml deleted file mode 100644 index 4edac27a..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/config.example.toml +++ /dev/null @@ -1,309 +0,0 @@ -# MCP桥接插件 - 配置文件示例 -# 将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot -# -# 使用方法:复制此文件为 config.toml,然后根据需要修改配置 -# -# ============================================================ -# 🎯 快速开始(三步) -# ============================================================ -# 1. 在下方 [servers] 添加 MCP 服务器配置 -# 2. 将 enabled 改为 true 启用服务器 -# 3. 重启 MaiBot 或发送 /mcp reconnect -# -# ============================================================ -# 📚 去哪找 MCP 服务器? -# ============================================================ -# -# 【远程服务(推荐新手)】 -# - ModelScope: https://mcp.modelscope.cn/ (免费,推荐) -# - Smithery: https://smithery.ai/ -# - Glama: https://glama.ai/mcp/servers -# -# 【本地服务(需要 npx 或 uvx)】 -# - 官方列表: https://github.com/modelcontextprotocol/servers -# -# ============================================================ - -# ============================================================ -# 🔌 MCP 服务器配置 -# ============================================================ -# -# ⚠️ 重要:配置格式(Claude Desktop 规范) -# ──────────────────────────────────────────────────────────── -# 统一使用 Claude Desktop 的 mcpServers JSON。 -# -# claude_config_json 的内容应为 JSON 对象: -# { -# "mcpServers": { -# "server_name": { ...server config... }, -# "another": { ... } -# } -# } -# -# 每个服务器支持字段: -# transport - 传输方式: "stdio" / "sse" / "http" / "streamable_http"(可选) -# url - 服务器地址(sse/http/streamable_http 模式) -# command - 启动命令(stdio 模式,如 "npx" / "uvx") -# args - 命令参数数组(stdio 模式) -# env - 环境变量对象(stdio 模式,可选) -# headers - 鉴权头(可选,如 {"Authorization": "Bearer xxx"}) -# enabled - 是否启用(可选,默认 true) -# post_process - 服务器级别后处理配置(可选) -# -# ============================================================ - -[servers] -claude_config_json = ''' -{ - "mcpServers": { - "time-mcp-server": { - "enabled": false, - "transport": "streamable_http", - "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" - }, - "my-auth-server": { - "enabled": false, - "transport": "streamable_http", - "url": "https://mcp.api-inference.modelscope.net/xxxxxx/mcp", - "headers": { - "Authorization": "Bearer ms-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - } - }, - "fetch-local": { - "enabled": false, - "command": "uvx", - "args": ["mcp-server-fetch"] - } - } -} -''' - -# ============================================================ -# 插件基本信息 -# ============================================================ -[plugin] -name = "mcp_bridge_plugin" -version = "2.0.0" -config_version = "2.0.0" -enabled = false # 默认禁用,在 WebUI 中启用 - -# ============================================================ -# Workflow(硬流程/工具链) -# ============================================================ -# -# 作用:把多个工具按顺序执行;后续步骤可引用前序输出。 -# -# ✅ 推荐配置方式:WebUI「Workflow(硬流程/工具链)」里用“快速添加”表单。 -# ✅ 也可以直接写 chains_list(JSON 数组)。 -# -# 变量替换: -# ${input.xxx} - 用户输入 -# ${step.} - 指定步骤输出(需设置 output_key) -# ${prev} - 上一步输出 -# ${prev.字段} - 上一步输出(JSON)的字段 -# ${step.geo.return.0.location} - 数组/下标访问(dot) -# ${step.geo.return[0].location} - 数组/下标访问([]) -# ${step.geo['return'][0]['location']} - bracket 写法 -# -# ============================================================ - -[tool_chains] -chains_enabled = true - -chains_list = ''' -[ - { - "name": "search_and_detail", - "description": "先搜索,再根据结果获取详情", - "input_params": { "query": "搜索关键词" }, - "steps": [ - { "tool_name": "把这里替换成你的搜索工具名", "args_template": { "keyword": "${input.query}" }, "output_key": "search" }, - { "tool_name": "把这里替换成你的详情工具名", "args_template": { "id": "${prev}" } } - ] - } -] -''' - -# ============================================================ -# ReAct(软流程) -# ============================================================ -# -# 作用:把 MCP 工具注册到 MaiBot 的 ReAct 系统,LLM 可自主多轮调用。 -# -# 注意:ReAct 适合“探索式/不确定”场景;Workflow 适合“固定/可控”场景。 -# -# ============================================================ - -[react] -react_enabled = false -filter_mode = "whitelist" # whitelist / blacklist -tool_filter = "" # 每行一个工具名,支持通配符 * - -# ============================================================ -# 全局设置(高级设置建议保持默认) -# ============================================================ -[settings] -# 🏷️ 工具前缀 - 用于区分 MCP 工具和原生工具 -tool_prefix = "mcp" - -# ⏱️ 连接超时(秒) -connect_timeout = 30.0 - -# ⏱️ 调用超时(秒) -call_timeout = 60.0 - -# 🔄 自动连接 - 启动时自动连接所有已启用的服务器 -auto_connect = true - -# 🔁 重试次数 - 连接失败时的重试次数 -retry_attempts = 3 - -# ⏳ 重试间隔(秒) -retry_interval = 5.0 - -# 💓 心跳检测 - 定期检测服务器连接状态 -heartbeat_enabled = true - -# 💓 心跳间隔(秒)- 建议 30-120 秒 -heartbeat_interval = 60.0 - -# 🔄 自动重连 - 检测到断开时自动尝试重连 -auto_reconnect = true - -# 🔄 最大重连次数 - 连续重连失败后暂停重连 -max_reconnect_attempts = 3 - -# ============================================================ -# 高级功能(实验性) -# ============================================================ -# 📦 启用 Resources - 允许读取 MCP 服务器提供的资源 -enable_resources = false - -# 📝 启用 Prompts - 允许使用 MCP 服务器提供的提示模板 -enable_prompts = false - -# ============================================================ -# 结果后处理功能 -# ============================================================ -# 当 MCP 工具返回的内容过长时,使用 LLM 对结果进行摘要提炼 - -# 🔄 启用结果后处理 -post_process_enabled = false - -# 📏 后处理阈值(字符数)- 结果长度超过此值才触发后处理 -post_process_threshold = 500 - -# 🔢 后处理输出限制 - LLM 摘要输出的最大 token 数 -post_process_max_tokens = 500 - -# 🤖 后处理模型(可选)- 留空则使用 utils 模型组 -post_process_model = "" - -# 🧠 后处理提示词模板 -post_process_prompt = '''用户问题:{query} - -工具返回内容: -{result} - -请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:''' - -# ============================================================ -# 调用链路追踪 -# ============================================================ -# 记录工具调用详情,便于调试和分析 - -# 🔍 启用调用追踪 -trace_enabled = true - -# 📊 追踪记录上限 - 内存中保留的最大记录数 -trace_max_records = 50 - -# 📝 追踪日志文件 - 是否将追踪记录写入日志文件 -# 启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl -trace_log_enabled = false - -# ============================================================ -# 工具调用缓存 -# ============================================================ -# 缓存相同参数的调用结果,减少重复请求 - -# 🗄️ 启用调用缓存 -cache_enabled = false - -# ⏱️ 缓存有效期(秒) -cache_ttl = 300 - -# 📦 最大缓存条目 - 超出后 LRU 淘汰 -cache_max_entries = 200 - -# 🚫 缓存排除列表 - 即不缓存的工具(每行一个,支持通配符 *) -# 时间类、随机类工具建议排除 -cache_exclude_tools = ''' -mcp_*_time_* -mcp_*_random_* -''' - -# ============================================================ -# 工具管理 -# ============================================================ -[tools] -# 📋 工具清单(只读)- 启动后自动生成 -tool_list = "(启动后自动生成)" - -# 🚫 禁用工具列表 - 要禁用的工具名(每行一个) -# 从上方工具清单复制工具名,禁用后该工具不会被 LLM 调用 -# 示例: -# disabled_tools = ''' -# mcp_filesystem_delete_file -# mcp_filesystem_write_file -# ''' -disabled_tools = "" - -# ============================================================ -# 权限控制 -# ============================================================ -[permissions] -# 🔐 启用权限控制 - 按群/用户限制工具使用 -perm_enabled = false - -# 📋 默认模式 -# allow_all: 未配置规则的工具默认允许 -# deny_all: 未配置规则的工具默认禁止 -perm_default_mode = "allow_all" - -# ──────────────────────────────────────────────────────────── -# 🚀 快捷配置(推荐新手使用) -# ──────────────────────────────────────────────────────────── - -# 🚫 禁用群列表 - 这些群无法使用任何 MCP 工具(每行一个群号) -# 示例: -# quick_deny_groups = ''' -# 123456789 -# 987654321 -# ''' -quick_deny_groups = "" - -# ✅ 管理员白名单 - 这些用户始终可以使用所有工具(每行一个QQ号) -# 示例: -# quick_allow_users = ''' -# 111111111 -# ''' -quick_allow_users = "" - -# ──────────────────────────────────────────────────────────── -# 📜 高级权限规则(可选,针对特定工具配置) -# ──────────────────────────────────────────────────────────── -# 格式: qq:ID:group/private/user,工具名支持通配符 * -# 示例: -# perm_rules = ''' -# [ -# {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -# ] -# ''' -perm_rules = "[]" - -# ============================================================ -# 状态显示(只读) -# ============================================================ -[status] -connection_status = "未初始化" diff --git a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py b/plugins/MaiBot_MCPBridgePlugin/core/__init__.py deleted file mode 100644 index d5656a8e..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Core helpers for MCP Bridge Plugin.""" diff --git a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py b/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py deleted file mode 100644 index f2a6f011..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py +++ /dev/null @@ -1,169 +0,0 @@ -import json -from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional - - -class ClaudeConfigError(ValueError): - pass - - -Transport = Literal["stdio", "sse", "http", "streamable_http"] - - -@dataclass(frozen=True) -class ClaudeMcpServer: - name: str - transport: Transport - command: str = "" - args: List[str] = field(default_factory=list) - env: Dict[str, str] = field(default_factory=dict) - url: str = "" - headers: Dict[str, str] = field(default_factory=dict) - enabled: bool = True - - -def _normalize_transport(value: Optional[str]) -> Transport: - if not value: - return "streamable_http" - v = value.strip().lower().replace("-", "_") - if v in ("streamable_http", "streamablehttp", "streamable"): - return "streamable_http" - if v in ("http",): - return "http" - if v in ("sse",): - return "sse" - if v in ("stdio",): - return "stdio" - raise ClaudeConfigError(f"unsupported transport: {value}") - - -def _coerce_str_list(value: Any, field_name: str) -> List[str]: - if value is None: - return [] - if isinstance(value, list): - return [str(v) for v in value] - raise ClaudeConfigError(f"{field_name} must be a list") - - -def _coerce_str_dict(value: Any, field_name: str) -> Dict[str, str]: - if value is None: - return {} - if isinstance(value, dict): - return {str(k): str(v) for k, v in value.items()} - raise ClaudeConfigError(f"{field_name} must be an object") - - -def parse_claude_mcp_config(config_json: str) -> List[ClaudeMcpServer]: - """Parse Claude Desktop style MCP config JSON. - - Supported: - - Full object: {"mcpServers": {...}} - - Direct mapping: {...} treated as mcpServers - """ - text = (config_json or "").strip() - if not text: - return [] - - try: - data = json.loads(text) - except json.JSONDecodeError as e: - raise ClaudeConfigError(f"invalid JSON: {e}") from e - - if not isinstance(data, dict): - raise ClaudeConfigError("config must be a JSON object") - - servers_obj = data.get("mcpServers", data) - if not isinstance(servers_obj, dict): - raise ClaudeConfigError("mcpServers must be an object") - - servers: List[ClaudeMcpServer] = [] - for name, raw in servers_obj.items(): - if not isinstance(name, str) or not name.strip(): - raise ClaudeConfigError("server name must be a non-empty string") - if not isinstance(raw, dict): - raise ClaudeConfigError(f"server '{name}' must be an object") - - enabled = bool(raw.get("enabled", True)) - command = str(raw.get("command", "") or "") - url = str(raw.get("url", "") or "") - args = _coerce_str_list(raw.get("args"), "args") - env = _coerce_str_dict(raw.get("env"), "env") - headers = _coerce_str_dict(raw.get("headers"), "headers") - - transport_hint = raw.get("transport", raw.get("type")) - - if command: - transport: Transport = "stdio" - elif url: - try: - transport = _normalize_transport(str(transport_hint) if transport_hint is not None else None) - except ClaudeConfigError: - transport = "streamable_http" - else: - raise ClaudeConfigError(f"server '{name}' must have either 'command' or 'url'") - - servers.append( - ClaudeMcpServer( - name=name, - transport=transport, - command=command, - args=args, - env=env, - url=url, - headers=headers, - enabled=enabled, - ) - ) - - return servers - - -def legacy_servers_list_to_claude_config(servers_list_json: str) -> str: - """Convert legacy v1.x servers list (JSON array) to Claude mcpServers JSON. - - Legacy item schema: - {"name","enabled","transport","url","headers","command","args","env"} - """ - text = (servers_list_json or "").strip() - if not text: - return "" - try: - data = json.loads(text) - except json.JSONDecodeError: - return "" - if isinstance(data, dict): - data = [data] - if not isinstance(data, list): - return "" - - mcp_servers: Dict[str, Any] = {} - for item in data: - if not isinstance(item, dict): - continue - name = str(item.get("name", "") or "").strip() - if not name: - continue - enabled = bool(item.get("enabled", True)) - transport = str(item.get("transport", "") or "").strip().lower().replace("-", "_") - - if transport == "stdio" or item.get("command"): - entry: Dict[str, Any] = { - "enabled": enabled, - "command": item.get("command", "") or "", - "args": item.get("args", []) or [], - } - if item.get("env"): - entry["env"] = item.get("env") - mcp_servers[name] = entry - continue - - entry = {"enabled": enabled, "url": item.get("url", "") or ""} - if item.get("headers"): - entry["headers"] = item.get("headers") - if transport: - entry["transport"] = transport - mcp_servers[name] = entry - - if not mcp_servers: - return "" - return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2) diff --git a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py deleted file mode 100644 index de5abab2..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py +++ /dev/null @@ -1,1485 +0,0 @@ -""" -MCP 客户端封装模块 -负责与 MCP 服务器建立连接、获取工具列表、执行工具调用 - -v1.7.0 稳定性优化: -- 断路器模式:连续失败 5 次后熔断,60 秒后试探恢复 -- 熔断期间快速失败,避免等待超时 -- 连接成功时自动重置断路器 - -v1.5.2 性能优化: -- 智能心跳间隔:根据服务器稳定性动态调整心跳频率 -- 稳定服务器逐渐增加间隔(最高 3x),减少不必要的检测 -- 断开的服务器使用较短间隔快速重连 - -v1.1.0 新增功能: -- 调用统计(次数、成功率、耗时) -- 心跳检测 -- 自动重连 -- 更好的错误处理 - -v1.2.0 新增功能: -- Resources 支持(资源读取) -- Prompts 支持(提示模板) -- 新增配置项: enable_resources, enable_prompts -""" - -import asyncio -import time -import logging -from typing import Any, Dict, List, Optional, Tuple -from dataclasses import dataclass, field -from enum import Enum - -# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging -try: - from src.common.logger import get_logger - - logger = get_logger("mcp_client") -except ImportError: - # Fallback: 使用标准 logging - logger = logging.getLogger("mcp_client") - if not logger.handlers: - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - - -class TransportType(Enum): - """MCP 传输类型""" - - STDIO = "stdio" # 本地进程通信 - SSE = "sse" # Server-Sent Events (旧版 HTTP) - HTTP = "http" # HTTP Streamable (新版,推荐) - STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名 - - -@dataclass -class MCPToolInfo: - """MCP 工具信息""" - - name: str - description: str - input_schema: Dict[str, Any] - server_name: str - - -@dataclass -class MCPResourceInfo: - """MCP 资源信息""" - - uri: str - name: str - description: str - mime_type: Optional[str] - server_name: str - - -@dataclass -class MCPPromptInfo: - """MCP 提示模板信息""" - - name: str - description: str - arguments: List[Dict[str, Any]] # [{name, description, required}] - server_name: str - - -@dataclass -class MCPServerConfig: - """MCP 服务器配置""" - - name: str - enabled: bool = True - transport: TransportType = TransportType.STDIO - # stdio 配置 - command: str = "" - args: List[str] = field(default_factory=list) - env: Dict[str, str] = field(default_factory=dict) - # http/sse 配置 - url: str = "" - headers: Dict[str, str] = field(default_factory=dict) # v1.4.2: 鉴权头支持 - - -@dataclass -class MCPCallResult: - """MCP 工具调用结果""" - - success: bool - content: Any - error: Optional[str] = None - duration_ms: float = 0.0 # 调用耗时(毫秒) - circuit_broken: bool = False # v1.7.0: 是否被断路器拦截 - - -class CircuitState(Enum): - """断路器状态""" - - CLOSED = "closed" # 正常状态,允许请求 - OPEN = "open" # 熔断状态,拒绝请求 - HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 - - -@dataclass -class CircuitBreaker: - """v1.7.0: 断路器 - 防止对故障服务器持续请求 - - 状态转换: - - CLOSED -> OPEN: 连续失败次数达到阈值 - - OPEN -> HALF_OPEN: 熔断时间到期 - - HALF_OPEN -> CLOSED: 试探请求成功 - - HALF_OPEN -> OPEN: 试探请求失败 - """ - - # 配置 - failure_threshold: int = 5 # 连续失败多少次后熔断 - recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) - half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 - - # 状态 - state: CircuitState = field(default=CircuitState.CLOSED) - failure_count: int = 0 - success_count: int = 0 - last_failure_time: float = 0.0 - last_state_change: float = field(default_factory=time.time) - half_open_calls: int = 0 - - def can_execute(self) -> Tuple[bool, Optional[str]]: - """检查是否允许执行请求 - - Returns: - (是否允许, 拒绝原因) - """ - current_time = time.time() - - if self.state == CircuitState.CLOSED: - return True, None - - if self.state == CircuitState.OPEN: - # 检查是否到了恢复时间 - time_since_failure = current_time - self.last_failure_time - if time_since_failure >= self.recovery_timeout: - # 转换到半开状态 - self._transition_to(CircuitState.HALF_OPEN) - return True, None - else: - remaining = self.recovery_timeout - time_since_failure - return False, f"断路器熔断中,{remaining:.0f}秒后重试" - - if self.state == CircuitState.HALF_OPEN: - # 半开状态,检查是否还有试探配额 - if self.half_open_calls < self.half_open_max_calls: - return True, None - else: - return False, "断路器半开状态,等待试探结果" - - return True, None - - def record_success(self) -> None: - """记录成功调用""" - self.success_count += 1 - - if self.state == CircuitState.HALF_OPEN: - # 半开状态下成功,恢复到关闭状态 - self._transition_to(CircuitState.CLOSED) - logger.info("断路器恢复正常(试探成功)") - elif self.state == CircuitState.CLOSED: - # 正常状态下成功,重置失败计数 - self.failure_count = 0 - - def record_failure(self) -> None: - """记录失败调用""" - self.failure_count += 1 - self.last_failure_time = time.time() - - if self.state == CircuitState.HALF_OPEN: - # 半开状态下失败,重新熔断 - self._transition_to(CircuitState.OPEN) - logger.warning("断路器重新熔断(试探失败)") - elif self.state == CircuitState.CLOSED: - # 检查是否达到熔断阈值 - if self.failure_count >= self.failure_threshold: - self._transition_to(CircuitState.OPEN) - logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)") - - def _transition_to(self, new_state: CircuitState) -> None: - """状态转换""" - old_state = self.state - self.state = new_state - self.last_state_change = time.time() - - if new_state == CircuitState.CLOSED: - self.failure_count = 0 - self.half_open_calls = 0 - elif new_state == CircuitState.HALF_OPEN: - self.half_open_calls = 0 - - logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}") - - def reset(self) -> None: - """重置断路器""" - self.state = CircuitState.CLOSED - self.failure_count = 0 - self.success_count = 0 - self.half_open_calls = 0 - self.last_state_change = time.time() - - def get_status(self) -> Dict[str, Any]: - """获取断路器状态""" - return { - "state": self.state.value, - "failure_count": self.failure_count, - "success_count": self.success_count, - "failure_threshold": self.failure_threshold, - "recovery_timeout": self.recovery_timeout, - "time_since_last_failure": time.time() - self.last_failure_time if self.last_failure_time > 0 else None, - } - - -@dataclass -class ToolCallStats: - """工具调用统计""" - - tool_key: str - total_calls: int = 0 - success_calls: int = 0 - failed_calls: int = 0 - total_duration_ms: float = 0.0 - last_call_time: Optional[float] = None - last_error: Optional[str] = None - - @property - def success_rate(self) -> float: - """成功率(0-100)""" - if self.total_calls == 0: - return 0.0 - return (self.success_calls / self.total_calls) * 100 - - @property - def avg_duration_ms(self) -> float: - """平均耗时(毫秒)""" - if self.success_calls == 0: - return 0.0 - return self.total_duration_ms / self.success_calls - - def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None: - """记录一次调用""" - self.total_calls += 1 - self.last_call_time = time.time() - if success: - self.success_calls += 1 - self.total_duration_ms += duration_ms - else: - self.failed_calls += 1 - self.last_error = error - - def to_dict(self) -> Dict[str, Any]: - """转换为字典""" - return { - "tool_key": self.tool_key, - "total_calls": self.total_calls, - "success_calls": self.success_calls, - "failed_calls": self.failed_calls, - "success_rate": round(self.success_rate, 2), - "avg_duration_ms": round(self.avg_duration_ms, 2), - "last_call_time": self.last_call_time, - "last_error": self.last_error, - } - - -@dataclass -class ServerStats: - """服务器统计""" - - server_name: str - connect_count: int = 0 # 连接次数 - disconnect_count: int = 0 # 断开次数 - reconnect_count: int = 0 # 重连次数 - last_connect_time: Optional[float] = None - last_disconnect_time: Optional[float] = None - last_heartbeat_time: Optional[float] = None - consecutive_failures: int = 0 # 连续失败次数 - - def record_connect(self) -> None: - self.connect_count += 1 - self.last_connect_time = time.time() - self.consecutive_failures = 0 - - def record_disconnect(self) -> None: - self.disconnect_count += 1 - self.last_disconnect_time = time.time() - - def record_reconnect(self) -> None: - self.reconnect_count += 1 - self.consecutive_failures = 0 - - def record_failure(self) -> None: - self.consecutive_failures += 1 - - def record_heartbeat(self) -> None: - self.last_heartbeat_time = time.time() - - def to_dict(self) -> Dict[str, Any]: - return { - "server_name": self.server_name, - "connect_count": self.connect_count, - "disconnect_count": self.disconnect_count, - "reconnect_count": self.reconnect_count, - "last_connect_time": self.last_connect_time, - "last_disconnect_time": self.last_disconnect_time, - "last_heartbeat_time": self.last_heartbeat_time, - "consecutive_failures": self.consecutive_failures, - } - - -class MCPClientSession: - """MCP 客户端会话,管理与单个 MCP 服务器的连接""" - - def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0): - self.config = config - self.call_timeout = call_timeout - self._session = None - self._read_stream = None - self._write_stream = None - self._process: Optional[asyncio.subprocess.Process] = None - self._tools: List[MCPToolInfo] = [] - self._resources: List[MCPResourceInfo] = [] # v1.2.0: Resources 支持 - self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持 - self._connected = False - self._lock = asyncio.Lock() - - # 功能支持标记(服务器可能不支持某些功能) - self._supports_resources: bool = False - self._supports_prompts: bool = False - - # 统计信息 - self.stats = ServerStats(server_name=config.name) - self._tool_stats: Dict[str, ToolCallStats] = {} - - # v1.7.0: 断路器 - self._circuit_breaker = CircuitBreaker() - - @property - def is_connected(self) -> bool: - return self._connected - - @property - def tools(self) -> List[MCPToolInfo]: - return self._tools.copy() - - @property - def resources(self) -> List[MCPResourceInfo]: - """v1.2.0: 获取资源列表""" - return self._resources.copy() - - @property - def prompts(self) -> List[MCPPromptInfo]: - """v1.2.0: 获取提示模板列表""" - return self._prompts.copy() - - @property - def supports_resources(self) -> bool: - """v1.2.0: 服务器是否支持 Resources""" - return self._supports_resources - - @property - def supports_prompts(self) -> bool: - """v1.2.0: 服务器是否支持 Prompts""" - return self._supports_prompts - - @property - def server_name(self) -> str: - return self.config.name - - def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]: - """获取工具统计""" - return self._tool_stats.get(tool_name) - - def get_circuit_breaker_status(self) -> Dict[str, Any]: - """v1.7.0: 获取断路器状态""" - return self._circuit_breaker.get_status() - - def reset_circuit_breaker(self) -> None: - """v1.7.0: 重置断路器""" - self._circuit_breaker.reset() - logger.info(f"[{self.server_name}] 断路器已重置") - - def get_all_tool_stats(self) -> Dict[str, ToolCallStats]: - """获取所有工具统计""" - return self._tool_stats.copy() - - async def connect(self) -> bool: - """连接到 MCP 服务器""" - async with self._lock: - if self._connected: - return True - - try: - success = False - if self.config.transport == TransportType.STDIO: - success = await self._connect_stdio() - elif self.config.transport == TransportType.SSE: - success = await self._connect_sse() - elif self.config.transport in (TransportType.HTTP, TransportType.STREAMABLE_HTTP): - success = await self._connect_http() - else: - logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}") - return False - - if success: - self.stats.record_connect() - # v1.7.0: 连接成功时重置断路器 - self._circuit_breaker.reset() - else: - self.stats.record_failure() - return success - - except Exception as e: - logger.error(f"[{self.server_name}] 连接失败: {e}") - self._connected = False - self.stats.record_failure() - return False - - async def _connect_stdio(self) -> bool: - """通过 stdio 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession, StdioServerParameters - from mcp.client.stdio import stdio_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - server_params = StdioServerParameters( - command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None - ) - - self._stdio_context = stdio_client(server_params) - self._read_stream, self._write_stream = await self._stdio_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] stdio 连接失败: {e}") - await self._cleanup() - return False - - async def _connect_sse(self) -> bool: - """通过 SSE 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession - from mcp.client.sse import sse_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - if not self.config.url: - logger.error(f"[{self.server_name}] SSE 传输需要配置 url") - return False - - logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}") - - # v1.4.2: 支持 headers 鉴权 - sse_kwargs = { - "url": self.config.url, - "timeout": 60.0, - "sse_read_timeout": 300.0, - } - if self.config.headers: - sse_kwargs["headers"] = self.config.headers - - self._sse_context = sse_client(**sse_kwargs) - self._read_stream, self._write_stream = await self._sse_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] SSE 连接失败: {e}") - import traceback - - logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") - await self._cleanup() - return False - - async def _connect_http(self) -> bool: - """通过 HTTP Streamable 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession - from mcp.client.streamable_http import streamablehttp_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - if not self.config.url: - logger.error(f"[{self.server_name}] HTTP 传输需要配置 url") - return False - - logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}") - - # v1.4.2: 支持 headers 鉴权 - http_kwargs = { - "url": self.config.url, - "timeout": 60.0, - "sse_read_timeout": 300.0, - } - if self.config.headers: - http_kwargs["headers"] = self.config.headers - - self._http_context = streamablehttp_client(**http_kwargs) - self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") - import traceback - - logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") - await self._cleanup() - return False - - async def _fetch_tools(self) -> None: - """获取 MCP 服务器的工具列表""" - if not self._session: - return - - try: - result = await self._session.list_tools() - self._tools = [] - - for tool in result.tools: - tool_info = MCPToolInfo( - name=tool.name, - description=tool.description or f"MCP tool: {tool.name}", - input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {}, - server_name=self.server_name, - ) - self._tools.append(tool_info) - # 初始化工具统计 - if tool.name not in self._tool_stats: - self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name) - logger.debug(f"[{self.server_name}] 发现工具: {tool.name}") - - except Exception as e: - logger.error(f"[{self.server_name}] 获取工具列表失败: {e}") - self._tools = [] - - async def fetch_resources(self) -> bool: - """v1.2.0: 获取 MCP 服务器的资源列表 - - Returns: - bool: 是否成功获取(服务器不支持时返回 False) - """ - if not self._session: - return False - - try: - result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout) - self._resources = [] - - for resource in result.resources: - resource_info = MCPResourceInfo( - uri=str(resource.uri), - name=resource.name or str(resource.uri), - description=resource.description or "", - mime_type=resource.mimeType if hasattr(resource, "mimeType") else None, - server_name=self.server_name, - ) - self._resources.append(resource_info) - logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") - - self._supports_resources = True - logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源") - return True - - except Exception as e: - # 服务器可能不支持 resources,这不是错误 - error_str = str(e).lower() - if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str: - logger.debug(f"[{self.server_name}] 服务器不支持 Resources 功能") - else: - logger.warning(f"[{self.server_name}] 获取资源列表失败: {e}") - self._supports_resources = False - self._resources = [] - return False - - async def fetch_prompts(self) -> bool: - """v1.2.0: 获取 MCP 服务器的提示模板列表 - - Returns: - bool: 是否成功获取(服务器不支持时返回 False) - """ - if not self._session: - return False - - try: - result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout) - self._prompts = [] - - for prompt in result.prompts: - # 解析参数 - arguments = [] - if hasattr(prompt, "arguments") and prompt.arguments: - for arg in prompt.arguments: - arguments.append( - { - "name": arg.name, - "description": arg.description or "", - "required": arg.required if hasattr(arg, "required") else False, - } - ) - - prompt_info = MCPPromptInfo( - name=prompt.name, - description=prompt.description or f"MCP prompt: {prompt.name}", - arguments=arguments, - server_name=self.server_name, - ) - self._prompts.append(prompt_info) - logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") - - self._supports_prompts = True - logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板") - return True - - except Exception as e: - # 服务器可能不支持 prompts,这不是错误 - error_str = str(e).lower() - if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str: - logger.debug(f"[{self.server_name}] 服务器不支持 Prompts 功能") - else: - logger.warning(f"[{self.server_name}] 获取提示模板列表失败: {e}") - self._supports_prompts = False - self._prompts = [] - return False - - async def read_resource(self, uri: str) -> MCPCallResult: - """v1.2.0: 读取指定资源的内容 - - Args: - uri: 资源 URI - - Returns: - MCPCallResult: 包含资源内容的结果 - """ - start_time = time.time() - - if not self._connected or not self._session: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") - - if not self._supports_resources: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能") - - try: - result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回内容 - content_parts = [] - for content in result.contents: - if hasattr(content, "text"): - content_parts.append(content.text) - elif hasattr(content, "blob"): - # 二进制数据,返回 base64 或提示 - import base64 - - blob_data = content.blob - if len(blob_data) < 10000: # 小于 10KB 返回 base64 - content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") - else: - content_parts.append(f"[二进制数据: {len(blob_data)} bytes]") - else: - content_parts.append(str(content)) - - return MCPCallResult( - success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - return MCPCallResult( - success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms - ) - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") - return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) - - async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: - """v1.2.0: 获取提示模板的内容 - - Args: - name: 提示模板名称 - arguments: 模板参数 - - Returns: - MCPCallResult: 包含提示内容的结果 - """ - start_time = time.time() - - if not self._connected or not self._session: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") - - if not self._supports_prompts: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能") - - try: - result = await asyncio.wait_for( - self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout - ) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回的消息 - messages = [] - for msg in result.messages: - role = msg.role if hasattr(msg, "role") else "unknown" - content_text = "" - if hasattr(msg, "content"): - if hasattr(msg.content, "text"): - content_text = msg.content.text - elif isinstance(msg.content, str): - content_text = msg.content - else: - content_text = str(msg.content) - messages.append(f"[{role}]: {content_text}") - - return MCPCallResult( - success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - return MCPCallResult( - success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms - ) - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") - return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) - - async def check_health(self) -> bool: - """检查连接健康状态(心跳检测) - - 通过调用 list_tools 来验证连接是否正常 - """ - if not self._connected or not self._session: - return False - - try: - # 使用 list_tools 作为心跳检测 - await asyncio.wait_for(self._session.list_tools(), timeout=10.0) - self.stats.record_heartbeat() - return True - except Exception as e: - logger.warning(f"[{self.server_name}] 心跳检测失败: {e}") - # 标记为断开 - self._connected = False - self.stats.record_disconnect() - return False - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult: - """调用 MCP 工具""" - start_time = time.time() - - # v1.7.0: 断路器检查 - can_execute, reject_reason = self._circuit_breaker.can_execute() - if not can_execute: - return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True) - - # 半开状态下增加试探计数 - if self._circuit_breaker.state == CircuitState.HALF_OPEN: - self._circuit_breaker.half_open_calls += 1 - - if not self._connected or not self._session: - error_msg = f"服务器 {self.server_name} 未连接" - # 记录失败 - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, 0, error_msg) - self._circuit_breaker.record_failure() - return MCPCallResult(success=False, content=None, error=error_msg) - - try: - result = await asyncio.wait_for( - self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout - ) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回内容 - content_parts = [] - for content in result.content: - if hasattr(content, "text"): - content_parts.append(content.text) - elif hasattr(content, "data"): - content_parts.append(f"[二进制数据: {len(content.data)} bytes]") - else: - content_parts.append(str(content)) - - # 记录成功 - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(True, duration_ms) - - # v1.7.0: 断路器记录成功 - self._circuit_breaker.record_success() - - return MCPCallResult( - success=True, - content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", - duration_ms=duration_ms, - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - error_msg = f"工具调用超时({self.call_timeout}秒)" - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, duration_ms, error_msg) - # v1.7.0: 断路器记录失败 - self._circuit_breaker.record_failure() - return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - error_msg = str(e) - logger.error(f"[{self.server_name}] 调用工具 {tool_name} 失败: {e}") - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, duration_ms, error_msg) - # v1.7.0: 断路器记录失败 - self._circuit_breaker.record_failure() - # 检查是否是连接问题 - if "connection" in error_msg.lower() or "closed" in error_msg.lower(): - self._connected = False - self.stats.record_disconnect() - return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - - async def disconnect(self) -> None: - """断开连接""" - async with self._lock: - if self._connected: - self.stats.record_disconnect() - await self._cleanup() - - async def _cleanup(self) -> None: - """清理资源""" - self._connected = False - self._tools = [] - self._resources = [] # v1.2.0 - self._prompts = [] # v1.2.0 - self._supports_resources = False # v1.2.0 - self._supports_prompts = False # v1.2.0 - - try: - if hasattr(self, "_session_context") and self._session_context: - await self._session_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") - - try: - if hasattr(self, "_stdio_context") and self._stdio_context: - await self._stdio_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") - - try: - if hasattr(self, "_http_context") and self._http_context: - await self._http_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") - - try: - if hasattr(self, "_sse_context") and self._sse_context: - await self._sse_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") - - self._session = None - self._session_context = None - self._stdio_context = None - self._http_context = None - self._sse_context = None - self._read_stream = None - self._write_stream = None - - logger.debug(f"[{self.server_name}] 连接已关闭") - - -class MCPClientManager: - """MCP 客户端管理器,管理多个 MCP 服务器连接 - - 功能: - - 管理多个 MCP 服务器连接 - - 心跳检测和自动重连 - - 调用统计 - """ - - _instance: Optional["MCPClientManager"] = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - self._clients: Dict[str, MCPClientSession] = {} - self._all_tools: Dict[str, Tuple[MCPToolInfo, MCPClientSession]] = {} - self._all_resources: Dict[str, Tuple[MCPResourceInfo, MCPClientSession]] = {} # v1.2.0 - self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0 - self._settings: Dict[str, Any] = {} - self._lock = asyncio.Lock() - - # 心跳检测任务 - self._heartbeat_task: Optional[asyncio.Task] = None - self._heartbeat_running = False - - # 状态变化回调 - self._on_status_change: Optional[callable] = None - - # 全局统计 - self._global_stats = { - "total_tool_calls": 0, - "successful_calls": 0, - "failed_calls": 0, - "start_time": time.time(), - } - - def configure(self, settings: Dict[str, Any]) -> None: - """配置管理器""" - self._settings = settings - - def set_status_change_callback(self, callback: callable) -> None: - """设置状态变化回调函数""" - self._on_status_change = callback - - def _notify_status_change(self) -> None: - """通知状态变化""" - if self._on_status_change: - try: - self._on_status_change() - except Exception as e: - logger.debug(f"状态变化回调出错: {e}") - - @property - def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]: - """获取所有已注册的工具""" - return self._all_tools.copy() - - @property - def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]: - """v1.2.0: 获取所有已注册的资源""" - return self._all_resources.copy() - - @property - def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]: - """v1.2.0: 获取所有已注册的提示模板""" - return self._all_prompts.copy() - - @property - def connected_servers(self) -> List[str]: - """获取已连接的服务器列表""" - return [name for name, client in self._clients.items() if client.is_connected] - - @property - def disconnected_servers(self) -> List[str]: - """获取已断开的服务器列表""" - return [name for name, client in self._clients.items() if not client.is_connected and client.config.enabled] - - async def add_server(self, config: MCPServerConfig) -> bool: - """添加并连接 MCP 服务器""" - async with self._lock: - if config.name in self._clients: - logger.warning(f"服务器 {config.name} 已存在") - return False - - call_timeout = self._settings.get("call_timeout", 60.0) - client = MCPClientSession(config, call_timeout) - self._clients[config.name] = client - - if not config.enabled: - logger.info(f"服务器 {config.name} 已添加但未启用") - return True - - # 尝试连接 - retry_attempts = self._settings.get("retry_attempts", 3) - retry_interval = self._settings.get("retry_interval", 5.0) - - for attempt in range(1, retry_attempts + 1): - if await client.connect(): - self._register_tools(client) - return True - - if attempt < retry_attempts: - logger.warning( - f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})" - ) - await asyncio.sleep(retry_interval) - - logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") - # 连接失败,但保留在 _clients 中以便后续重连 - return False - - def _register_tools(self, client: MCPClientSession) -> None: - """注册客户端的工具""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for tool in client.tools: - if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"): - tool_key = tool.name - else: - tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}" - self._all_tools[tool_key] = (tool, client) - logger.debug(f"注册 MCP 工具: {tool_key}") - - def _unregister_tools(self, server_name: str) -> List[str]: - """注销服务器的工具,返回被注销的工具键列表""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_" - - keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_tools[key] - logger.debug(f"注销 MCP 工具: {key}") - return keys_to_remove - - def _register_resources(self, client: MCPClientSession) -> None: - """v1.2.0: 注册客户端的资源""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for resource in client.resources: - # 资源键格式: mcp_{server}_{uri_safe_name} - # 将 URI 转换为安全的键名 - safe_uri = resource.uri.replace("://", "_").replace("/", "_").replace(".", "_") - resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}" - self._all_resources[resource_key] = (resource, client) - logger.debug(f"注册 MCP 资源: {resource_key}") - - def _unregister_resources(self, server_name: str) -> List[str]: - """v1.2.0: 注销服务器的资源""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_res_" - - keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_resources[key] - logger.debug(f"注销 MCP 资源: {key}") - return keys_to_remove - - def _register_prompts(self, client: MCPClientSession) -> None: - """v1.2.0: 注册客户端的提示模板""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for prompt in client.prompts: - prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}" - self._all_prompts[prompt_key] = (prompt, client) - logger.debug(f"注册 MCP 提示模板: {prompt_key}") - - def _unregister_prompts(self, server_name: str) -> List[str]: - """v1.2.0: 注销服务器的提示模板""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_prompt_" - - keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_prompts[key] - logger.debug(f"注销 MCP 提示模板: {key}") - return keys_to_remove - - async def remove_server(self, server_name: str) -> bool: - """移除 MCP 服务器""" - async with self._lock: - if server_name not in self._clients: - return False - - client = self._clients[server_name] - await client.disconnect() - self._unregister_tools(server_name) - self._unregister_resources(server_name) # v1.2.0 - self._unregister_prompts(server_name) # v1.2.0 - del self._clients[server_name] - - logger.info(f"服务器 {server_name} 已移除") - return True - - async def reconnect_server(self, server_name: str) -> bool: - """重新连接服务器""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - - async with self._lock: - self._unregister_tools(server_name) - self._unregister_resources(server_name) # v1.2.0 - self._unregister_prompts(server_name) # v1.2.0 - await client.disconnect() - - # 尝试重连 - retry_attempts = self._settings.get("retry_attempts", 3) - retry_interval = self._settings.get("retry_interval", 5.0) - - for attempt in range(1, retry_attempts + 1): - if await client.connect(): - async with self._lock: - self._register_tools(client) - # v1.2.0: 重连后也尝试获取 resources 和 prompts - if self._settings.get("enable_resources", False): - await client.fetch_resources() - self._register_resources(client) - if self._settings.get("enable_prompts", False): - await client.fetch_prompts() - self._register_prompts(client) - client.stats.record_reconnect() - logger.info(f"服务器 {server_name} 重连成功") - return True - - if attempt < retry_attempts: - logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") - await asyncio.sleep(retry_interval) - - logger.error(f"服务器 {server_name} 重连失败") - return False - - async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: - """调用 MCP 工具""" - if tool_key not in self._all_tools: - return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在") - - tool_info, client = self._all_tools[tool_key] - - # 更新全局统计 - self._global_stats["total_tool_calls"] += 1 - - result = await client.call_tool(tool_info.name, arguments) - - if result.success: - self._global_stats["successful_calls"] += 1 - else: - self._global_stats["failed_calls"] += 1 - - return result - - async def fetch_resources_for_server(self, server_name: str) -> bool: - """v1.2.0: 获取指定服务器的资源列表""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - if not client.is_connected: - return False - - success = await client.fetch_resources() - if success: - async with self._lock: - self._register_resources(client) - return success - - async def fetch_prompts_for_server(self, server_name: str) -> bool: - """v1.2.0: 获取指定服务器的提示模板列表""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - if not client.is_connected: - return False - - success = await client.fetch_prompts() - if success: - async with self._lock: - self._register_prompts(client) - return success - - async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult: - """v1.2.0: 读取资源内容 - - Args: - uri: 资源 URI - server_name: 指定服务器名称(可选,不指定则自动查找) - """ - # 如果指定了服务器 - if server_name: - if server_name not in self._clients: - return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") - client = self._clients[server_name] - return await client.read_resource(uri) - - # 自动查找拥有该资源的服务器 - for _resource_key, (resource_info, client) in self._all_resources.items(): - if resource_info.uri == uri: - return await client.read_resource(uri) - - # 尝试在所有支持 resources 的服务器上查找 - for client in self._clients.values(): - if client.is_connected and client.supports_resources: - result = await client.read_resource(uri) - if result.success: - return result - - return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}") - - async def get_prompt( - self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None - ) -> MCPCallResult: - """v1.2.0: 获取提示模板内容 - - Args: - name: 提示模板名称 - arguments: 模板参数 - server_name: 指定服务器名称(可选) - """ - # 如果指定了服务器 - if server_name: - if server_name not in self._clients: - return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") - client = self._clients[server_name] - return await client.get_prompt(name, arguments) - - # 自动查找拥有该提示模板的服务器 - for _prompt_key, (prompt_info, client) in self._all_prompts.items(): - if prompt_info.name == name: - return await client.get_prompt(name, arguments) - - return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}") - - # ==================== 心跳检测 ==================== - - async def start_heartbeat(self) -> None: - """启动心跳检测任务""" - if self._heartbeat_running: - logger.warning("心跳检测任务已在运行") - return - - heartbeat_enabled = self._settings.get("heartbeat_enabled", True) - if not heartbeat_enabled: - logger.info("心跳检测已禁用") - return - - self._heartbeat_running = True - self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) - logger.info("心跳检测任务已启动") - - async def stop_heartbeat(self) -> None: - """停止心跳检测任务""" - self._heartbeat_running = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - try: - await self._heartbeat_task - except asyncio.CancelledError: - pass - self._heartbeat_task = None - logger.info("心跳检测任务已停止") - - async def _heartbeat_loop(self) -> None: - """心跳检测循环(v1.5.2: 智能心跳间隔)""" - base_interval = self._settings.get("heartbeat_interval", 60.0) - auto_reconnect = self._settings.get("auto_reconnect", True) - max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3) - - # v1.5.2: 智能心跳配置 - adaptive_enabled = self._settings.get("heartbeat_adaptive", True) - max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0) - - # 每个服务器独立的心跳间隔(根据稳定性动态调整) - server_intervals: Dict[str, float] = {} - min_interval = max(base_interval * 0.5, 30.0) # 最小间隔 - max_interval = base_interval * max_multiplier # 最大间隔 - - mode_str = "智能" if adaptive_enabled else "固定" - logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}秒") - - while self._heartbeat_running: - try: - # 使用最小的服务器间隔作为循环间隔 - current_interval = min(server_intervals.values()) if server_intervals else base_interval - current_interval = max(current_interval, min_interval) - - await asyncio.sleep(current_interval) - - if not self._heartbeat_running: - break - - current_time = time.time() - - # 检查所有已启用的服务器 - for server_name, client in list(self._clients.items()): - if not client.config.enabled: - continue - - # 初始化服务器间隔 - if server_name not in server_intervals: - server_intervals[server_name] = base_interval - - # 检查是否到达该服务器的心跳时间 - last_heartbeat = client.stats.last_heartbeat_time or 0 - if current_time - last_heartbeat < server_intervals[server_name] * 0.9: - continue # 还没到心跳时间 - - if client.is_connected: - # 检查健康状态 - healthy = await client.check_health() - if healthy: - # v1.5.2: 智能心跳 - 稳定服务器逐渐增加间隔 - if adaptive_enabled and client.stats.consecutive_failures == 0: - new_interval = min(server_intervals[server_name] * 1.2, max_interval) - if new_interval != server_intervals[server_name]: - server_intervals[server_name] = new_interval - logger.debug(f"[{server_name}] 稳定,心跳间隔调整为 {new_interval:.0f}s") - else: - logger.warning(f"[{server_name}] 心跳检测失败,连接可能已断开") - # 失败后重置为基准间隔 - if adaptive_enabled: - server_intervals[server_name] = base_interval - self._notify_status_change() - if auto_reconnect: - await self._try_reconnect(server_name, max_reconnect_attempts) - else: - # 服务器未连接,尝试重连 - if adaptive_enabled: - # 智能心跳:断开的服务器使用较短间隔 - server_intervals[server_name] = min_interval - if auto_reconnect and client.stats.consecutive_failures < max_reconnect_attempts: - logger.info(f"[{server_name}] 检测到断开,尝试重连...") - await self._try_reconnect(server_name, max_reconnect_attempts) - elif client.stats.consecutive_failures >= max_reconnect_attempts: - if adaptive_enabled: - # 达到最大重连次数,降低检测频率 - server_intervals[server_name] = max_interval - logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率") - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"心跳检测循环出错: {e}") - await asyncio.sleep(5) - - async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool: - """尝试重连服务器""" - client = self._clients.get(server_name) - if not client: - return False - - if client.stats.consecutive_failures >= max_attempts: - logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连") - return False - - logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})") - - success = await self.reconnect_server(server_name) - if not success: - client.stats.record_failure() - - self._notify_status_change() # 重连后更新状态 - return success - - # ==================== 统计和状态 ==================== - - def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]: - """获取指定工具的统计信息""" - if tool_key not in self._all_tools: - return None - - tool_info, client = self._all_tools[tool_key] - stats = client.get_tool_stats(tool_info.name) - return stats.to_dict() if stats else None - - def get_all_stats(self) -> Dict[str, Any]: - """获取所有统计信息""" - server_stats = {} - tool_stats = {} - - for server_name, client in self._clients.items(): - server_stats[server_name] = client.stats.to_dict() - for tool_name, stats in client.get_all_tool_stats().items(): - full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}" - tool_stats[full_key] = stats.to_dict() - - uptime = time.time() - self._global_stats["start_time"] - - return { - "global": { - **self._global_stats, - "uptime_seconds": round(uptime, 2), - "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) - if uptime > 0 - else 0, - }, - "servers": server_stats, - "tools": tool_stats, - } - - async def shutdown(self) -> None: - """关闭所有连接""" - # 停止心跳检测 - await self.stop_heartbeat() - - async with self._lock: - for client in self._clients.values(): - await client.disconnect() - self._clients.clear() - self._all_tools.clear() - self._all_resources.clear() # v1.2.0 - self._all_prompts.clear() # v1.2.0 - logger.info("MCP 客户端管理器已关闭") - - def get_status(self) -> Dict[str, Any]: - """获取状态信息""" - return { - "total_servers": len(self._clients), - "connected_servers": len(self.connected_servers), - "disconnected_servers": len(self.disconnected_servers), - "total_tools": len(self._all_tools), - "total_resources": len(self._all_resources), # v1.2.0 - "total_prompts": len(self._all_prompts), # v1.2.0 - "heartbeat_running": self._heartbeat_running, - "servers": { - name: { - "connected": client.is_connected, - "enabled": client.config.enabled, - "tools_count": len(client.tools), - "resources_count": len(client.resources), # v1.2.0 - "prompts_count": len(client.prompts), # v1.2.0 - "supports_resources": client.supports_resources, # v1.2.0 - "supports_prompts": client.supports_prompts, # v1.2.0 - "transport": client.config.transport.value, - "consecutive_failures": client.stats.consecutive_failures, - "circuit_breaker": client.get_circuit_breaker_status(), # v1.7.0 - } - for name, client in self._clients.items() - }, - "global_stats": self._global_stats, - } - - -# 全局单例 -mcp_manager = MCPClientManager() diff --git a/plugins/MaiBot_MCPBridgePlugin/plugin.py b/plugins/MaiBot_MCPBridgePlugin/plugin.py deleted file mode 100644 index 1d965e25..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/plugin.py +++ /dev/null @@ -1,3733 +0,0 @@ -""" -MCP 桥接插件 v2.0.0 -将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot - -v2.0.0 配置与架构精简(功能保持不变): -- MCP 服务器配置统一为 Claude Desktop 的 mcpServers JSON(WebUI / config.toml 同一入口) -- 兼容迁移:检测到旧版 servers.list 时自动迁移为 mcpServers(仅迁移,避免多入口混淆) -- 移除 WebUI 导入导出/快速添加服务器的旧实现(避免 tomlkit 依赖与格式混乱) - -v1.9.0 双轨制架构: -- 软流程 (ReAct): LLM 自主决策,动态多轮调用 MCP 工具,灵活应对复杂场景 -- 硬流程 (Workflow): 用户预定义的工作流,固定执行顺序,可靠可控 -- 工具链重命名为 Workflow,更清晰地表达其"预定义流程"的本质 -- 命令更新:/mcp workflow 替代 /mcp chain - -v1.8.1 工具链易用性优化: -- 快速添加工具链:WebUI 表单式配置,无需手写 JSON -- 工具链模板:提供常用工具链配置模板参考 -- 使用指南:内置变量语法和命令说明 -- 状态显示优化:详细展示工具链步骤和参数信息 - -v1.8.0 工具链支持: -- 工具链:将多个工具按顺序执行,后续工具可使用前序工具的输出 -- 自定义工具链:在 WebUI 配置工具链,自动注册为组合工具供 LLM 调用 -- 变量替换:支持 ${input.参数}、${step.输出键}、${prev} 变量 -- 工具链命令:/mcp chain 查看、测试、管理工具链 - -v1.7.0 稳定性与易用性优化: -- 断路器模式:故障服务器快速失败,避免拖慢整体响应 -- 状态实时刷新:WebUI 每 10 秒自动更新连接状态 -- 断路器状态显示:在状态面板显示熔断/试探状态 - -v1.6.0 配置导入导出: -- 新增 /mcp import 命令,支持从 Claude Desktop 格式导入配置 -- 新增 /mcp export 命令,导出为 Claude Desktop (mcpServers) 格式 -- 支持 stdio、sse、http、streamable_http 全部传输类型 -- 自动跳过同名服务器,防止重复导入 - -v1.5.4 易用性优化: -- 新增 MCP 服务器获取快捷入口(魔搭、Smithery、Glama 等) -- 优化快速入门指南,提供配置示例 -- 帮助新用户快速上手 MCP - -v1.5.3 配置优化: -- 新增智能心跳 WebUI 配置项:启用开关、最大间隔倍数 -- 支持在 WebUI 中开启/关闭智能心跳功能 - -v1.5.2 性能优化: -- 智能心跳间隔:根据服务器稳定性动态调整心跳频率 -- 稳定服务器逐渐增加间隔,减少不必要的网络请求 -- 断开的服务器使用较短间隔快速重连 - -v1.5.1 易用性优化(v2.0.0 起已移除): -- 「快速添加服务器」表单式配置(已统一为 Claude mcpServers JSON,避免多入口混淆) - -v1.5.0 性能优化: -- 服务器并行连接:多个服务器同时连接,大幅减少启动时间 -- 连接耗时统计:日志显示并行连接总耗时 - -v1.4.4 修复: -- 修复首次生成默认配置文件时多行字符串导致 TOML 解析失败的问题 -- 简化 config_schema 默认值,避免主程序 json.dumps 产生无效 TOML - -v1.4.3 修复: -- 修复 WebUI 保存配置后多行字符串格式错误导致配置文件无法读取的问题 -- 清理未使用的导入 - -v1.4.0 新增功能: -- 工具禁用管理 -- 调用链路追踪 -- 工具调用缓存 -- 工具权限控制 -""" - -import asyncio -import fnmatch -import hashlib -import json -import re -import time -import uuid -from collections import OrderedDict, deque -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type - -from src.common.logger import get_logger -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseTool, - BaseCommand, - ComponentInfo, - ConfigField, - ToolParamType, -) -from src.plugin_system.base.config_types import section_meta -from src.plugin_system.base.component_types import ToolInfo, ComponentType, EventType -from src.plugin_system.base.base_events_handler import BaseEventHandler - -from .mcp_client import ( - MCPServerConfig, - MCPToolInfo, - MCPResourceInfo, - MCPPromptInfo, - TransportType, - mcp_manager, -) -from .core.claude_config import ( - ClaudeConfigError, - legacy_servers_list_to_claude_config, - parse_claude_mcp_config, -) -from .tool_chain import ( - ToolChainDefinition, - tool_chain_manager, -) - -logger = get_logger("mcp_bridge_plugin") - - -# ============================================================================ -# v1.4.0: 调用链路追踪 -# ============================================================================ - - -@dataclass -class ToolCallRecord: - """工具调用记录""" - - call_id: str - timestamp: float - tool_name: str - server_name: str - chat_id: str = "" - user_id: str = "" - user_query: str = "" - arguments: Dict = field(default_factory=dict) - raw_result: str = "" - processed_result: str = "" - duration_ms: float = 0.0 - success: bool = True - error: str = "" - post_processed: bool = False - cache_hit: bool = False - - -class ToolCallTracer: - """工具调用追踪器""" - - def __init__(self, max_records: int = 100): - self._records: deque[ToolCallRecord] = deque(maxlen=max_records) - self._enabled: bool = True - self._log_enabled: bool = False - self._log_path: Optional[Path] = None - - def configure(self, enabled: bool, max_records: int, log_enabled: bool, log_path: Optional[Path] = None) -> None: - """配置追踪器""" - self._enabled = enabled - self._records = deque(self._records, maxlen=max_records) - self._log_enabled = log_enabled - self._log_path = log_path - - def record(self, record: ToolCallRecord) -> None: - """添加调用记录""" - if not self._enabled: - return - - self._records.append(record) - - if self._log_enabled and self._log_path: - self._write_to_log(record) - - def get_recent(self, n: int = 10) -> List[ToolCallRecord]: - """获取最近 N 条记录""" - return list(self._records)[-n:] - - def get_by_tool(self, tool_name: str) -> List[ToolCallRecord]: - """按工具名筛选记录""" - return [r for r in self._records if r.tool_name == tool_name] - - def get_by_server(self, server_name: str) -> List[ToolCallRecord]: - """按服务器名筛选记录""" - return [r for r in self._records if r.server_name == server_name] - - def clear(self) -> None: - """清空记录""" - self._records.clear() - - def _write_to_log(self, record: ToolCallRecord) -> None: - """写入 JSONL 日志文件""" - try: - if self._log_path: - self._log_path.parent.mkdir(parents=True, exist_ok=True) - with open(self._log_path, "a", encoding="utf-8") as f: - f.write(json.dumps(asdict(record), ensure_ascii=False) + "\n") - except Exception as e: - logger.warning(f"写入追踪日志失败: {e}") - - @property - def total_records(self) -> int: - return len(self._records) - - -# 全局追踪器实例 -tool_call_tracer = ToolCallTracer() - - -# ============================================================================ -# v1.4.0: 工具调用缓存 -# ============================================================================ - - -@dataclass -class CacheEntry: - """缓存条目""" - - tool_name: str - args_hash: str - result: str - created_at: float - expires_at: float - hit_count: int = 0 - - -class ToolCallCache: - """工具调用缓存(LRU)""" - - def __init__(self, max_entries: int = 200, ttl: int = 300): - self._cache: OrderedDict[str, CacheEntry] = OrderedDict() - self._max_entries = max_entries - self._ttl = ttl - self._enabled = False - self._exclude_patterns: List[str] = [] - self._stats = {"hits": 0, "misses": 0} - - def configure(self, enabled: bool, ttl: int, max_entries: int, exclude_tools: str) -> None: - """配置缓存""" - self._enabled = enabled - self._ttl = ttl - self._max_entries = max_entries - self._exclude_patterns = [p.strip() for p in exclude_tools.strip().split("\n") if p.strip()] - - def get(self, tool_name: str, args: Dict) -> Optional[str]: - """获取缓存""" - if not self._enabled: - return None - - if self._is_excluded(tool_name): - return None - - key = self._generate_key(tool_name, args) - - if key not in self._cache: - self._stats["misses"] += 1 - return None - - entry = self._cache[key] - - # 检查是否过期 - if time.time() > entry.expires_at: - del self._cache[key] - self._stats["misses"] += 1 - return None - - # LRU: 移到末尾 - self._cache.move_to_end(key) - entry.hit_count += 1 - self._stats["hits"] += 1 - - return entry.result - - def set(self, tool_name: str, args: Dict, result: str) -> None: - """设置缓存""" - if not self._enabled: - return - - if self._is_excluded(tool_name): - return - - key = self._generate_key(tool_name, args) - now = time.time() - - entry = CacheEntry( - tool_name=tool_name, - args_hash=key, - result=result, - created_at=now, - expires_at=now + self._ttl, - ) - - # 如果已存在,更新 - if key in self._cache: - self._cache[key] = entry - self._cache.move_to_end(key) - else: - # 检查容量 - self._evict_if_needed() - self._cache[key] = entry - - def clear(self) -> None: - """清空缓存""" - self._cache.clear() - self._stats = {"hits": 0, "misses": 0} - - def _generate_key(self, tool_name: str, args: Dict) -> str: - """生成缓存键""" - args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) - content = f"{tool_name}:{args_str}" - return hashlib.md5(content.encode()).hexdigest() - - def _is_excluded(self, tool_name: str) -> bool: - """检查是否在排除列表中""" - for pattern in self._exclude_patterns: - if fnmatch.fnmatch(tool_name, pattern): - return True - return False - - def _evict_if_needed(self) -> None: - """必要时淘汰条目""" - # 先清理过期的 - now = time.time() - expired_keys = [k for k, v in self._cache.items() if now > v.expires_at] - for k in expired_keys: - del self._cache[k] - - # LRU 淘汰 - while len(self._cache) >= self._max_entries: - self._cache.popitem(last=False) - - def get_stats(self) -> Dict[str, Any]: - """获取缓存统计""" - total = self._stats["hits"] + self._stats["misses"] - hit_rate = (self._stats["hits"] / total * 100) if total > 0 else 0 - return { - "enabled": self._enabled, - "entries": len(self._cache), - "max_entries": self._max_entries, - "ttl": self._ttl, - "hits": self._stats["hits"], - "misses": self._stats["misses"], - "hit_rate": f"{hit_rate:.1f}%", - } - - -# 全局缓存实例 -tool_call_cache = ToolCallCache() - - -# ============================================================================ -# v1.4.0: 工具权限控制 -# ============================================================================ - - -class PermissionChecker: - """工具权限检查器""" - - def __init__(self): - self._enabled = False - self._default_mode = "allow_all" # allow_all 或 deny_all - self._rules: List[Dict] = [] - self._quick_deny_groups: set = set() - self._quick_allow_users: set = set() - - def configure( - self, - enabled: bool, - default_mode: str, - rules_json: str, - quick_deny_groups: str = "", - quick_allow_users: str = "", - ) -> None: - """配置权限检查器""" - self._enabled = enabled - self._default_mode = default_mode if default_mode in ("allow_all", "deny_all") else "allow_all" - - # 解析快捷配置 - self._quick_deny_groups = {g.strip() for g in quick_deny_groups.strip().split("\n") if g.strip()} - self._quick_allow_users = {u.strip() for u in quick_allow_users.strip().split("\n") if u.strip()} - - try: - self._rules = json.loads(rules_json) if rules_json.strip() else [] - except json.JSONDecodeError as e: - logger.warning(f"权限规则 JSON 解析失败: {e}") - self._rules = [] - - def check(self, tool_name: str, chat_id: str, user_id: str, is_group: bool) -> bool: - """检查权限 - - Args: - tool_name: 工具名称 - chat_id: 聊天 ID(群号或私聊 ID) - user_id: 用户 ID - is_group: 是否为群聊 - - Returns: - True 表示允许,False 表示拒绝 - """ - if not self._enabled: - return True - - # 快捷配置优先级最高 - # 1. 管理员白名单(始终允许) - if user_id and user_id in self._quick_allow_users: - return True - - # 2. 禁用群列表(始终拒绝) - if is_group and chat_id and chat_id in self._quick_deny_groups: - return False - - # 查找匹配的规则 - for rule in self._rules: - tool_pattern = rule.get("tool", "") - if not self._match_tool(tool_pattern, tool_name): - continue - - # 找到匹配的规则 - mode = rule.get("mode", "") - allowed = rule.get("allowed", []) - denied = rule.get("denied", []) - - # 构建当前上下文的 ID 列表 - context_ids = self._build_context_ids(chat_id, user_id, is_group) - - # 检查 denied 列表(优先级最高) - if denied: - for ctx_id in context_ids: - if self._match_id_list(denied, ctx_id): - return False - - # 检查 allowed 列表 - if allowed: - for ctx_id in context_ids: - if self._match_id_list(allowed, ctx_id): - return True - # 如果是 whitelist 模式且不在 allowed 中,拒绝 - if mode == "whitelist": - return False - - # 规则匹配但没有明确允许/拒绝,继续检查下一条规则 - - # 没有匹配的规则,使用默认模式 - return self._default_mode == "allow_all" - - def _match_tool(self, pattern: str, tool_name: str) -> bool: - """工具名通配符匹配""" - if not pattern: - return False - return fnmatch.fnmatch(tool_name, pattern) - - def _build_context_ids(self, chat_id: str, user_id: str, is_group: bool) -> List[str]: - """构建上下文 ID 列表""" - ids = [] - - # 用户级别(任何场景生效) - if user_id: - ids.append(f"qq:{user_id}:user") - - # 场景级别 - if is_group and chat_id: - ids.append(f"qq:{chat_id}:group") - elif chat_id: - ids.append(f"qq:{chat_id}:private") - - return ids - - def _match_id_list(self, id_list: List[str], context_id: str) -> bool: - """检查 ID 是否在列表中""" - for rule_id in id_list: - if fnmatch.fnmatch(context_id, rule_id): - return True - return False - - def get_rules_for_tool(self, tool_name: str) -> List[Dict]: - """获取特定工具的权限规则""" - return [r for r in self._rules if self._match_tool(r.get("tool", ""), tool_name)] - - -# 全局权限检查器实例 -permission_checker = PermissionChecker() - - -# ============================================================================ -# 工具类型转换 -# ============================================================================ - - -def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: - """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" - type_mapping = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "number": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "array": ToolParamType.STRING, - "object": ToolParamType.STRING, - } - return type_mapping.get(json_type, ToolParamType.STRING) - - -def parse_mcp_parameters( - input_schema: Dict[str, Any], -) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: - """解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式""" - parameters = [] - - if not input_schema: - # 为无参数的工具添加占位参数,避免某些模型报错 - parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) - return parameters - - properties = input_schema.get("properties", {}) - required = input_schema.get("required", []) - - # 如果没有任何参数,添加占位参数 - if not properties: - parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) - return parameters - - for param_name, param_info in properties.items(): - json_type = param_info.get("type", "string") - param_type = convert_json_type_to_tool_param_type(json_type) - description = param_info.get("description", f"参数 {param_name}") - - if json_type == "array": - description = f"{description} (JSON 数组格式)" - elif json_type == "object": - description = f"{description} (JSON 对象格式)" - - is_required = param_name in required - enum_values = param_info.get("enum") - - if enum_values is not None: - enum_values = [str(v) for v in enum_values] - - parameters.append((param_name, param_type, description, is_required, enum_values)) - - return parameters - - -# ============================================================================ -# MCP 工具代理 -# ============================================================================ - - -class MCPToolProxy(BaseTool): - """MCP 工具代理基类""" - - name: str = "" - description: str = "" - parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] - available_for_llm: bool = True - - _mcp_tool_key: str = "" - _mcp_original_name: str = "" - _mcp_server_name: str = "" - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行 MCP 工具调用""" - global _plugin_instance - - call_id = str(uuid.uuid4())[:8] - start_time = time.time() - - # 移除 MaiBot 内部标记 - args = {k: v for k, v in function_args.items() if k != "llm_called"} - - # 解析 JSON 字符串参数 - parsed_args = {} - for key, value in args.items(): - if isinstance(value, str): - try: - if value.startswith(("[", "{")): - parsed_args[key] = json.loads(value) - else: - parsed_args[key] = value - except json.JSONDecodeError: - parsed_args[key] = value - else: - parsed_args[key] = value - - # 获取上下文信息 - chat_id, user_id, is_group, user_query = self._get_context_info() - - # v1.4.0: 权限检查 - if not permission_checker.check(self.name, chat_id, user_id, is_group): - logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") - return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"} - - logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") - - # v1.4.0: 检查缓存 - cache_hit = False - cached_result = tool_call_cache.get(self.name, parsed_args) - - if cached_result is not None: - cache_hit = True - content = cached_result - raw_result = cached_result - success = True - error = "" - logger.debug(f"MCP 工具 {self.name} 命中缓存") - else: - # 调用 MCP - result = await mcp_manager.call_tool(self._mcp_tool_key, parsed_args) - - if result.success: - content = result.content - raw_result = content - success = True - error = "" - - # 存入缓存 - tool_call_cache.set(self.name, parsed_args, content) - else: - content = self._format_error_message(result.error, result.duration_ms) - raw_result = result.error - success = False - error = result.error - logger.warning(f"MCP 工具 {self.name} 调用失败: {result.error}") - - # v1.3.0: 后处理 - post_processed = False - processed_result = content - if success: - processed_content = await self._post_process_result(content) - if processed_content != content: - post_processed = True - processed_result = processed_content - content = processed_content - - duration_ms = (time.time() - start_time) * 1000 - - # v1.4.0: 记录调用追踪 - record = ToolCallRecord( - call_id=call_id, - timestamp=start_time, - tool_name=self.name, - server_name=self._mcp_server_name, - chat_id=chat_id, - user_id=user_id, - user_query=user_query, - arguments=parsed_args, - raw_result=raw_result[:1000] if raw_result else "", - processed_result=processed_result[:1000] if processed_result else "", - duration_ms=duration_ms, - success=success, - error=error, - post_processed=post_processed, - cache_hit=cache_hit, - ) - tool_call_tracer.record(record) - - return {"name": self.name, "content": content} - - def _get_context_info(self) -> Tuple[str, str, bool, str]: - """获取上下文信息""" - chat_id = "" - user_id = "" - is_group = False - user_query = "" - - if self.chat_stream and hasattr(self.chat_stream, "context") and self.chat_stream.context: - try: - ctx = self.chat_stream.context - if hasattr(ctx, "chat_id"): - chat_id = str(ctx.chat_id) if ctx.chat_id else "" - if hasattr(ctx, "user_id"): - user_id = str(ctx.user_id) if ctx.user_id else "" - if hasattr(ctx, "is_group"): - is_group = bool(ctx.is_group) - - last_message = ctx.get_last_message() - if last_message and hasattr(last_message, "processed_plain_text"): - user_query = last_message.processed_plain_text or "" - except Exception as e: - logger.debug(f"获取上下文信息失败: {e}") - - return chat_id, user_id, is_group, user_query - - async def _post_process_result(self, content: str) -> str: - """v1.3.0: 对工具返回结果进行后处理(摘要提炼)""" - global _plugin_instance - - if _plugin_instance is None: - return content - - settings = _plugin_instance.config.get("settings", {}) - - if not settings.get("post_process_enabled", False): - return content - - server_post_config = self._get_server_post_process_config() - - if server_post_config is not None: - if not server_post_config.get("enabled", True): - return content - - threshold = settings.get("post_process_threshold", 500) - if server_post_config and "threshold" in server_post_config: - threshold = server_post_config["threshold"] - - content_length = len(content) if content else 0 - if content_length <= threshold: - return content - - user_query = self._get_context_info()[3] - if not user_query: - return content - - max_tokens = settings.get("post_process_max_tokens", 500) - if server_post_config and "max_tokens" in server_post_config: - max_tokens = server_post_config["max_tokens"] - - prompt_template = settings.get("post_process_prompt", "") - if server_post_config and "prompt" in server_post_config: - prompt_template = server_post_config["prompt"] - - if not prompt_template: - prompt_template = """用户问题:{query} - -工具返回内容: -{result} - -请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:""" - - try: - prompt = prompt_template.format(query=user_query, result=content) - except KeyError as e: - logger.warning(f"后处理 prompt 模板格式错误: {e}") - return content - - try: - processed_content = await self._call_post_process_llm(prompt, max_tokens, settings, server_post_config) - if processed_content: - logger.info(f"MCP 工具 {self.name} 后处理完成: {content_length} -> {len(processed_content)} 字符") - return processed_content - return content - except Exception as e: - logger.error(f"MCP 工具 {self.name} 后处理失败: {e}") - return content - - def _get_server_post_process_config(self) -> Optional[Dict[str, Any]]: - """获取当前服务器的后处理配置""" - global _plugin_instance - - if _plugin_instance is None: - return None - - servers = _plugin_instance._load_mcp_servers_config() - for server_conf in servers: - if server_conf.get("name") == self._mcp_server_name: - return server_conf.get("post_process") - - return None - - async def _call_post_process_llm( - self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]] - ) -> Optional[str]: - """调用 LLM 进行后处理""" - from src.config.config import model_config - from src.config.model_configs import TaskConfig - from src.llm_models.utils_model import LLMRequest - - model_name = settings.get("post_process_model", "") - if server_config and "model" in server_config: - model_name = server_config["model"] - - if model_name: - task_config = TaskConfig( - model_list=[model_name], - max_tokens=max_tokens, - temperature=0.3, - slow_threshold=30.0, - ) - else: - task_config = model_config.model_task_config.utils - - llm_request = LLMRequest(model_set=task_config, request_type="mcp_post_process") - - response, (reasoning, model_used, _) = await llm_request.generate_response_async( - prompt=prompt, - max_tokens=max_tokens, - temperature=0.3, - ) - - return response.strip() if response else None - - def _format_error_message(self, error: str, duration_ms: float) -> str: - """格式化友好的错误消息""" - if not error: - return "工具调用失败(未知错误)" - - error_lower = error.lower() - - if "未连接" in error or "not connected" in error_lower: - return f"⚠️ MCP 服务器 [{self._mcp_server_name}] 未连接,请检查服务器状态或等待自动重连" - - if "超时" in error or "timeout" in error_lower: - return f"⏱️ 工具调用超时(耗时 {duration_ms:.0f}ms),服务器响应过慢,请稍后重试" - - if "connection" in error_lower and ("closed" in error_lower or "reset" in error_lower): - return f"🔌 与 MCP 服务器 [{self._mcp_server_name}] 的连接已断开,正在尝试重连..." - - if "invalid" in error_lower and "argument" in error_lower: - return f"❌ 参数错误: {error}" - - return f"❌ 工具调用失败: {error}" - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - """直接执行(供其他插件调用)""" - return await self.execute(function_args) - - -def create_mcp_tool_class( - tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False -) -> Type[MCPToolProxy]: - """根据 MCP 工具信息动态创建 BaseTool 子类""" - parameters = parse_mcp_parameters(tool_info.input_schema) - - class_name = f"MCPTool_{tool_info.server_name}_{tool_info.name}".replace("-", "_").replace(".", "_") - tool_name = tool_key.replace("-", "_").replace(".", "_") - - description = tool_info.description - if not description.endswith(f"[来自 MCP 服务器: {tool_info.server_name}]"): - description = f"{description} [来自 MCP 服务器: {tool_info.server_name}]" - - tool_class = type( - class_name, - (MCPToolProxy,), - { - "name": tool_name, - "description": description, - "parameters": parameters, - "available_for_llm": not disabled, # v1.4.0: 禁用的工具不可被 LLM 调用 - "_mcp_tool_key": tool_key, - "_mcp_original_name": tool_info.name, - "_mcp_server_name": tool_info.server_name, - }, - ) - - return tool_class - - -class MCPToolRegistry: - """MCP 工具注册表""" - - def __init__(self): - self._tool_classes: Dict[str, Type[MCPToolProxy]] = {} - self._tool_infos: Dict[str, ToolInfo] = {} - - def register_tool( - self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False - ) -> Tuple[ToolInfo, Type[MCPToolProxy]]: - """注册 MCP 工具""" - tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) - - self._tool_classes[tool_key] = tool_class - - info = ToolInfo( - name=tool_class.name, - tool_description=tool_class.description, - enabled=True, - tool_parameters=tool_class.parameters, - component_type=ComponentType.TOOL, - ) - self._tool_infos[tool_key] = info - - return info, tool_class - - def unregister_tool(self, tool_key: str) -> bool: - """注销工具""" - if tool_key in self._tool_classes: - del self._tool_classes[tool_key] - del self._tool_infos[tool_key] - return True - return False - - def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: - """获取所有工具组件""" - return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - - def clear(self) -> None: - """清空所有注册""" - self._tool_classes.clear() - self._tool_infos.clear() - - -# 全局工具注册表 -mcp_tool_registry = MCPToolRegistry() - -# 全局插件实例引用 -_plugin_instance: Optional["MCPBridgePlugin"] = None - - -# ============================================================================ -# 内置工具 -# ============================================================================ - - -class MCPReadResourceTool(BaseTool): - """v1.2.0: MCP 资源读取工具""" - - name = "mcp_read_resource" - description = "读取 MCP 服务器提供的资源内容(如文件、数据库记录等)。使用前请先用 mcp_status 查看可用资源。" - parameters = [ - ("uri", ToolParamType.STRING, "资源 URI(如 file:///path/to/file 或自定义 URI)", True, None), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选,不指定则自动查找)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - uri = function_args.get("uri", "") - server_name = function_args.get("server_name") - - if not uri: - return {"name": self.name, "content": "❌ 请提供资源 URI"} - - result = await mcp_manager.read_resource(uri, server_name) - - if result.success: - return {"name": self.name, "content": result.content} - else: - return {"name": self.name, "content": f"❌ 读取资源失败: {result.error}"} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -class MCPGetPromptTool(BaseTool): - """v1.2.0: MCP 提示模板工具""" - - name = "mcp_get_prompt" - description = "获取 MCP 服务器提供的提示模板内容。使用前请先用 mcp_status 查看可用模板。" - parameters = [ - ("name", ToolParamType.STRING, "提示模板名称", True, None), - ("arguments", ToolParamType.STRING, "模板参数(JSON 对象格式)", False, None), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - prompt_name = function_args.get("name", "") - arguments_str = function_args.get("arguments", "") - server_name = function_args.get("server_name") - - if not prompt_name: - return {"name": self.name, "content": "❌ 请提供提示模板名称"} - - arguments = None - if arguments_str: - try: - arguments = json.loads(arguments_str) - except json.JSONDecodeError: - return {"name": self.name, "content": "❌ 参数格式错误,请使用 JSON 对象格式"} - - result = await mcp_manager.get_prompt(prompt_name, arguments, server_name) - - if result.success: - return {"name": self.name, "content": result.content} - else: - return {"name": self.name, "content": f"❌ 获取提示模板失败: {result.error}"} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -# ============================================================================ -# v1.8.0: 工具链代理工具 -# ============================================================================ - - -class ToolChainProxyBase(BaseTool): - """工具链代理基类""" - - name: str = "" - description: str = "" - parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] - available_for_llm: bool = True - - _chain_name: str = "" - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行工具链""" - # 移除内部标记 - args = {k: v for k, v in function_args.items() if k != "llm_called"} - - logger.debug(f"执行工具链 {self._chain_name},参数: {args}") - - result = await tool_chain_manager.execute_chain(self._chain_name, args) - - if result.success: - # 构建输出 - output_parts = [] - output_parts.append(result.final_output) - - # 可选:添加执行摘要 - # output_parts.append(f"\n\n---\n执行摘要:\n{result.to_summary()}") - - return {"name": self.name, "content": "\n".join(output_parts)} - else: - error_msg = f"⚠️ 工具链执行失败: {result.error}" - if result.step_results: - error_msg += f"\n\n执行详情:\n{result.to_summary()}" - return {"name": self.name, "content": error_msg} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBase]: - """根据工具链定义动态创建工具类""" - # 构建参数列表 - parameters = [] - for param_name, param_desc in chain.input_params.items(): - parameters.append((param_name, ToolParamType.STRING, param_desc, True, None)) - - # 生成类名和工具名 - class_name = f"ToolChain_{chain.name}".replace("-", "_").replace(".", "_") - tool_name = f"chain_{chain.name}".replace("-", "_").replace(".", "_") - - # 构建描述 - description = chain.description - if chain.steps: - step_names = [s.tool_name.split("_")[-1] for s in chain.steps[:3]] - description += f" (执行流程: {' → '.join(step_names)}{'...' if len(chain.steps) > 3 else ''})" - - tool_class = type( - class_name, - (ToolChainProxyBase,), - { - "name": tool_name, - "description": description, - "parameters": parameters, - "available_for_llm": True, - "_chain_name": chain.name, - }, - ) - - return tool_class - - -class ToolChainRegistry: - """工具链注册表""" - - def __init__(self): - self._tool_classes: Dict[str, Type[ToolChainProxyBase]] = {} - self._tool_infos: Dict[str, ToolInfo] = {} - - def register_chain(self, chain: ToolChainDefinition) -> Tuple[ToolInfo, Type[ToolChainProxyBase]]: - """注册工具链为组合工具""" - tool_class = create_chain_tool_class(chain) - - self._tool_classes[chain.name] = tool_class - - info = ToolInfo( - name=tool_class.name, - tool_description=tool_class.description, - enabled=True, - tool_parameters=tool_class.parameters, - component_type=ComponentType.TOOL, - ) - self._tool_infos[chain.name] = info - - return info, tool_class - - def unregister_chain(self, chain_name: str) -> bool: - """注销工具链""" - if chain_name in self._tool_classes: - del self._tool_classes[chain_name] - del self._tool_infos[chain_name] - return True - return False - - def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: - """获取所有工具链组件""" - return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - - def clear(self) -> None: - """清空所有注册""" - self._tool_classes.clear() - self._tool_infos.clear() - - -# 全局工具链注册表 -tool_chain_registry = ToolChainRegistry() - - -class MCPStatusTool(BaseTool): - """MCP 状态查询工具""" - - name = "mcp_status" - description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息" - parameters = [ - ( - "query_type", - ToolParamType.STRING, - "查询类型", - False, - ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"], - ), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - query_type = function_args.get("query_type", "status") - server_name = function_args.get("server_name") - - result_parts = [] - - if query_type in ("status", "all"): - result_parts.append(self._format_status(server_name)) - - if query_type in ("tools", "all"): - result_parts.append(self._format_tools(server_name)) - - if query_type in ("chains", "all"): - result_parts.append(self._format_chains()) - - if query_type in ("resources", "all"): - result_parts.append(self._format_resources(server_name)) - - if query_type in ("prompts", "all"): - result_parts.append(self._format_prompts(server_name)) - - if query_type in ("stats", "all"): - result_parts.append(self._format_stats(server_name)) - - # v1.4.0: 追踪记录 - if query_type in ("trace",): - result_parts.append(self._format_trace()) - - # v1.4.0: 缓存状态 - if query_type in ("cache",): - result_parts.append(self._format_cache()) - - return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"} - - def _format_status(self, server_name: Optional[str] = None) -> str: - status = mcp_manager.get_status() - lines = ["📊 MCP 桥接插件状态"] - lines.append(f" 总服务器数: {status['total_servers']}") - lines.append(f" 已连接: {status['connected_servers']}") - lines.append(f" 已断开: {status['disconnected_servers']}") - lines.append(f" 可用工具数: {status['total_tools']}") - lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") - - lines.append("\n🔌 服务器详情:") - for name, info in status["servers"].items(): - if server_name and name != server_name: - continue - status_icon = "✅" if info["connected"] else "❌" - enabled_text = "" if info["enabled"] else " (已禁用)" - lines.append(f" {status_icon} {name}{enabled_text}") - lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") - if info["consecutive_failures"] > 0: - lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次") - - return "\n".join(lines) - - def _format_tools(self, server_name: Optional[str] = None) -> str: - tools = mcp_manager.all_tools - lines = ["🔧 可用 MCP 工具"] - - by_server: Dict[str, List[str]] = {} - for tool_key, (tool_info, _) in tools.items(): - if server_name and tool_info.server_name != server_name: - continue - if tool_info.server_name not in by_server: - by_server[tool_info.server_name] = [] - by_server[tool_info.server_name].append(f" • {tool_key}: {tool_info.description[:50]}...") - - for srv_name, tool_list in by_server.items(): - lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个工具):") - lines.extend(tool_list) - - if not by_server: - lines.append(" (无可用工具)") - - return "\n".join(lines) - - def _format_stats(self, server_name: Optional[str] = None) -> str: - stats = mcp_manager.get_all_stats() - lines = ["📈 调用统计"] - - g = stats["global"] - lines.append(f" 总调用次数: {g['total_tool_calls']}") - lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") - if g["total_tool_calls"] > 0: - success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 - lines.append(f" 成功率: {success_rate:.1f}%") - lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒") - - return "\n".join(lines) - - def _format_resources(self, server_name: Optional[str] = None) -> str: - resources = mcp_manager.all_resources - if not resources: - return "📦 当前没有可用的 MCP 资源" - - lines = ["📦 可用 MCP 资源"] - by_server: Dict[str, List[MCPResourceInfo]] = {} - for _key, (resource_info, _) in resources.items(): - if server_name and resource_info.server_name != server_name: - continue - if resource_info.server_name not in by_server: - by_server[resource_info.server_name] = [] - by_server[resource_info.server_name].append(resource_info) - - for srv_name, resource_list in by_server.items(): - lines.append(f"\n🔌 {srv_name} ({len(resource_list)} 个资源):") - for res in resource_list: - lines.append(f" • {res.name}: {res.uri}") - - return "\n".join(lines) - - def _format_prompts(self, server_name: Optional[str] = None) -> str: - prompts = mcp_manager.all_prompts - if not prompts: - return "📝 当前没有可用的 MCP 提示模板" - - lines = ["📝 可用 MCP 提示模板"] - by_server: Dict[str, List[MCPPromptInfo]] = {} - for _key, (prompt_info, _) in prompts.items(): - if server_name and prompt_info.server_name != server_name: - continue - if prompt_info.server_name not in by_server: - by_server[prompt_info.server_name] = [] - by_server[prompt_info.server_name].append(prompt_info) - - for srv_name, prompt_list in by_server.items(): - lines.append(f"\n🔌 {srv_name} ({len(prompt_list)} 个模板):") - for prompt in prompt_list: - lines.append(f" • {prompt.name}") - - return "\n".join(lines) - - def _format_trace(self) -> str: - """v1.4.0: 格式化追踪记录""" - records = tool_call_tracer.get_recent(10) - if not records: - return "🔍 暂无调用追踪记录" - - lines = ["🔍 最近调用追踪记录"] - for r in reversed(records): - status = "✅" if r.success else "❌" - cache = "📦" if r.cache_hit else "" - post = "🔄" if r.post_processed else "" - lines.append(f" {status}{cache}{post} {r.tool_name} ({r.duration_ms:.0f}ms)") - if r.error: - lines.append(f" 错误: {r.error[:50]}") - - return "\n".join(lines) - - def _format_cache(self) -> str: - """v1.4.0: 格式化缓存状态""" - stats = tool_call_cache.get_stats() - lines = ["🗄️ 缓存状态"] - lines.append(f" 启用: {'是' if stats['enabled'] else '否'}") - lines.append(f" 条目数: {stats['entries']}/{stats['max_entries']}") - lines.append(f" TTL: {stats['ttl']}秒") - lines.append(f" 命中: {stats['hits']}, 未命中: {stats['misses']}") - lines.append(f" 命中率: {stats['hit_rate']}") - return "\n".join(lines) - - def _format_chains(self) -> str: - """v1.8.0: 格式化工具链列表""" - chains = tool_chain_manager.get_all_chains() - if not chains: - return "🔗 当前没有配置工具链" - - lines = ["🔗 工具链列表"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f"\n{status} {name}") - lines.append(f" 描述: {chain.description[:50]}...") - lines.append(f" 步骤: {len(chain.steps)} 个") - for i, step in enumerate(chain.steps[:3]): - lines.append(f" {i + 1}. {step.tool_name}") - if len(chain.steps) > 3: - lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤") - if chain.input_params: - params = ", ".join(chain.input_params.keys()) - lines.append(f" 参数: {params}") - - return "\n".join(lines) - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -# ============================================================================ -# 命令处理 -# ============================================================================ - - -class MCPStatusCommand(BaseCommand): - """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" - - command_name = "mcp_status_command" - command_description = "查看 MCP 服务器连接状态和统计信息" - command_pattern = r"^[//]mcp(?:\s+(?Pstatus|tools|stats|reconnect|trace|cache|perm|export|search|chain))?(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """执行命令""" - subcommand = self.matched_groups.get("subcommand", "status") or "status" - arg = self.matched_groups.get("arg") - - if subcommand == "reconnect": - return await self._handle_reconnect(arg) - - # v1.4.0: 追踪命令 - if subcommand == "trace": - return await self._handle_trace(arg) - - # v1.4.0: 缓存命令 - if subcommand == "cache": - return await self._handle_cache(arg) - - # v1.4.0: 权限命令 - if subcommand == "perm": - return await self._handle_perm(arg) - - # v1.6.0: 导出命令 - if subcommand == "export": - return await self._handle_export(arg) - - # v1.7.0: 工具搜索命令 - if subcommand == "search": - return await self._handle_search(arg) - - # v1.8.0: 工具链命令 - if subcommand == "chain": - return await self._handle_chain(arg) - - result = self._format_output(subcommand, arg) - await self.send_text(result) - return (True, None, True) - - def _find_similar_servers(self, name: str, max_results: int = 3) -> List[str]: - """查找相似的服务器名称""" - name_lower = name.lower() - all_servers = list(mcp_manager._clients.keys()) - - # 简单的相似度匹配:包含关系或前缀匹配 - similar = [] - for srv in all_servers: - srv_lower = srv.lower() - if name_lower in srv_lower or srv_lower in name_lower: - similar.append(srv) - elif srv_lower.startswith(name_lower[:3]) if len(name_lower) >= 3 else False: - similar.append(srv) - - return similar[:max_results] - - async def _handle_reconnect(self, server_name: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """处理重连请求""" - if server_name: - if server_name not in mcp_manager._clients: - # 提示相似的服务器名 - similar = self._find_similar_servers(server_name) - msg = f"❌ 服务器 '{server_name}' 不存在" - if similar: - msg += f"\n💡 你是不是想找: {', '.join(similar)}" - await self.send_text(msg) - return (True, None, True) - - await self.send_text(f"🔄 正在重连服务器 {server_name}...") - success = await mcp_manager.reconnect_server(server_name) - if success: - await self.send_text(f"✅ 服务器 {server_name} 重连成功") - else: - await self.send_text(f"❌ 服务器 {server_name} 重连失败") - else: - disconnected = mcp_manager.disconnected_servers - if not disconnected: - await self.send_text("✅ 所有服务器都已连接") - return (True, None, True) - - await self.send_text(f"🔄 正在重连 {len(disconnected)} 个断开的服务器...") - for srv in disconnected: - success = await mcp_manager.reconnect_server(srv) - status = "✅" if success else "❌" - await self.send_text(f"{status} {srv}") - - return (True, None, True) - - async def _handle_trace(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理追踪命令""" - if arg and arg.isdigit(): - # /mcp trace 20 - 最近 N 条 - n = int(arg) - records = tool_call_tracer.get_recent(n) - elif arg: - # /mcp trace - 特定工具 - records = tool_call_tracer.get_by_tool(arg) - else: - # /mcp trace - 最近 10 条 - records = tool_call_tracer.get_recent(10) - - if not records: - await self.send_text("🔍 暂无调用追踪记录\n\n用法: /mcp trace [数量|工具名]") - return (True, None, True) - - lines = [f"🔍 调用追踪记录 ({len(records)} 条)"] - lines.append("-" * 30) - for i, r in enumerate(reversed(records)): - status_icon = "✅" if r.success else "❌" - cache_tag = " [缓存]" if r.cache_hit else "" - post_tag = " [后处理]" if r.post_processed else "" - ts = time.strftime("%H:%M:%S", time.localtime(r.timestamp)) - lines.append(f"{status_icon} [{ts}] {r.tool_name}") - lines.append(f" {r.duration_ms:.0f}ms | {r.server_name}{cache_tag}{post_tag}") - if r.error: - lines.append(f" 错误: {r.error[:50]}") - if i < len(records) - 1: - lines.append("") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_cache(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理缓存命令""" - if arg == "clear": - tool_call_cache.clear() - await self.send_text("✅ 缓存已清空") - return (True, None, True) - - stats = tool_call_cache.get_stats() - lines = ["🗄️ 缓存状态"] - lines.append(f"├ 启用: {'是' if stats['enabled'] else '否'}") - lines.append(f"├ 条目: {stats['entries']}/{stats['max_entries']}") - lines.append(f"├ TTL: {stats['ttl']}秒") - lines.append(f"├ 命中: {stats['hits']}") - lines.append(f"├ 未命中: {stats['misses']}") - lines.append(f"└ 命中率: {stats['hit_rate']}") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_perm(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理权限命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - perm_config = _plugin_instance.config.get("permissions", {}) - enabled = perm_config.get("perm_enabled", False) - default_mode = perm_config.get("perm_default_mode", "allow_all") - - if arg: - # 查看特定工具的权限 - rules = permission_checker.get_rules_for_tool(arg) - if not rules: - await self.send_text(f"🔐 工具 {arg} 无特定权限规则\n默认模式: {default_mode}") - else: - lines = [f"🔐 工具 {arg} 的权限规则:"] - for r in rules: - lines.append(f" • 模式: {r.get('mode', 'default')}") - if r.get("allowed"): - lines.append(f" 允许: {', '.join(r['allowed'][:3])}...") - if r.get("denied"): - lines.append(f" 拒绝: {', '.join(r['denied'][:3])}...") - await self.send_text("\n".join(lines)) - else: - # 查看权限配置概览 - lines = ["🔐 权限控制配置"] - lines.append(f"├ 启用: {'是' if enabled else '否'}") - lines.append(f"├ 默认模式: {default_mode}") - # 快捷配置 - deny_count = len(permission_checker._quick_deny_groups) - allow_count = len(permission_checker._quick_allow_users) - if deny_count > 0: - lines.append(f"├ 禁用群: {deny_count} 个") - if allow_count > 0: - lines.append(f"├ 管理员白名单: {allow_count} 人") - lines.append(f"└ 高级规则: {len(permission_checker._rules)} 条") - await self.send_text("\n".join(lines)) - - return (True, None, True) - - async def _handle_export(self, format_type: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.6.0: 处理导出命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - servers_section = _plugin_instance.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - claude_json = str(servers_section.get("claude_config_json", "") or "") - if not claude_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - claude_json = legacy_servers_list_to_claude_config(legacy_list) or "" - - if not claude_json.strip(): - await self.send_text("📤 当前没有配置任何服务器") - return (True, None, True) - - try: - pretty = json.dumps(json.loads(claude_json), ensure_ascii=False, indent=2) - except Exception: - pretty = claude_json - - lines = ["📤 导出为 Claude Desktop 格式(mcpServers):"] - if format_type and format_type.strip() and format_type.strip().lower() != "claude": - lines.append("(v2.0 已精简为仅 Claude 格式,忽略其他格式参数)") - lines.append("") - lines.append(pretty) - await self.send_text("\n".join(lines)) - - return (True, None, True) - - async def _handle_search(self, query: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.7.0: 处理工具搜索命令""" - if not query or not query.strip(): - # 显示使用帮助 - help_text = """🔍 工具搜索 - -用法: /mcp search <关键词> - -示例: - /mcp search time 搜索包含 time 的工具 - /mcp search fetch 搜索包含 fetch 的工具 - /mcp search * 列出所有工具 - -支持模糊匹配工具名称和描述""" - await self.send_text(help_text) - return (True, None, True) - - query = query.strip().lower() - tools = mcp_manager.all_tools - - if not tools: - await self.send_text("🔍 当前没有可用的 MCP 工具") - return (True, None, True) - - # 搜索匹配的工具 - matched = [] - for tool_key, (tool_info, client) in tools.items(): - tool_name = tool_key.lower() - tool_desc = (tool_info.description or "").lower() - - # * 表示列出所有 - if query == "*": - matched.append((tool_key, tool_info, client)) - elif query in tool_name or query in tool_desc: - matched.append((tool_key, tool_info, client)) - - if not matched: - await self.send_text(f"🔍 未找到匹配 '{query}' 的工具") - return (True, None, True) - - # 按服务器分组显示 - by_server: Dict[str, List[Tuple[str, Any]]] = {} - for tool_key, tool_info, _client in matched: - server_name = tool_info.server_name - if server_name not in by_server: - by_server[server_name] = [] - by_server[server_name].append((tool_key, tool_info)) - - # 如果只有一个服务器或结果较少,显示全部;否则折叠 - single_server = len(by_server) == 1 - lines = [f"🔍 搜索结果: {len(matched)} 个工具匹配 '{query}'"] - - for srv_name, tool_list in by_server.items(): - lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个):") - - # 单服务器或结果少于 15 个时显示全部 - show_all = single_server or len(matched) <= 15 - display_limit = len(tool_list) if show_all else 5 - - for tool_key, tool_info in tool_list[:display_limit]: - desc = tool_info.description[:40] + "..." if len(tool_info.description) > 40 else tool_info.description - lines.append(f" • {tool_key}") - lines.append(f" {desc}") - if len(tool_list) > display_limit: - lines.append(f" ... 还有 {len(tool_list) - display_limit} 个,用 /mcp search {query} {srv_name} 筛选") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_chain(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.8.0: 处理工具链命令""" - if not arg or not arg.strip(): - # 显示工具链列表和帮助 - chains = tool_chain_manager.get_all_chains() - - lines = ["🔗 工具链管理"] - lines.append("") - - if chains: - lines.append(f"已配置 {len(chains)} 个工具链:") - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - steps_count = len(chain.steps) - lines.append(f" {status} {name} ({steps_count} 步)") - else: - lines.append("当前没有配置工具链") - - lines.append("") - lines.append("命令:") - lines.append(" /mcp chain list 查看所有工具链") - lines.append(" /mcp chain <名称> 查看工具链详情") - lines.append(" /mcp chain test <名称> <参数JSON> 测试执行") - lines.append(" /mcp chain reload 重新加载配置") - lines.append("") - lines.append("💡 在 WebUI「工具链」配置区编辑工具链") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - parts = arg.strip().split(maxsplit=2) - sub_action = parts[0].lower() - - if sub_action == "list": - # 列出所有工具链 - chains = tool_chain_manager.get_all_chains() - if not chains: - await self.send_text("🔗 当前没有配置工具链") - return (True, None, True) - - lines = [f"🔗 工具链列表 ({len(chains)} 个)"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f"\n{status} {name}") - lines.append(f" {chain.description[:60]}...") - lines.append(f" 步骤: {' → '.join([s.tool_name.split('_')[-1] for s in chain.steps[:4]])}") - if chain.input_params: - lines.append(f" 参数: {', '.join(chain.input_params.keys())}") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - elif sub_action == "reload": - # 重新加载工具链配置 - global _plugin_instance - if _plugin_instance: - _plugin_instance._load_tool_chains() - chains = tool_chain_manager.get_all_chains() - from src.plugin_system.core.component_registry import component_registry - - registered = 0 - for name, _chain in tool_chain_manager.get_enabled_chains().items(): - tool_name = f"chain_{name}".replace("-", "_").replace(".", "_") - if component_registry.get_component_info(tool_name, ComponentType.TOOL): - registered += 1 - lines = ["✅ 已重新加载工具链配置"] - lines.append(f"📋 配置数: {len(chains)} 个") - lines.append(f"🔧 已注册: {registered} 个(可被 LLM 调用)") - if chains: - lines.append("") - lines.append("工具链列表:") - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f" {status} chain_{name}") - await self.send_text("\n".join(lines)) - else: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - elif sub_action == "test" and len(parts) >= 2: - # 测试执行工具链 - chain_name = parts[1] - args_json = parts[2] if len(parts) > 2 else "{}" - - chain = tool_chain_manager.get_chain(chain_name) - if not chain: - await self.send_text(f"❌ 工具链 '{chain_name}' 不存在") - return (True, None, True) - - try: - input_args = json.loads(args_json) - except json.JSONDecodeError: - await self.send_text("❌ 参数 JSON 格式错误") - return (True, None, True) - - await self.send_text(f"🔄 正在执行工具链 {chain_name}...") - - result = await tool_chain_manager.execute_chain(chain_name, input_args) - - lines = [] - if result.success: - lines.append(f"✅ 工具链执行成功 ({result.total_duration_ms:.0f}ms)") - lines.append("") - lines.append("执行详情:") - lines.append(result.to_summary()) - lines.append("") - lines.append("最终输出:") - output_preview = result.final_output[:500] - if len(result.final_output) > 500: - output_preview += "..." - lines.append(output_preview) - else: - lines.append("❌ 工具链执行失败") - lines.append(f"错误: {result.error}") - if result.step_results: - lines.append("") - lines.append("执行详情:") - lines.append(result.to_summary()) - - await self.send_text("\n".join(lines)) - return (True, None, True) - - else: - # 查看特定工具链详情 - chain_name = sub_action - chain = tool_chain_manager.get_chain(chain_name) - - if not chain: - # 尝试模糊匹配 - all_chains = tool_chain_manager.get_all_chains() - similar = [n for n in all_chains.keys() if chain_name.lower() in n.lower()] - msg = f"❌ 工具链 '{chain_name}' 不存在" - if similar: - msg += f"\n💡 你是不是想找: {', '.join(similar[:3])}" - await self.send_text(msg) - return (True, None, True) - - lines = [f"🔗 工具链: {chain.name}"] - lines.append(f"状态: {'✅ 启用' if chain.enabled else '❌ 禁用'}") - lines.append(f"描述: {chain.description}") - lines.append("") - - if chain.input_params: - lines.append("📥 输入参数:") - for param, desc in chain.input_params.items(): - lines.append(f" • {param}: {desc}") - lines.append("") - - lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):") - for i, step in enumerate(chain.steps): - optional_tag = " (可选)" if step.optional else "" - lines.append(f" {i + 1}. {step.tool_name}{optional_tag}") - if step.description: - lines.append(f" {step.description}") - if step.output_key: - lines.append(f" 输出键: {step.output_key}") - if step.args_template: - args_preview = json.dumps(step.args_template, ensure_ascii=False)[:60] - lines.append(f" 参数: {args_preview}...") - - lines.append("") - lines.append(f"💡 测试: /mcp chain test {chain.name} " + '{"参数": "值"}') - - await self.send_text("\n".join(lines)) - return (True, None, True) - - def _format_output(self, subcommand: str, server_name: str = None) -> str: - """格式化输出""" - status = mcp_manager.get_status() - stats = mcp_manager.get_all_stats() - lines = [] - - if subcommand in ("status", "all"): - lines.append("📊 MCP 桥接插件状态") - lines.append(f"├ 服务器: {status['connected_servers']}/{status['total_servers']} 已连接") - lines.append(f"├ 工具数: {status['total_tools']}") - lines.append(f"└ 心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") - - if status["servers"]: - lines.append("\n🔌 服务器列表:") - for name, info in status["servers"].items(): - if server_name and name != server_name: - continue - icon = "✅" if info["connected"] else "❌" - enabled = "" if info["enabled"] else " (禁用)" - lines.append(f" {icon} {name}{enabled}") - lines.append(f" {info['transport']} | {info['tools_count']} 工具") - # 显示断路器状态 - cb = info.get("circuit_breaker", {}) - cb_state = cb.get("state", "closed") - if cb_state == "open": - lines.append(" ⚡ 断路器熔断中") - elif cb_state == "half_open": - lines.append(" ⚡ 断路器试探中") - if info["consecutive_failures"] > 0: - lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次") - - if subcommand in ("tools", "all"): - tools = mcp_manager.all_tools - if tools: - lines.append("\n🔧 可用工具:") - by_server = {} - for _key, (info, _) in tools.items(): - if server_name and info.server_name != server_name: - continue - by_server.setdefault(info.server_name, []).append(info.name) - - # 如果指定了服务器名,显示全部工具;否则折叠显示 - show_all = server_name is not None - - for srv, tool_list in by_server.items(): - lines.append(f" 📦 {srv} ({len(tool_list)})") - if show_all: - # 指定服务器时显示全部 - for t in tool_list: - lines.append(f" • {t}") - else: - # 未指定时折叠显示 - for t in tool_list[:5]: - lines.append(f" • {t}") - if len(tool_list) > 5: - lines.append(f" ... 还有 {len(tool_list) - 5} 个,用 /mcp tools {srv} 查看全部") - - if subcommand in ("stats", "all"): - g = stats["global"] - lines.append("\n📈 调用统计:") - lines.append(f" 总调用: {g['total_tool_calls']}") - if g["total_tool_calls"] > 0: - rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 - lines.append(f" 成功率: {rate:.1f}%") - lines.append(f" 运行: {g['uptime_seconds']:.0f}秒") - - if not lines: - lines.append("📖 MCP 桥接插件命令帮助") - lines.append("") - lines.append("状态查询:") - lines.append(" /mcp 查看连接状态") - lines.append(" /mcp tools 查看所有工具") - lines.append(" /mcp tools <服务器> 查看指定服务器工具") - lines.append(" /mcp stats 查看调用统计") - lines.append("") - lines.append("工具搜索:") - lines.append(" /mcp search <关键词> 搜索工具") - lines.append(" /mcp search * 列出所有工具") - lines.append("") - lines.append("服务器管理:") - lines.append(" /mcp reconnect 重连断开的服务器") - lines.append(" /mcp reconnect <名称> 重连指定服务器") - lines.append("") - lines.append("服务器配置(Claude):") - lines.append(" /mcp import 合并 Claude mcpServers 配置") - lines.append(" /mcp export 导出当前 mcpServers 配置") - lines.append("") - lines.append("工具链:") - lines.append(" /mcp chain 查看工具链列表") - lines.append(" /mcp chain <名称> 查看工具链详情") - lines.append(" /mcp chain test <名称> <参数> 测试执行") - lines.append("") - lines.append("其他:") - lines.append(" /mcp trace 查看调用追踪") - lines.append(" /mcp cache 查看缓存状态") - lines.append(" /mcp perm 查看权限配置") - - return "\n".join(lines) - - -class MCPImportCommand(BaseCommand): - """v1.6.0: MCP 配置导入命令 - 支持从 Claude Desktop 格式导入""" - - command_name = "mcp_import_command" - command_description = "从 Claude Desktop 或其他格式导入 MCP 服务器配置" - # 匹配 /mcp import 后面的所有内容(包括多行 JSON) - command_pattern = r"^[//]mcp\s+import(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """执行导入命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - content = self.matched_groups.get("content", "") - - if not content or not content.strip(): - # 显示使用帮助 - help_text = """📥 MCP 配置导入 - -用法: /mcp import - -支持的格式: -• Claude Desktop 格式 (mcpServers 对象) -• 兼容旧版:MaiBot servers 列表数组(将自动迁移为 mcpServers) - -示例: -/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]}}} - -/mcp import {"mcpServers":{"api":{"url":"https://example.com/mcp","transport":"sse"}}}""" - await self.send_text(help_text) - return (True, None, True) - - raw_text = content.strip() - - # 解析输入:支持 Claude mcpServers 或旧版 servers 列表数组 - try: - data = json.loads(raw_text) - except json.JSONDecodeError as e: - await self.send_text(f"❌ JSON 解析失败: {e}") - return (True, None, True) - - if isinstance(data, list): - migrated = legacy_servers_list_to_claude_config(raw_text) - if not migrated: - await self.send_text("❌ 旧版 servers 列表解析失败,无法迁移") - return (True, None, True) - data = json.loads(migrated) - - if not isinstance(data, dict): - await self.send_text("❌ 配置必须是 JSON 对象(包含 mcpServers)") - return (True, None, True) - - incoming_mapping = data.get("mcpServers", data) - if not isinstance(incoming_mapping, dict): - await self.send_text("❌ mcpServers 必须是 JSON 对象") - return (True, None, True) - - # 校验输入配置 - try: - parse_claude_mcp_config(json.dumps({"mcpServers": incoming_mapping}, ensure_ascii=False)) - except ClaudeConfigError as e: - await self.send_text(f"❌ 配置校验失败: {e}") - return (True, None, True) - - servers_section = _plugin_instance.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - existing_json = str(servers_section.get("claude_config_json", "") or "") - if not existing_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - existing_json = legacy_servers_list_to_claude_config(legacy_list) or "" - - existing_mapping: Dict[str, Any] = {} - if existing_json.strip(): - try: - parsed = json.loads(existing_json) - mapping = parsed.get("mcpServers", parsed) - if isinstance(mapping, dict): - existing_mapping = mapping - except Exception: - existing_mapping = {} - - added: List[str] = [] - skipped: List[str] = [] - - for name, conf in incoming_mapping.items(): - if name in existing_mapping: - skipped.append(str(name)) - continue - existing_mapping[str(name)] = conf - added.append(str(name)) - - if "servers" not in _plugin_instance.config: - _plugin_instance.config["servers"] = {} - - _plugin_instance.config["servers"]["claude_config_json"] = json.dumps( - {"mcpServers": existing_mapping}, ensure_ascii=False, indent=2 - ) - - # 持久化到配置文件(使用插件基类的写入逻辑) - try: - config_path = Path(_plugin_instance.plugin_dir) / _plugin_instance.config_file_name - _plugin_instance._save_config_to_file(_plugin_instance.config, str(config_path)) - except Exception as e: - logger.warning(f"保存配置文件失败: {e}") - - lines = [] - if added: - lines.append(f"✅ 成功导入 {len(added)} 个服务器:") - for n in added[:20]: - lines.append(f" • {n}") - if len(added) > 20: - lines.append(f" ... 还有 {len(added) - 20} 个") - else: - lines.append("⚠️ 没有新服务器可导入") - - if skipped: - lines.append(f"\n⏭️ 跳过 {len(skipped)} 个已存在的服务器") - - lines.append("\n💡 发送 /mcp reconnect 使配置生效") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - -# ============================================================================ -# 事件处理器 -# ============================================================================ - - -class MCPStartupHandler(BaseEventHandler): - """MCP 启动事件处理器""" - - event_type = EventType.ON_START - handler_name = "mcp_startup_handler" - handler_description = "MCP 桥接插件启动处理器" - weight = 0 - intercept_message = False - - async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: - """处理启动事件""" - global _plugin_instance - - if _plugin_instance is None: - logger.warning("MCP 桥接插件实例未初始化") - return (False, True, None, None, None) - - logger.info("MCP 桥接插件收到 ON_START 事件,开始连接 MCP 服务器...") - await _plugin_instance._async_connect_servers() - - await mcp_manager.start_heartbeat() - - return (True, True, None, None, None) - - -class MCPStopHandler(BaseEventHandler): - """MCP 停止事件处理器""" - - event_type = EventType.ON_STOP - handler_name = "mcp_stop_handler" - handler_description = "MCP 桥接插件停止处理器" - weight = 0 - intercept_message = False - - async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: - """处理停止事件""" - global _plugin_instance - - logger.info("MCP 桥接插件收到 ON_STOP 事件,正在关闭...") - - if _plugin_instance is not None: - await _plugin_instance._stop_status_refresher() - - await mcp_manager.shutdown() - mcp_tool_registry.clear() - - logger.info("MCP 桥接插件已关闭所有连接") - return (True, True, None, None, None) - - -# ============================================================================ -# 主插件类 -# ============================================================================ - - -@register_plugin -class MCPBridgePlugin(BasePlugin): - """MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot""" - - plugin_name: str = "mcp_bridge_plugin" - enable_plugin: bool = False # 默认禁用,用户需在 WebUI 手动启用 - dependencies: List[str] = [] - python_dependencies: List[str] = ["mcp"] - config_file_name: str = "config.toml" - - config_section_descriptions = { - "guide": section_meta("📖 快速入门", order=1), - "plugin": section_meta("🔘 插件开关", order=2), - "servers": section_meta("🔌 MCP Servers(Claude)", order=3), - "tool_chains": section_meta("🔗 Workflow(硬流程/工具链)", order=4), - "react": section_meta("🔄 ReAct(软流程)", collapsed=True, order=5), - "status": section_meta("📊 运行状态", order=10), - "tools": section_meta("🔧 工具管理", collapsed=True, order=20), - "permissions": section_meta("🔐 权限控制", collapsed=True, order=21), - "settings": section_meta("⚙️ 高级设置", collapsed=True, order=30), - } - - config_schema: dict = { - # 新手引导区(只读) - "guide": { - "quick_start": ConfigField( - type=str, - default="1. 获取 MCP 服务器 2. 在「MCP Servers(Claude)」粘贴 mcpServers 配置 3. 保存后发送 /mcp reconnect 4. (可选)在「Workflow/ ReAct」配置流程", - description="三步开始使用", - label="🚀 快速入门", - disabled=True, - order=1, - ), - "mcp_sources": ConfigField( - type=str, - default="https://modelscope.cn/mcp (魔搭·推荐) | https://smithery.ai | https://glama.ai | https://mcp.so", - description="复制链接到浏览器打开,获取免费 MCP 服务器", - label="🌐 获取 MCP 服务器", - disabled=True, - hint="魔搭 ModelScope 国内免费推荐,将 mcpServers 配置粘贴到「MCP Servers(Claude)」即可", - order=2, - ), - "example_config": ConfigField( - type=str, - default='{"mcpServers":{"time":{"url":"https://mcp.api-inference.modelscope.cn/server/mcp-server-time"}}}', - description="复制到 MCP Servers(Claude)可直接使用(免费时间服务器)", - label="📝 配置示例", - disabled=True, - order=3, - ), - }, - "plugin": { - "enabled": ConfigField( - type=bool, - default=False, - description="是否启用插件(默认关闭)", - label="启用插件", - ), - }, - "settings": { - "tool_prefix": ConfigField( - type=str, - default="mcp", - description="🏷️ 工具前缀 - 生成的工具名格式: {前缀}_{服务器名}_{工具名}", - label="🏷️ 工具前缀", - placeholder="mcp", - order=1, - ), - "connect_timeout": ConfigField( - type=float, - default=30.0, - description="⏱️ 连接超时(秒)", - label="⏱️ 连接超时(秒)", - min=5.0, - max=120.0, - step=5.0, - order=2, - ), - "call_timeout": ConfigField( - type=float, - default=60.0, - description="⏱️ 调用超时(秒)", - label="⏱️ 调用超时(秒)", - min=10.0, - max=300.0, - step=10.0, - order=3, - ), - "auto_connect": ConfigField( - type=bool, - default=True, - description="🔄 启动时自动连接所有已启用的服务器", - label="🔄 自动连接", - order=4, - ), - "retry_attempts": ConfigField( - type=int, - default=3, - description="🔁 连接失败时的重试次数", - label="🔁 重试次数", - min=0, - max=10, - order=5, - ), - "retry_interval": ConfigField( - type=float, - default=5.0, - description="⏳ 重试间隔(秒)", - label="⏳ 重试间隔(秒)", - min=1.0, - max=60.0, - step=1.0, - order=6, - ), - "heartbeat_enabled": ConfigField( - type=bool, - default=True, - description="💓 定期检测服务器连接状态", - label="💓 启用心跳检测", - order=7, - ), - "heartbeat_interval": ConfigField( - type=float, - default=60.0, - description="💓 基准心跳间隔(秒)", - label="💓 心跳间隔(秒)", - min=10.0, - max=300.0, - step=10.0, - hint="智能心跳会根据服务器稳定性自动调整", - order=8, - ), - "heartbeat_adaptive": ConfigField( - type=bool, - default=True, - description="🧠 根据服务器稳定性自动调整心跳间隔", - label="🧠 智能心跳", - hint="稳定服务器逐渐增加间隔,断开的服务器缩短间隔", - order=9, - ), - "heartbeat_max_multiplier": ConfigField( - type=float, - default=3.0, - description="稳定服务器的最大间隔倍数", - label="📈 最大间隔倍数", - min=1.5, - max=5.0, - step=0.5, - hint="稳定服务器心跳间隔最高可达 基准间隔 × 此值", - order=10, - ), - "auto_reconnect": ConfigField( - type=bool, - default=True, - description="🔄 检测到断开时自动尝试重连", - label="🔄 自动重连", - order=11, - ), - "max_reconnect_attempts": ConfigField( - type=int, - default=3, - description="🔄 连续重连失败后暂停重连", - label="🔄 最大重连次数", - min=1, - max=10, - order=12, - ), - # v1.7.0: 状态刷新配置 - "status_refresh_enabled": ConfigField( - type=bool, - default=True, - description="📊 定期更新 WebUI 状态显示", - label="📊 启用状态实时刷新", - hint="关闭后 WebUI 状态仅在启动时更新", - order=13, - ), - "status_refresh_interval": ConfigField( - type=float, - default=10.0, - description="📊 状态刷新间隔(秒)", - label="📊 状态刷新间隔(秒)", - min=5.0, - max=60.0, - step=5.0, - hint="值越小刷新越频繁,但会增加少量 CPU 消耗", - order=14, - ), - "enable_resources": ConfigField( - type=bool, - default=False, - description="📦 允许读取 MCP 服务器提供的资源", - label="📦 启用 Resources(实验性)", - order=11, - ), - "enable_prompts": ConfigField( - type=bool, - default=False, - description="📝 允许使用 MCP 服务器提供的提示模板", - label="📝 启用 Prompts(实验性)", - order=12, - ), - # v1.3.0 后处理配置 - "post_process_enabled": ConfigField( - type=bool, - default=False, - description="🔄 使用 LLM 对长结果进行摘要提炼", - label="🔄 启用结果后处理", - order=20, - ), - "post_process_threshold": ConfigField( - type=int, - default=500, - description="📏 结果长度超过此值才触发后处理", - label="📏 后处理阈值(字符)", - min=100, - max=5000, - step=100, - order=21, - ), - "post_process_max_tokens": ConfigField( - type=int, - default=500, - description="📝 LLM 摘要输出的最大 token 数", - label="📝 后处理最大输出 token", - min=100, - max=2000, - step=50, - order=22, - ), - "post_process_model": ConfigField( - type=str, - default="", - description="🤖 指定用于后处理的模型名称", - label="🤖 后处理模型(可选)", - placeholder="留空则使用 Utils 模型组", - order=23, - ), - "post_process_prompt": ConfigField( - type=str, - default="用户问题:{query}\\n\\n工具返回内容:\\n{result}\\n\\n请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:", - description="📋 后处理提示词模板", - label="📋 后处理提示词模板", - input_type="textarea", - rows=8, - order=24, - ), - # v1.4.0 追踪配置 - "trace_enabled": ConfigField( - type=bool, - default=True, - description="🔍 记录工具调用详情", - label="🔍 启用调用追踪", - order=30, - ), - "trace_max_records": ConfigField( - type=int, - default=100, - description="内存中保留的最大记录数", - label="📊 追踪记录上限", - min=10, - max=1000, - order=31, - ), - "trace_log_enabled": ConfigField( - type=bool, - default=False, - description="是否将追踪记录写入日志文件", - label="📝 追踪日志文件", - hint="启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl", - order=32, - ), - # v1.4.0 缓存配置 - "cache_enabled": ConfigField( - type=bool, - default=False, - description="🗄️ 缓存相同参数的调用结果", - label="🗄️ 启用调用缓存", - hint="相同参数的调用会返回缓存结果,减少重复请求", - order=40, - ), - "cache_ttl": ConfigField( - type=int, - default=300, - description="缓存有效期(秒)", - label="⏱️ 缓存有效期(秒)", - min=60, - max=3600, - order=41, - ), - "cache_max_entries": ConfigField( - type=int, - default=200, - description="最大缓存条目数(超出后 LRU 淘汰)", - label="📦 最大缓存条目", - min=50, - max=1000, - order=42, - ), - "cache_exclude_tools": ConfigField( - type=str, - default="", - description="不缓存的工具(每行一个,支持通配符 *)", - label="🚫 缓存排除列表", - input_type="textarea", - rows=4, - hint="时间类、随机类工具建议排除,如 mcp_time_*", - order=43, - ), - }, - # v1.4.0 工具管理 - "tools": { - "tool_list": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册的 MCP 工具列表(只读)", - label="📋 工具清单", - input_type="textarea", - disabled=True, - rows=12, - hint="从此处复制工具名到下方禁用列表或工具链配置", - order=1, - ), - "disabled_tools": ConfigField( - type=str, - default="", - description="要禁用的工具名(每行一个)", - label="🚫 禁用工具列表", - input_type="textarea", - rows=6, - hint="从上方工具清单复制工具名,每行一个。禁用后该工具不会被 LLM 调用", - order=2, - ), - }, - # v1.8.0 工具链配置 - "tool_chains": { - "chains_enabled": ConfigField( - type=bool, - default=True, - description="🔗 启用工具链功能", - label="🔗 启用工具链", - hint="工具链可将多个工具按顺序执行,后续工具可使用前序工具的输出", - order=1, - ), - # 工具链使用指南 - "chains_guide": ConfigField( - type=str, - default="""工具链将多个 MCP 工具串联执行,后续步骤可使用前序步骤的输出 - -📌 变量语法: - ${input.参数名} - 用户输入的参数 - ${step.输出键} - 某步骤的输出(需设置 output_key) - ${prev} - 上一步的输出 - ${prev.字段} - 上一步输出(JSON)的某字段 - ${step.输出键.0.字段} / ${step.输出键[0].字段} - 访问数组下标 - ${step.输出键['return'][0]['location']} - 支持 bracket 写法 - -📌 测试命令: - /mcp chain list - 查看所有工具链 - /mcp chain 链名 {"参数":"值"} - 测试执行""", - description="工具链使用说明", - label="📖 使用指南", - input_type="textarea", - disabled=True, - rows=10, - order=2, - ), - # 快速添加工具链(表单式) - "quick_chain_name": ConfigField( - type=str, - default="", - description="工具链名称(英文,如 search_and_summarize)", - label="➕ 快速添加 - 名称", - placeholder="my_tool_chain", - hint="必填,将作为 LLM 可调用的工具名", - order=10, - ), - "quick_chain_desc": ConfigField( - type=str, - default="", - description="工具链描述(供 LLM 理解何时使用)", - label="➕ 快速添加 - 描述", - placeholder="先搜索内容,再获取详情并总结", - hint="必填,清晰描述工具链的用途", - order=11, - ), - "quick_chain_params": ConfigField( - type=str, - default="", - description="输入参数(每行一个,格式: 参数名=描述)", - label="➕ 快速添加 - 输入参数", - input_type="textarea", - rows=3, - placeholder="query=搜索关键词\nmax_results=最大结果数", - hint="定义用户需要提供的参数", - order=12, - ), - "quick_chain_steps": ConfigField( - type=str, - default="", - description="执行步骤(每行一个,格式: 工具名|参数JSON|输出键)", - label="➕ 快速添加 - 执行步骤", - input_type="textarea", - rows=5, - placeholder='mcp_server_search|{"keyword":"${input.query}"}|search_result\nmcp_server_detail|{"id":"${prev}"}|\n# 访问数组示例:\n# mcp_geo|{"q":"${input.query}"}|geo\n# mcp_next|{"location":"${step.geo.return.0.location}"}|', - hint="格式: 工具名|参数模板|输出键(输出键可选,用于后续步骤引用 ${step.xxx})", - order=13, - ), - "quick_chain_add": ConfigField( - type=str, - default="", - description="填写上方信息后,在此输入 ADD 并保存即可添加", - label="➕ 确认添加", - placeholder="输入 ADD 并保存", - hint="添加后会自动合并到下方工具链列表", - order=14, - ), - # 工具链模板 - "chains_templates": ConfigField( - type=str, - default="""📋 常用工具链模板(复制到下方列表使用): - -1️⃣ 搜索+详情模板: -{ - "name": "search_and_detail", - "description": "搜索内容并获取详情", - "input_params": {"query": "搜索关键词"}, - "steps": [ - {"tool_name": "搜索工具名", "args_template": {"keyword": "${input.query}"}, "output_key": "results"}, - {"tool_name": "详情工具名", "args_template": {"id": "${prev}"}} - ] -} - -2️⃣ 获取+处理模板: -{ - "name": "fetch_and_process", - "description": "获取数据并处理", - "input_params": {"url": "目标URL"}, - "steps": [ - {"tool_name": "获取工具名", "args_template": {"url": "${input.url}"}, "output_key": "data"}, - {"tool_name": "处理工具名", "args_template": {"content": "${step.data}"}} - ] -} - -3️⃣ 多步骤可选模板: -{ - "name": "multi_step_chain", - "description": "多步骤处理,部分可选", - "input_params": {"input": "输入内容"}, - "steps": [ - {"tool_name": "步骤1工具", "args_template": {"data": "${input.input}"}, "output_key": "step1"}, - {"tool_name": "步骤2工具", "args_template": {"data": "${prev}"}, "output_key": "step2", "optional": true}, - {"tool_name": "步骤3工具", "args_template": {"data": "${step.step1}"}} - ] -}""", - description="工具链配置模板参考", - label="📝 配置模板", - input_type="textarea", - disabled=True, - rows=15, - order=20, - ), - "chains_list": ConfigField( - type=str, - default="[]", - description="工具链配置(JSON 数组格式)", - label="📋 工具链列表", - input_type="textarea", - rows=20, - placeholder="""[ - { - "name": "search_and_detail", - "description": "先搜索再获取详情", - "input_params": {"query": "搜索关键词"}, - "steps": [ - {"tool_name": "mcp_server_search", "args_template": {"keyword": "${input.query}"}, "output_key": "search_result"}, - {"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}} - ] - } -]""", - hint="每个工具链包含 name、description、input_params、steps", - order=30, - ), - "chains_status": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册的工具链状态(只读)", - label="📊 工具链状态", - input_type="textarea", - disabled=True, - rows=8, - order=40, - ), - }, - # v1.9.0 ReAct 软流程配置 - "react": { - "react_enabled": ConfigField( - type=bool, - default=False, - description="🔄 将 MCP 工具注册到记忆检索 ReAct 系统", - label="🔄 启用 ReAct 集成", - hint="启用后,MaiBot 的 ReAct Agent 可在记忆检索时调用 MCP 工具", - order=1, - ), - "react_guide": ConfigField( - type=str, - default="""ReAct 软流程说明: - -📌 什么是 ReAct? -ReAct (Reasoning + Acting) 是 LLM 自主决策的多轮工具调用模式。 -与 Workflow 硬流程不同,ReAct 由 LLM 动态决定调用哪些工具。 - -📌 工作原理: -1. 用户提问 → LLM 分析需要什么信息 -2. LLM 选择调用工具 → 获取结果 -3. LLM 观察结果 → 决定是否需要更多信息 -4. 重复 2-3 直到信息足够 → 生成最终回答 - -📌 与 Workflow 的区别: -- ReAct (软流程): LLM 自主决策,灵活但不可预测 -- Workflow (硬流程): 用户预定义,固定流程,可靠可控 - -📌 使用场景: -- 复杂问题需要多步推理 -- 不确定需要调用哪些工具 -- 需要根据中间结果动态调整""", - description="ReAct 软流程使用说明", - label="📖 使用指南", - input_type="textarea", - disabled=True, - rows=15, - order=2, - ), - "filter_mode": ConfigField( - type=str, - default="whitelist", - description="过滤模式", - label="📋 过滤模式", - choices=["whitelist", "blacklist"], - hint="whitelist: 只注册列出的工具;blacklist: 排除列出的工具", - order=3, - ), - "tool_filter": ConfigField( - type=str, - default="", - description="工具过滤列表(每行一个,支持通配符 * 和精确匹配)", - label="🔍 工具过滤列表", - input_type="textarea", - rows=6, - placeholder="""# 精确匹配示例: -mcp_bing_web_search_bing_search -mcp_mcmod_search_mod - -# 通配符示例: -mcp_*_search_* -mcp_bing_*""", - hint="白名单模式: 只注册列出的工具;黑名单模式: 排除列出的工具。支持 # 注释", - order=4, - ), - "react_status": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册到 ReAct 的工具状态(只读)", - label="📊 ReAct 工具状态", - input_type="textarea", - disabled=True, - rows=6, - order=10, - ), - }, - # v1.4.0 权限控制 - "permissions": { - "perm_enabled": ConfigField( - type=bool, - default=False, - description="🔐 按群/用户限制工具使用", - label="🔐 启用权限控制", - order=1, - ), - "perm_default_mode": ConfigField( - type=str, - default="allow_all", - description="默认模式:allow_all(默认允许)或 deny_all(默认禁止)", - label="📋 默认模式", - placeholder="allow_all", - hint="allow_all: 未配置的默认允许;deny_all: 未配置的默认禁止", - order=2, - ), - # 快捷配置(简化版) - "quick_deny_groups": ConfigField( - type=str, - default="", - description="禁止使用所有 MCP 工具的群号(每行一个)", - label="🚫 禁用群列表(快捷)", - input_type="textarea", - rows=4, - hint="填入群号,该群将无法使用任何 MCP 工具", - order=3, - ), - "quick_allow_users": ConfigField( - type=str, - default="", - description="始终允许使用所有工具的用户 QQ 号(管理员白名单,每行一个)", - label="✅ 管理员白名单(快捷)", - input_type="textarea", - rows=3, - hint="填入 QQ 号,该用户在任何场景都可使用 MCP 工具", - order=4, - ), - # 高级配置 - "perm_rules": ConfigField( - type=str, - default="[]", - description="高级权限规则(JSON 格式,可针对特定工具配置)", - label="📜 高级权限规则(可选)", - input_type="textarea", - rows=10, - placeholder="""[ - {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -]""", - hint="格式: qq:ID:group/private/user,工具名支持通配符 *", - order=10, - ), - }, - # v2.0: 服务器配置统一为 Claude Desktop mcpServers 规范(JSON) - "servers": { - "claude_config_json": ConfigField( - type=str, - default='{"mcpServers":{}}', - description="Claude Desktop 规范的 MCP 配置(JSON)", - label="🔌 MCP Servers(Claude 规范)", - input_type="textarea", - rows=18, - hint="仅支持 Claude Desktop 的 mcpServers JSON。每个服务器需包含 command(stdio) 或 url(remote)。", - order=1, - ), - "claude_config_guide": ConfigField( - type=str, - default="""示例: -{ - "mcpServers": { - "fetch": { "command": "uvx", "args": ["mcp-server-fetch"] }, - "time": { "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" } - } -} - -可选字段: -- enabled: true/false -- headers: {"Authorization":"Bearer ..."} -- env: {"KEY":"VALUE"} -- transport/type: "streamable_http" | "http" | "sse"(remote 可选,默认 streamable_http) -""", - description="配置说明(只读)", - label="📖 配置说明", - input_type="textarea", - disabled=True, - rows=12, - order=2, - ), - }, - "status": { - "connection_status": ConfigField( - type=str, - default="未初始化", - description="当前 MCP 服务器连接状态和工具列表", - label="📊 连接状态", - input_type="textarea", - disabled=True, - rows=15, - hint="此状态仅在插件启动时更新。查询实时状态请发送 /mcp 命令", - order=1, - ), - }, - } - - @staticmethod - def _fix_config_multiline_strings(config_path: Path) -> bool: - """修复配置文件中的多行字符串格式问题 - - 处理两种情况: - 1. 带转义 \\n 的单行字符串(json.dumps 生成) - 2. 跨越多行但使用普通双引号的字符串(控制字符错误) - - Returns: - bool: 是否进行了修复 - """ - if not config_path.exists(): - return False - - try: - content = config_path.read_text(encoding="utf-8") - - # 情况1: 修复带转义 \n 的单行字符串 - # 匹配: key = "内容包含\n的字符串" - pattern1 = r'^(\s*\w+\s*=\s*)"((?:[^"\\]|\\.)*\\n(?:[^"\\]|\\.)*)"(\s*)$' - - # 情况2: 修复跨越多行的普通双引号字符串 - # 匹配: key = "第一行 - # 第二行 - # 第三行" - pattern2_start = r'^(\s*\w+\s*=\s*)"([^"]*?)$' # 开始行 - pattern2_end = r'^([^"]*)"(\s*)$' # 结束行 - - lines = content.split("\n") - fixed_lines = [] - modified = False - - i = 0 - while i < len(lines): - line = lines[i] - - # 情况1: 单行带转义换行符 - match1 = re.match(pattern1, line) - if match1: - prefix = match1.group(1) - value = match1.group(2) - suffix = match1.group(3) - # 将转义的换行符还原为实际换行符 - unescaped = ( - value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") - ) - fixed_line = f'{prefix}"""{unescaped}"""{suffix}' - fixed_lines.append(fixed_line) - modified = True - i += 1 - continue - - # 情况2: 跨越多行的字符串 - match2_start = re.match(pattern2_start, line) - if match2_start: - prefix = match2_start.group(1) - first_part = match2_start.group(2) - - # 收集后续行直到找到结束引号 - multiline_parts = [first_part] - j = i + 1 - found_end = False - - while j < len(lines): - next_line = lines[j] - match2_end = re.match(pattern2_end, next_line) - if match2_end: - multiline_parts.append(match2_end.group(1)) - suffix = match2_end.group(2) - found_end = True - j += 1 - break - else: - multiline_parts.append(next_line) - j += 1 - - if found_end and len(multiline_parts) > 1: - # 合并为三引号字符串 - full_value = "\n".join(multiline_parts) - fixed_line = f'{prefix}"""{full_value}"""{suffix}' - fixed_lines.append(fixed_line) - modified = True - i = j - continue - - fixed_lines.append(line) - i += 1 - - if modified: - config_path.write_text("\n".join(fixed_lines), encoding="utf-8") - logger.info("已自动修复配置文件中的多行字符串格式") - return True - - return False - except Exception as e: - logger.warning(f"修复配置文件格式失败: {e}") - return False - - def __init__(self, *args, **kwargs): - global _plugin_instance - - # 在父类初始化前尝试修复配置文件格式 - config_path = Path(__file__).parent / "config.toml" - self._fix_config_multiline_strings(config_path) - - super().__init__(*args, **kwargs) - self._initialized = False - self._status_refresh_running = False - self._status_refresh_task: Optional[asyncio.Task] = None - self._last_persisted_display_hash: str = "" - self._last_servers_config_error: str = "" - _plugin_instance = self - - # 配置 MCP 管理器 - settings = self.config.get("settings", {}) - mcp_manager.configure(settings) - - # v1.4.0: 配置追踪器 - trace_log_path = Path(__file__).parent / "logs" / "trace.jsonl" - tool_call_tracer.configure( - enabled=settings.get("trace_enabled", True), - max_records=settings.get("trace_max_records", 100), - log_enabled=settings.get("trace_log_enabled", False), - log_path=trace_log_path, - ) - - # v1.4.0: 配置缓存 - tool_call_cache.configure( - enabled=settings.get("cache_enabled", False), - ttl=settings.get("cache_ttl", 300), - max_entries=settings.get("cache_max_entries", 200), - exclude_tools=settings.get("cache_exclude_tools", ""), - ) - - # v1.4.0: 配置权限检查器 - perm_config = self.config.get("permissions", {}) - permission_checker.configure( - enabled=perm_config.get("perm_enabled", False), - default_mode=perm_config.get("perm_default_mode", "allow_all"), - rules_json=perm_config.get("perm_rules", "[]"), - quick_deny_groups=perm_config.get("quick_deny_groups", ""), - quick_allow_users=perm_config.get("quick_allow_users", ""), - ) - - # 注册状态变化回调 - mcp_manager.set_status_change_callback(self._update_status_display) - - # v2.0: 服务器配置统一由 servers.claude_config_json 提供(不再通过 WebUI 导入/快速添加写入旧 servers.list) - - # v1.8.0: 初始化工具链管理器 - tool_chain_manager.set_executor(mcp_manager) - self._load_tool_chains() - - def _persist_runtime_displays(self) -> None: - """将 WebUI 只读展示字段写回配置文件,使 WebUI 能正确显示运行状态。""" - try: - config_path = Path(self.plugin_dir) / self.config_file_name - - payload = { - "status.connection_status": str(self.config.get("status", {}).get("connection_status", "") or ""), - "tools.tool_list": str(self.config.get("tools", {}).get("tool_list", "") or ""), - "tool_chains.chains_status": str(self.config.get("tool_chains", {}).get("chains_status", "") or ""), - "react.react_status": str(self.config.get("react", {}).get("react_status", "") or ""), - } - digest = hashlib.sha256(json.dumps(payload, ensure_ascii=False).encode("utf-8")).hexdigest() - if digest == self._last_persisted_display_hash: - return - - self._save_config_to_file(self.config, str(config_path)) - self._last_persisted_display_hash = digest - except Exception as e: - logger.debug(f"写回运行状态到配置文件失败: {e}") - - def _process_quick_add_chain(self) -> None: - """v1.8.0: 处理快速添加工具链表单""" - chains_config = self.config.get("tool_chains", {}) - - # 检查是否触发添加 - add_trigger = chains_config.get("quick_chain_add", "").strip().upper() - if add_trigger != "ADD": - return - - # 获取表单数据 - chain_name = chains_config.get("quick_chain_name", "").strip() - chain_desc = chains_config.get("quick_chain_desc", "").strip() - params_str = chains_config.get("quick_chain_params", "").strip() - steps_str = chains_config.get("quick_chain_steps", "").strip() - - # 验证必填字段 - if not chain_name: - logger.warning("快速添加工具链: 名称不能为空") - self._clear_quick_chain_fields() - return - - if not chain_desc: - logger.warning("快速添加工具链: 描述不能为空") - self._clear_quick_chain_fields() - return - - if not steps_str: - logger.warning("快速添加工具链: 步骤不能为空") - self._clear_quick_chain_fields() - return - - # 解析输入参数 - input_params = {} - if params_str: - for line in params_str.split("\n"): - line = line.strip() - if not line or "=" not in line: - continue - parts = line.split("=", 1) - param_name = parts[0].strip() - param_desc = parts[1].strip() if len(parts) > 1 else param_name - input_params[param_name] = param_desc - - # 解析步骤 - steps = [] - for line in steps_str.split("\n"): - line = line.strip() - if not line: - continue - - parts = line.split("|") - if len(parts) < 2: - logger.warning(f"快速添加工具链: 步骤格式错误: {line}") - continue - - tool_name = parts[0].strip() - args_str = parts[1].strip() if len(parts) > 1 else "{}" - output_key = parts[2].strip() if len(parts) > 2 else "" - - # 解析参数 JSON - try: - args_template = json.loads(args_str) if args_str else {} - except json.JSONDecodeError: - logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}") - args_template = {} - - steps.append( - { - "tool_name": tool_name, - "args_template": args_template, - "output_key": output_key, - } - ) - - if not steps: - logger.warning("快速添加工具链: 没有有效的步骤") - self._clear_quick_chain_fields() - return - - # 构建新工具链 - new_chain = { - "name": chain_name, - "description": chain_desc, - "input_params": input_params, - "steps": steps, - "enabled": True, - } - - # 获取现有工具链列表 - chains_json = chains_config.get("chains_list", "[]") - try: - chains_list = json.loads(chains_json) if chains_json.strip() else [] - except json.JSONDecodeError: - chains_list = [] - - # 检查是否已存在同名工具链 - for existing in chains_list: - if existing.get("name") == chain_name: - logger.info(f"快速添加: 工具链 {chain_name} 已存在,将更新") - chains_list.remove(existing) - break - - # 添加新工具链 - chains_list.append(new_chain) - new_chains_json = json.dumps(chains_list, ensure_ascii=False, indent=2) - - # 更新配置 - self.config["tool_chains"]["chains_list"] = new_chains_json - - # 清空表单字段 - self._clear_quick_chain_fields() - - # 保存到配置文件 - self._save_chains_list(new_chains_json) - - logger.info(f"快速添加: 已添加工具链 {chain_name} ({len(steps)} 个步骤)") - - def _clear_quick_chain_fields(self) -> None: - """清空快速添加工具链表单字段""" - if "tool_chains" not in self.config: - self.config["tool_chains"] = {} - self.config["tool_chains"]["quick_chain_name"] = "" - self.config["tool_chains"]["quick_chain_desc"] = "" - self.config["tool_chains"]["quick_chain_params"] = "" - self.config["tool_chains"]["quick_chain_steps"] = "" - self.config["tool_chains"]["quick_chain_add"] = "" - - def _save_chains_list(self, chains_json: str) -> None: - """保存工具链列表到配置文件""" - try: - config_path = Path(self.plugin_dir) / self.config_file_name - self._save_config_to_file(self.config, str(config_path)) - logger.info("工具链列表已保存到配置文件") - except Exception as e: - logger.warning(f"保存工具链列表失败: {e}") - - def _load_tool_chains(self) -> None: - """v1.8.0: 加载工具链配置""" - # 先处理快速添加 - self._process_quick_add_chain() - - chains_config = self.config.get("tool_chains", {}) - if not isinstance(chains_config, dict): - chains_config = {} - - # 兼容旧版本:部分版本可能使用 tool_chain 或其他字段名 - if not chains_config: - legacy_section = self.config.get("tool_chain") - if isinstance(legacy_section, dict): - chains_config = legacy_section - self.config["tool_chains"] = legacy_section - - # 兼容旧版本:chains_list 字段名变化 - chains_json = str(chains_config.get("chains_list", "") or "") - if not chains_json.strip(): - for legacy_key in ("list", "chains", "workflow_list", "workflows", "toolchains"): - legacy_val = chains_config.get(legacy_key) - if legacy_val is None: - continue - - if isinstance(legacy_val, str) and legacy_val.strip(): - chains_json = legacy_val - break - - if isinstance(legacy_val, list): - chains_json = json.dumps(legacy_val, ensure_ascii=False, indent=2) - break - - if isinstance(legacy_val, dict): - chains_json = json.dumps([legacy_val], ensure_ascii=False, indent=2) - break - - if chains_json.strip(): - if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict): - self.config["tool_chains"] = {} - self.config["tool_chains"]["chains_list"] = chains_json - logger.info( - "检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)" - ) - - chains_config = self.config.get("tool_chains", {}) - if not isinstance(chains_config, dict): - chains_config = {} - - if not chains_config.get("chains_enabled", True): - logger.info("工具链功能已禁用") - return - - chains_json = str(chains_config.get("chains_list", "[]") or "") - if not chains_json or not chains_json.strip(): - return - - # 清空现有工具链 - tool_chain_manager.clear() - tool_chain_registry.clear() - - # 加载新配置 - loaded, errors = tool_chain_manager.load_from_json(chains_json) - - if errors: - for err in errors: - logger.warning(f"工具链配置错误: {err}") - - if loaded > 0: - logger.info(f"已加载 {loaded} 个工具链") - # 注册工具链到组件系统 - self._register_tool_chains() - self._update_chains_status_display() - - def _register_tool_chains(self) -> None: - """v1.8.1: 将工具链注册到 MaiBot 组件系统,使 LLM 可调用""" - from src.plugin_system.core.component_registry import component_registry - - chain_count = 0 - for chain_name, chain in tool_chain_manager.get_enabled_chains().items(): - try: - expected_tool_name = f"chain_{chain.name}".replace("-", "_").replace(".", "_") - if component_registry.get_component_info(expected_tool_name, ComponentType.TOOL): - chain_count += 1 - logger.debug(f"🔗 工具链已存在,跳过重复注册: {expected_tool_name}") - continue - - info, tool_class = tool_chain_registry.register_chain(chain) - info.plugin_name = self.plugin_name - - if component_registry.register_component(info, tool_class): - chain_count += 1 - logger.info(f"🔗 注册工具链: {tool_class.name}") - else: - logger.warning(f"⚠️ 工具链注册被跳过(可能已存在): {tool_class.name}") - except Exception as e: - logger.error(f"注册工具链 {chain_name} 失败: {e}") - - if chain_count > 0: - logger.info(f"已注册 {chain_count} 个工具链到组件系统") - - def _register_tools_to_react(self) -> int: - """v1.9.0: 将 MCP 工具注册到记忆检索 ReAct 系统(软流程) - - 这样 MaiBot 的 ReAct Agent 在检索记忆时可以调用 MCP 工具, - 实现 LLM 自主决策的多轮工具调用。 - - Returns: - int: 成功注册的工具数量 - """ - try: - from src.memory_system.retrieval_tools import register_memory_retrieval_tool - except ImportError: - logger.warning("无法导入记忆检索工具注册模块,跳过 ReAct 工具注册") - return 0 - - react_config = self.config.get("react", {}) - filter_mode = react_config.get("filter_mode", "whitelist") - tool_filter = react_config.get("tool_filter", "").strip() - - # 解析过滤列表(支持 # 注释) - filter_patterns = [] - for line in tool_filter.split("\n"): - line = line.strip() - if line and not line.startswith("#"): - filter_patterns.append(line) - - registered_count = 0 - disabled_tools = self._get_disabled_tools() - registered_tools = [] # 记录已注册的工具名 - - for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - - # 跳过禁用的工具 - if tool_name in disabled_tools: - continue - - # 应用过滤器 - if filter_patterns: - matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns) - - if filter_mode == "whitelist": - # 白名单模式:只注册匹配的 - if not matched: - continue - else: - # 黑名单模式:排除匹配的 - if matched: - continue - - try: - # 转换参数格式 - parameters = self._convert_mcp_params_to_react_format(tool_info.input_schema) - - # 创建异步执行函数(使用闭包捕获 tool_key) - def make_execute_func(tk: str): - async def _execute_func(**kwargs) -> str: - result = await mcp_manager.call_tool(tk, kwargs) - if result.success: - return result.content or "(无返回内容)" - else: - return f"工具调用失败: {result.error}" - - return _execute_func - - execute_func = make_execute_func(tool_key) - - # 注册到 ReAct 系统 - register_memory_retrieval_tool( - name=f"mcp_{tool_name}", - description=f"{tool_info.description} [MCP: {tool_info.server_name}]", - parameters=parameters, - execute_func=execute_func, - ) - - registered_count += 1 - registered_tools.append(f"mcp_{tool_name}") - logger.debug(f"🔄 注册 ReAct 工具: mcp_{tool_name}") - - except Exception as e: - logger.warning(f"注册 ReAct 工具 {tool_name} 失败: {e}") - - if registered_count > 0: - mode_str = "白名单" if filter_mode == "whitelist" else "黑名单" - logger.info(f"已注册 {registered_count} 个 MCP 工具到 ReAct 系统 (过滤模式: {mode_str})") - - # 更新状态显示 - self._update_react_status_display(registered_tools, filter_mode, filter_patterns) - - return registered_count - - def _update_react_status_display( - self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str] - ) -> None: - """更新 ReAct 工具状态显示""" - if not registered_tools: - status_text = "(未注册任何工具)" - else: - mode_str = "白名单" if filter_mode == "whitelist" else "黑名单" - lines = [f"📊 已注册 {len(registered_tools)} 个工具 (模式: {mode_str})"] - if filter_patterns: - lines.append(f"过滤规则: {len(filter_patterns)} 条") - lines.append("") - for tool in registered_tools[:20]: - lines.append(f" • {tool}") - if len(registered_tools) > 20: - lines.append(f" ... 还有 {len(registered_tools) - 20} 个") - status_text = "\n".join(lines) - - # 更新内存配置 - if "react" not in self.config: - self.config["react"] = {} - self.config["react"]["react_status"] = status_text - - def _convert_mcp_params_to_react_format(self, input_schema: Dict) -> List[Dict[str, Any]]: - """将 MCP 工具参数转换为 ReAct 工具参数格式""" - parameters = [] - - if not input_schema: - return parameters - - properties = input_schema.get("properties", {}) - required = input_schema.get("required", []) - - for param_name, param_info in properties.items(): - param_type = param_info.get("type", "string") - description = param_info.get("description", f"参数 {param_name}") - is_required = param_name in required - - parameters.append( - { - "name": param_name, - "type": param_type, - "description": description, - "required": is_required, - } - ) - - return parameters - - def _update_chains_status_display(self) -> None: - """v1.8.0: 更新工具链状态显示""" - chains = tool_chain_manager.get_all_chains() - - if not chains: - status_text = "(无工具链配置)" - else: - lines = [f"📊 已配置 {len(chains)} 个工具链:\n"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - # 显示工具链基本信息 - lines.append(f"{status} chain_{name}") - lines.append(f" 描述: {chain.description[:40]}{'...' if len(chain.description) > 40 else ''}") - - # 显示输入参数 - if chain.input_params: - params = ", ".join(chain.input_params.keys()) - lines.append(f" 参数: {params}") - - # 显示步骤 - lines.append(f" 步骤: {len(chain.steps)} 个") - for i, step in enumerate(chain.steps): - opt = " (可选)" if step.optional else "" - out = f" → {step.output_key}" if step.output_key else "" - lines.append(f" {i + 1}. {step.tool_name}{out}{opt}") - lines.append("") - - status_text = "\n".join(lines) - - # 更新内存配置 - if "tool_chains" not in self.config: - self.config["tool_chains"] = {} - self.config["tool_chains"]["chains_status"] = status_text - - def _get_disabled_tools(self) -> set: - """v1.4.0: 获取禁用的工具列表""" - tools_config = self.config.get("tools", {}) - disabled_str = tools_config.get("disabled_tools", "") - return {t.strip() for t in disabled_str.strip().split("\n") if t.strip()} - - async def _async_connect_servers(self) -> None: - """异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)""" - import asyncio - - settings = self.config.get("settings", {}) - - servers_config = self._load_mcp_servers_config() - - if not servers_config: - logger.warning("未配置任何 MCP 服务器") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - auto_connect = settings.get("auto_connect", True) - if not auto_connect: - logger.info("auto_connect 已禁用,跳过自动连接") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - tool_prefix = settings.get("tool_prefix", "mcp") - disabled_tools = self._get_disabled_tools() - enable_resources = settings.get("enable_resources", False) - enable_prompts = settings.get("enable_prompts", False) - - # 解析所有服务器配置 - enabled_configs: List[MCPServerConfig] = [] - for idx, server_conf in enumerate(servers_config): - server_name = server_conf.get("name", f"unknown_{idx}") - - if not server_conf.get("enabled", True): - logger.info(f"服务器 {server_name} 已禁用,跳过") - continue - - try: - config = self._parse_server_config(server_conf) - enabled_configs.append(config) - except Exception as e: - logger.error(f"解析服务器 {server_name} 配置失败: {e}") - - if not enabled_configs: - logger.warning("没有已启用的 MCP 服务器") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - logger.info(f"准备并行连接 {len(enabled_configs)} 个 MCP 服务器") - - # v1.5.0: 并行连接所有服务器 - async def connect_single_server(config: MCPServerConfig) -> Tuple[MCPServerConfig, bool]: - """连接单个服务器""" - logger.info(f"正在连接服务器: {config.name} ({config.transport.value})") - try: - success = await mcp_manager.add_server(config) - if success: - logger.info(f"✅ 服务器 {config.name} 连接成功") - # 获取资源和提示模板 - if enable_resources: - try: - await mcp_manager.fetch_resources_for_server(config.name) - except Exception as e: - logger.warning(f"服务器 {config.name} 获取资源列表失败: {e}") - if enable_prompts: - try: - await mcp_manager.fetch_prompts_for_server(config.name) - except Exception as e: - logger.warning(f"服务器 {config.name} 获取提示模板列表失败: {e}") - else: - logger.warning(f"❌ 服务器 {config.name} 连接失败") - return config, success - except Exception as e: - logger.error(f"❌ 服务器 {config.name} 连接异常: {e}") - return config, False - - # 并行执行所有连接 - start_time = time.time() - results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True) - connect_duration = time.time() - start_time - - # 统计连接结果 - success_count = 0 - failed_count = 0 - for result in results: - if isinstance(result, Exception): - failed_count += 1 - logger.error(f"连接任务异常: {result}") - elif isinstance(result, tuple): - _, success = result - if success: - success_count += 1 - else: - failed_count += 1 - - logger.info(f"并行连接完成: {success_count} 成功, {failed_count} 失败, 耗时 {connect_duration:.2f}s") - - # 注册所有工具 - from src.plugin_system.core.component_registry import component_registry - - registered_count = 0 - - for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - is_disabled = tool_name in disabled_tools - - info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled) - info.plugin_name = self.plugin_name - - if component_registry.register_component(info, tool_class): - registered_count += 1 - status = "🚫" if is_disabled else "✅" - logger.info(f"{status} 注册 MCP 工具: {tool_class.name}") - else: - logger.warning(f"❌ 注册 MCP 工具失败: {tool_class.name}") - - chains_config = self.config.get("tool_chains", {}) - chains_enabled = bool(chains_config.get("chains_enabled", True)) if isinstance(chains_config, dict) else True - chain_count = len(tool_chain_manager.get_enabled_chains()) if chains_enabled else 0 - - # v1.9.0: 注册 MCP 工具到记忆检索 ReAct 系统(软流程) - react_count = 0 - react_config = self.config.get("react", {}) - if react_config.get("react_enabled", False): - react_count = self._register_tools_to_react() - - self._initialized = True - logger.info( - f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具" - ) - - # 更新状态显示 - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - - def _start_status_refresher(self) -> None: - """启动 WebUI 状态刷新任务(不写入磁盘)""" - task = getattr(self, "_status_refresh_task", None) - if task and not task.done(): - return - - self._status_refresh_running = True - self._status_refresh_task = asyncio.create_task(self._status_refresh_loop()) - - async def _stop_status_refresher(self) -> None: - """停止 WebUI 状态刷新任务""" - self._status_refresh_running = False - task = getattr(self, "_status_refresh_task", None) - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self._status_refresh_task = None - - async def _status_refresh_loop(self) -> None: - """定期刷新 WebUI 展示字段(状态/工具列表/工具链状态)""" - while getattr(self, "_status_refresh_running", False): - try: - settings = self.config.get("settings", {}) - enabled = bool(settings.get("status_refresh_enabled", True)) - interval = float(settings.get("status_refresh_interval", 10.0) or 10.0) - interval = max(5.0, min(interval, 60.0)) - - if enabled and self._initialized: - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._persist_runtime_displays() - - await asyncio.sleep(interval if enabled else 5.0) - except asyncio.CancelledError: - break - except Exception as e: - logger.debug(f"状态刷新任务异常: {e}") - await asyncio.sleep(5.0) - - def _load_mcp_servers_config(self) -> List[Dict[str, Any]]: - """v2.0: 从 Claude mcpServers JSON 加载服务器配置。 - - - 唯一主入口:config.servers.claude_config_json - - 兼容:若旧版 servers.list 存在且 claude_config_json 为空,会自动迁移并写回内存配置 - """ - servers_section = self.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - claude_json = str(servers_section.get("claude_config_json", "") or "") - - if not claude_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - migrated = legacy_servers_list_to_claude_config(legacy_list) - if migrated: - claude_json = migrated - if "servers" not in self.config: - self.config["servers"] = {} - self.config["servers"]["claude_config_json"] = migrated - logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)") - - if not claude_json.strip(): - self._last_servers_config_error = ( - "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)" - ) - return [] - - try: - servers = parse_claude_mcp_config(claude_json) - except ClaudeConfigError as e: - self._last_servers_config_error = str(e) - logger.error(f"Claude mcpServers 配置解析失败: {e}") - return [] - except Exception as e: - self._last_servers_config_error = str(e) - logger.error(f"Claude mcpServers 配置解析异常: {e}") - return [] - - self._last_servers_config_error = "" - - # 保留未知字段(如 post_process)供旧功能使用 - raw_mapping: Dict[str, Any] = {} - try: - parsed = json.loads(claude_json) - mapping = parsed.get("mcpServers", parsed) - if isinstance(mapping, dict): - raw_mapping = mapping - except Exception: - raw_mapping = {} - - configs: List[Dict[str, Any]] = [] - for srv in servers: - raw = raw_mapping.get(srv.name, {}) - cfg: Dict[str, Any] = raw.copy() if isinstance(raw, dict) else {} - cfg.update( - { - "name": srv.name, - "enabled": srv.enabled, - "transport": srv.transport, - "command": srv.command, - "args": srv.args, - "env": srv.env, - "url": srv.url, - "headers": srv.headers, - } - ) - configs.append(cfg) - - return configs - - def _parse_server_config(self, conf: Dict) -> MCPServerConfig: - """解析服务器配置字典""" - transport_str = conf.get("transport", "stdio").lower() - - transport_map = { - "stdio": TransportType.STDIO, - "sse": TransportType.SSE, - "http": TransportType.HTTP, - "streamable_http": TransportType.STREAMABLE_HTTP, - } - transport = transport_map.get(transport_str, TransportType.STDIO) - - return MCPServerConfig( - name=conf.get("name", "unnamed"), - enabled=conf.get("enabled", True), - transport=transport, - command=conf.get("command", ""), - args=conf.get("args", []), - env=conf.get("env", {}), - url=conf.get("url", ""), - headers=conf.get("headers", {}), # v1.4.2: 鉴权头支持 - ) - - def _update_tool_list_display(self) -> None: - """v1.4.0: 更新工具列表显示""" - tools = mcp_manager.all_tools - disabled_tools = self._get_disabled_tools() - - lines = [] - by_server: Dict[str, List[str]] = {} - - for tool_key, (tool_info, _) in tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - if tool_info.server_name not in by_server: - by_server[tool_info.server_name] = [] - - is_disabled = tool_name in disabled_tools - status = " ❌" if is_disabled else "" - by_server[tool_info.server_name].append(f" • {tool_name}{status}") - - for srv_name, tool_list in by_server.items(): - lines.append(f"📦 {srv_name} ({len(tool_list)}个工具):") - lines.extend(tool_list) - lines.append("") - - if not by_server: - lines.append("(无已注册工具)") - - tool_list_text = "\n".join(lines) - - # 更新内存配置 - if "tools" not in self.config: - self.config["tools"] = {} - self.config["tools"]["tool_list"] = tool_list_text - - def _update_status_display(self) -> None: - """更新配置文件中的状态显示字段""" - status = mcp_manager.get_status() - settings = self.config.get("settings", {}) - lines = [] - - cfg_err = str(getattr(self, "_last_servers_config_error", "") or "").strip() - if cfg_err: - lines.append(f"⚠️ 配置: {cfg_err}") - lines.append("") - - lines.append(f"服务器: {status['connected_servers']}/{status['total_servers']} 已连接") - lines.append(f"工具数: {status['total_tools']}") - if settings.get("enable_resources", False): - lines.append(f"资源数: {status.get('total_resources', 0)}") - if settings.get("enable_prompts", False): - lines.append(f"模板数: {status.get('total_prompts', 0)}") - lines.append(f"心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") - lines.append("") - - tools = mcp_manager.all_tools - - for name, info in status.get("servers", {}).items(): - icon = "✅" if info["connected"] else "❌" - lines.append(f"{icon} {name} ({info['transport']})") - - # v1.7.0: 显示断路器状态 - cb_status = info.get("circuit_breaker", {}) - cb_state = cb_status.get("state", "closed") - if cb_state == "open": - lines.append(" ⚡ 断路器: 熔断中") - elif cb_state == "half_open": - lines.append(" ⚡ 断路器: 试探中") - - server_tools = [t.name for key, (t, _) in tools.items() if t.server_name == name] - if server_tools: - for tool_name in server_tools: - lines.append(f" • {tool_name}") - else: - lines.append(" (无工具)") - - if not status.get("servers"): - lines.append("(无服务器)") - - status_text = "\n".join(lines) - - if "status" not in self.config: - self.config["status"] = {} - self.config["status"]["connection_status"] = status_text - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - """返回插件的所有组件""" - components: List[Tuple[ComponentInfo, Type]] = [] - - # 事件处理器 - components.append((MCPStartupHandler.get_handler_info(), MCPStartupHandler)) - components.append((MCPStopHandler.get_handler_info(), MCPStopHandler)) - - # 命令 - components.append((MCPStatusCommand.get_command_info(), MCPStatusCommand)) - components.append((MCPImportCommand.get_command_info(), MCPImportCommand)) - - # 内置工具 - status_tool_info = ToolInfo( - name=MCPStatusTool.name, - tool_description=MCPStatusTool.description, - enabled=True, - tool_parameters=MCPStatusTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((status_tool_info, MCPStatusTool)) - - settings = self.config.get("settings", {}) - - if settings.get("enable_resources", False): - read_resource_info = ToolInfo( - name=MCPReadResourceTool.name, - tool_description=MCPReadResourceTool.description, - enabled=True, - tool_parameters=MCPReadResourceTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((read_resource_info, MCPReadResourceTool)) - - if settings.get("enable_prompts", False): - get_prompt_info = ToolInfo( - name=MCPGetPromptTool.name, - tool_description=MCPGetPromptTool.description, - enabled=True, - tool_parameters=MCPGetPromptTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((get_prompt_info, MCPGetPromptTool)) - - return components - - def get_status(self) -> Dict[str, Any]: - """获取插件状态""" - return { - "initialized": self._initialized, - "mcp_manager": mcp_manager.get_status(), - "registered_tools": len(mcp_tool_registry._tool_classes), - "trace_records": tool_call_tracer.total_records, - "cache_stats": tool_call_cache.get_stats(), - } - - def get_stats(self) -> Dict[str, Any]: - """获取详细统计信息""" - return mcp_manager.get_all_stats() diff --git a/plugins/MaiBot_MCPBridgePlugin/requirements.txt b/plugins/MaiBot_MCPBridgePlugin/requirements.txt deleted file mode 100644 index 7580f09e..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -# MCP 桥接插件依赖 -mcp>=1.0.0 diff --git a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py b/plugins/MaiBot_MCPBridgePlugin/tool_chain.py deleted file mode 100644 index 6a1530cc..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py +++ /dev/null @@ -1,584 +0,0 @@ -""" -MCP Workflow 模块 v1.9.0 -支持用户自定义工作流(硬流程),将多个 MCP 工具按顺序执行 - -双轨制架构: -- 软流程 (ReAct): LLM 自主决策,动态多轮调用工具,灵活但不可预测 -- 硬流程 (Workflow): 用户预定义的工作流,固定流程,可靠可控 - -功能: -- Workflow 定义和管理 -- 顺序执行多个工具(硬流程) -- 支持变量替换(使用前序工具的输出) -- 自动注册为组合工具供 LLM 调用 -- 与 ReAct 软流程互补,用户可选择合适的执行方式 -""" - -import json -import re -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -try: - from src.common.logger import get_logger - - logger = get_logger("mcp_tool_chain") -except ImportError: - import logging - - logger = logging.getLogger("mcp_tool_chain") - - -@dataclass -class ToolChainStep: - """工具链步骤""" - - tool_name: str # 要调用的工具名(如 mcp_server_tool) - args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换 - output_key: str = "" # 输出存储的键名,供后续步骤引用 - description: str = "" # 步骤描述 - optional: bool = False # 是否可选(失败时继续执行) - - def to_dict(self) -> Dict[str, Any]: - return { - "tool_name": self.tool_name, - "args_template": self.args_template, - "output_key": self.output_key, - "description": self.description, - "optional": self.optional, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep": - return cls( - tool_name=data.get("tool_name", ""), - args_template=data.get("args_template", {}), - output_key=data.get("output_key", ""), - description=data.get("description", ""), - optional=data.get("optional", False), - ) - - -@dataclass -class ToolChainDefinition: - """工具链定义""" - - name: str # 工具链名称(将作为组合工具的名称) - description: str # 工具链描述(供 LLM 理解) - steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤 - input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述} - enabled: bool = True # 是否启用 - - def to_dict(self) -> Dict[str, Any]: - return { - "name": self.name, - "description": self.description, - "steps": [step.to_dict() for step in self.steps], - "input_params": self.input_params, - "enabled": self.enabled, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition": - steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])] - return cls( - name=data.get("name", ""), - description=data.get("description", ""), - steps=steps, - input_params=data.get("input_params", {}), - enabled=data.get("enabled", True), - ) - - -@dataclass -class ChainExecutionResult: - """工具链执行结果""" - - success: bool - final_output: str # 最终输出(最后一个步骤的结果) - step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果 - error: str = "" - total_duration_ms: float = 0.0 - - def to_summary(self) -> str: - """生成执行摘要""" - lines = [] - for i, step in enumerate(self.step_results): - status = "✅" if step.get("success") else "❌" - tool = step.get("tool_name", "unknown") - duration = step.get("duration_ms", 0) - lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)") - if not step.get("success") and step.get("error"): - lines.append(f" 错误: {step['error'][:50]}") - return "\n".join(lines) - - -class ToolChainExecutor: - """工具链执行器""" - - # 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev} - VAR_PATTERN = re.compile(r"\$\{([^}]+)\}") - - def __init__(self, mcp_manager): - self._mcp_manager = mcp_manager - - def _resolve_tool_key(self, tool_name: str) -> Optional[str]: - """解析工具名,返回有效的 tool_key - - 支持: - - 直接使用 tool_key(如 mcp_server_tool) - - 使用注册后的工具名(会自动转换 - 和 . 为 _) - """ - all_tools = self._mcp_manager.all_tools - - # 直接匹配 - if tool_name in all_tools: - return tool_name - - # 尝试转换后匹配(用户可能使用了注册后的名称) - normalized = tool_name.replace("-", "_").replace(".", "_") - if normalized in all_tools: - return normalized - - # 尝试查找包含该名称的工具 - for key in all_tools.keys(): - if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"): - return key - - return None - - async def execute( - self, - chain: ToolChainDefinition, - input_args: Dict[str, Any], - ) -> ChainExecutionResult: - """执行工具链 - - Args: - chain: 工具链定义 - input_args: 用户输入的参数 - - Returns: - ChainExecutionResult: 执行结果 - """ - start_time = time.time() - step_results = [] - context = { - "input": input_args or {}, # 用户输入,确保不为 None - "step": {}, # 各步骤输出,按 output_key 存储 - "prev": "", # 上一步的输出 - } - - final_output = "" - - # 验证必需的输入参数 - missing_params = [] - for param_name in chain.input_params.keys(): - if param_name not in context["input"]: - missing_params.append(param_name) - - if missing_params: - return ChainExecutionResult( - success=False, - final_output="", - error=f"缺少必需参数: {', '.join(missing_params)}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - for i, step in enumerate(chain.steps): - step_start = time.time() - step_result = { - "step_index": i, - "tool_name": step.tool_name, - "success": False, - "output": "", - "error": "", - "duration_ms": 0, - } - - try: - # 替换参数中的变量 - resolved_args = self._resolve_args(step.args_template, context) - step_result["resolved_args"] = resolved_args - - # 解析工具名 - tool_key = self._resolve_tool_key(step.tool_name) - if not tool_key: - step_result["error"] = f"工具 {step.tool_name} 不存在" - logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在", - total_duration_ms=(time.time() - start_time) * 1000, - ) - step_results.append(step_result) - continue - - logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}") - - # 调用工具 - result = await self._mcp_manager.call_tool(tool_key, resolved_args) - - step_duration = (time.time() - step_start) * 1000 - step_result["duration_ms"] = step_duration - - if result.success: - step_result["success"] = True - # 确保 content 不为 None - content = result.content if result.content is not None else "" - step_result["output"] = content - - # 更新上下文 - context["prev"] = content - if step.output_key: - context["step"][step.output_key] = content - - final_output = content - content_preview = content[:100] if content else "(空)" - logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...") - else: - step_result["error"] = result.error or "未知错误" - logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - except Exception as e: - step_duration = (time.time() - step_start) * 1000 - step_result["duration_ms"] = step_duration - step_result["error"] = str(e) - logger.error(f"工具链步骤 {i + 1} 异常: {e}") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - step_results.append(step_result) - - total_duration = (time.time() - start_time) * 1000 - - return ChainExecutionResult( - success=True, - final_output=final_output, - step_results=step_results, - total_duration_ms=total_duration, - ) - - def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: - """解析参数模板,替换变量 - - 支持的变量格式: - - ${input.param_name}: 用户输入的参数 - - ${step.output_key}: 某个步骤的输出 - - ${prev}: 上一步的输出 - - ${prev.field}: 上一步输出(JSON)的某个字段 - """ - resolved = {} - - for key, value in args_template.items(): - if isinstance(value, str): - resolved[key] = self._substitute_vars(value, context) - elif isinstance(value, dict): - resolved[key] = self._resolve_args(value, context) - elif isinstance(value, list): - resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value] - else: - resolved[key] = value - - return resolved - - def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str: - """替换字符串中的变量""" - - def replacer(match): - var_path = match.group(1) - return self._get_var_value(var_path, context) - - return self.VAR_PATTERN.sub(replacer, template) - - def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str: - """获取变量值 - - Args: - var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id" - context: 上下文 - """ - parts = self._parse_var_path(var_path) - - if not parts: - return "" - - # 获取根对象 - root = parts[0] - if root not in context: - logger.warning(f"变量 {var_path} 的根 '{root}' 不存在") - return "" - - value = context[root] - - # 遍历路径 - for part in parts[1:]: - if isinstance(value, str): - parsed = self._try_parse_json(value) - if parsed is not None: - value = parsed - - if isinstance(value, dict): - value = value.get(part, "") - elif isinstance(value, list): - if part.isdigit(): - idx = int(part) - value = value[idx] if 0 <= idx < len(value) else "" - else: - value = "" - else: - value = "" - - # 确保返回字符串 - if isinstance(value, (dict, list)): - return json.dumps(value, ensure_ascii=False) - if value is None: - return "" - if value == "": - return "" - return str(value) - - def _try_parse_json(self, value: str) -> Optional[Any]: - """尝试将字符串解析为 JSON 对象,失败则返回 None。""" - if not value: - return None - try: - return json.loads(value) - except json.JSONDecodeError: - return None - - def _parse_var_path(self, var_path: str) -> List[str]: - """解析变量路径,支持点号与下标写法。 - - 支持: - - step.geo.return.0.location - - step.geo.return[0].location - - step.geo['return'][0]['location'] - """ - if not var_path: - return [] - - tokens: List[str] = [] - buf: List[str] = [] - in_bracket = False - in_quote = False - quote_char = "" - - def flush_buf() -> None: - if buf: - token = "".join(buf).strip() - if token: - tokens.append(token) - buf.clear() - - i = 0 - while i < len(var_path): - ch = var_path[i] - - if not in_bracket and ch == ".": - flush_buf() - i += 1 - continue - - if not in_bracket and ch == "[": - flush_buf() - in_bracket = True - in_quote = False - quote_char = "" - i += 1 - continue - - if in_bracket and not in_quote and ch == "]": - flush_buf() - in_bracket = False - i += 1 - continue - - if in_bracket and ch in ("'", '"'): - if not in_quote: - in_quote = True - quote_char = ch - i += 1 - continue - if quote_char == ch: - in_quote = False - quote_char = "" - i += 1 - continue - - if in_bracket and not in_quote: - if ch.isspace(): - i += 1 - continue - if ch == ",": - i += 1 - continue - - buf.append(ch) - i += 1 - - flush_buf() - - if in_bracket or in_quote: - return [p for p in var_path.split(".") if p] - - return tokens - - -class ToolChainManager: - """工具链管理器""" - - _instance: Optional["ToolChainManager"] = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - self._chains: Dict[str, ToolChainDefinition] = {} - self._executor: Optional[ToolChainExecutor] = None - - def set_executor(self, mcp_manager) -> None: - """设置执行器""" - self._executor = ToolChainExecutor(mcp_manager) - - def add_chain(self, chain: ToolChainDefinition) -> bool: - """添加工具链""" - if not chain.name: - logger.error("工具链名称不能为空") - return False - - if chain.name in self._chains: - logger.warning(f"工具链 {chain.name} 已存在,将被覆盖") - - self._chains[chain.name] = chain - logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)") - return True - - def remove_chain(self, name: str) -> bool: - """移除工具链""" - if name in self._chains: - del self._chains[name] - logger.info(f"已移除工具链: {name}") - return True - return False - - def get_chain(self, name: str) -> Optional[ToolChainDefinition]: - """获取工具链""" - return self._chains.get(name) - - def get_all_chains(self) -> Dict[str, ToolChainDefinition]: - """获取所有工具链""" - return self._chains.copy() - - def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]: - """获取所有启用的工具链""" - return {name: chain for name, chain in self._chains.items() if chain.enabled} - - async def execute_chain( - self, - chain_name: str, - input_args: Dict[str, Any], - ) -> ChainExecutionResult: - """执行工具链""" - chain = self._chains.get(chain_name) - if not chain: - return ChainExecutionResult( - success=False, - final_output="", - error=f"工具链 {chain_name} 不存在", - ) - - if not chain.enabled: - return ChainExecutionResult( - success=False, - final_output="", - error=f"工具链 {chain_name} 已禁用", - ) - - if not self._executor: - return ChainExecutionResult( - success=False, - final_output="", - error="工具链执行器未初始化", - ) - - return await self._executor.execute(chain, input_args) - - def load_from_json(self, json_str: str) -> Tuple[int, List[str]]: - """从 JSON 字符串加载工具链配置 - - Returns: - (成功加载数量, 错误列表) - """ - errors = [] - loaded = 0 - - try: - data = json.loads(json_str) if json_str.strip() else [] - except json.JSONDecodeError as e: - return 0, [f"JSON 解析失败: {e}"] - - if not isinstance(data, list): - data = [data] - - for i, item in enumerate(data): - try: - chain = ToolChainDefinition.from_dict(item) - if not chain.name: - errors.append(f"第 {i + 1} 个工具链缺少名称") - continue - if not chain.steps: - errors.append(f"工具链 {chain.name} 没有步骤") - continue - - self.add_chain(chain) - loaded += 1 - except Exception as e: - errors.append(f"第 {i + 1} 个工具链解析失败: {e}") - - return loaded, errors - - def export_to_json(self, pretty: bool = True) -> str: - """导出所有工具链为 JSON""" - chains_data = [chain.to_dict() for chain in self._chains.values()] - if pretty: - return json.dumps(chains_data, ensure_ascii=False, indent=2) - return json.dumps(chains_data, ensure_ascii=False) - - def clear(self) -> None: - """清空所有工具链""" - self._chains.clear() - - -# 全局工具链管理器实例 -tool_chain_manager = ToolChainManager() diff --git a/prompts/en-US/maidairy_timing.prompt b/prompts/en-US/maidairy_timing.prompt deleted file mode 100644 index e7b785be..00000000 --- a/prompts/en-US/maidairy_timing.prompt +++ /dev/null @@ -1,22 +0,0 @@ -你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析: - -【时间感知分析】 -1. 对话持续时长:当前对话已经进行了多久 -2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何 -3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适 -4. 时间相关洞察: - - 用户是否可能正在忙(回复变慢) - - 用户是否正在积极对话(回复很快) - - 当前时段(深夜/早晨/工作时间等)是否适合继续聊 - - 对话是否已经持续太久,用户可能需要休息 - - 是否应该主动结束对话 - -【自我反思分析】 -1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论 -2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论 -3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断 - -要求: -- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半 -- 重点关注对话节奏的变化趋势和助手自身的人设一致性 -- 直接输出分析结果,不要有格式标题或分段标记 diff --git a/prompts/ja-JP/maidairy_timing.prompt b/prompts/ja-JP/maidairy_timing.prompt deleted file mode 100644 index e7b785be..00000000 --- a/prompts/ja-JP/maidairy_timing.prompt +++ /dev/null @@ -1,22 +0,0 @@ -你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析: - -【时间感知分析】 -1. 对话持续时长:当前对话已经进行了多久 -2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何 -3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适 -4. 时间相关洞察: - - 用户是否可能正在忙(回复变慢) - - 用户是否正在积极对话(回复很快) - - 当前时段(深夜/早晨/工作时间等)是否适合继续聊 - - 对话是否已经持续太久,用户可能需要休息 - - 是否应该主动结束对话 - -【自我反思分析】 -1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论 -2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论 -3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断 - -要求: -- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半 -- 重点关注对话节奏的变化趋势和助手自身的人设一致性 -- 直接输出分析结果,不要有格式标题或分段标记 diff --git a/prompts/zh-CN/action.prompt b/prompts/zh-CN/action.prompt deleted file mode 100644 index 91831b2a..00000000 --- a/prompts/zh-CN/action.prompt +++ /dev/null @@ -1,5 +0,0 @@ -{action_name} -动作描述:{action_description} -使用条件{parallel_text}: -{action_require} -{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}} \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_group1.prompt b/prompts/zh-CN/chat_target_group1.prompt deleted file mode 100644 index 77e89bcc..00000000 --- a/prompts/zh-CN/chat_target_group1.prompt +++ /dev/null @@ -1 +0,0 @@ -你正在qq群里聊天,下面是群里正在聊的内容: \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_group2.prompt b/prompts/zh-CN/chat_target_group2.prompt deleted file mode 100644 index 5b71bace..00000000 --- a/prompts/zh-CN/chat_target_group2.prompt +++ /dev/null @@ -1 +0,0 @@ -正在群里聊天 \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_private1.prompt b/prompts/zh-CN/chat_target_private1.prompt deleted file mode 100644 index 3e86c71f..00000000 --- a/prompts/zh-CN/chat_target_private1.prompt +++ /dev/null @@ -1 +0,0 @@ -你正在和{sender_name}聊天,这是你们之前聊的内容: \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_private2.prompt b/prompts/zh-CN/chat_target_private2.prompt deleted file mode 100644 index 9225ec82..00000000 --- a/prompts/zh-CN/chat_target_private2.prompt +++ /dev/null @@ -1 +0,0 @@ -和{sender_name}聊天 \ No newline at end of file diff --git a/prompts/zh-CN/lpmm_get_knowledge.prompt b/prompts/zh-CN/lpmm_get_knowledge.prompt deleted file mode 100644 index 2ade0d0f..00000000 --- a/prompts/zh-CN/lpmm_get_knowledge.prompt +++ /dev/null @@ -1,10 +0,0 @@ -你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。 -群里正在进行的聊天内容: -{chat_history} - -现在,{sender}发送了内容:{target_message},你想要回复ta。 -请仔细分析聊天内容,考虑以下几点: -1. 内容中是否包含需要查询信息的问题 -2. 是否有明确的知识获取指令 - -If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed". \ No newline at end of file diff --git a/prompts/zh-CN/maidairy_chat.prompt b/prompts/zh-CN/maidairy_chat.prompt index 043b6dc1..aac8e8ac 100644 --- a/prompts/zh-CN/maidairy_chat.prompt +++ b/prompts/zh-CN/maidairy_chat.prompt @@ -1,5 +1,5 @@ 你的任务是分析聊天和聊天中的互动情况。 -你需要关注 {bot_name}(AI) 与不同用户的对话来为选择正确的动作和行为提供建议 +你需要关注 {bot_name}(AI) 与不同用户的对话来为选择正确的动作和行为以及搜集信息提供建议 【参考信息】 {bot_name}的人设:{identity} @@ -8,28 +8,28 @@ 你需要根据提供的参考信息,当前场景和输出规则来进行分析 在当前场景中,用户正在与AI麦麦进行聊天互动,你的任务不是生成对用户可见的发言,而是进行分析来指导AI进行回复。 “分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。 +你需要先搜集能够帮助{bot_name}回复的信息,然后再给出回复意见 你可以使用这些工具: - wait(seconds) - 暂时停止对话,等待(seconds)秒,把话语权交给用户,等待对方新的发言。 -- stop() - 结束对话,不进行任何回复,直到对方有新消息。 -- reply():当你判断现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。 -- no_reply():当你判断现在不应该发言,应该继续内部思考时调用。这个工具不会做任何外部行为,只会继续下一轮循环。 -{file_tools_section} +- stop() - 当你判断{bot_name}现在不应该发言,结束对话,不进行任何回复,直到对方有新消息。 +- reply():当你判断{bot_name}现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。 +- query_jargon():当你认为某些词的含义不明确,或用户询问某些词的含义,需要进行查询 +- 其他定义的工具,你可以视情况合适使用 工具使用规则: -1.如果麦麦已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用wait或者stop进行等待 +1.如果{bot_name}已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用wait或者stop进行等待 2.如果用户有新发言,但是你评估用户还有后续发言尚未发送,可以适当等待让用户说完 3.在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,可以不使用stop或者wait -4.如果你想指导麦麦直接发言,可以不使用任何工具 +4.你需要控制自己发言的频率,如果用户一对一聊天,可以以均匀地频率发言,如果用户较多,不要每句都回复,控制回复频率。当你决定暂时不发言,可以使用wait暂时等待一定时间或者stop等待新消息 +5.如果存在用户的疑问,或者对某些概念的不确定,你可以使用工具来搜集信息或者查询含义,你可以使用多个工具 -你的输出规则: +你的分析规则: 1. 默认直接输出你当前的最新分析,不要重复之前的分析内容。 2. 最新分析应尽量具体,贴近上下文,不要空泛重复。 -3. 如果你认为现在更适合等待用户补充,可以调用 `wait(seconds)`。 -4. 如果你认为应当结束当前对话,不回复任何内容,可以调用 `stop()`。 -5. 只有在确实需要等待或停止时才调用工具,否则优先直接输出分析想法。 -6. 如果你刚刚做了工具调用,下一轮应结合工具结果继续输出新的分析。 -7. 分析应服务于后续决策,而不是机械复述用户内容。 +3. 如果你刚刚做了工具调用,下一轮应结合工具结果继续输出新的分析。 +4. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题。 +5. 如果你上一轮没有发言,需要重新进行分析,输出新的分析内容,不要重复上一轮的分析内容 -现在,请你输出你的分析: +现在,请你输出你对{bot_name}发言的分析,你必须先输出文本内容的分析,然后再进行工具调用: diff --git a/prompts/zh-CN/maidairy_cognition.prompt b/prompts/zh-CN/maidairy_cognition.prompt deleted file mode 100644 index 7c5c814a..00000000 --- a/prompts/zh-CN/maidairy_cognition.prompt +++ /dev/null @@ -1,11 +0,0 @@ -你是一个认知感知分析模块。你的任务是根据对话上下文,分析对话中用户的: -1. 核心意图(如:寻求帮助、纯粹聊天、请求任务、发泄情绪、获取信息、表达观点等) -2. 认知状态(如:明确具体、模糊试探、犹豫不决、困惑迷茫、思路清晰、逻辑混乱等) -3. 隐含目的(如:解决问题、获得安慰、打发时间、寻求认同、交换想法、表达自我等) - -要求: -- 只分析用户(对话中 role=user 的内容),不要分析助手自己 -- 根据用户最新发言重点分析,同时结合上下文理解深层动机 -- 输出简洁(2-4 句话),不要太长 -- 如果信息太少无法判断,就说信息不足,给出初步印象 -- 直接输出分析结果,不要有格式标题 diff --git a/prompts/zh-CN/maidairy_emotion.prompt b/prompts/zh-CN/maidairy_emotion.prompt deleted file mode 100644 index b8440527..00000000 --- a/prompts/zh-CN/maidairy_emotion.prompt +++ /dev/null @@ -1,11 +0,0 @@ -你是一个情绪感知分析模块。你的任务是根据对话上下文,分析对话中用户的: -1. 当前情绪状态(如:开心、沮丧、焦虑、平静、兴奋、愤怒等) -2. 言语态度(如:友好、冷淡、热情、敷衍、试探、认真、调侃等) -3. 潜在的情感需求(如:需要倾听、需要鼓励、想要倾诉、只是闲聊等) - -要求: -- 只分析用户(对话中 role=user 的内容),不要分析助手自己 -- 根据用户最新发言重点分析,同时结合上下文理解变化趋势 -- 输出简洁(2-4 句话),不要太长 -- 如果信息太少无法判断,就说信息不足,给出初步印象 -- 直接输出分析结果,不要有格式标题 diff --git a/prompts/zh-CN/maidairy_replyer.prompt b/prompts/zh-CN/maidairy_replyer.prompt index f06f086f..9e13f45b 100644 --- a/prompts/zh-CN/maidairy_replyer.prompt +++ b/prompts/zh-CN/maidairy_replyer.prompt @@ -1,7 +1,10 @@ -你的名字是麦麦,也有人叫你牢麦,麦叠,你是一个大二在读女大学生,现在正在上网和群友聊天,有点攻击性; -你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复, -尽量简短一些。 -请注意把握聊天内容,不要回复的太有条理。 -你的风格平淡但不失讽刺,不过分兴奋,很简短。可以参考贴吧,知乎和微博的回复风格。很平淡和白话,不浮夸不长篇大论,b站评论风格,但一定注意不要过分修辞和复杂句。 -请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 -最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。 +你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片 +其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分: + +{time_block} + +{identity} +你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复, +尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。 +{reply_style} +请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 \ No newline at end of file diff --git a/prompts/zh-CN/maidairy_timing.prompt b/prompts/zh-CN/maidairy_timing.prompt deleted file mode 100644 index e7b785be..00000000 --- a/prompts/zh-CN/maidairy_timing.prompt +++ /dev/null @@ -1,22 +0,0 @@ -你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析: - -【时间感知分析】 -1. 对话持续时长:当前对话已经进行了多久 -2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何 -3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适 -4. 时间相关洞察: - - 用户是否可能正在忙(回复变慢) - - 用户是否正在积极对话(回复很快) - - 当前时段(深夜/早晨/工作时间等)是否适合继续聊 - - 对话是否已经持续太久,用户可能需要休息 - - 是否应该主动结束对话 - -【自我反思分析】 -1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论 -2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论 -3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断 - -要求: -- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半 -- 重点关注对话节奏的变化趋势和助手自身的人设一致性 -- 直接输出分析结果,不要有格式标题或分段标记 diff --git a/prompts/zh-CN/private_replyer_self.prompt b/prompts/zh-CN/private_replyer_self.prompt deleted file mode 100644 index f58136ef..00000000 --- a/prompts/zh-CN/private_replyer_self.prompt +++ /dev/null @@ -1,14 +0,0 @@ -{knowledge_prompt}{tool_info_block}{extra_info_block} -{expression_habits_block}{memory_retrieval}{jargon_explanation} - -你正在和{sender_name}聊天,这是你们之前聊的内容: -{time_block} -{dialogue_prompt} - -你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason} -请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。 -{identity} -{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。 -{reply_style} -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 -{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。 \ No newline at end of file diff --git a/prompts/zh-CN/replyer_light.prompt b/prompts/zh-CN/replyer_light.prompt deleted file mode 100644 index 8e3a425a..00000000 --- a/prompts/zh-CN/replyer_light.prompt +++ /dev/null @@ -1,18 +0,0 @@ -{knowledge_prompt}{tool_info_block}{extra_info_block} -{expression_habits_block}{memory_retrieval}{jargon_explanation} - -你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片 -其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分: -{time_block} -{dialogue_prompt} - -{reply_target_block}。 -{planner_reasoning} -{identity} -{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复, -尽量简短一些。{keywords_reaction_prompt} -请注意把握聊天内容,不要回复的太有条理。 -{reply_style} -请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 -最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。 -现在,你说: \ No newline at end of file diff --git a/prompts/zh-CN/tool_executor.prompt b/prompts/zh-CN/tool_executor.prompt deleted file mode 100644 index 23f2b043..00000000 --- a/prompts/zh-CN/tool_executor.prompt +++ /dev/null @@ -1,11 +0,0 @@ -你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 -群里正在进行的聊天内容: -{chat_history} - -现在,{sender}发送了内容:{target_message},你想要回复ta。 -请仔细分析聊天内容,考虑以下几点: -1. 内容中是否包含需要查询信息的问题 -2. 是否有明确的工具使用指令 -你可以选择多个动作 - -If you need to use tools, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed". \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fa006ff0..22db5201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ "jieba>=0.42.1", "json-repair>=0.47.6", "maim-message>=0.6.2", - "maibot-plugin-sdk>=2.0.0", + "maibot-plugin-sdk>=2.1.0", + "mcp", "msgpack>=1.1.2", "numpy>=2.2.6", "openai>=1.95.0", diff --git a/pytests/common_test/test_database_migration_foundation.py b/pytests/common_test/test_database_migration_foundation.py new file mode 100644 index 00000000..ffec2b6a --- /dev/null +++ b/pytests/common_test/test_database_migration_foundation.py @@ -0,0 +1,924 @@ +"""数据库迁移基础设施测试。""" + +from pathlib import Path +from typing import List, Optional, Tuple + +from sqlalchemy import text +from sqlalchemy.engine import Connection, Engine +from sqlmodel import SQLModel, create_engine + +import json +import msgpack +import pytest + +from src.common.database import database as database_module +from src.common.database.migrations import ( + BaseSchemaVersionDetector, + BaseMigrationProgressReporter, + DatabaseSchemaSnapshot, + DatabaseMigrationBootstrapper, + DatabaseMigrationState, + DatabaseMigrationManager, + EMPTY_SCHEMA_VERSION, + LATEST_SCHEMA_VERSION, + LEGACY_V1_SCHEMA_VERSION, + MigrationExecutionContext, + MigrationPlan, + MigrationRegistry, + MigrationStep, + ResolvedSchemaVersion, + SchemaVersionResolver, + SchemaVersionSource, + SQLiteSchemaInspector, + SQLiteUserVersionStore, + build_default_migration_registry, + build_default_schema_version_resolver, + create_database_migration_bootstrapper, +) + + +class FixedVersionDetector(BaseSchemaVersionDetector): + """测试用固定版本探测器。""" + + @property + def name(self) -> str: + """返回测试探测器名称。 + + Returns: + str: 探测器名称。 + """ + return "fixed_version_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """根据测试表是否存在返回固定版本。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。 + """ + if snapshot.has_table("legacy_records"): + return 2 + return None + + +class FakeMigrationProgressReporter(BaseMigrationProgressReporter): + """测试用迁移进度上报器。""" + + def __init__(self) -> None: + """初始化测试用进度上报器。""" + self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = [] + + def open(self) -> None: + """记录打开事件。""" + self.events.append(("open", None, None, None)) + + def close(self) -> None: + """记录关闭事件。""" + self.events.append(("close", None, None, None)) + + def start( + self, + total_records: int, + total_tables: int, + description: str = "总迁移进度", + table_unit_name: str = "表", + record_unit_name: str = "记录", + ) -> None: + """记录启动事件。 + + Args: + total_records: 任务记录总数。 + total_tables: 任务表总数。 + description: 任务描述。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 + """ + del table_unit_name, record_unit_name + self.events.append(("start", total_records, total_tables, description)) + + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: + """记录推进事件。 + + Args: + records: 推进的记录数。 + completed_tables: 已完成的表数。 + item_name: 当前完成的项目名称。 + """ + self.events.append(("advance", records, completed_tables, item_name)) + + +def _create_sqlite_engine(database_file: Path) -> Engine: + """创建测试用 SQLite 引擎。 + + Args: + database_file: 测试数据库文件路径。 + + Returns: + Engine: SQLite 引擎实例。 + """ + return create_engine( + f"sqlite:///{database_file}", + echo=False, + connect_args={"check_same_thread": False}, + ) + + +def _create_current_schema(connection: Connection) -> None: + """创建当前最新版本的数据库结构。 + + Args: + connection: 当前数据库连接。 + """ + import src.common.database.database_model # noqa: F401 + + SQLModel.metadata.create_all(connection) + + +def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None: + """创建带示例数据的旧版 ``0.x`` 数据库结构。 + + Args: + connection: 当前数据库连接。 + """ + connection.execute( + text( + """ + CREATE TABLE chat_streams ( + id INTEGER PRIMARY KEY, + stream_id TEXT NOT NULL, + create_time REAL NOT NULL, + last_active_time REAL NOT NULL, + platform TEXT NOT NULL, + user_id TEXT, + group_id TEXT, + group_name TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + message_id TEXT NOT NULL, + time REAL NOT NULL, + chat_id TEXT NOT NULL, + chat_info_platform TEXT, + user_id TEXT, + user_nickname TEXT, + chat_info_group_id TEXT, + chat_info_group_name TEXT, + is_mentioned INTEGER, + is_at INTEGER, + processed_plain_text TEXT, + display_message TEXT, + is_emoji INTEGER, + is_picid INTEGER, + is_command INTEGER, + is_notify INTEGER, + additional_config TEXT, + priority_mode TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE action_records ( + id INTEGER PRIMARY KEY, + action_id TEXT NOT NULL, + time REAL NOT NULL, + action_reasoning TEXT, + action_name TEXT NOT NULL, + action_data TEXT, + action_prompt_display TEXT, + chat_id TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE expression ( + id INTEGER PRIMARY KEY, + situation TEXT NOT NULL, + style TEXT NOT NULL, + content_list TEXT, + count INTEGER, + last_active_time REAL NOT NULL, + chat_id TEXT, + create_date REAL, + checked INTEGER, + rejected INTEGER, + modified_by TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE jargon ( + id INTEGER PRIMARY KEY, + content TEXT NOT NULL, + raw_content TEXT, + meaning TEXT, + chat_id TEXT, + is_global INTEGER, + count INTEGER, + is_jargon INTEGER, + last_inference_count INTEGER, + is_complete INTEGER, + inference_with_context TEXT, + inference_content_only TEXT + ) + """ + ) + ) + + connection.execute( + text( + """ + INSERT INTO chat_streams ( + id, + stream_id, + create_time, + last_active_time, + platform, + user_id, + group_id, + group_name + ) VALUES ( + 1, + 'session-1', + 1710000000.0, + 1710000300.0, + 'qq', + 'user-1', + 'group-1', + '测试群' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO messages ( + id, + message_id, + time, + chat_id, + chat_info_platform, + user_id, + user_nickname, + chat_info_group_id, + chat_info_group_name, + is_mentioned, + is_at, + processed_plain_text, + display_message, + is_emoji, + is_picid, + is_command, + is_notify, + additional_config, + priority_mode + ) VALUES ( + 1, + 'msg-1', + 1710000010.0, + 'session-1', + 'qq', + 'user-1', + '测试用户', + 'group-1', + '测试群', + 1, + 0, + '你好', + '你好呀', + 0, + 1, + 0, + 1, + '{"source":"legacy"}', + 'high' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO action_records ( + id, + action_id, + time, + action_reasoning, + action_name, + action_data, + action_prompt_display, + chat_id + ) VALUES ( + 1, + 'action-1', + 1710000020.0, + '需要调用工具', + 'search', + '{"query":"MaiBot"}', + '执行搜索', + 'session-1' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO expression ( + id, + situation, + style, + content_list, + count, + last_active_time, + chat_id, + create_date, + checked, + rejected, + modified_by + ) VALUES ( + 1, + '打招呼', + '可爱', + '["你好呀","早上好"]', + 3, + 1710000030.0, + 'session-1', + 1710000040.0, + 1, + 0, + 'ai' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO jargon ( + id, + content, + raw_content, + meaning, + chat_id, + is_global, + count, + is_jargon, + last_inference_count, + is_complete, + inference_with_context, + inference_content_only + ) VALUES ( + 1, + '上分', + '["上分"]', + '提高排名', + 'session-1', + 0, + 5, + 1, + 2, + 1, + '{"guess":"context"}', + '{"guess":"content"}' + ) + """ + ) + ) + + +def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None: + """应支持读取与写入 SQLite ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "version_store.db") + version_store = SQLiteUserVersionStore() + + with engine.begin() as connection: + assert version_store.read_version(connection) == 0 + version_store.write_version(connection, 7) + + with engine.connect() as connection: + assert version_store.read_version(connection) == 7 + + +def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None: + """应能提取 SQLite 数据库的表与列结构。""" + engine = _create_sqlite_engine(tmp_path / "schema_inspector.db") + inspector = SQLiteSchemaInspector() + + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE legacy_records ( + id INTEGER PRIMARY KEY, + payload TEXT NOT NULL, + created_at TEXT + ) + """ + ) + ) + + with engine.connect() as connection: + snapshot = inspector.inspect(connection) + + assert snapshot.has_table("legacy_records") + assert snapshot.has_column("legacy_records", "payload") + assert not snapshot.has_column("legacy_records", "missing_column") + table_schema = snapshot.get_table("legacy_records") + + assert table_schema is not None + assert table_schema.column_names() == ["created_at", "id", "payload"] + + +def test_resolver_can_identify_empty_database(tmp_path: Path) -> None: + """空数据库应被解析为版本 ``0``。""" + engine = _create_sqlite_engine(tmp_path / "empty_resolver.db") + resolver = SchemaVersionResolver() + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == 0 + assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE + assert resolved_version.snapshot is not None + assert resolved_version.snapshot.is_empty() + + +def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None: + """未写入 ``user_version`` 的历史库应支持通过探测器识别版本。""" + engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db") + resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()]) + + with engine.begin() as connection: + connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)")) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == 2 + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "fixed_version_detector" + + +def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None: + """迁移编排器应能按顺序执行已注册步骤并更新版本号。""" + engine = _create_sqlite_engine(tmp_path / "manager.db") + executed_steps: List[str] = [] + + def migrate_0_to_1(context: MigrationExecutionContext) -> None: + """测试迁移步骤 0 -> 1。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1") + context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")) + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试迁移步骤 1 -> 2。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2") + context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT")) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=0, + version_to=1, + name="create_sample_records", + description="创建示例表。", + handler=migrate_0_to_1, + ), + MigrationStep( + version_from=1, + version_to=2, + name="add_sample_email", + description="为示例表增加邮箱字段。", + handler=migrate_1_to_2, + ), + ] + ) + manager = DatabaseMigrationManager(engine=engine, registry=registry) + + migration_plan = manager.migrate() + + assert migration_plan.step_count() == 2 + assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"] + + with engine.connect() as connection: + version_store = SQLiteUserVersionStore() + snapshot = SQLiteSchemaInspector().inspect(connection) + recorded_version = version_store.read_version(connection) + + assert recorded_version == 2 + assert snapshot.has_table("sample_records") + assert snapshot.has_column("sample_records", "email") + + +def test_manager_can_report_step_progress(tmp_path: Path) -> None: + """迁移编排器应支持通过上下文上报步骤进度。""" + engine = _create_sqlite_engine(tmp_path / "manager_progress.db") + reporter_instances: List[FakeMigrationProgressReporter] = [] + + def _build_reporter() -> BaseMigrationProgressReporter: + """构建测试用进度上报器。 + + Returns: + BaseMigrationProgressReporter: 测试用进度上报器实例。 + """ + reporter = FakeMigrationProgressReporter() + reporter_instances.append(reporter) + return reporter + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试迁移步骤 ``1 -> 2`` 的进度上报。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + context.start_progress(total_tables=3, total_records=30, description="总迁移进度") + context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions") + context.advance_progress(records=10, completed_tables=1, item_name="mai_messages") + context.advance_progress(records=10, completed_tables=1, item_name="tool_records") + context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")) + + with engine.begin() as connection: + SQLiteUserVersionStore().write_version(connection, 1) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=1, + version_to=2, + name="progress_step", + description="测试进度上报。", + handler=migrate_1_to_2, + ) + ] + ) + manager = DatabaseMigrationManager( + engine=engine, + registry=registry, + progress_reporter_factory=_build_reporter, + ) + + migration_plan = manager.migrate() + + assert migration_plan.step_count() == 1 + assert len(reporter_instances) == 1 + assert reporter_instances[0].events == [ + ("open", None, None, None), + ("start", 30, 3, "总迁移进度"), + ("advance", 10, 1, "chat_sessions"), + ("advance", 10, 1, "mai_messages"), + ("advance", 10, 1, "tool_records"), + ("close", None, None, None), + ] + + +def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None: + """默认解析器应能识别未写入版本号的最新结构数据库。""" + engine = _create_sqlite_engine(tmp_path / "latest_resolver.db") + resolver = build_default_schema_version_resolver() + + with engine.begin() as connection: + _create_current_schema(connection) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == LATEST_SCHEMA_VERSION + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "latest_schema_detector" + + +def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None: + """默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。""" + engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db") + resolver = build_default_schema_version_resolver() + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "legacy_v1_schema_detector" + + +def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None: + """已是最新结构但未写版本号的数据库应直接补写 ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "latest_finalize.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + with engine.begin() as connection: + _create_current_schema(connection) + + migration_state = bootstrapper.prepare_database() + bootstrapper.finalize_database(migration_state) + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION + assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == LATEST_SCHEMA_VERSION + + +def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None: + """空库在建表完成后应回写最新 ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + migration_state = bootstrapper.prepare_database() + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION + assert migration_state.target_version == LATEST_SCHEMA_VERSION + + with engine.begin() as connection: + _create_current_schema(connection) + + bootstrapper.finalize_database(migration_state) + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == LATEST_SCHEMA_VERSION + + +def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None: + """启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。""" + engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db") + execution_marks: List[str] = [] + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试桥接器迁移步骤 ``1 -> 2``。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + execution_marks.append(f"step={context.step_name},index={context.step_index}") + context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT")) + + with engine.begin() as connection: + connection.execute( + text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)") + ) + SQLiteUserVersionStore().write_version(connection, 1) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=1, + version_to=2, + name="bootstrap_add_email", + description="为桥接器测试表增加邮箱字段。", + handler=migrate_1_to_2, + ) + ] + ) + bootstrapper = DatabaseMigrationBootstrapper( + manager=DatabaseMigrationManager(engine=engine, registry=registry), + latest_schema_version=2, + ) + + migration_state = bootstrapper.prepare_database() + + assert migration_state.resolved_version.version == 2 + assert migration_state.target_version == 2 + assert execution_marks == ["step=bootstrap_add_email,index=1"] + + with engine.connect() as connection: + snapshot = SQLiteSchemaInspector().inspect(connection) + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == 2 + assert snapshot.has_table("bootstrap_records") + assert snapshot.has_column("bootstrap_records", "email") + + +def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None: + """默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。""" + engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + migration_state = bootstrapper.prepare_database() + bootstrapper.finalize_database(migration_state) + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION + assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + snapshot = SQLiteSchemaInspector().inspect(connection) + message_row = connection.execute( + text( + """ + SELECT session_id, processed_plain_text, additional_config, raw_content + FROM mai_messages + WHERE message_id = 'msg-1' + """ + ) + ).mappings().one() + action_row = connection.execute( + text( + """ + SELECT session_id, action_name, action_display_prompt + FROM action_records + WHERE action_id = 'action-1' + """ + ) + ).mappings().one() + tool_row = connection.execute( + text( + """ + SELECT session_id, tool_name, tool_display_prompt + FROM tool_records + WHERE tool_id = 'action-1' + """ + ) + ).mappings().one() + expression_row = connection.execute( + text( + """ + SELECT session_id, content_list, modified_by + FROM expressions + WHERE id = 1 + """ + ) + ).mappings().one() + jargon_row = connection.execute( + text( + """ + SELECT session_id_dict, raw_content, inference_with_content_only + FROM jargons + WHERE id = 1 + """ + ) + ).mappings().one() + + assert recorded_version == LATEST_SCHEMA_VERSION + assert snapshot.has_table("__legacy_v1_messages") + assert snapshot.has_table("chat_sessions") + assert snapshot.has_table("mai_messages") + assert snapshot.has_table("tool_records") + + unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False) + additional_config = json.loads(message_row["additional_config"]) + expression_content_list = json.loads(expression_row["content_list"]) + jargon_session_id_dict = json.loads(jargon_row["session_id_dict"]) + jargon_raw_content = json.loads(jargon_row["raw_content"]) + + assert message_row["session_id"] == "session-1" + assert message_row["processed_plain_text"] == "你好" + assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}] + assert additional_config == {"priority_mode": "high", "source": "legacy"} + assert action_row["session_id"] == "session-1" + assert action_row["action_name"] == "search" + assert action_row["action_display_prompt"] == "执行搜索" + assert tool_row["session_id"] == "session-1" + assert tool_row["tool_name"] == "search" + assert tool_row["tool_display_prompt"] == "执行搜索" + assert expression_row["session_id"] == "session-1" + assert expression_row["modified_by"] == "AI" + assert expression_content_list == ["你好呀", "早上好"] + assert jargon_session_id_dict == {"session-1": 5} + assert jargon_raw_content == ["上分"] + assert jargon_row["inference_with_content_only"] == '{"guess":"content"}' + + +def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None: + """旧版迁移步骤应按目标表数量推进总进度。""" + engine = _create_sqlite_engine(tmp_path / "legacy_progress.db") + reporter_instances: List[FakeMigrationProgressReporter] = [] + + def _build_reporter() -> BaseMigrationProgressReporter: + """构建测试用进度上报器。 + + Returns: + BaseMigrationProgressReporter: 测试用进度上报器实例。 + """ + reporter = FakeMigrationProgressReporter() + reporter_instances.append(reporter) + return reporter + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + manager = DatabaseMigrationManager( + engine=engine, + registry=build_default_migration_registry(), + resolver=build_default_schema_version_resolver(), + progress_reporter_factory=_build_reporter, + ) + + migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION) + + assert migration_plan.step_count() == 1 + assert len(reporter_instances) == 1 + reporter_events = reporter_instances[0].events + + assert reporter_events[0] == ("open", None, None, None) + assert reporter_events[1] == ("start", 6, 12, "总迁移进度") + assert reporter_events[-1] == ("close", None, None, None) + assert reporter_events.count(("advance", 1, 0, None)) == 6 + assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1 + assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1 + assert len([event for event in reporter_events if event[0] == "advance"]) == 18 + + +def test_initialize_database_calls_bootstrapper_before_create_all( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """数据库初始化入口应先准备迁移,再建表、补迁移并收尾。""" + call_order: List[str] = [] + + def _fake_prepare_database() -> DatabaseMigrationState: + """返回测试用迁移状态。 + + Returns: + DatabaseMigrationState: 不包含迁移步骤的测试状态。 + """ + call_order.append("prepare_database") + return DatabaseMigrationState( + resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE), + target_version=LATEST_SCHEMA_VERSION, + plan=MigrationPlan( + current_version=EMPTY_SCHEMA_VERSION, + target_version=LATEST_SCHEMA_VERSION, + steps=[], + ), + ) + + def _fake_create_all(bind) -> None: + """记录建表调用。 + + Args: + bind: 传入的数据库绑定对象。 + """ + del bind + call_order.append("create_all") + + def _fake_migrate_action_records() -> None: + """记录轻量补迁移调用。""" + call_order.append("migrate_action_records") + + def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None: + """记录迁移收尾调用。 + + Args: + migration_state: 当前数据库迁移状态。 + """ + del migration_state + call_order.append("finalize_database") + + monkeypatch.setattr(database_module, "_db_initialized", False) + monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data") + monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database) + monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database) + monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all) + monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records) + + database_module.initialize_database() + + assert call_order == [ + "prepare_database", + "create_all", + "migrate_action_records", + "finalize_database", + ] diff --git a/pytests/test_maisaka_message_adapter.py b/pytests/test_maisaka_message_adapter.py new file mode 100644 index 00000000..d872253c --- /dev/null +++ b/pytests/test_maisaka_message_adapter.py @@ -0,0 +1,55 @@ +from datetime import datetime +from pathlib import Path + +import sys + +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.llm_models.payload_content.tool_option import ToolCall +from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def test_build_message_returns_session_message_with_maisaka_metadata() -> None: + timestamp = datetime.now() + tool_call = ToolCall( + call_id="call-1", + func_name="reply", + args={"message_id": "msg-1"}, + ) + raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")]) + + message = build_message( + role="assistant", + content="展示消息内容", + message_kind="perception", + source="assistant", + tool_call_id="call-1", + tool_calls=[tool_call], + timestamp=timestamp, + message_id="maisaka-msg-1", + raw_message=raw_message, + display_text="展示消息内容", + ) + + assert isinstance(message, SessionMessage) + assert message.initialized is True + assert message.message_id == "maisaka-msg-1" + assert message.timestamp == timestamp + assert message.processed_plain_text == "展示消息内容" + assert message.display_message == "展示消息内容" + assert message.raw_message is raw_message + + assert get_message_role(message) == "assistant" + assert get_message_kind(message) == "perception" + assert get_tool_call_id(message) == "call-1" + + restored_tool_calls = get_tool_calls(message) + assert len(restored_tool_calls) == 1 + assert restored_tool_calls[0].call_id == "call-1" + assert restored_tool_calls[0].func_name == "reply" + assert restored_tool_calls[0].args == {"message_id": "msg-1"} diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index e3247f05..5c9f39b0 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -5,6 +5,7 @@ from pathlib import Path from types import SimpleNamespace +from typing import Any, Awaitable, Callable, Dict, List, Optional import asyncio import json @@ -1831,395 +1832,445 @@ class TestMaiMessages: assert msg.llm_response_content == "new response" -# ─── WorkflowExecutor 测试 ──────────────────────────────── +class _FakeHookSupervisor: + """用于 Hook 分发测试的简化 Supervisor。""" + + def __init__( + self, + group_name: str, + component_registry: Any, + handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]], + call_log: List[tuple[str, str]], + ) -> None: + """初始化测试用 Supervisor。 + + Args: + group_name: 运行时分组名称。 + component_registry: 组件注册表实例。 + handlers: 处理器映射,键为 `plugin_id.component_name`。 + call_log: 记录调用顺序的列表。 + """ + + self._group_name = group_name + self.component_registry = component_registry + self._handlers = handlers + self._call_log = call_log + + @property + def group_name(self) -> str: + """返回当前测试 Supervisor 的分组名称。""" + + return self._group_name + + async def invoke_plugin( + self, + method: str, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> SimpleNamespace: + """模拟调用插件组件。 + + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + component_name: 目标组件名称。 + args: 调用参数。 + timeout_ms: 超时配置,测试中仅用于保持接口一致。 + + Returns: + SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。 + """ + + del method + del timeout_ms + + full_name = f"{plugin_id}.{component_name}" + handler = self._handlers[full_name] + self._call_log.append((plugin_id, component_name)) + result = handler(dict(args or {})) + if asyncio.iscoroutine(result): + result = await result + return SimpleNamespace(payload=result) -class TestWorkflowExecutor: - """Host-side Workflow 执行器测试(新 pipeline 模型)""" +# ─── HookDispatcher 测试 ──────────────────────────────── + + +class TestHookDispatcher: + """命名 Hook 分发器测试。""" + + @staticmethod + def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: + """导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。 + + Args: + monkeypatch: pytest 的 monkeypatch 工具。 + + Returns: + tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。 + """ + + monkeypatch.setattr(sys, "exit", lambda code=0: None) + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.hook_dispatcher import HookDispatcher + + return ComponentRegistry, HookDispatcher @pytest.mark.asyncio - async def test_empty_pipeline_completes(self): - """无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: + """未注册处理器时应直接返回原始参数。""" - reg = ComponentRegistry() - executor = WorkflowExecutor(reg) + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - async def mock_invoke(plugin_id, comp_name, args): - return {"hook_result": "continue"} + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, []) - result, final_msg, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "test"}, - ) - assert result.status == "completed" - assert result.return_message == "workflow completed" - assert len(ctx.timings) == 6 # 6 stages + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.hook_name == "heart_fc.cycle_start" + assert result.kwargs == {"session_id": "s-1"} + assert result.aborted is False @pytest.mark.asyncio - async def test_blocking_hook_modifies_message(self): - """blocking hook 可以修改消息""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: + """blocking 处理器可以修改参数。""" - reg = ComponentRegistry() - reg.register_component( + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( "upper", - "workflow_step", + "HOOK_HANDLER", "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", }, ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - msg = args.get("message", {}) - return { - "hook_result": "continue", - "modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()}, - } - - result, final_msg, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "hello"}, + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + { + "p1.upper": lambda args: { + "success": True, + "action": "continue", + "modified_kwargs": { + "session_id": args["session_id"], + "text": str(args["text"]).upper(), + }, + } + }, + [], ) - assert result.status == "completed" - assert final_msg["plain_text"] == "HELLO" - assert len(ctx.modification_log) == 1 - assert ctx.modification_log[0].stage == "pre_process" + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello") + + assert result.kwargs["session_id"] == "s-1" + assert result.kwargs["text"] == "HELLO" + assert result.aborted is False @pytest.mark.asyncio - async def test_abort_stops_pipeline(self): - """HookResult.ABORT 立即终止 pipeline""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None: + """blocking 处理器的 abort 应阻止后续 blocking 处理器执行。""" - reg = ComponentRegistry() - reg.register_component( - "blocker", - "workflow_step", + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "stopper", + "HOOK_HANDLER", "p1", - { - "stage": "pre_process", - "priority": 10, - "blocking": True, - }, + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - return {"hook_result": "abort"} - - result, _, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "test"}, - ) - assert result.status == "aborted" - assert result.stopped_at == "pre_process" - - @pytest.mark.asyncio - async def test_skip_stage(self): - """HookResult.SKIP_STAGE 跳过当前阶段剩余 hook""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - # high-priority hook 返回 skip_stage - reg.register_component( - "skipper", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 100, - "blocking": True, - }, - ) - # low-priority hook 不应被执行 - reg.register_component( - "checker", - "workflow_step", + registry.register_component( + "after_stop", + "HOOK_HANDLER", "p2", - { - "stage": "ingress", - "priority": 1, - "blocking": True, - }, + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, + ) + call_log: List[tuple[str, str]] = [] + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + { + "p1.stopper": lambda args: {"success": True, "action": "abort"}, + "p2.after_stop": lambda args: {"success": True, "action": "continue"}, + }, + call_log, ) - executor = WorkflowExecutor(reg) - call_log = [] + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1") - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - if comp_name == "skipper": - return {"hook_result": "skip_stage"} - return {"hook_result": "continue"} - - result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "completed" - # 只有 skipper 被调用,checker 被跳过 - assert call_log == ["skipper"] + assert result.aborted is True + assert result.stopped_by == "p1.stopper" + assert call_log == [("p1", "stopper")] @pytest.mark.asyncio - async def test_pre_filter(self): - """filter 条件不匹配时跳过 hook""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_observe_handler_runs_in_background_without_mutation( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """observe 处理器应后台执行且不能影响主流程参数。""" - reg = ComponentRegistry() - reg.register_component( - "only_dm", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 10, - "blocking": True, - "filter": {"chat_type": "direct"}, - }, - ) - executor = WorkflowExecutor(reg) + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - return {"hook_result": "continue"} - - # 不匹配 filter —— hook 不应被调用 - await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"}) - assert not call_log - - # 匹配 filter —— hook 应被调用 - await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}) - assert call_log == ["only_dm"] - - @pytest.mark.asyncio - async def test_error_policy_skip(self): - """error_policy=skip 时跳过失败的 hook 继续执行""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( - "failer", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 100, - "blocking": True, - "error_policy": "skip", - }, - ) - reg.register_component( - "ok_step", - "workflow_step", - "p2", - { - "stage": "ingress", - "priority": 1, - "blocking": True, - }, - ) - executor = WorkflowExecutor(reg) - - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - if comp_name == "failer": - raise RuntimeError("boom") - return {"hook_result": "continue"} - - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "completed" - assert "failer" in call_log - assert "ok_step" in call_log - assert any("boom" in e for e in ctx.errors) - - @pytest.mark.asyncio - async def test_error_policy_abort(self): - """error_policy=abort(默认)时 pipeline 失败""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( - "failer", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 10, - "blocking": True, - # error_policy defaults to "abort" - }, - ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - raise RuntimeError("fatal") - - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "failed" - assert result.stopped_at == "ingress" - - @pytest.mark.asyncio - async def test_nonblocking_hooks_concurrent(self): - """non-blocking hook 并发执行,不修改消息""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - for i in range(3): - reg.register_component( - f"nb_{i}", - "workflow_step", - f"p{i}", - { - "stage": "post_process", - "priority": 0, - "blocking": False, - }, - ) - executor = WorkflowExecutor(reg) - - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} - - result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) - # non-blocking 的 modified_message 被忽略 - assert final_msg["plain_text"] == "original" - # 给异步 task 时间完成 - await asyncio.sleep(0.1) - assert result.status == "completed" - - @pytest.mark.asyncio - async def test_nonblocking_tasks_are_retained_until_completion(self): - """execute 返回后,non-blocking task 仍应保持强引用直到执行完成。""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( + registry = ComponentRegistry() + registry.register_component( "observer", - "workflow_step", + "HOOK_HANDLER", "p1", - { - "stage": "post_process", - "priority": 0, - "blocking": False, - }, + {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, ) - executor = WorkflowExecutor(reg) - started = asyncio.Event() release = asyncio.Event() + call_log: List[tuple[str, str]] = [] + + async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """模拟耗时观察型处理器。""" - async def mock_invoke(plugin_id, comp_name, args): started.set() await release.wait() - return {"hook_result": "continue"} + return { + "success": True, + "action": "abort", + "modified_kwargs": {"session_id": "changed"}, + "custom_result": args["session_id"], + } - result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + {"p1.observer": observe_handler}, + call_log, + ) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") await asyncio.sleep(0) - assert result.status == "completed" - assert final_msg["plain_text"] == "original" + assert result.aborted is False + assert result.kwargs["session_id"] == "s-1" assert started.is_set() - assert len(executor._background_tasks) == 1 + assert len(dispatcher._background_tasks) == 1 release.set() await asyncio.sleep(0) await asyncio.sleep(0) - assert not executor._background_tasks + assert call_log == [("p1", "observer")] + assert not dispatcher._background_tasks @pytest.mark.asyncio - async def test_command_routing(self): - """PLAN 阶段内置命令路由""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None: + """全局排序应先看 order,再看内置/第三方来源。""" - reg = ComponentRegistry() - reg.register_component( - "help", - "command", - "p1", - { - "command_pattern": r"^/help", - }, + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + builtin_registry = ComponentRegistry() + third_registry = ComponentRegistry() + builtin_registry.register_component( + "builtin_early", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + builtin_registry.register_component( + "builtin_normal", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, + ) + third_registry.register_component( + "third_early", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + third_registry.register_component( + "third_normal", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, ) - executor = WorkflowExecutor(reg) - async def mock_invoke(plugin_id, comp_name, args): - if comp_name == "help": - return {"output": "帮助信息"} - return {"hook_result": "continue"} + call_log: List[tuple[str, str]] = [] + dispatcher = HookDispatcher() + builtin_supervisor = _FakeHookSupervisor( + "builtin", + builtin_registry, + { + "b1.builtin_early": lambda args: {"success": True, "action": "continue"}, + "b1.builtin_normal": lambda args: {"success": True, "action": "continue"}, + }, + call_log, + ) + third_supervisor = _FakeHookSupervisor( + "third_party", + third_registry, + { + "t1.third_early": lambda args: {"success": True, "action": "continue"}, + "t1.third_normal": lambda args: {"success": True, "action": "continue"}, + }, + call_log, + ) - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"}) - assert result.status == "completed" - assert ctx.matched_command == "p1.help" - cmd_result = ctx.get_stage_output("plan", "command_result") - assert cmd_result is not None - assert cmd_result["output"] == "帮助信息" + await dispatcher.invoke_hook( + "heart_fc.cycle_start", + [third_supervisor, builtin_supervisor], + cycle_id="c-1", + ) + + assert call_log == [ + ("b1", "builtin_early"), + ("t1", "third_early"), + ("b1", "builtin_normal"), + ("t1", "third_normal"), + ] @pytest.mark.asyncio - async def test_stage_outputs(self): - """stage_outputs 数据在阶段间传递""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None: + """error_policy=abort 时应中止本次 Hook 调用。""" - reg = ComponentRegistry() - # ingress 阶段写入数据 - reg.register_component( - "writer", - "workflow_step", + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "failer", + "HOOK_HANDLER", "p1", { - "stage": "ingress", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", + "error_policy": "abort", }, ) - # pre_process 阶段读取数据 - reg.register_component( - "reader", - "workflow_step", - "p2", + call_log: List[tuple[str, str]] = [] + + async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """抛出异常以触发 abort 策略。""" + + del args + raise RuntimeError("boom") + + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.aborted is True + assert result.stopped_by == "p1.failer" + assert any("boom" in error for error in result.errors) + assert call_log == [("p1", "failer")] + + @pytest.mark.asyncio + async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None: + """处理器超时应被记录为错误并继续。""" + + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "slow", + "HOOK_HANDLER", + "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", + "timeout_ms": 10, }, ) - executor = WorkflowExecutor(reg) + call_log: List[tuple[str, str]] = [] - async def mock_invoke(plugin_id, comp_name, args): - if comp_name == "writer": - return { - "hook_result": "continue", - "stage_output": {"parsed_intent": "greeting"}, - } - if comp_name == "reader": - # 验证 stage_outputs 被传递过来 - outputs = args.get("stage_outputs", {}) - ingress_data = outputs.get("ingress", {}) - assert ingress_data.get("parsed_intent") == "greeting" - return {"hook_result": "continue"} - return {"hook_result": "continue"} + async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """模拟超时处理器。""" - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"}) - assert result.status == "completed" - assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" + del args + await asyncio.sleep(0.05) + return {"success": True, "action": "continue"} + + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.aborted is False + assert any("超时" in error for error in result.errors) + assert call_log == [("p1", "slow")] + + +class TestPluginRuntimeHookEntry: + """PluginRuntimeManager 命名 Hook 入口测试。""" + + @staticmethod + def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: + """导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。 + + Args: + monkeypatch: pytest 的 monkeypatch 工具。 + + Returns: + tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。 + """ + + monkeypatch.setattr(sys, "exit", lambda code=0: None) + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.integration import PluginRuntimeManager + + return ComponentRegistry, PluginRuntimeManager + + @pytest.mark.asyncio + async def test_manager_invoke_hook_dispatches_across_supervisors( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。""" + + ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) + + builtin_registry = ComponentRegistry() + builtin_registry.register_component( + "builtin_guard", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + third_registry = ComponentRegistry() + third_registry.register_component( + "observer", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, + ) + + call_log: List[tuple[str, str]] = [] + manager = PluginRuntimeManager() + manager._started = True + manager._builtin_supervisor = _FakeHookSupervisor( + "builtin", + builtin_registry, + {"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}}, + call_log, + ) + manager._third_party_supervisor = _FakeHookSupervisor( + "third_party", + third_registry, + {"t1.observer": lambda args: {"success": True, "action": "continue"}}, + call_log, + ) + + result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1") + + await asyncio.sleep(0) + assert manager.invoke_dispatcher is manager.hook_dispatcher + assert result.aborted is False + assert result.kwargs["session_id"] == "s-1" + assert ("b1", "builtin_guard") in call_log class TestRPCServer: diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py index 16aad080..44f77090 100644 --- a/pytests/test_send_service.py +++ b/pytests/test_send_service.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List import pytest from src.chat.message_receive.chat_manager import BotChatSession +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.services import send_service @@ -13,42 +14,18 @@ class _FakePlatformIOManager: """用于测试的 Platform IO 管理器假对象。""" def __init__(self, delivery_batch: Any) -> None: - """初始化假 Platform IO 管理器。 - - Args: - delivery_batch: 发送时返回的批量回执。 - """ self._delivery_batch = delivery_batch self.ensure_calls = 0 self.sent_messages: List[Dict[str, Any]] = [] async def ensure_send_pipeline_ready(self) -> None: - """记录发送管线准备调用次数。""" self.ensure_calls += 1 def build_route_key_from_message(self, message: Any) -> Any: - """根据消息构造假的路由键。 - - Args: - message: 待发送的内部消息对象。 - - Returns: - Any: 简化后的路由键对象。 - """ del message return SimpleNamespace(platform="qq") async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: - """记录发送请求并返回预设回执。 - - Args: - message: 待发送的内部消息对象。 - route_key: 本次发送使用的路由键。 - metadata: 发送元数据。 - - Returns: - Any: 预设的批量发送回执。 - """ self.sent_messages.append( { "message": message, @@ -59,12 +36,7 @@ class _FakePlatformIOManager: return self._delivery_batch -def _build_target_stream() -> BotChatSession: - """构造一个最小可用的目标会话对象。 - - Returns: - BotChatSession: 测试用会话对象。 - """ +def _build_private_stream() -> BotChatSession: return BotChatSession( session_id="test-session", platform="qq", @@ -73,14 +45,21 @@ def _build_target_stream() -> BotChatSession: ) +def _build_group_stream() -> BotChatSession: + return BotChatSession( + session_id="group-session", + platform="qq", + user_id="target-user", + group_id="target-group", + ) + + def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( monkeypatch: pytest.MonkeyPatch, ) -> None: - """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。""" - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") - metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream()) + metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream()) assert metadata["platform_io_account_id"] == "bot-qq" assert metadata["platform_io_target_user_id"] == "target-user" @@ -88,7 +67,6 @@ def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( @pytest.mark.asyncio async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: - """send service 应将发送职责统一交给 Platform IO。""" fake_manager = _FakePlatformIOManager( delivery_batch=SimpleNamespace( has_success=True, @@ -104,7 +82,7 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke monkeypatch.setattr( send_service._chat_manager, "get_session_by_session_id", - lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, ) monkeypatch.setattr( send_service.MessageUtils, @@ -123,7 +101,6 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke @pytest.mark.asyncio async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: - """Platform IO 批量发送全部失败时,应直接向上返回失败。""" fake_manager = _FakePlatformIOManager( delivery_batch=SimpleNamespace( has_success=False, @@ -144,7 +121,7 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: monkeypatch.setattr( send_service._chat_manager, "get_session_by_session_id", - lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, ) result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") @@ -152,3 +129,63 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: assert result is False assert fake_manager.ensure_calls == 1 assert len(fake_manager.sent_messages) == 1 + + +@pytest.mark.asyncio +async def test_private_outbound_message_preserves_bot_sender_and_receiver_user( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, + ) + + outbound_message = send_service._build_outbound_session_message( + message_sequence=MessageSequence(components=[TextComponent(text="你好")]), + stream_id="test-session", + display_message="你好", + ) + + assert outbound_message is not None + maim_message = await outbound_message.to_maim_message() + + assert maim_message.message_info.user_info is not None + assert maim_message.message_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.group_info is None + assert maim_message.message_info.sender_info is not None + assert maim_message.message_info.sender_info.user_info is not None + assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.receiver_info is not None + assert maim_message.message_info.receiver_info.user_info is not None + assert maim_message.message_info.receiver_info.user_info.user_id == "target-user" + + +@pytest.mark.asyncio +async def test_group_outbound_message_preserves_bot_sender_and_target_group( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_group_stream() if stream_id == "group-session" else None, + ) + + outbound_message = send_service._build_outbound_session_message( + message_sequence=MessageSequence(components=[TextComponent(text="大家好")]), + stream_id="group-session", + display_message="大家好", + ) + + assert outbound_message is not None + maim_message = await outbound_message.to_maim_message() + + assert maim_message.message_info.user_info is not None + assert maim_message.message_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.group_info is not None + assert maim_message.message_info.group_info.group_id == "target-group" + assert maim_message.message_info.receiver_info is not None + assert maim_message.message_info.receiver_info.group_info is not None + assert maim_message.message_info.receiver_info.group_info.group_id == "target-group" diff --git a/requirements.txt b/requirements.txt index 7f869361..1d0a9b0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ jieba>=0.42.1 json-repair>=0.47.6 maim-message>=0.6.2 maibot-plugin-sdk>=1.2.3,<2.0.0 +mcp msgpack>=1.1.2 numpy>=2.2.6 openai>=1.95.0 @@ -30,4 +31,4 @@ structlog>=25.4.0 tomlkit>=0.13.3 typing-extensions uvicorn>=0.35.0 -watchfiles>=1.1.1 \ No newline at end of file +watchfiles>=1.1.1 diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index bd52536e..ab9d295b 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -189,7 +189,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension elif doc_item: with open_ie_doc_lock: open_ie_doc.append(doc_item) - logger.info('已处理"%s"', doc_item.get("passage", "")) + logger.info(f'已处理"{doc_item.get("passage", "")}"') progress.update(task, advance=1) except KeyboardInterrupt: logger.info("\n接收到中断信号,正在优雅地关闭程序...") diff --git a/scripts/lpmm_manager.py b/scripts/lpmm_manager.py index 2f935c51..868d4b14 100644 --- a/scripts/lpmm_manager.py +++ b/scripts/lpmm_manager.py @@ -110,7 +110,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None: 这里不重复解析子参数,而是直接调用各脚本的 main(), 让子脚本保留原有的交互/参数行为。 """ - logger.info("开始执行操作: %s", action) + logger.info(f"开始执行操作: {action}") extra_args = extra_args or [] @@ -162,14 +162,14 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None: _warn_if_lpmm_disabled() _with_overridden_argv(extra_args, refresh_lpmm_knowledge_main) else: - logger.error("未知操作: %s", action) + logger.error(f"未知操作: {action}") except KeyboardInterrupt: logger.info("用户中断当前操作(Ctrl+C)") except SystemExit: # 子脚本里大量使用 sys.exit,直接透传即可 raise except Exception as exc: # pragma: no cover - 防御性兜底 - logger.error("执行操作 %s 时发生未捕获异常: %s", action, exc) + logger.error(f"执行操作 {action} 时发生未捕获异常: {exc}") raise @@ -442,7 +442,7 @@ def _run_embedding_helper() -> None: try: test_path.rename(archive_path) except Exception as exc: # pragma: no cover - 防御性兜底 - logger.error("归档 embedding_model_test.json 失败: %s", exc) + logger.error(f"归档 embedding_model_test.json 失败: {exc}") print("[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。") return diff --git a/src/chat/brain_chat/PFC/action_planner.py b/src/chat/brain_chat/PFC/action_planner.py deleted file mode 100644 index 94f68585..00000000 --- a/src/chat/brain_chat/PFC/action_planner.py +++ /dev/null @@ -1,499 +0,0 @@ -import time -from typing import Tuple, Optional # 增加了 Optional - -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -import random -from .chat_observer import ChatObserver -from .pfc_utils import get_items_from_json -from src.services.message_service import build_readable_messages - -from .observation_info import ObservationInfo, dict_to_session_message -from .conversation_info import ConversationInfo - - -logger = get_logger("pfc_action_planner") - - -# --- 定义 Prompt 模板 --- - -# Prompt(1): 首次回复或非连续回复时的决策 Prompt -PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方: - -【当前对话目标】 -{goals_str} -{knowledge_info_str} - -【最近行动历史概要】 -{action_history_summary} -【上一次行动的详细情况和结果】 -{last_action_context} -【时间和超时提示】 -{time_since_last_bot_message_info}{timeout_context} -【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息) -{chat_history_text} - ------- -可选行动类型以及解释: -fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择 -listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择 -direct_reply: 直接回复对方 -rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择 -end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择 -block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择 - -请以JSON格式输出你的决策: -{{ - "action": "选择的行动类型 (必须是上面列表中的一个)", - "reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)" -}} - -注意:请严格按照JSON格式输出,不要包含任何其他内容。""" - -# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt -PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方: - -【当前对话目标】 -{goals_str} -{knowledge_info_str} - -【最近行动历史概要】 -{action_history_summary} -【上一次行动的详细情况和结果】 -{last_action_context} -【时间和超时提示】 -{time_since_last_bot_message_info}{timeout_context} -【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息) -{chat_history_text} - ------- -可选行动类型以及解释: -fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择 -wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择) -listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个) -send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言** -rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择 -end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择 -block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择 - -请以JSON格式输出你的决策: -{{ - "action": "选择的行动类型 (必须是上面列表中的一个)", - "reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)" -}} - -注意:请严格按照JSON格式输出,不要包含任何其他内容。""" - -# 新增:Prompt(3): 决定是否在结束对话前发送告别语 -PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。 - -【你们之前的聊天记录】 -{chat_history_text} - -你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。 -如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。 -如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。 - -请以 JSON 格式输出你的选择: -{{ - "say_bye": "yes/no", - "reason": "选择 yes 或 no 的原因和内心想法 (简要说明)" -}} - -注意:请严格按照 JSON 格式输出,不要包含任何其他内容。""" - - -# ActionPlanner 类定义,顶格 -class ActionPlanner: - """行动规划器""" - - def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.planner, - request_type="action_planning", - ) - self.personality_info = self._get_personality_prompt() - self.name = global_config.bot.nickname - self.private_name = private_name - self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - # self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量 - - def _get_personality_prompt(self) -> str: - """获取个性提示信息""" - prompt_personality = global_config.personality.personality - - # 检查是否需要随机替换为状态 - if ( - global_config.personality.states - and global_config.personality.state_probability > 0 - and random.random() < global_config.personality.state_probability - ): - prompt_personality = random.choice(global_config.personality.states) - - bot_name = global_config.bot.nickname - return f"你的名字是{bot_name},你{prompt_personality};" - - # 修改 plan 方法签名,增加 last_successful_reply_action 参数 - async def plan( - self, - observation_info: ObservationInfo, - conversation_info: ConversationInfo, - last_successful_reply_action: Optional[str], - ) -> Tuple[str, str]: - """规划下一步行动 - - Args: - observation_info: 决策信息 - conversation_info: 对话信息 - last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None) - - Returns: - Tuple[str, str]: (行动类型, 行动原因) - """ - # --- 获取 Bot 上次发言时间信息 --- - time_since_last_bot_message_info = "" - try: - bot_id = str(global_config.bot.qq_account) - chat_history = getattr(observation_info, "chat_history", None) - if chat_history and len(chat_history) > 0: - for i in range(len(chat_history) - 1, -1, -1): - msg = chat_history[i] - if not isinstance(msg, dict): - continue - sender_info = msg.get("user_info", {}) - sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None - msg_time = msg.get("time") - if sender_id == bot_id and msg_time: - time_diff = time.time() - msg_time - if time_diff < 60.0: - time_since_last_bot_message_info = ( - f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n" - ) - break - else: - logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。") - except Exception as e: - logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}") - - # --- 获取超时提示信息 --- - # (这部分逻辑不变) - timeout_context = "" - try: - if hasattr(conversation_info, "goal_list") and conversation_info.goal_list: - last_goal_dict = conversation_info.goal_list[-1] - if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict: - last_goal_text = last_goal_dict["goal"] - if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text: - try: - timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "") - timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n" - except Exception: - timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n" - else: - logger.debug( - f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check." - ) - except AttributeError: - logger.warning( - f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check." - ) - except Exception as e: - logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}") - - # --- 构建通用 Prompt 参数 --- - logger.debug( - f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}" - ) - - # 构建对话目标 (goals_str) - goals_str = "" - try: - if hasattr(conversation_info, "goal_list") and conversation_info.goal_list: - for goal_reason in conversation_info.goal_list: - if isinstance(goal_reason, dict): - goal = goal_reason.get("goal", "目标内容缺失") - reasoning = goal_reason.get("reasoning", "没有明确原因") - else: - goal = str(goal_reason) - reasoning = "没有明确原因" - - goal = str(goal) if goal is not None else "目标内容缺失" - reasoning = str(reasoning) if reasoning is not None else "没有明确原因" - goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n" - - if not goals_str: - goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n" - else: - goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n" - except AttributeError: - logger.warning( - f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet." - ) - goals_str = "- 获取对话目标时出错。\n" - except Exception as e: - logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}") - goals_str = "- 构建对话目标时出错。\n" - - # --- 知识信息字符串构建开始 --- - knowledge_info_str = "【已获取的相关知识和记忆】\n" - try: - # 检查 conversation_info 是否有 knowledge_list 并且不为空 - if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list: - # 最多只显示最近的 5 条知识,防止 Prompt 过长 - recent_knowledge = conversation_info.knowledge_list[-5:] - for i, knowledge_item in enumerate(recent_knowledge): - if isinstance(knowledge_item, dict): - query = knowledge_item.get("query", "未知查询") - knowledge = knowledge_item.get("knowledge", "无知识内容") - source = knowledge_item.get("source", "未知来源") - # 只取知识内容的前 2000 个字,避免太长 - knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge - knowledge_info_str += ( - f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n" - ) - else: - # 处理列表里不是字典的异常情况 - knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n" - - if not recent_knowledge: # 如果 knowledge_list 存在但为空 - knowledge_info_str += "- 暂无相关知识和记忆。\n" - - else: - # 如果 conversation_info 没有 knowledge_list 属性,或者列表为空 - knowledge_info_str += "- 暂无相关知识记忆。\n" - except AttributeError: - logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。") - knowledge_info_str += "- 获取知识列表时出错。\n" - except Exception as e: - logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}") - knowledge_info_str += "- 处理知识列表时出错。\n" - # --- 知识信息字符串构建结束 --- - - # 获取聊天历史记录 (chat_history_text) - try: - if hasattr(observation_info, "chat_history") and observation_info.chat_history: - chat_history_text = observation_info.chat_history_str or "还没有聊天记录。\n" - else: - chat_history_text = "还没有聊天记录。\n" - - if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0: - if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages: - new_messages_list = observation_info.unprocessed_messages - # Convert dict format to SessionMessage objects. - session_messages = [dict_to_session_message(m) for m in new_messages_list] - new_messages_str = build_readable_messages( - session_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) - chat_history_text += ( - f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}" - ) - else: - logger.warning( - f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing." - ) - except AttributeError: - logger.warning( - f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history." - ) - chat_history_text = "获取聊天记录时出错。\n" - except Exception as e: - logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}") - chat_history_text = "处理聊天记录时出错。\n" - - # 构建 Persona 文本 (persona_text) - persona_text = f"你的名字是{self.name},{self.personality_info}。" - - # 构建行动历史和上一次行动结果 (action_history_summary, last_action_context) - # (这部分逻辑不变) - action_history_summary = "你最近执行的行动历史:\n" - last_action_context = "关于你【上一次尝试】的行动:\n" - action_history_list = [] - try: - if hasattr(conversation_info, "done_action") and conversation_info.done_action: - action_history_list = conversation_info.done_action[-5:] - else: - logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.") - except AttributeError: - logger.warning( - f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet." - ) - except Exception as e: - logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}") - - if not action_history_list: - action_history_summary += "- 还没有执行过行动。\n" - last_action_context += "- 这是你规划的第一个行动。\n" - else: - for i, action_data in enumerate(action_history_list): - action_type = "未知" - plan_reason = "未知" - status = "未知" - final_reason = "" - action_time = "" - - if isinstance(action_data, dict): - action_type = action_data.get("action", "未知") - plan_reason = action_data.get("plan_reason", "未知规划原因") - status = action_data.get("status", "未知") - final_reason = action_data.get("final_reason", "") - action_time = action_data.get("time", "") - elif isinstance(action_data, tuple): - # 假设旧格式兼容 - if len(action_data) > 0: - action_type = action_data[0] - if len(action_data) > 1: - plan_reason = action_data[1] # 可能是规划原因或最终原因 - if len(action_data) > 2: - status = action_data[2] - if status == "recall" and len(action_data) > 3: - final_reason = action_data[3] - elif status == "done" and action_type in ["direct_reply", "send_new_message"]: - plan_reason = "成功发送" # 简化显示 - - reason_text = f", 失败/取消原因: {final_reason}" if final_reason else "" - summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}" - action_history_summary += summary_line + "\n" - - if i == len(action_history_list) - 1: - last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n" - last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n" - if status == "done": - last_action_context += "- 该行动已【成功执行】。\n" - # 记录这次成功的行动类型,供下次决策 - # self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制 - elif status == "recall": - last_action_context += "- 但该行动最终【未能执行/被取消】。\n" - if final_reason: - last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n" - else: - last_action_context += "- 【重要】失败/取消原因未明确记录。\n" - # self.last_successful_action_type = None # 行动失败,清除记录 - else: - last_action_context += f"- 该行动当前状态: {status}\n" - # self.last_successful_action_type = None # 非完成状态,清除记录 - - # --- 选择 Prompt --- - if last_successful_reply_action in ["direct_reply", "send_new_message"]: - prompt_template = PROMPT_FOLLOW_UP - logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)") - else: - prompt_template = PROMPT_INITIAL_REPLY - logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)") - - # --- 格式化最终的 Prompt --- - prompt = prompt_template.format( - persona_text=persona_text, - goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。", - action_history_summary=action_history_summary, - last_action_context=last_action_context, - time_since_last_bot_message_info=time_since_last_bot_message_info, - timeout_context=timeout_context, - chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。", - knowledge_info_str=knowledge_info_str, - ) - - logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------") - try: - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}") - - # --- 初始行动规划解析 --- - success, initial_result = get_items_from_json( - content, - self.private_name, - "action", - "reason", - default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"}, - ) - - initial_action = initial_result.get("action", "wait") - initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待") - - # 检查是否需要进行结束对话决策 --- - if initial_action == "end_conversation": - logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...") - - # 使用新的 PROMPT_END_DECISION - end_decision_prompt = PROMPT_END_DECISION.format( - persona_text=persona_text, # 复用之前的 persona_text - chat_history_text=chat_history_text, # 复用之前的 chat_history_text - ) - - logger.debug( - f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------" - ) - try: - end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM - logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}") - - # 解析结束决策的JSON - end_success, end_result = get_items_from_json( - end_content, - self.private_name, - "say_bye", - "reason", - default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"}, - required_types={"say_bye": str, "reason": str}, # 明确类型 - ) - - say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较 - end_decision_reason = end_result.get("reason", "未提供原因") - - if end_success and say_bye_decision == "yes": - # 决定要告别,返回新的 'say_goodbye' 动作 - logger.info( - f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}" - ) - # 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因 - final_action = "say_goodbye" - final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})" - return final_action, final_reason - else: - # 决定不告别 (包括解析失败或明确说no) - logger.info( - f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}" - ) - # 返回原始的 'end_conversation' 动作 - final_action = "end_conversation" - final_reason = initial_reason # 保持原始的结束理由 - return final_action, final_reason - - except Exception as end_e: - logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}") - # 出错时,默认执行原始的结束对话 - logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation") - return "end_conversation", initial_reason # 返回原始动作和原因 - - else: - action = initial_action - reason = initial_reason - - # 验证action类型 (保持不变) - valid_actions = [ - "direct_reply", - "send_new_message", - "fetch_knowledge", - "wait", - "listening", - "rethink_goal", - "end_conversation", # 仍然需要验证,因为可能从上面决策后返回 - "block_and_ignore", - "say_goodbye", # 也要验证这个新动作 - ] - if action not in valid_actions: - logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait") - reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}" - action = "wait" - - logger.info(f"[私聊][{self.private_name}]规划的行动: {action}") - logger.info(f"[私聊][{self.private_name}]行动原因: {reason}") - return action, reason - - except Exception as e: - # 外层异常处理保持不变 - logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}") - return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}" diff --git a/src/chat/brain_chat/PFC/chat_observer.py b/src/chat/brain_chat/PFC/chat_observer.py deleted file mode 100644 index 60426d4c..00000000 --- a/src/chat/brain_chat/PFC/chat_observer.py +++ /dev/null @@ -1,394 +0,0 @@ -import time -import asyncio -import traceback -from datetime import datetime -from typing import Optional, Dict, Any, List -from src.common.logger import get_logger -from sqlmodel import select, col -from src.common.database.database import get_db_session -from src.common.database.database_model import Messages -from maim_message import UserInfo -from src.config.config import global_config -from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification -from rich.traceback import install - -install(extra_lines=3) - -logger = get_logger("chat_observer") - - -def _message_to_dict(message: Messages) -> Dict[str, Any]: - """Convert Peewee Message model to dict for PFC compatibility - - Args: - message: Peewee Messages model instance - - Returns: - Dict[str, Any]: Message dictionary - """ - message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp - return { - "message_id": message.message_id, - "time": message_timestamp, - "chat_id": message.session_id, - "user_id": message.user_id, - "user_nickname": message.user_nickname, - "processed_plain_text": message.processed_plain_text, - "display_message": message.display_message, - "is_mentioned": message.is_mentioned, - "is_command": message.is_command, - # Add user_info dict for compatibility with existing code - "user_info": { - "user_id": message.user_id, - "user_nickname": message.user_nickname, - }, - } - - -class ChatObserver: - """聊天状态观察器""" - - # 类级别的实例管理 - _instances: Dict[str, "ChatObserver"] = {} - - @classmethod - def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver": - """获取或创建观察器实例 - - Args: - stream_id: 聊天流ID - private_name: 私聊名称 - - Returns: - ChatObserver: 观察器实例 - """ - if stream_id not in cls._instances: - cls._instances[stream_id] = cls(stream_id, private_name) - return cls._instances[stream_id] - - def __init__(self, stream_id: str, private_name: str): - """初始化观察器 - - Args: - stream_id: 聊天流ID - """ - self.last_check_time = None - self.last_bot_speak_time = None - self.last_user_speak_time = None - if stream_id in self._instances: - raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.") - - self.stream_id = stream_id - self.private_name = private_name - - self.last_message_read: Optional[Dict[str, Any]] = None - self.last_message_time: float = time.time() - - self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 - - # 运行状态 - self._running: bool = False - self._task: Optional[asyncio.Task] = None - self._update_event = asyncio.Event() # 触发更新的事件 - self._update_complete = asyncio.Event() # 更新完成的事件 - - # 通知管理器 - self.notification_manager = NotificationManager() - - # 冷场检查配置 - self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场 - self.last_cold_chat_check: float = time.time() - self.is_cold_chat_state: bool = False - - self.update_event = asyncio.Event() - self.update_interval = 2 # 更新间隔(秒) - self.message_cache = [] - self.update_running = False - - async def check(self) -> bool: - """检查距离上一次观察之后是否有了新消息 - - Returns: - bool: 是否有新消息 - """ - logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}") - - last_check_time = self.last_check_time or 0.0 - last_check_dt = datetime.fromtimestamp(last_check_time) - with get_db_session() as session: - statement = select(Messages).where( - (col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt) - ) - new_message_exists = session.exec(statement).first() is not None - - if new_message_exists: - logger.debug(f"[私聊][{self.private_name}]发现新消息") - self.last_check_time = time.time() - - return new_message_exists - - async def _add_message_to_history(self, message: Dict[str, Any]): - """添加消息到历史记录并发送通知 - - Args: - message: 消息数据 - """ - try: - # 发送新消息通知 - notification = create_new_message_notification( - sender="chat_observer", target="observation_info", message=message - ) - # print(self.notification_manager) - await self.notification_manager.send_notification(notification) - except Exception as e: - logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}") - print(traceback.format_exc()) - - # 检查并更新冷场状态 - await self._check_cold_chat() - - async def _check_cold_chat(self): - """检查是否处于冷场状态并发送通知""" - current_time = time.time() - - # 每10秒检查一次冷场状态 - if current_time - self.last_cold_chat_check < 10: - return - - self.last_cold_chat_check = current_time - - # 判断是否冷场 - is_cold = ( - True - if self.last_message_time is None - else (current_time - self.last_message_time) > self.cold_chat_threshold - ) - - # 如果冷场状态发生变化,发送通知 - if is_cold != self.is_cold_chat_state: - self.is_cold_chat_state = is_cold - notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) - await self.notification_manager.send_notification(notification) - - def new_message_after(self, time_point: float) -> bool: - """判断是否在指定时间点后有新消息 - - Args: - time_point: 时间戳 - - Returns: - bool: 是否有新消息 - """ - - if self.last_message_time is None: - logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False") - return False - - has_new = self.last_message_time > time_point - logger.debug( - f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}" - ) - return has_new - - async def _fetch_new_messages(self) -> List[Dict[str, Any]]: - """获取新消息 - - Returns: - List[Dict[str, Any]]: 新消息列表 - """ - last_message_time = self.last_message_time or 0.0 - last_message_dt = datetime.fromtimestamp(last_message_time) - with get_db_session() as session: - statement = ( - select(Messages) - .where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt)) - .order_by(col(Messages.timestamp)) - ) - new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()] - - if new_messages: - self.last_message_read = new_messages[-1] - self.last_message_time = new_messages[-1]["time"] - - # print(f"获取数据库中找到的新消息: {new_messages}") - - return new_messages - - async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]: - """获取指定时间点之前的消息 - - Args: - time_point: 时间戳 - - Returns: - List[Dict[str, Any]]: 最多5条消息 - """ - time_point_dt = datetime.fromtimestamp(time_point) - with get_db_session() as session: - statement = ( - select(Messages) - .where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt)) - .order_by(col(Messages.timestamp)) - .limit(5) - ) - messages = list(session.exec(statement).all()) - messages.reverse() - new_messages = [_message_to_dict(msg) for msg in messages] - - if new_messages: - self.last_message_read = new_messages[-1]["message_id"] - - logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}") - - return new_messages - - """主要观察循环""" - - async def _update_loop(self): - """更新循环""" - # try: - # start_time = time.time() - # messages = await self._fetch_new_messages_before(start_time) - # for message in messages: - # await self._add_message_to_history(message) - # logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}") - # except Exception as e: - # logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}") - - while self._running: - try: - # 等待事件或超时(1秒) - try: - # print("等待事件") - await asyncio.wait_for(self._update_event.wait(), timeout=1) - - except asyncio.TimeoutError: - # print("超时") - pass # 超时后也执行一次检查 - - self._update_event.clear() # 重置触发事件 - self._update_complete.clear() # 重置完成事件 - - # 获取新消息 - new_messages = await self._fetch_new_messages() - - if new_messages: - # 处理新消息 - for message in new_messages: - await self._add_message_to_history(message) - - # 设置完成事件 - self._update_complete.set() - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - self._update_complete.set() # 即使出错也要设置完成事件 - - def trigger_update(self): - """触发一次立即更新""" - self._update_event.set() - - async def wait_for_update(self, timeout: float = 5.0) -> bool: - """等待更新完成 - - Args: - timeout: 超时时间(秒) - - Returns: - bool: 是否成功完成更新(False表示超时) - """ - try: - await asyncio.wait_for(self._update_complete.wait(), timeout=timeout) - return True - except asyncio.TimeoutError: - logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)") - return False - - def start(self): - """启动观察器""" - if self._running: - return - - self._running = True - self._task = asyncio.create_task(self._update_loop()) - logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started") - - def stop(self): - """停止观察器""" - self._running = False - self._update_event.set() # 设置事件以解除等待 - self._update_complete.set() # 设置完成事件以解除等待 - if self._task: - self._task.cancel() - logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped") - - async def process_chat_history(self, messages: list): - """处理聊天历史 - - Args: - messages: 消息列表 - """ - self.update_check_time() - - for msg in messages: - try: - user_info = UserInfo.from_dict(msg.get("user_info", {})) - if user_info.user_id == global_config.bot.qq_account: - self.update_bot_speak_time(msg["time"]) - else: - self.update_user_speak_time(msg["time"]) - except Exception as e: - logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}") - continue - - def update_check_time(self): - """更新查看时间""" - self.last_check_time = time.time() - - def update_bot_speak_time(self, speak_time: Optional[float] = None): - """更新机器人说话时间""" - self.last_bot_speak_time = speak_time or time.time() - - def update_user_speak_time(self, speak_time: Optional[float] = None): - """更新用户说话时间""" - self.last_user_speak_time = speak_time or time.time() - - def get_time_info(self) -> str: - """获取时间信息文本""" - current_time = time.time() - time_info = "" - - if self.last_bot_speak_time: - bot_speak_ago = current_time - self.last_bot_speak_time - time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒" - - if self.last_user_speak_time: - user_speak_ago = current_time - self.last_user_speak_time - time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒" - - return time_info - - def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取缓存的消息历史 - - Args: - limit: 获取的最大消息数量,默认50 - - Returns: - List[Dict[str, Any]]: 缓存的消息历史列表 - """ - return self.message_cache[-limit:] - - def get_last_message(self) -> Optional[Dict[str, Any]]: - """获取最后一条消息 - - Returns: - Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None - """ - if not self.message_cache: - return None - return self.message_cache[-1] - - def __str__(self): - return f"ChatObserver for {self.stream_id}" diff --git a/src/chat/brain_chat/PFC/chat_states.py b/src/chat/brain_chat/PFC/chat_states.py deleted file mode 100644 index 4b839b7b..00000000 --- a/src/chat/brain_chat/PFC/chat_states.py +++ /dev/null @@ -1,290 +0,0 @@ -from enum import Enum, auto -from typing import Optional, Dict, Any, List, Set -from dataclasses import dataclass -from datetime import datetime -from abc import ABC, abstractmethod - - -class ChatState(Enum): - """聊天状态枚举""" - - NORMAL = auto() # 正常状态 - NEW_MESSAGE = auto() # 有新消息 - COLD_CHAT = auto() # 冷场状态 - ACTIVE_CHAT = auto() # 活跃状态 - BOT_SPEAKING = auto() # 机器人正在说话 - USER_SPEAKING = auto() # 用户正在说话 - SILENT = auto() # 沉默状态 - ERROR = auto() # 错误状态 - - -class NotificationType(Enum): - """通知类型枚举""" - - NEW_MESSAGE = auto() # 新消息通知 - COLD_CHAT = auto() # 冷场通知 - ACTIVE_CHAT = auto() # 活跃通知 - BOT_SPEAKING = auto() # 机器人说话通知 - USER_SPEAKING = auto() # 用户说话通知 - MESSAGE_DELETED = auto() # 消息删除通知 - USER_JOINED = auto() # 用户加入通知 - USER_LEFT = auto() # 用户离开通知 - ERROR = auto() # 错误通知 - - -@dataclass -class ChatStateInfo: - """聊天状态信息""" - - state: ChatState - last_message_time: Optional[float] = None - last_message_content: Optional[str] = None - last_speaker: Optional[str] = None - message_count: int = 0 - cold_duration: float = 0.0 # 冷场持续时间(秒) - active_duration: float = 0.0 # 活跃持续时间(秒) - - -@dataclass -class Notification: - """通知基类""" - - type: NotificationType - timestamp: float - sender: str # 发送者标识 - target: str # 接收者标识 - data: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: - """转换为字典格式""" - return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data} - - -@dataclass -class StateNotification(Notification): - """持续状态通知""" - - is_active: bool = True - - def to_dict(self) -> Dict[str, Any]: - base_dict = super().to_dict() - base_dict["is_active"] = self.is_active - return base_dict - - -class NotificationHandler(ABC): - """通知处理器接口""" - - @abstractmethod - async def handle_notification(self, notification: Notification): - """处理通知""" - pass - - -class NotificationManager: - """通知管理器""" - - def __init__(self): - # 按接收者和通知类型存储处理器 - self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {} - self._active_states: Set[NotificationType] = set() - self._notification_history: List[Notification] = [] - - def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): - """注册通知处理器 - - Args: - target: 接收者标识(例如:"pfc") - notification_type: 要处理的通知类型 - handler: 处理器实例 - """ - if target not in self._handlers: - self._handlers[target] = {} - if notification_type not in self._handlers[target]: - self._handlers[target][notification_type] = [] - # print(self._handlers[target][notification_type]) - self._handlers[target][notification_type].append(handler) - # print(self._handlers[target][notification_type]) - - def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): - """注销通知处理器 - - Args: - target: 接收者标识 - notification_type: 通知类型 - handler: 要注销的处理器实例 - """ - if target in self._handlers and notification_type in self._handlers[target]: - handlers = self._handlers[target][notification_type] - if handler in handlers: - handlers.remove(handler) - # 如果该类型的处理器列表为空,删除该类型 - if not handlers: - del self._handlers[target][notification_type] - # 如果该目标没有任何处理器,删除该目标 - if not self._handlers[target]: - del self._handlers[target] - - async def send_notification(self, notification: Notification): - """发送通知""" - self._notification_history.append(notification) - - # 如果是状态通知,更新活跃状态 - if isinstance(notification, StateNotification): - if notification.is_active: - self._active_states.add(notification.type) - else: - self._active_states.discard(notification.type) - - # 调用目标接收者的处理器 - target = notification.target - if target in self._handlers: - handlers = self._handlers[target].get(notification.type, []) - # print(handlers) - for handler in handlers: - # print(f"调用处理器: {handler}") - await handler.handle_notification(notification) - - def get_active_states(self) -> Set[NotificationType]: - """获取当前活跃的状态""" - return self._active_states.copy() - - def is_state_active(self, state_type: NotificationType) -> bool: - """检查特定状态是否活跃""" - return state_type in self._active_states - - def get_notification_history( - self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None - ) -> List[Notification]: - """获取通知历史 - - Args: - sender: 过滤特定发送者的通知 - target: 过滤特定接收者的通知 - limit: 限制返回数量 - """ - history = self._notification_history - - if sender: - history = [n for n in history if n.sender == sender] - if target: - history = [n for n in history if n.target == target] - - if limit is not None: - history = history[-limit:] - - return history - - def __str__(self): - str = "" - for target, handlers in self._handlers.items(): - for notification_type, handler_list in handlers.items(): - str += f"NotificationManager for {target} {notification_type} {handler_list}" - return str - - -# 一些常用的通知创建函数 -def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: - """创建新消息通知""" - return Notification( - type=NotificationType.NEW_MESSAGE, - timestamp=datetime.now().timestamp(), - sender=sender, - target=target, - data={ - "message_id": message.get("message_id"), - "processed_plain_text": message.get("processed_plain_text"), - "detailed_plain_text": message.get("detailed_plain_text"), - "user_info": message.get("user_info"), - "time": message.get("time"), - }, - ) - - -def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification: - """创建冷场状态通知""" - return StateNotification( - type=NotificationType.COLD_CHAT, - timestamp=datetime.now().timestamp(), - sender=sender, - target=target, - data={"is_cold": is_cold}, - is_active=is_cold, - ) - - -def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification: - """创建活跃状态通知""" - return StateNotification( - type=NotificationType.ACTIVE_CHAT, - timestamp=datetime.now().timestamp(), - sender=sender, - target=target, - data={"is_active": is_active}, - is_active=is_active, - ) - - -class ChatStateManager: - """聊天状态管理器""" - - def __init__(self): - self.current_state = ChatState.NORMAL - self.state_info = ChatStateInfo(state=ChatState.NORMAL) - self.state_history: list[ChatStateInfo] = [] - - def update_state(self, new_state: ChatState, **kwargs): - """更新聊天状态 - - Args: - new_state: 新的状态 - **kwargs: 其他状态信息 - """ - self.current_state = new_state - self.state_info.state = new_state - - # 更新其他状态信息 - for key, value in kwargs.items(): - if hasattr(self.state_info, key): - setattr(self.state_info, key, value) - - # 记录状态历史 - self.state_history.append(self.state_info) - - def get_current_state_info(self) -> ChatStateInfo: - """获取当前状态信息""" - return self.state_info - - def get_state_history(self) -> list[ChatStateInfo]: - """获取状态历史""" - return self.state_history - - def is_cold_chat(self, threshold: float = 60.0) -> bool: - """判断是否处于冷场状态 - - Args: - threshold: 冷场阈值(秒) - - Returns: - bool: 是否冷场 - """ - if not self.state_info.last_message_time: - return True - - current_time = datetime.now().timestamp() - return (current_time - self.state_info.last_message_time) > threshold - - def is_active_chat(self, threshold: float = 5.0) -> bool: - """判断是否处于活跃状态 - - Args: - threshold: 活跃阈值(秒) - - Returns: - bool: 是否活跃 - """ - if not self.state_info.last_message_time: - return False - - current_time = datetime.now().timestamp() - return (current_time - self.state_info.last_message_time) <= threshold diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py deleted file mode 100644 index ab5a7b3d..00000000 --- a/src/chat/brain_chat/PFC/conversation.py +++ /dev/null @@ -1,722 +0,0 @@ -import asyncio -import datetime -import time - -from typing import Dict, Any, Optional - -from src.common.data_models.mai_message_data_model import MaiMessage -from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat - -# from .message_storage import MongoDBMessageStorage -# from src.config.config import global_config -from .pfc_types import ConversationState -from .pfc import ChatObserver, GoalAnalyzer -from .message_sender import DirectMessageSender -from src.common.logger import get_logger -from .action_planner import ActionPlanner -from .observation_info import ObservationInfo -from .conversation_info import ConversationInfo # 确保导入 ConversationInfo -from .reply_generator import ReplyGenerator -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from maim_message import UserInfo -from .pfc_KnowledgeFetcher import KnowledgeFetcher -from .waiter import Waiter - -import traceback -from rich.traceback import install - -install(extra_lines=3) - -logger = get_logger("pfc") - - -class Conversation: - """对话类,负责管理单个对话的状态和行为""" - - def __init__(self, stream_id: str, private_name: str): - """初始化对话实例 - - Args: - stream_id: 聊天流ID - """ - self.stream_id = stream_id - self.private_name = private_name - self.state = ConversationState.INIT - self.should_continue = False - self.ignore_until_timestamp: Optional[float] = None - - # 回复相关 - self.generated_reply = "" - - async def _initialize(self): - """初始化实例,注册所有组件""" - - try: - self.action_planner = ActionPlanner(self.stream_id, self.private_name) - self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name) - self.reply_generator = ReplyGenerator(self.stream_id, self.private_name) - self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id) - self.waiter = Waiter(self.stream_id, self.private_name) - self.direct_sender = DirectMessageSender(self.private_name) - - # 获取聊天流信息 - self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id) - - self.stop_action_planner = False - except Exception as e: - logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - raise - - try: - # 决策所需要的信息,包括自身自信和观察信息两部分 - # 注册观察器和观测信息 - self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name) - self.chat_observer.start() - self.observation_info = ObservationInfo(self.private_name) - self.observation_info.bind_to_chat_observer(self.chat_observer) - # print(self.chat_observer.get_cached_messages(limit=) - - self.conversation_info = ConversationInfo() - except Exception as e: - logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - raise - try: - logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...") - initial_messages = get_messages_before_time_in_chat( - chat_id=self.stream_id, - timestamp=time.time(), - limit=30, # 加载最近30条作为初始上下文,可以调整 - ) - chat_talking_prompt = build_readable_messages( - initial_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) - if initial_messages: - # 将 SessionMessage 列表转换为 PFC 期望的 dict 格式(保持嵌套结构) - initial_messages_dict: list[dict] = [] - for msg in initial_messages: - user_info = msg.message_info.user_info - msg_dict = { - "message_id": msg.message_id, - "time": msg.timestamp.timestamp(), - "chat_id": msg.session_id, - "processed_plain_text": msg.processed_plain_text, - "display_message": msg.display_message, - "is_mentioned": msg.is_mentioned, - "is_command": msg.is_command, - "user_info": { - "user_id": user_info.user_id, - "user_nickname": user_info.user_nickname, - "user_cardname": user_info.user_cardname, - "platform": msg.platform, - }, - } - initial_messages_dict.append(msg_dict) - - # 将加载的消息填充到 ObservationInfo 的 chat_history - self.observation_info.chat_history = initial_messages_dict - self.observation_info.chat_history_str = chat_talking_prompt + "\n" - self.observation_info.chat_history_count = len(initial_messages_dict) - - # 更新 ObservationInfo 中的时间戳等信息 - last_msg_dict: dict = initial_messages_dict[-1] - self.observation_info.last_message_time = last_msg_dict.get("time") - last_user_info = UserInfo.from_dict(last_msg_dict.get("user_info", {})) - self.observation_info.last_message_sender = last_user_info.user_id - self.observation_info.last_message_content = last_msg_dict.get("processed_plain_text", "") - - logger.info( - f"[私聊][{self.private_name}]成功加载 {len(initial_messages_dict)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}" - ) - - # 让 ChatObserver 从加载的最后一条消息之后开始同步 - if self.observation_info.last_message_time: - self.chat_observer.last_message_time = self.observation_info.last_message_time - self.chat_observer.last_message_read = last_msg_dict # 更新 observer 的最后读取记录 - else: - logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。") - - except Exception as load_err: - logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}") - # 出错也要继续,只是没有历史记录而已 - # 组件准备完成,启动该论对话 - self.should_continue = True - asyncio.create_task(self.start()) - - async def start(self): - """开始对话流程""" - try: - logger.info(f"[私聊][{self.private_name}]对话系统启动中...") - asyncio.create_task(self._plan_and_action_loop()) - except Exception as e: - logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}") - raise - - async def _plan_and_action_loop(self): - """思考步,PFC核心循环模块""" - while self.should_continue: - # 忽略逻辑 - if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp: - await asyncio.sleep(30) - continue - elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp: - logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。") - self.ignore_until_timestamp = None - self.should_continue = False - continue - try: - # --- 在规划前记录当前新消息数量 --- - initial_new_message_count = 0 - if hasattr(self.observation_info, "new_messages_count"): - initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条 - else: - logger.warning( - f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning." - ) - - # --- 调用 Action Planner --- - # 传递 self.conversation_info.last_successful_reply_action - action, reason = await self.action_planner.plan( - self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action - ) - - # --- 规划后检查是否有 *更多* 新消息到达 --- - current_new_message_count = 0 - if hasattr(self.observation_info, "new_messages_count"): - current_new_message_count = self.observation_info.new_messages_count - else: - logger.warning( - f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning." - ) - - if current_new_message_count > initial_new_message_count + 2: - logger.info( - f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划" - ) - # 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了 - self.conversation_info.last_successful_reply_action = None - await asyncio.sleep(0.1) - continue - - # 包含 send_new_message - if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]: - if hasattr(self.observation_info, "clear_unprocessed_messages"): - logger.debug( - f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。" - ) - await self.observation_info.clear_unprocessed_messages() - if hasattr(self.observation_info, "new_messages_count"): - self.observation_info.new_messages_count = 0 - else: - logger.error( - f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!" - ) - - await self._handle_action(action, reason, self.observation_info, self.conversation_info) - - # 检查是否需要结束对话 (逻辑不变) - goal_ended = False - if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list: - for goal_item in self.conversation_info.goal_list: - if isinstance(goal_item, dict): - current_goal = goal_item.get("goal") - - if current_goal == "结束对话": - goal_ended = True - break - - if goal_ended: - self.should_continue = False - logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。") - - except Exception as loop_err: - logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - await asyncio.sleep(1) - - if self.should_continue: - await asyncio.sleep(0.1) - - logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}") - - def _check_new_messages_after_planning(self): - """检查在规划后是否有新消息""" - # 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性 - if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"): - logger.warning( - f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。" - ) - return False # 或者根据需要抛出错误 - - if self.observation_info.new_messages_count > 2: - logger.info( - f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划" - ) - # 如果有新消息,也应该重置上次回复状态 - if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化 - self.conversation_info.last_successful_reply_action = None - else: - logger.warning( - f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。" - ) - return True - return False - - def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage: - """将消息字典转换为MaiMessage对象""" - from datetime import datetime as dt - from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo - from src.common.data_models.message_component_data_model import MessageSequence - - try: - user_info_dict = msg_dict.get("user_info", {}) - user_info = MaiUserInfo( - user_id=user_info_dict.get("user_id", ""), - user_nickname=user_info_dict.get("user_nickname", ""), - user_cardname=user_info_dict.get("user_cardname"), - ) - - msg = MaiMessage( - message_id=msg_dict.get("message_id", f"gen_{time.time()}"), - timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())), - ) - msg.message_info = MessageInfo(user_info=user_info) - msg.platform = user_info_dict.get("platform", "") - msg.session_id = self.stream_id - msg.processed_plain_text = msg_dict.get("processed_plain_text", "") - msg.raw_message = MessageSequence(components=[]) - msg.initialized = True - return msg - except Exception as e: - logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}") - raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e - - async def _handle_action( - self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo - ): - """处理规划的行动""" - - logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}") - - # 记录action历史 (逻辑不变) - current_action_record = { - "action": action, - "plan_reason": reason, - "status": "start", - "time": datetime.datetime.now().strftime("%H:%M:%S"), - "final_reason": None, - } - # 确保 done_action 列表存在 - if not hasattr(conversation_info, "done_action"): - conversation_info.done_action = [] - conversation_info.done_action.append(current_action_record) - action_index = len(conversation_info.done_action) - 1 - - action_successful = False # 用于标记动作是否成功完成 - - # --- 根据不同的 action 执行 --- - - # send_new_message 失败后执行 wait - if action == "send_new_message": - max_reply_attempts = 3 - reply_attempt_count = 0 - is_suitable = False - need_replan = False - check_reason = "未进行尝试" - final_reply_to_send = "" - - while reply_attempt_count < max_reply_attempts and not is_suitable: - reply_attempt_count += 1 - logger.info( - f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..." - ) - self.state = ConversationState.GENERATING - - # 1. 生成回复 (调用 generate 时传入 action_type) - self.generated_reply = await self.reply_generator.generate( - observation_info, conversation_info, action_type="send_new_message" - ) - logger.info( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}" - ) - - # 2. 检查回复 (逻辑不变) - self.state = ConversationState.CHECKING - try: - current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else "" - is_suitable, check_reason, need_replan = await self.reply_generator.check_reply( - reply=self.generated_reply, - goal=current_goal_str, - chat_history=observation_info.chat_history, - chat_history_str=observation_info.chat_history_str, - retry_count=reply_attempt_count - 1, - ) - logger.info( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}" - ) - if is_suitable: - final_reply_to_send = self.generated_reply - break - elif need_replan: - logger.warning( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}" - ) - break - except Exception as check_err: - logger.error( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}" - ) - check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}" - break - - # 循环结束,处理最终结果 - if is_suitable: - # 检查是否有新消息 - if self._check_new_messages_after_planning(): - logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"} - ) - return # 直接返回,重新规划 - - # 发送合适的回复 - self.generated_reply = final_reply_to_send - # --- 在这里调用 _send_reply --- - await self._send_reply() # <--- 调用恢复后的函数 - - # 更新状态: 标记上次成功是 send_new_message - self.conversation_info.last_successful_reply_action = "send_new_message" - action_successful = True # 标记动作成功 - - elif need_replan: - # 打回动作决策 - logger.warning( - f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}" - ) - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"} - ) - - else: - # 追问失败 - logger.warning( - f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}" - ) - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"} - ) - # 重置状态: 追问失败,下次用初始 prompt - self.conversation_info.last_successful_reply_action = None - - # 执行 Wait 操作 - logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...") - self.state = ConversationState.WAITING - await self.waiter.wait(self.conversation_info) - wait_action_record = { - "action": "wait", - "plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待", - "status": "done", - "time": datetime.datetime.now().strftime("%H:%M:%S"), - "final_reason": None, - } - conversation_info.done_action.append(wait_action_record) - - elif action == "direct_reply": - max_reply_attempts = 3 - reply_attempt_count = 0 - is_suitable = False - need_replan = False - check_reason = "未进行尝试" - final_reply_to_send = "" - - while reply_attempt_count < max_reply_attempts and not is_suitable: - reply_attempt_count += 1 - logger.info( - f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..." - ) - self.state = ConversationState.GENERATING - - # 1. 生成回复 - self.generated_reply = await self.reply_generator.generate( - observation_info, conversation_info, action_type="direct_reply" - ) - logger.info( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}" - ) - - # 2. 检查回复 - self.state = ConversationState.CHECKING - try: - current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else "" - is_suitable, check_reason, need_replan = await self.reply_generator.check_reply( - reply=self.generated_reply, - goal=current_goal_str, - chat_history=observation_info.chat_history, - chat_history_str=observation_info.chat_history_str, - retry_count=reply_attempt_count - 1, - ) - logger.info( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}" - ) - if is_suitable: - final_reply_to_send = self.generated_reply - break - elif need_replan: - logger.warning( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}" - ) - break - except Exception as check_err: - logger.error( - f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}" - ) - check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}" - break - - # 循环结束,处理最终结果 - if is_suitable: - # 检查是否有新消息 - if self._check_new_messages_after_planning(): - logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"} - ) - return # 直接返回,重新规划 - - # 发送合适的回复 - self.generated_reply = final_reply_to_send - # --- 在这里调用 _send_reply --- - await self._send_reply() # <--- 调用恢复后的函数 - - # 更新状态: 标记上次成功是 direct_reply - self.conversation_info.last_successful_reply_action = "direct_reply" - action_successful = True # 标记动作成功 - - elif need_replan: - # 打回动作决策 - logger.warning( - f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}" - ) - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"} - ) - - else: - # 首次回复失败 - logger.warning( - f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}" - ) - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"} - ) - # 重置状态: 首次回复失败,下次还是用初始 prompt - self.conversation_info.last_successful_reply_action = None - - # 执行 Wait 操作 (保持原有逻辑) - logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...") - self.state = ConversationState.WAITING - await self.waiter.wait(self.conversation_info) - wait_action_record = { - "action": "wait", - "plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待", - "status": "done", - "time": datetime.datetime.now().strftime("%H:%M:%S"), - "final_reason": None, - } - conversation_info.done_action.append(wait_action_record) - - elif action == "fetch_knowledge": - self.state = ConversationState.FETCHING - knowledge_query = reason - try: - # 检查 knowledge_fetcher 是否存在 - if not hasattr(self, "knowledge_fetcher"): - logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。") - raise AttributeError("KnowledgeFetcher not initialized") - - knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history) - logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}") - if knowledge: - # 确保 knowledge_list 存在 - if not hasattr(conversation_info, "knowledge_list"): - conversation_info.knowledge_list = [] - conversation_info.knowledge_list.append( - {"query": knowledge_query, "knowledge": knowledge, "source": source} - ) - action_successful = True - except Exception as fetch_err: - logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"} - ) - self.conversation_info.last_successful_reply_action = None # 重置状态 - - elif action == "rethink_goal": - self.state = ConversationState.RETHINKING - try: - # 检查 goal_analyzer 是否存在 - if not hasattr(self, "goal_analyzer"): - logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。") - raise AttributeError("GoalAnalyzer not initialized") - await self.goal_analyzer.analyze_goal(conversation_info, observation_info) - action_successful = True - except Exception as rethink_err: - logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"} - ) - self.conversation_info.last_successful_reply_action = None # 重置状态 - - elif action == "listening": - self.state = ConversationState.LISTENING - logger.info(f"[私聊][{self.private_name}]倾听对方发言...") - try: - # 检查 waiter 是否存在 - if not hasattr(self, "waiter"): - logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。") - raise AttributeError("Waiter not initialized") - await self.waiter.wait_listening(conversation_info) - action_successful = True # Listening 完成就算成功 - except Exception as listen_err: - logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"倾听失败: {listen_err}"} - ) - self.conversation_info.last_successful_reply_action = None # 重置状态 - - elif action == "say_goodbye": - self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING - logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...") - try: - # 1. 生成告别语 (使用 'say_goodbye' action_type) - self.generated_reply = await self.reply_generator.generate( - observation_info, conversation_info, action_type="say_goodbye" - ) - logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}") - - # 2. 直接发送告别语 (不经过检查) - if self.generated_reply: # 确保生成了内容 - await self._send_reply() # 调用发送方法 - # 发送成功后,标记动作成功 - action_successful = True - logger.info(f"[私聊][{self.private_name}]告别语已发送。") - else: - logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。") - action_successful = False # 标记动作失败 - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": "未能生成告别语内容"} - ) - - # 3. 无论是否发送成功,都准备结束对话 - self.should_continue = False - logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。") - - except Exception as goodbye_err: - logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - # 即使出错,也结束对话 - self.should_continue = False - action_successful = False # 标记动作失败 - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"} - ) - - elif action == "end_conversation": - # 这个分支现在只会在 action_planner 最终决定不告别时被调用 - self.should_continue = False - logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...") - action_successful = True # 标记这个指令本身是成功的 - - elif action == "block_and_ignore": - logger.info(f"[私聊][{self.private_name}]不想再理你了...") - ignore_duration_seconds = 10 * 60 - self.ignore_until_timestamp = time.time() + ignore_duration_seconds - logger.info( - f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}" - ) - self.state = ConversationState.IGNORED - action_successful = True # 标记动作成功 - - else: # 对应 'wait' 动作 - self.state = ConversationState.WAITING - logger.info(f"[私聊][{self.private_name}]等待更多信息...") - try: - # 检查 waiter 是否存在 - if not hasattr(self, "waiter"): - logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。") - raise AttributeError("Waiter not initialized") - _timeout_occurred = await self.waiter.wait(self.conversation_info) - action_successful = True # Wait 完成就算成功 - except Exception as wait_err: - logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}") - conversation_info.done_action[action_index].update( - {"status": "recall", "final_reason": f"等待失败: {wait_err}"} - ) - self.conversation_info.last_successful_reply_action = None # 重置状态 - - # --- 更新 Action History 状态 --- - # 只有当动作本身成功时,才更新状态为 done - if action_successful: - conversation_info.done_action[action_index].update( - { - "status": "done", - "time": datetime.datetime.now().strftime("%H:%M:%S"), - } - ) - # 重置状态: 对于非回复类动作的成功,清除上次回复状态 - if action not in ["direct_reply", "send_new_message"]: - self.conversation_info.last_successful_reply_action = None - logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action") - # 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action - - async def _send_reply(self): - """发送回复""" - if not self.generated_reply: - logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。") - return - - try: - _current_time = time.time() - reply_content = self.generated_reply - - # 发送消息 (确保 direct_sender 和 chat_stream 有效) - if not hasattr(self, "direct_sender") or not self.direct_sender: - logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。") - return - if not self.chat_stream: - logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。") - return - - await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content) - - # 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息 - # 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持) - # 暂时注释掉,观察是否影响 ObservationInfo 的更新 - # self.chat_observer.trigger_update() - # if not await self.chat_observer.wait_for_update(): - # logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时") - - self.state = ConversationState.ANALYZING # 更新状态 - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") - self.state = ConversationState.ANALYZING - - async def _send_timeout_message(self): - """发送超时结束消息""" - try: - messages = self.chat_observer.get_cached_messages(limit=1) - if not messages: - return - - latest_message = self._convert_to_message(messages[0]) - await self.direct_sender.send_message( - chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message - ) - except Exception as e: - logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}") diff --git a/src/chat/brain_chat/PFC/conversation_info.py b/src/chat/brain_chat/PFC/conversation_info.py deleted file mode 100644 index d9afd6ac..00000000 --- a/src/chat/brain_chat/PFC/conversation_info.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional - - -class ConversationInfo: - def __init__(self): - self.done_action: list = [] - self.goal_list: list = [] - self.knowledge_list: list = [] - self.memory_list: list = [] - self.last_successful_reply_action: Optional[str] = None diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py deleted file mode 100644 index b9da905c..00000000 --- a/src/chat/brain_chat/PFC/message_sender.py +++ /dev/null @@ -1,61 +0,0 @@ -"""PFC 侧消息发送封装。""" - -from typing import Optional - -from rich.traceback import install - -from src.chat.message_receive.chat_manager import BotChatSession -from src.common.data_models.mai_message_data_model import MaiMessage -from src.common.logger import get_logger -from src.services import send_service as send_api - -install(extra_lines=3) - -logger = get_logger("message_sender") - - -class DirectMessageSender: - """直接消息发送器。""" - - def __init__(self, private_name: str) -> None: - """初始化直接消息发送器。 - - Args: - private_name: 当前私聊实例的名称。 - """ - self.private_name = private_name - - async def send_message( - self, - chat_stream: BotChatSession, - content: str, - reply_to_message: Optional[MaiMessage] = None, - ) -> None: - """发送文本消息到聊天流。 - - Args: - chat_stream: 目标聊天会话。 - content: 待发送的文本内容。 - reply_to_message: 可选的引用回复锚点消息。 - - Raises: - RuntimeError: 当消息发送失败时抛出。 - """ - try: - sent = await send_api.text_to_stream( - text=content, - stream_id=chat_stream.session_id, - set_reply=reply_to_message is not None, - reply_message=reply_to_message, - storage_message=True, - ) - - if sent: - logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") - return - - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") - raise RuntimeError("消息发送失败") - except Exception as exc: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}") - raise diff --git a/src/chat/brain_chat/PFC/observation_info.py b/src/chat/brain_chat/PFC/observation_info.py deleted file mode 100644 index 3d3b235a..00000000 --- a/src/chat/brain_chat/PFC/observation_info.py +++ /dev/null @@ -1,429 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional, Set - -from maim_message import UserInfo -import time - -from src.chat.message_receive.message import SessionMessage -from src.common.logger import get_logger -from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo as MaiUserInfo -from src.services.message_service import build_readable_messages - -from .chat_observer import ChatObserver -from .chat_states import NotificationHandler, NotificationType, Notification -import traceback # 导入 traceback 用于调试 - -logger = get_logger("observation_info") - - -def dict_to_session_message(msg_dict: Dict[str, Any]) -> SessionMessage: - """Convert PFC dict format to SessionMessage object. - - Args: - msg_dict: Message in PFC dict format with nested user_info - - Returns: - SessionMessage object compatible with build_readable_messages() - """ - user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {}) - timestamp = msg_dict.get("time", 0.0) - platform = user_info_dict.get("platform", "") - message = SessionMessage( - message_id=msg_dict.get("message_id", ""), - timestamp=datetime.fromtimestamp(timestamp), - platform=platform, - ) - message.message_info = MessageInfo( - user_info=MaiUserInfo( - user_id=user_info_dict.get("user_id", ""), - user_nickname=user_info_dict.get("user_nickname", ""), - user_cardname=user_info_dict.get("user_cardname"), - ) - ) - message.session_id = msg_dict.get("chat_id", "") - message.processed_plain_text = msg_dict.get("processed_plain_text", "") - message.display_message = msg_dict.get("display_message", "") - message.is_mentioned = msg_dict.get("is_mentioned", False) - message.is_command = msg_dict.get("is_command", False) - message.initialized = True - return message - - -class ObservationInfoHandler(NotificationHandler): - """ObservationInfo的通知处理器""" - - def __init__(self, observation_info: "ObservationInfo", private_name: str): - """初始化处理器 - - Args: - observation_info: 要更新的ObservationInfo实例 - private_name: 私聊对象的名称,用于日志记录 - """ - self.observation_info = observation_info - # 将 private_name 存储在 handler 实例中 - self.private_name = private_name - - async def handle_notification(self, notification: Notification): # 添加类型提示 - # 获取通知类型和数据 - notification_type = notification.type - data = notification.data - - try: # 添加错误处理块 - if notification_type == NotificationType.NEW_MESSAGE: - # 处理新消息通知 - # logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释 - message_id = data.get("message_id") - processed_plain_text = data.get("processed_plain_text") - detailed_plain_text = data.get("detailed_plain_text") - user_info_dict = data.get("user_info") # 先获取字典 - time_value = data.get("time") - - # 确保 user_info 是字典类型再创建 UserInfo 对象 - user_info = None - if isinstance(user_info_dict, dict): - try: - user_info = UserInfo.from_dict(user_info_dict) - except Exception as e: - logger.error( - f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}" - ) - # 可以选择在这里返回或记录错误,避免后续代码出错 - return - elif user_info_dict is not None: - logger.warning( - f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}" - ) - # 根据需要处理非字典情况,这里暂时返回 - return - - message = { - "message_id": message_id, - "processed_plain_text": processed_plain_text, - "detailed_plain_text": detailed_plain_text, - "user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理 - "time": time_value, - } - # 传递 UserInfo 对象(如果成功创建)或原始字典 - await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象 - - elif notification_type == NotificationType.COLD_CHAT: - # 处理冷场通知 - is_cold = data.get("is_cold", False) - await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用 - - elif notification_type == NotificationType.ACTIVE_CHAT: - # 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理) - is_active = data.get("is_active", False) - self.observation_info.is_cold = not is_active - - elif notification_type == NotificationType.BOT_SPEAKING: - # 处理机器人说话通知 (按需实现) - self.observation_info.is_typing = False - self.observation_info.last_bot_speak_time = time.time() - - elif notification_type == NotificationType.USER_SPEAKING: - # 处理用户说话通知 - self.observation_info.is_typing = False - self.observation_info.last_user_speak_time = time.time() - - elif notification_type == NotificationType.MESSAGE_DELETED: - # 处理消息删除通知 - message_id = data.get("message_id") - # 从 unprocessed_messages 中移除被删除的消息 - original_count = len(self.observation_info.unprocessed_messages) - self.observation_info.unprocessed_messages = [ - msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id - ] - if len(self.observation_info.unprocessed_messages) < original_count: - logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})") - - elif notification_type == NotificationType.USER_JOINED: - # 处理用户加入通知 (如果适用私聊场景) - user_id = data.get("user_id") - if user_id: - self.observation_info.active_users.add(str(user_id)) # 确保是字符串 - - elif notification_type == NotificationType.USER_LEFT: - # 处理用户离开通知 (如果适用私聊场景) - user_id = data.get("user_id") - if user_id: - self.observation_info.active_users.discard(str(user_id)) # 确保是字符串 - - elif notification_type == NotificationType.ERROR: - # 处理错误通知 - error_msg = data.get("error", "未提供错误信息") - logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}") - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}") - logger.error(traceback.format_exc()) # 打印详细堆栈信息 - - -# @dataclass <-- 这个,不需要了(递黄瓜) -class ObservationInfo: - """决策信息类,用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)""" - - # 类型提示保留,可用于文档和静态分析 - private_name: str - chat_history: List[Dict[str, Any]] - chat_history_str: str - unprocessed_messages: List[Dict[str, Any]] - active_users: Set[str] - last_bot_speak_time: Optional[float] - last_user_speak_time: Optional[float] - last_message_time: Optional[float] - last_message_id: Optional[str] - last_message_content: str - last_message_sender: Optional[str] - bot_id: Optional[str] - chat_history_count: int - new_messages_count: int - cold_chat_start_time: Optional[float] - cold_chat_duration: float - is_typing: bool - is_cold_chat: bool - changed: bool - chat_observer: Optional[ChatObserver] - handler: Optional[ObservationInfoHandler] - - def __init__(self, private_name: str): - """ - 手动初始化 ObservationInfo 的所有实例变量。 - """ - - # 接收的参数 - self.private_name: str = private_name - - # data_list - self.chat_history: List[Dict[str, Any]] = [] - self.chat_history_str: str = "" - self.unprocessed_messages: List[Dict[str, Any]] = [] - self.active_users: Set[str] = set() - - # data - self.last_bot_speak_time: Optional[float] = None - self.last_user_speak_time: Optional[float] = None - self.last_message_time: Optional[float] = None - self.last_message_id: Optional[str] = None - self.last_message_content: str = "" - self.last_message_sender: Optional[str] = None - self.bot_id: Optional[str] = None - self.chat_history_count: int = 0 - self.new_messages_count: int = 0 - self.cold_chat_start_time: Optional[float] = None - self.cold_chat_duration: float = 0.0 - - # state - self.is_typing: bool = False - self.is_cold_chat: bool = False - self.changed: bool = False - - # 关联对象 - self.chat_observer: Optional[ChatObserver] = None - - self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name) - - def bind_to_chat_observer(self, chat_observer: ChatObserver): - """绑定到指定的chat_observer - - Args: - chat_observer: 要绑定的 ChatObserver 实例 - """ - if self.chat_observer: - logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver") - return - - self.chat_observer = chat_observer - try: - if not self.handler: # 确保 handler 已经被创建 - logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!") - self.chat_observer = None # 重置,防止后续错误 - return - - # 注册关心的通知类型 - self.chat_observer.notification_manager.register_handler( - target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler - ) - self.chat_observer.notification_manager.register_handler( - target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler - ) - # 可以根据需要注册更多通知类型 - # self.chat_observer.notification_manager.register_handler( - # target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler - # ) - logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver") - except Exception as e: - logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}") - self.chat_observer = None # 绑定失败,重置 - - def unbind_from_chat_observer(self): - """解除与chat_observer的绑定""" - if ( - self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler - ): # 增加 handler 检查 - try: - self.chat_observer.notification_manager.unregister_handler( - target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler - ) - self.chat_observer.notification_manager.unregister_handler( - target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler - ) - # 如果注册了其他类型,也要在这里注销 - # self.chat_observer.notification_manager.unregister_handler( - # target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler - # ) - logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑") - except Exception as e: - logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}") - finally: # 确保 chat_observer 被重置 - self.chat_observer = None - else: - logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置") - - # 修改:update_from_message 接收 UserInfo 对象 - async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]): - """从消息更新信息 - - Args: - message: 消息数据字典 - user_info: 解析后的 UserInfo 对象 (可能为 None) - """ - message_time = message.get("time") - message_id = message.get("message_id") - processed_text = message.get("processed_plain_text", "") - - # 只有在新消息到达时才更新 last_message 相关信息 - if message_time and message_time > (self.last_message_time or 0): - self.last_message_time = message_time - self.last_message_id = message_id - self.last_message_content = processed_text - # 重置冷场计时器 - self.is_cold_chat = False - self.cold_chat_start_time = None - self.cold_chat_duration = 0.0 - - if user_info: - sender_id = str(user_info.user_id) # 确保是字符串 - self.last_message_sender = sender_id - # 更新发言时间 - if sender_id == self.bot_id: - self.last_bot_speak_time = message_time - else: - self.last_user_speak_time = message_time - self.active_users.add(sender_id) # 用户发言则认为其活跃 - else: - logger.warning( - f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}" - ) - self.last_message_sender = None # 发送者未知 - - # 将原始消息字典添加到未处理列表 - self.unprocessed_messages.append(message) - self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度 - - # logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}") - self.update_changed() # 标记状态已改变 - else: - # 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告 - pass - # logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}") - - def update_changed(self): - """标记状态已改变,并重置标记""" - # logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)") - self.changed = True - - async def update_cold_chat_status(self, is_cold: bool, current_time: float): - """更新冷场状态 - - Args: - is_cold: 是否处于冷场状态 - current_time: 当前时间戳 - """ - if is_cold != self.is_cold_chat: # 仅在状态变化时更新 - self.is_cold_chat = is_cold - if is_cold: - # 进入冷场状态 - self.cold_chat_start_time = ( - self.last_message_time or current_time - ) # 从最后消息时间开始算,或从当前时间开始 - logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}") - else: - # 结束冷场状态 - if self.cold_chat_start_time: - self.cold_chat_duration = current_time - self.cold_chat_start_time - logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f} 秒") - self.cold_chat_start_time = None # 重置开始时间 - self.update_changed() # 状态变化,标记改变 - - # 即使状态没变,如果是冷场状态,也更新持续时间 - if self.is_cold_chat and self.cold_chat_start_time: - self.cold_chat_duration = current_time - self.cold_chat_start_time - - def get_active_duration(self) -> float: - """获取当前活跃时长 (距离最后一条消息的时间) - - Returns: - float: 最后一条消息到现在的时长(秒) - """ - if not self.last_message_time: - return 0.0 - return time.time() - self.last_message_time - - def get_user_response_time(self) -> Optional[float]: - """获取用户最后响应时间 (距离用户最后发言的时间) - - Returns: - Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None - """ - if not self.last_user_speak_time: - return None - return time.time() - self.last_user_speak_time - - def get_bot_response_time(self) -> Optional[float]: - """获取机器人最后响应时间 (距离机器人最后发言的时间) - - Returns: - Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None - """ - if not self.last_bot_speak_time: - return None - return time.time() - self.last_bot_speak_time - - async def clear_unprocessed_messages(self): - """将未处理消息移入历史记录,并更新相关状态""" - if not self.unprocessed_messages: - return # 没有未处理消息,直接返回 - - # logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...") - # 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长) - max_history_len = 100 # 示例:最多保留100条历史记录 - self.chat_history.extend(self.unprocessed_messages) - if len(self.chat_history) > max_history_len: - self.chat_history = self.chat_history[-max_history_len:] - - # 更新历史记录字符串 (只使用最近一部分生成,例如20条) - history_slice_for_str = self.chat_history[-20:] - try: - # Convert dict format to SessionMessage objects. - session_messages = [dict_to_session_message(m) for m in history_slice_for_str] - self.chat_history_str = build_readable_messages( - session_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, # read_mark 可能需要根据逻辑调整 - ) - except Exception as e: - logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}") - self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示 - - # 清空未处理消息列表和计数 - # cleared_count = len(self.unprocessed_messages) - self.unprocessed_messages.clear() - self.new_messages_count = 0 - # self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断 - - self.chat_history_count = len(self.chat_history) # 更新历史记录总数 - # logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。") - - self.update_changed() # 状态改变 diff --git a/src/chat/brain_chat/PFC/pfc.py b/src/chat/brain_chat/PFC/pfc.py deleted file mode 100644 index 5d051716..00000000 --- a/src/chat/brain_chat/PFC/pfc.py +++ /dev/null @@ -1,361 +0,0 @@ -from typing import List, Tuple, TYPE_CHECKING -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -import random -from .chat_observer import ChatObserver -from .pfc_utils import get_items_from_json -from .conversation_info import ConversationInfo -from src.services.message_service import build_readable_messages - -from .observation_info import ObservationInfo, dict_to_session_message -from rich.traceback import install - -install(extra_lines=3) - -if TYPE_CHECKING: - pass - -logger = get_logger("pfc") - - -def _calculate_similarity(goal1: str, goal2: str) -> float: - """简单计算两个目标之间的相似度 - - 这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法 - - Args: - goal1: 第一个目标 - goal2: 第二个目标 - - Returns: - float: 相似度得分 (0-1) - """ - # 简单实现:检查重叠字数比例 - words1 = set(goal1) - words2 = set(goal2) - overlap = len(words1.intersection(words2)) - total = len(words1.union(words2)) - return overlap / total if total > 0 else 0 - - -class GoalAnalyzer: - """对话目标分析器""" - - def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal") - - self.personality_info = self._get_personality_prompt() - self.name = global_config.bot.nickname - self.nick_name = global_config.bot.alias_names - self.private_name = private_name - self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - - # 多目标存储结构 - self.goals = [] # 存储多个目标 - self.max_goals = 3 # 同时保持的最大目标数量 - self.current_goal_and_reason = None - - def _get_personality_prompt(self) -> str: - """获取个性提示信息""" - prompt_personality = global_config.personality.personality - - # 检查是否需要随机替换为状态 - if ( - global_config.personality.states - and global_config.personality.state_probability > 0 - and random.random() < global_config.personality.state_probability - ): - prompt_personality = random.choice(global_config.personality.states) - - bot_name = global_config.bot.nickname - return f"你的名字是{bot_name},你{prompt_personality};" - - async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo): - """分析对话历史并设定目标 - - Args: - conversation_info: 对话信息 - observation_info: 观察信息 - - Returns: - Tuple[str, str, str]: (目标, 方法, 原因) - """ - # 构建对话目标 - goals_str = "" - if conversation_info.goal_list: - for goal_reason in conversation_info.goal_list: - if isinstance(goal_reason, dict): - goal = goal_reason.get("goal", "目标内容缺失") - reasoning = goal_reason.get("reasoning", "没有明确原因") - else: - goal = str(goal_reason) - reasoning = "没有明确原因" - - goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" - goals_str += goal_str - else: - goal = "目前没有明确对话目标" - reasoning = "目前没有明确对话目标,最好思考一个对话目标" - goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" - - # 获取聊天历史记录 - chat_history_text = observation_info.chat_history_str - - if observation_info.new_messages_count > 0: - new_messages_list = observation_info.unprocessed_messages - session_messages = [dict_to_session_message(m) for m in new_messages_list] - new_messages_str = build_readable_messages( - session_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) - chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}" - - # await observation_info.clear_unprocessed_messages() - - persona_text = f"你的名字是{self.name},{self.personality_info}。" - # 构建action历史文本 - action_history_list = conversation_info.done_action - action_history_text = "你之前做的事情是:" - for action in action_history_list: - action_history_text += f"{action}\n" - - prompt = f"""{persona_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。 -这些目标应该反映出对话的不同方面和意图。 - -{action_history_text} -当前对话目标: -{goals_str} - -聊天记录: -{chat_history_text} - -请分析当前对话并确定最适合的对话目标。你可以: -1. 保持现有目标不变 -2. 修改现有目标 -3. 添加新目标 -4. 删除不再相关的目标 -5. 如果你想结束对话,请设置一个目标,目标goal为"结束对话",原因reasoning为你希望结束对话 - -请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段: -1. goal: 对话目标(简短的一句话) -2. reasoning: 对话原因,为什么设定这个目标(简要解释) - -输出格式示例: -[ -{{ - "goal": "回答用户关于Python编程的具体问题", - "reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" -}}, -{{ - "goal": "回答用户关于python安装的具体问题", - "reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" -}} -]""" - - logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}") - try: - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}") - except Exception as e: - logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}") - content = "" - - # 使用改进后的get_items_from_json函数处理JSON数组 - success, result = get_items_from_json( - content, - self.private_name, - "goal", - "reasoning", - required_types={"goal": str, "reasoning": str}, - allow_array=True, - ) - - if success: - # 判断结果是单个字典还是字典列表 - if isinstance(result, list): - # 清空现有目标列表并添加新目标 - conversation_info.goal_list = [] - for item in result: - conversation_info.goal_list.append(item) - - # 返回第一个目标作为当前主要目标(如果有) - if result: - first_goal = result[0] - return first_goal.get("goal", ""), "", first_goal.get("reasoning", "") - else: - # 单个目标的情况 - conversation_info.goal_list.append(result) - goal_value = result.get("goal", "") - reasoning_value = result.get("reasoning", "") - return goal_value, "", reasoning_value - - # 如果解析失败,返回默认值 - return "", "", "" - - async def _update_goals(self, new_goal: str, method: str, reasoning: str): - """更新目标列表 - - Args: - new_goal: 新的目标 - method: 实现目标的方法 - reasoning: 目标的原因 - """ - # 检查新目标是否与现有目标相似 - for i, (existing_goal, _, _) in enumerate(self.goals): - if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值 - # 更新现有目标 - self.goals[i] = (new_goal, method, reasoning) - # 将此目标移到列表前面(最主要的位置) - self.goals.insert(0, self.goals.pop(i)) - return - - # 添加新目标到列表前面 - self.goals.insert(0, (new_goal, method, reasoning)) - - # 限制目标数量 - if len(self.goals) > self.max_goals: - self.goals.pop() # 移除最老的目标 - - async def get_all_goals(self) -> List[Tuple[str, str, str]]: - """获取所有当前目标 - - Returns: - List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因) - """ - return self.goals.copy() - - async def get_alternative_goals(self) -> List[Tuple[str, str, str]]: - """获取除了当前主要目标外的其他备选目标 - - Returns: - List[Tuple[str, str, str]]: 备选目标列表 - """ - if len(self.goals) <= 1: - return [] - return self.goals[1:].copy() - - async def analyze_conversation(self, goal, reasoning): - messages = self.chat_observer.get_cached_messages() - session_messages = [dict_to_session_message(m) for m in messages] - chat_history_text = build_readable_messages( - session_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) - - persona_text = f"你的名字是{self.name},{self.personality_info}。" - # ===> Persona 文本构建结束 <=== - - # --- 修改 Prompt 字符串,使用 persona_text --- - prompt = f"""{persona_text}。现在你在参与一场QQ聊天, - 当前对话目标:{goal} - 产生该对话目标的原因:{reasoning} - - 请分析以下聊天记录,并根据你的性格特征评估该目标是否已经达到,或者你是否希望停止该次对话。 - 聊天记录: - {chat_history_text} - 请以JSON格式输出,包含以下字段: - 1. goal_achieved: 对话目标是否已经达到(true/false) - 2. stop_conversation: 是否希望停止该次对话(true/false) - 3. reason: 为什么希望停止该次对话(简要解释) - -输出格式示例: -{{ - "goal_achieved": true, - "stop_conversation": false, - "reason": "虽然目标已达成,但对话仍然有继续的价值" -}}""" - - try: - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}") - - # 尝试解析JSON - success, result = get_items_from_json( - content, - self.private_name, - "goal_achieved", - "stop_conversation", - "reason", - required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}, - ) - - if not success: - logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON") - return False, False, "解析结果失败" - - goal_achieved = result["goal_achieved"] - stop_conversation = result["stop_conversation"] - reason = result["reason"] - - return goal_achieved, stop_conversation, reason - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}") - return False, False, f"分析出错: {str(e)}" - - -# 先注释掉,万一以后出问题了还能开回来((( -# class DirectMessageSender: -# """直接发送消息到平台的发送器""" - -# def __init__(self, private_name: str): -# self.logger = get_module_logger("direct_sender") -# self.storage = MessageStorage() -# self.private_name = private_name - -# async def send_via_ws(self, message: MessageSending) -> None: -# try: -# await global_api.send_message(message) -# except Exception as e: -# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e - -# async def send_message( -# self, -# chat_stream: ChatStream, -# content: str, -# reply_to_message: Optional[Message] = None, -# ) -> None: -# """直接发送消息到平台 - -# Args: -# chat_stream: 聊天流 -# content: 消息内容 -# reply_to_message: 要回复的消息 -# """ -# # 构建消息对象 -# message_segment = Seg(type="text", data=content) -# bot_user_info = UserInfo( -# user_id=global_config.BOT_QQ, -# user_nickname=global_config.BOT_NICKNAME, -# platform=chat_stream.platform, -# ) - -# message = MessageSending( -# message_id=f"dm{round(time.time(), 2)}", -# chat_stream=chat_stream, -# bot_user_info=bot_user_info, -# sender_info=reply_to_message.message_info.user_info if reply_to_message else None, -# message_segment=message_segment, -# reply=reply_to_message, -# is_head=True, -# is_emoji=False, -# thinking_start_time=time.time(), -# ) - -# # 处理消息 -# await message.process() - -# _message_json = message.to_dict() - -# # 发送消息 -# try: -# await self.send_via_ws(message) -# await self.storage.store_message(message, chat_stream) -# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}") -# except Exception as e: -# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}") diff --git a/src/chat/brain_chat/PFC/pfc_manager.py b/src/chat/brain_chat/PFC/pfc_manager.py deleted file mode 100644 index 174be78b..00000000 --- a/src/chat/brain_chat/PFC/pfc_manager.py +++ /dev/null @@ -1,115 +0,0 @@ -import time -from typing import Dict, Optional -from src.common.logger import get_logger -from .conversation import Conversation -import traceback - -logger = get_logger("pfc_manager") - - -class PFCManager: - """PFC对话管理器,负责管理所有对话实例""" - - # 单例模式 - _instance = None - - # 会话实例管理 - _instances: Dict[str, Conversation] = {} - _initializing: Dict[str, bool] = {} - - @classmethod - def get_instance(cls) -> "PFCManager": - """获取管理器单例 - - Returns: - PFCManager: 管理器实例 - """ - if cls._instance is None: - cls._instance = PFCManager() - return cls._instance - - async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]: - """获取或创建对话实例 - - Args: - stream_id: 聊天流ID - private_name: 私聊名称 - - Returns: - Optional[Conversation]: 对话实例,创建失败则返回None - """ - # 检查是否已经有实例 - if stream_id in self._initializing and self._initializing[stream_id]: - logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}") - return None - - if stream_id in self._instances and self._instances[stream_id].should_continue: - logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}") - return self._instances[stream_id] - if stream_id in self._instances: - instance = self._instances[stream_id] - if ( - hasattr(instance, "ignore_until_timestamp") - and instance.ignore_until_timestamp - and time.time() < instance.ignore_until_timestamp - ): - logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}") - # 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵? - # 还是返回 None 吧喵。 - return None - - # 检查 should_continue 状态 - if instance.should_continue: - logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}") - return instance - # else: 实例存在但不应继续 - try: - # 创建新实例 - logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}") - self._initializing[stream_id] = True - # 创建实例 - conversation_instance = Conversation(stream_id, private_name) - self._instances[stream_id] = conversation_instance - - # 启动实例初始化 - await self._initialize_conversation(conversation_instance) - except Exception as e: - logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}") - return None - - return conversation_instance - - async def _initialize_conversation(self, conversation: Conversation): - """初始化会话实例 - - Args: - conversation: 要初始化的会话实例 - """ - stream_id = conversation.stream_id - private_name = conversation.private_name - - try: - logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}") - # 启动初始化流程 - await conversation._initialize() - - # 标记初始化完成 - self._initializing[stream_id] = False - - logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成") - - except Exception as e: - logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}") - logger.error(f"[私聊][{private_name}]{traceback.format_exc()}") - # 清理失败的初始化 - - async def get_conversation(self, stream_id: str) -> Optional[Conversation]: - """获取已存在的会话实例 - - Args: - stream_id: 聊天流ID - - Returns: - Optional[Conversation]: 会话实例,不存在则返回None - """ - return self._instances.get(stream_id) diff --git a/src/chat/brain_chat/PFC/pfc_types.py b/src/chat/brain_chat/PFC/pfc_types.py deleted file mode 100644 index 0ea5eda6..00000000 --- a/src/chat/brain_chat/PFC/pfc_types.py +++ /dev/null @@ -1,23 +0,0 @@ -from enum import Enum -from typing import Literal - - -class ConversationState(Enum): - """对话状态""" - - INIT = "初始化" - RETHINKING = "重新思考" - ANALYZING = "分析历史" - PLANNING = "规划目标" - GENERATING = "生成回复" - CHECKING = "检查回复" - SENDING = "发送消息" - FETCHING = "获取知识" - WAITING = "等待" - LISTENING = "倾听" - ENDED = "结束" - JUDGING = "判断" - IGNORED = "屏蔽" - - -ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] diff --git a/src/chat/brain_chat/PFC/pfc_utils.py b/src/chat/brain_chat/PFC/pfc_utils.py deleted file mode 100644 index b9e93ee5..00000000 --- a/src/chat/brain_chat/PFC/pfc_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import json -import re -from typing import Dict, Any, Optional, Tuple, List, Union -from src.common.logger import get_logger - -logger = get_logger("pfc_utils") - - -def get_items_from_json( - content: str, - private_name: str, - *items: str, - default_values: Optional[Dict[str, Any]] = None, - required_types: Optional[Dict[str, type]] = None, - allow_array: bool = True, -) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: - """从文本中提取JSON内容并获取指定字段 - - Args: - content: 包含JSON的文本 - private_name: 私聊名称 - *items: 要提取的字段名 - default_values: 字段的默认值,格式为 {字段名: 默认值} - required_types: 字段的必需类型,格式为 {字段名: 类型} - allow_array: 是否允许解析JSON数组 - - Returns: - Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表) - """ - content = content.strip() - result = {} - - # 设置默认值 - if default_values: - result.update(default_values) - - # 首先尝试解析为JSON数组 - if allow_array: - try: - # 尝试找到文本中的JSON数组 - array_pattern = r"\[[\s\S]*\]" - array_match = re.search(array_pattern, content) - if array_match: - array_content = array_match.group() - json_array = json.loads(array_content) - - # 确认是数组类型 - if isinstance(json_array, list): - # 验证数组中的每个项目是否包含所有必需字段 - valid_items = [] - for item in json_array: - if not isinstance(item, dict): - continue - - # 检查是否有所有必需字段 - if all(field in item for field in items): - # 验证字段类型 - if required_types: - type_valid = True - for field, expected_type in required_types.items(): - if field in item and not isinstance(item[field], expected_type): - type_valid = False - break - - if not type_valid: - continue - - # 验证字符串字段不为空 - string_valid = True - for field in items: - if isinstance(item[field], str) and not item[field].strip(): - string_valid = False - break - - if not string_valid: - continue - - valid_items.append(item) - - if valid_items: - return True, valid_items - except json.JSONDecodeError: - logger.debug(f"[私聊][{private_name}]JSON数组解析失败,尝试解析单个JSON对象") - except Exception as e: - logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}") - - # 尝试解析JSON对象 - try: - json_data = json.loads(content) - except json.JSONDecodeError: - # 如果直接解析失败,尝试查找和提取JSON部分 - json_pattern = r"\{[^{}]*\}" - json_match = re.search(json_pattern, content) - if json_match: - try: - json_data = json.loads(json_match.group()) - except json.JSONDecodeError: - logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败") - return False, result - else: - logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON") - return False, result - - # 提取字段 - for item in items: - if item in json_data: - result[item] = json_data[item] - - # 验证必需字段 - if not all(item in result for item in items): - logger.error(f"[私聊][{private_name}]JSON缺少必要字段,实际内容: {json_data}") - return False, result - - # 验证字段类型 - if required_types: - for field, expected_type in required_types.items(): - if field in result and not isinstance(result[field], expected_type): - logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型") - return False, result - - # 验证字符串字段不为空 - for field in items: - if isinstance(result[field], str) and not result[field].strip(): - logger.error(f"[私聊][{private_name}]{field} 不能为空") - return False, result - - return True, result diff --git a/src/chat/brain_chat/PFC/reply_checker.py b/src/chat/brain_chat/PFC/reply_checker.py deleted file mode 100644 index c6304b30..00000000 --- a/src/chat/brain_chat/PFC/reply_checker.py +++ /dev/null @@ -1,198 +0,0 @@ -import json -import random -from typing import Tuple, List, Dict, Any -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from .chat_observer import ChatObserver -from maim_message import UserInfo - -logger = get_logger("reply_checker") - - -class ReplyChecker: - """回复检查器""" - - def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check") - self.personality_info = self._get_personality_prompt() - self.name = global_config.bot.nickname - self.private_name = private_name - self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - self.max_retries = 3 # 最大重试次数 - - def _get_personality_prompt(self) -> str: - """获取个性提示信息""" - prompt_personality = global_config.personality.personality - - # 检查是否需要随机替换为状态 - if ( - global_config.personality.states - and global_config.personality.state_probability > 0 - and random.random() < global_config.personality.state_probability - ): - prompt_personality = random.choice(global_config.personality.states) - - bot_name = global_config.bot.nickname - return f"你的名字是{bot_name},你{prompt_personality};" - - async def check( - self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0 - ) -> Tuple[bool, str, bool]: - """检查生成的回复是否合适 - - Args: - reply: 生成的回复 - goal: 对话目标 - chat_history: 对话历史记录 - chat_history_text: 对话历史记录文本 - retry_count: 当前重试次数 - - Returns: - Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划) - """ - # 不再从 observer 获取,直接使用传入的 chat_history - # messages = self.chat_observer.get_cached_messages(limit=20) - try: - # 筛选出最近由 Bot 自己发送的消息 - bot_messages = [] - for msg in reversed(chat_history): - user_info = UserInfo.from_dict(msg.get("user_info", {})) - if str(user_info.user_id) == str(global_config.bot.qq_account): - bot_messages.append(msg.get("processed_plain_text", "")) - if len(bot_messages) >= 2: # 只和最近的两条比较 - break - # 进行比较 - if bot_messages: - # 可以用简单比较,或者更复杂的相似度库 (如 difflib) - # 简单比较:是否完全相同 - if reply == bot_messages[0]: # 和最近一条完全一样 - logger.warning( - f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'" - ) - return ( - False, - "被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待", - True, - ) # 不合适,需要返回至决策层 - # 2. 相似度检查 (如果精确匹配未通过) - import difflib # 导入 difflib 库 - - # 计算编辑距离相似度,ratio() 返回 0 到 1 之间的浮点数 - similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio() - logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}") - - # 设置一个相似度阈值 - similarity_threshold = 0.9 - if similarity_ratio > similarity_threshold: - logger.warning( - f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'" - ) - return ( - False, - f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。", - True, - ) - - except Exception as e: - import traceback - - logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}") - logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息 - - prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适: - -当前对话目标:{goal} -最新的对话记录: -{chat_history_text} - -待检查的消息: -{reply} - -请结合聊天记录检查以下几点: -1. 这条消息是否依然符合当前对话目标和实现方式 -2. 这条消息是否与最新的对话记录保持一致性 -3. 是否存在重复发言,或重复表达同质内容(尤其是只是换一种方式表达了相同的含义) -4. 这条消息是否包含违规内容(例如血腥暴力,政治敏感等) -5. 这条消息是否以发送者的角度发言(不要让发送者自己回复自己的消息) -6. 这条消息是否通俗易懂 -7. 这条消息是否有些多余,例如在对方没有回复的情况下,依然连续多次“消息轰炸”(尤其是已经连续发送3条信息的情况,这很可能不合理,需要着重判断) -8. 这条消息是否使用了完全没必要的修辞 -9. 这条消息是否逻辑通顺 -10. 这条消息是否太过冗长了(通常私聊的每条消息长度在20字以内,除非特殊情况) -11. 在连续多次发送消息的情况下,这条消息是否衔接自然,会不会显得奇怪(例如连续两条消息中部分内容重叠) - -请以JSON格式输出,包含以下字段: -1. suitable: 是否合适 (true/false) -2. reason: 原因说明 -3. need_replan: 是否需要重新决策 (true/false),当你认为此时已经不适合发消息,需要规划其它行动时,设为true - -输出格式示例: -{{ - "suitable": true, - "reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体", - "need_replan": false -}} - -注意:请严格按照JSON格式输出,不要包含任何其他内容。""" - - try: - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}") - - # 清理内容,尝试提取JSON部分 - content = content.strip() - try: - # 尝试直接解析 - result: dict = json.loads(content) - except json.JSONDecodeError: - # 如果直接解析失败,尝试查找和提取JSON部分 - import re - - json_pattern = r"\{[^{}]*\}" - json_match = re.search(json_pattern, content) - if json_match: - try: - result: dict = json.loads(json_match.group()) - except json.JSONDecodeError: - # 如果JSON解析失败,尝试从文本中提取结果 - is_suitable = "不合适" not in content.lower() and "违规" not in content.lower() - reason = content[:100] if content else "无法解析响应" - need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower() - return is_suitable, reason, need_replan - else: - # 如果找不到JSON,从文本中判断 - is_suitable = "不合适" not in content.lower() and "违规" not in content.lower() - reason = content[:100] if content else "无法解析响应" - need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower() - return is_suitable, reason, need_replan - - # 验证JSON字段 - suitable = result.get("suitable", None) - reason = result.get("reason", "未提供原因") - need_replan = result.get("need_replan", False) - - # 如果suitable字段是字符串,转换为布尔值 - if isinstance(suitable, str): - suitable = suitable.lower() == "true" - - # 如果suitable字段不存在或不是布尔值,从reason中判断 - if suitable is None: - suitable = "不合适" not in reason.lower() and "违规" not in reason.lower() - - # 如果不合适且未达到最大重试次数,返回需要重试 - if not suitable and retry_count < self.max_retries: - return False, reason, False - - # 如果不合适且已达到最大重试次数,返回需要重新规划 - if not suitable and retry_count >= self.max_retries: - return False, f"多次重试后仍不合适: {reason}", True - - return suitable, reason, need_replan - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}") - # 如果出错且已达到最大重试次数,建议重新规划 - if retry_count >= self.max_retries: - return False, "多次检查失败,建议重新规划", True - return False, f"检查过程出错,建议重试: {str(e)}", False diff --git a/src/chat/brain_chat/PFC/reply_generator.py b/src/chat/brain_chat/PFC/reply_generator.py deleted file mode 100644 index 95853e26..00000000 --- a/src/chat/brain_chat/PFC/reply_generator.py +++ /dev/null @@ -1,242 +0,0 @@ -from typing import Tuple, List, Dict, Any -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -import random -from .chat_observer import ChatObserver -from .reply_checker import ReplyChecker -from src.services.message_service import build_readable_messages - -from .observation_info import ObservationInfo, dict_to_session_message -from .conversation_info import ConversationInfo - -logger = get_logger("reply_generator") - -# --- 定义 Prompt 模板 --- - -# Prompt for direct_reply (首次回复) -PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下信息生成一条回复: - -当前对话目标:{goals_str} - -{knowledge_info_str} - -最近的聊天记录: -{chat_history_text} - - -请根据上述信息,结合聊天记录,回复对方。该回复应该: -1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!) -2. 符合你的性格特征和身份细节 -3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况) -4. 可以适当利用相关知识,但不要生硬引用 -5. 自然、得体,结合聊天记录逻辑合理,且没有重复表达同质内容 - -请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。 -可以回复得自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。 -请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 -不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。 - -请直接输出回复内容,不需要任何额外格式。""" - -# Prompt for send_new_message (追问/补充) -PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊,**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息: - -当前对话目标:{goals_str} - -{knowledge_info_str} - -最近的聊天记录: -{chat_history_text} - - -请根据上述信息,结合聊天记录,继续发一条新消息(例如对之前消息的补充,深入话题,或追问等等)。该消息应该: -1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!) -2. 符合你的性格特征和身份细节 -3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况) -4. 可以适当利用相关知识,但不要生硬引用 -5. 跟之前你发的消息自然的衔接,逻辑合理,且没有重复表达同质内容或部分重叠内容 - -请注意把握聊天内容,不用太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。 -这条消息可以自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。 -请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出消息内容。 -不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。 - -请直接输出回复内容,不需要任何额外格式。""" - -# Prompt for say_goodbye (告别语生成) -PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。 - -最近的聊天记录: -{chat_history_text} - -请根据上述信息,结合聊天记录,构思一条**简短、自然、符合你人设**的最后的消息。 -这条消息应该: -1. 从你自己的角度发言。 -2. 符合你的性格特征和身份细节。 -3. 通俗易懂,自然流畅,通常很简短。 -4. 自然地为这场对话画上句号,避免开启新话题或显得冗长、刻意。 - -请像真人一样随意自然,**简洁是关键**。 -不要输出多余内容(包括前后缀、冒号、引号、括号、表情包、at或@等)。 - -请直接输出最终的告别消息内容,不需要任何额外格式。""" - - -class ReplyGenerator: - """回复生成器""" - - def __init__(self, stream_id: str, private_name: str): - self.llm = LLMRequest( - model_set=model_config.model_task_config.replyer, - request_type="reply_generation", - ) - self.personality_info = self._get_personality_prompt() - self.name = global_config.bot.nickname - self.private_name = private_name - self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - self.reply_checker = ReplyChecker(stream_id, private_name) - - def _get_personality_prompt(self) -> str: - """获取个性提示信息""" - prompt_personality = global_config.personality.personality - - # 检查是否需要随机替换为状态 - if ( - global_config.personality.states - and global_config.personality.state_probability > 0 - and random.random() < global_config.personality.state_probability - ): - prompt_personality = random.choice(global_config.personality.states) - - bot_name = global_config.bot.nickname - return f"你的名字是{bot_name},你{prompt_personality};" - - # 修改 generate 方法签名,增加 action_type 参数 - async def generate( - self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str - ) -> str: - """生成回复 - - Args: - observation_info: 观察信息 - conversation_info: 对话信息 - action_type: 当前执行的动作类型 ('direct_reply' 或 'send_new_message') - - Returns: - str: 生成的回复 - """ - # 构建提示词 - logger.debug( - f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}" - ) - - # --- 构建通用 Prompt 参数 --- - # (这部分逻辑基本不变) - - # 构建对话目标 (goals_str) - goals_str = "" - if conversation_info.goal_list: - for goal_reason in conversation_info.goal_list: - if isinstance(goal_reason, dict): - goal = goal_reason.get("goal", "目标内容缺失") - reasoning = goal_reason.get("reasoning", "没有明确原因") - else: - goal = str(goal_reason) - reasoning = "没有明确原因" - - goal = str(goal) if goal is not None else "目标内容缺失" - reasoning = str(reasoning) if reasoning is not None else "没有明确原因" - goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n" - else: - goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况 - - # --- 新增:构建知识信息字符串 --- - knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考 - try: - # 检查 conversation_info 是否有 knowledge_list 并且不为空 - if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list: - # 最多只显示最近的 5 条知识 - recent_knowledge = conversation_info.knowledge_list[-5:] - for i, knowledge_item in enumerate(recent_knowledge): - if isinstance(knowledge_item, dict): - query = knowledge_item.get("query", "未知查询") - knowledge = knowledge_item.get("knowledge", "无知识内容") - source = knowledge_item.get("source", "未知来源") - # 只取知识内容的前 2000 个字 - knowledge_snippet = f"{knowledge[:2000]}..." if len(knowledge) > 2000 else knowledge - knowledge_info_str += ( - f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁 - ) - else: - knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n" - - if not recent_knowledge: - knowledge_info_str += "- 暂无。\n" # 更简洁的提示 - - else: - knowledge_info_str += "- 暂无。\n" - except AttributeError: - logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。") - knowledge_info_str += "- 获取知识列表时出错。\n" - except Exception as e: - logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}") - knowledge_info_str += "- 处理知识列表时出错。\n" - - # 获取聊天历史记录 (chat_history_text) - chat_history_text = observation_info.chat_history_str - if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages: - new_messages_list = observation_info.unprocessed_messages - session_messages = [dict_to_session_message(m) for m in new_messages_list] - new_messages_str = build_readable_messages( - session_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - ) - chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}" - elif not chat_history_text: - chat_history_text = "还没有聊天记录。" - - # 构建 Persona 文本 (persona_text) - persona_text = f"你的名字是{self.name},{self.personality_info}。" - - # --- 选择 Prompt --- - if action_type == "send_new_message": - prompt_template = PROMPT_SEND_NEW_MESSAGE - logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)") - elif action_type == "say_goodbye": # 处理告别动作 - prompt_template = PROMPT_FAREWELL - logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)") - else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型) - prompt_template = PROMPT_DIRECT_REPLY - logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)") - - # --- 格式化最终的 Prompt --- - prompt = prompt_template.format( - persona_text=persona_text, - goals_str=goals_str, - chat_history_text=chat_history_text, - knowledge_info_str=knowledge_info_str, - ) - - # --- 调用 LLM 生成 --- - logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------") - try: - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}") - # 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理 - return content - - except Exception as e: - logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}") - return "抱歉,我现在有点混乱,让我重新思考一下..." - - # check_reply 方法保持不变 - async def check_reply( - self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0 - ) -> Tuple[bool, str, bool]: - """检查回复是否合适 - (此方法逻辑保持不变) - """ - return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count) diff --git a/src/chat/brain_chat/PFC/waiter.py b/src/chat/brain_chat/PFC/waiter.py deleted file mode 100644 index b93b84d9..00000000 --- a/src/chat/brain_chat/PFC/waiter.py +++ /dev/null @@ -1,79 +0,0 @@ -from src.common.logger import get_logger -from .chat_observer import ChatObserver -from .conversation_info import ConversationInfo - -# from src.individuality.individuality import Individuality # 不再需要 -from src.config.config import global_config -import time -import asyncio - -logger = get_logger("waiter") - -# --- 在这里设定你想要的超时时间(秒) --- -# 例如: 120 秒 = 2 分钟 -DESIRED_TIMEOUT_SECONDS = 300 - - -class Waiter: - """等待处理类""" - - def __init__(self, stream_id: str, private_name: str): - self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - self.name = global_config.bot.nickname - self.private_name = private_name - # self.wait_accumulated_time = 0 # 不再需要累加计时 - - async def wait(self, conversation_info: ConversationInfo) -> bool: - """等待用户新消息或超时""" - wait_start_time = time.time() - logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...") - - while True: - # 检查是否有新消息 - if self.chat_observer.new_message_after(wait_start_time): - logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息") - return False # 返回 False 表示不是超时 - - # 检查是否超时 - elapsed_time = time.time() - wait_start_time - if elapsed_time > DESIRED_TIMEOUT_SECONDS: - logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。") - wait_goal = { - "goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么", - "reasoning": "对方很久没有回复你的消息了", - } - conversation_info.goal_list.append(wait_goal) - logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}") - return True # 返回 True 表示超时 - - await asyncio.sleep(5) # 每 5 秒检查一次 - logger.debug( - f"[私聊][{self.private_name}]等待中..." - ) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出 - - async def wait_listening(self, conversation_info: ConversationInfo) -> bool: - """倾听用户发言或超时""" - wait_start_time = time.time() - logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...") - - while True: - # 检查是否有新消息 - if self.chat_observer.new_message_after(wait_start_time): - logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息") - return False # 返回 False 表示不是超时 - - # 检查是否超时 - elapsed_time = time.time() - wait_start_time - if elapsed_time > DESIRED_TIMEOUT_SECONDS: - logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。") - wait_goal = { - # 保持 goal 文本一致 - "goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么", - "reasoning": "对方话说一半消失了,很久没有回复", - } - conversation_info.goal_list.append(wait_goal) - logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}") - return True # 返回 True 表示超时 - - await asyncio.sleep(5) # 每 5 秒检查一次 - logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉 diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py deleted file mode 100644 index 1e9e648a..00000000 --- a/src/chat/brain_chat/brain_chat.py +++ /dev/null @@ -1,800 +0,0 @@ -import asyncio -import random -import time -import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from rich.traceback import install - -from src.config.config import global_config -from src.common.logger import get_logger -from src.common.utils.utils_config import ExpressionConfigUtils -from src.learners.expression_learner import ExpressionLearner -from src.learners.jargon_miner import JargonMiner -from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent -from src.chat.utils.prompt_builder import global_prompt_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.brain_chat.brain_planner import BrainPlanner -from src.chat.planner_actions.action_modifier import ActionModifier -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.heart_flow.heartFC_utils import CycleDetail -from src.person_info.person_info import Person -from src.core.types import ActionInfo, EventType -from src.core.event_bus import event_bus -from src.chat.event_helpers import build_event_message -from src.services import ( - generator_service as generator_api, - send_service as send_api, - message_service as message_api, - database_service as database_api, -) -from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat - -if TYPE_CHECKING: - from src.chat.message_receive.message import SessionMessage - - -ERROR_LOOP_INFO = { - "loop_plan_info": { - "action_result": { - "action_type": "error", - "action_data": {}, - "reasoning": "循环处理失败", - }, - }, - "loop_action_info": { - "action_taken": False, - "reply_text": "", - "command": "", - "taken_time": time.time(), - }, -} - - -install(extra_lines=3) - -# 注释:原来的动作修改超时常量已移除,因为改为顺序执行 - -logger = get_logger("bc") # Logger Name Changed - - -class BrainChatting: - """ - 管理一个连续的私聊Brain Chat循环 - 用于在特定聊天流中生成回复。 - """ - - def __init__(self, session_id: str): - """ - BrainChatting 初始化函数 - - 参数: - chat_id: 聊天流唯一标识符(如stream_id) - on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数 - performance_version: 性能记录版本号,用于区分不同启动版本 - """ - # 基础属性 - self.stream_id: str = session_id # 聊天流ID - self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore[assignment] - if not self.chat_stream: - raise ValueError(f"无法找到聊天流: {self.stream_id}") - self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]" - - expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.stream_id) - self._enable_expression_use = expr_use - self._enable_expression_learning = expr_learn - self._enable_jargon_learning = jargon_learn - self._expression_learner = ExpressionLearner(self.stream_id) - self._jargon_miner = JargonMiner(self.stream_id, session_name=self.log_prefix.strip("[]")) - self._min_messages_for_extraction = 30 - self._min_extraction_interval = 60 - self._last_extraction_time = 0.0 - - self.action_manager = ActionManager() - self.action_planner = BrainPlanner(chat_id=self.stream_id, action_manager=self.action_manager) - self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id) - - # 循环控制内部状态 - self.running: bool = False - self._loop_task: Optional[asyncio.Task] = None # 主循环任务 - self._new_message_event = asyncio.Event() # 新消息事件,用于打断 wait - - # 添加循环信息管理相关的属性 - self.history_loop: List[CycleDetail] = [] - self._cycle_counter = 0 - self._current_cycle_detail: CycleDetail = None # type: ignore - - self.last_read_time = time.time() - 2 - - self.more_plan = False - - # 最近一次是否成功进行了 reply,用于选择 BrainPlanner 的 Prompt - self._last_successful_reply: bool = False - - async def start(self): - """检查是否需要启动主循环,如果未激活则启动。""" - - # 如果循环已经激活,直接返回 - if self.running: - logger.debug(f"{self.log_prefix} BrainChatting 已激活,无需重复启动") - return - - try: - # 标记为活动状态,防止重复启动 - self.running = True - - self._loop_task = asyncio.create_task(self._main_chat_loop()) - self._loop_task.add_done_callback(self._handle_loop_completion) - logger.info(f"{self.log_prefix} BrainChatting 启动完成") - - except Exception as e: - # 启动失败时重置状态 - self.running = False - self._loop_task = None - logger.error(f"{self.log_prefix} BrainChatting 启动失败: {e}") - raise - - def _handle_loop_completion(self, task: asyncio.Task): - """当 _hfc_loop 任务完成时执行的回调。""" - try: - if exception := task.exception(): - logger.error(f"{self.log_prefix} BrainChatting: 脱离了聊天(异常): {exception}") - logger.error(traceback.format_exc()) # Log full traceback for exceptions - else: - logger.info(f"{self.log_prefix} BrainChatting: 脱离了聊天 (外部停止)") - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} BrainChatting: 结束了聊天") - - def start_cycle(self) -> Tuple[Dict[str, float], str]: - self._cycle_counter += 1 - self._current_cycle_detail = CycleDetail(self._cycle_counter) - self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" - cycle_timers = {} - return cycle_timers, self._current_cycle_detail.thinking_id - - def end_cycle(self, loop_info, cycle_timers): - self._current_cycle_detail.set_loop_info(loop_info) - self.history_loop.append(self._current_cycle_detail) - self._current_cycle_detail.timers = cycle_timers - self._current_cycle_detail.end_time = time.time() - - def print_cycle_info(self, cycle_timers): - # 记录循环信息和计时器结果 - timer_strings = [] - for name, elapsed in cycle_timers.items(): - formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒" - timer_strings.append(f"{name}: {formatted_time}") - - logger.info( - f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," - f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore - + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") - ) - - async def _trigger_expression_learning(self, messages: List[SessionMessage]) -> None: - if not messages: - return - - self._expression_learner.add_messages(messages) - if time.time() - self._last_extraction_time < self._min_extraction_interval: - return - if self._expression_learner.get_cache_size() < self._min_messages_for_extraction: - return - if not self._enable_expression_learning: - return - - self._last_extraction_time = time.time() - try: - jargon_miner = self._jargon_miner if self._enable_jargon_learning else None - await self._expression_learner.learn(jargon_miner) - except Exception as exc: - logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True) - - async def _loopbody(self): # sourcery skip: hoist-if-from-if - # 获取最新消息(用于上下文,但不影响是否调用 observe) - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.stream_id, - start_time=self.last_read_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=1, - ) - - # 如果有新消息,更新 last_read_time 并触发事件以打断正在进行的 wait - if len(recent_messages_list) >= 1: - self.last_read_time = time.time() - self._new_message_event.set() # 触发新消息事件,打断 wait - - # 总是执行一次思考迭代(不管有没有新消息) - # wait 动作会在其内部等待,不需要在这里处理 - should_continue = await self._observe(recent_messages_list=recent_messages_list) - - if not should_continue: - # 选择了 complete_talk,返回 False 表示需要等待新消息 - return False - - # 继续下一次迭代(除非选择了 complete_talk) - # 短暂等待后再继续,避免过于频繁的循环 - await asyncio.sleep(0.1) - - return True - - async def _send_and_store_reply( - self, - response_set: MessageSequence, - action_message: SessionMessage, - cycle_timers: Dict[str, float], - thinking_id, - actions, - selected_expressions: Optional[List[int]] = None, - ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): - reply_text = await self._send_response( - reply_set=response_set, - message_data=action_message, - selected_expressions=selected_expressions, - ) - - # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - platform = action_message.platform - if platform is None: - platform = getattr(self.chat_stream, "platform", "unknown") - - person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id) - person_name = person.person_name - action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" - - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=action_prompt_display, - thinking_id=thinking_id, - action_data={"reply_text": reply_text}, - action_name="reply", - ) - - # 构建循环信息 - loop_info: Dict[str, Any] = { - "loop_plan_info": { - "action_result": actions, - }, - "loop_action_info": { - "action_taken": True, - "reply_text": reply_text, - "command": "", - "taken_time": time.time(), - }, - } - - return loop_info, reply_text, cycle_timers - - async def _observe( - self, # interest_value: float = 0.0, - recent_messages_list: Optional[List[SessionMessage]] = None, - ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if - if recent_messages_list is None: - recent_messages_list = [] - _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - - async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): - if recent_messages_list: - asyncio.create_task(self._trigger_expression_learning(recent_messages_list)) - - cycle_timers, thinking_id = self.start_cycle() - logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") - - # 第一步:动作检查 - available_actions: Dict[str, ActionInfo] = {} - try: - await self.action_modifier.modify_actions() - available_actions = self.action_manager.get_using_actions() - except Exception as e: - logger.error(f"{self.log_prefix} 动作修改失败: {e}") - - # 获取必要信息 - is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() - - # 一次思考迭代:Think - Act - Observe - # 获取聊天上下文 - message_list_before_now = get_messages_before_time_in_chat( - chat_id=self.stream_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.action_planner.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - prompt_info = await self.action_planner.build_planner_prompt( - chat_target_info=chat_target_info, - current_available_actions=available_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - prompt_key="brain_planner", - ) - _event_msg = build_event_message( - EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.session_id - ) - continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg) - if not continue_flag: - return False - if modified_message and modified_message._modify_flags.modify_llm_prompt: - prompt_info = (modified_message.llm_prompt, prompt_info[1]) - - with Timer("规划器", cycle_timers): - action_to_use_info = await self.action_planner.plan( - loop_start_time=self.last_read_time, - available_actions=available_actions, - ) - - # 检查是否有 complete_talk 动作(会停止后续迭代) - has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info) - - # 并行执行所有动作 - action_tasks = [ - asyncio.create_task( - self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) - ) - for action in action_to_use_info - ] - - # 并行执行所有任务 - results = await asyncio.gather(*action_tasks, return_exceptions=True) - - # 处理执行结果 - reply_loop_info = None - reply_text_from_reply = "" - action_success = False - action_reply_text = "" - - for result in results: - if isinstance(result, BaseException): - logger.error(f"{self.log_prefix} 动作执行异常: {result}") - continue - - if result["action_type"] != "reply": - action_success = result["success"] - action_reply_text = result["reply_text"] - elif result["action_type"] == "reply": - if result["success"]: - reply_loop_info = result["loop_info"] - reply_text_from_reply = result["reply_text"] - else: - logger.warning(f"{self.log_prefix} 回复动作执行失败") - - # 更新观察时间标记 - self.action_planner.last_obs_time_mark = time.time() - - # 如果选择了 complete_talk,标记为完成,不再继续迭代 - if has_complete_talk: - logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成") - - # 构建循环信息 - if reply_loop_info: - # 如果有回复信息,使用回复的loop_info作为基础 - loop_info = reply_loop_info - # 更新动作执行信息 - loop_info["loop_action_info"].update( - { - "action_taken": action_success, - "taken_time": time.time(), - } - ) - _reply_text = reply_text_from_reply - else: - # 没有回复信息,构建纯动作的loop_info - loop_info = { - "loop_plan_info": { - "action_result": action_to_use_info, - }, - "loop_action_info": { - "action_taken": action_success, - "reply_text": action_reply_text, - "taken_time": time.time(), - }, - } - _reply_text = action_reply_text - - # 如果选择了 complete_talk,返回 False 以停止 _loopbody 的循环 - # 否则返回 True,让 _loopbody 继续下一次迭代 - should_continue = not has_complete_talk - - self.end_cycle(loop_info, cycle_timers) - self.print_cycle_info(cycle_timers) - - # 如果选择了 complete_talk,返回 False 停止循环 - # 否则返回 True,继续下一次思考迭代 - return should_continue - - async def _main_chat_loop(self): - """主循环,持续进行计划并可能回复消息,直到被外部取消。""" - try: - while self.running: - # 主循环 - success = await self._loopbody() - if not success: - # 选择了 complete,等待新消息 - logger.info(f"{self.log_prefix} 选择了 complete,等待新消息...") - await self._wait_for_new_message() - # 有新消息后继续循环 - continue - await asyncio.sleep(0.1) - except asyncio.CancelledError: - # 设置了关闭标志位后被取消是正常流程 - logger.info(f"{self.log_prefix} 麦麦已关闭聊天") - except Exception: - logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动") - print(traceback.format_exc()) - await asyncio.sleep(3) - self._loop_task = asyncio.create_task(self._main_chat_loop()) - logger.error(f"{self.log_prefix} 结束了当前聊天循环") - - async def _wait_for_new_message(self): - """等待新消息到达""" - last_check_time = self.last_read_time - check_interval = 1.0 # 每秒检查一次 - - # 清除事件状态,准备等待新消息 - self._new_message_event.clear() - - while self.running: - # 检查是否有新消息 - recent_messages_list = message_api.get_messages_by_time_in_chat( - chat_id=self.stream_id, - start_time=last_check_time, - end_time=time.time(), - limit=20, - limit_mode="latest", - filter_mai=True, - filter_command=False, - filter_intercept_message_level=1, - ) - - # 如果有新消息,更新 last_read_time 并返回 - if len(recent_messages_list) >= 1: - self.last_read_time = time.time() - logger.info(f"{self.log_prefix} 检测到新消息,恢复循环") - return - - # 等待新消息事件或超时后再次检查 - try: - await asyncio.wait_for(self._new_message_event.wait(), timeout=check_interval) - # 事件被触发,说明有新消息 - logger.info(f"{self.log_prefix} 检测到新消息事件,恢复循环") - return - except asyncio.TimeoutError: - # 超时后继续检查 - continue - - async def _handle_action( - self, - action: str, - reasoning: str, - action_data: dict, - cycle_timers: Dict[str, float], - thinking_id: str, - action_message: Optional[SessionMessage] = None, - ) -> tuple[bool, str, str]: - """ - 处理规划动作,使用动作工厂创建相应的动作处理器 - - 参数: - action: 动作类型 - reasoning: 决策理由 - action_data: 动作数据,包含不同动作需要的参数 - cycle_timers: 计时器字典 - thinking_id: 思考ID - - 返回: - tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令) - """ - try: - # 使用工厂创建动作处理器实例 - try: - action_handler = self.action_manager.create_action( - action_name=action, - action_data=action_data, - action_reasoning=reasoning, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=self.chat_stream, - log_prefix=self.log_prefix, - action_message=action_message, - ) - except Exception as e: - logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}") - traceback.print_exc() - return False, "", "" - - if not action_handler: - logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}") - return False, "", "" - - # 处理动作并获取结果(固定记录一次动作信息) - # BaseAction 定义了异步方法 execute() 作为统一执行入口 - # 这里调用 execute() 以兼容所有 Action 实现 - result = await action_handler.execute() - success, action_text = result - command = "" - - return success, action_text, command - - except Exception as e: - logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") - traceback.print_exc() - return False, "", "" - - async def _send_response( - self, - reply_set: MessageSequence, - message_data: SessionMessage, - selected_expressions: Optional[List[int]] = None, - ) -> str: - new_message_count = message_api.count_new_messages( - chat_id=self.chat_stream.session_id, start_time=self.last_read_time, end_time=time.time() - ) - - need_reply = new_message_count >= random.randint(2, 4) - - if need_reply: - logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复") - - reply_text = "" - first_replied = False - for component in reply_set.components: - if not isinstance(component, TextComponent): - continue - data = component.text - if not first_replied: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.session_id, - reply_message=message_data, - set_reply=need_reply, - typing=False, - selected_expressions=selected_expressions, - ) - first_replied = True - else: - await send_api.text_to_stream( - text=data, - stream_id=self.chat_stream.session_id, - reply_message=message_data, - set_reply=False, - typing=True, - selected_expressions=selected_expressions, - ) - reply_text += data - - return reply_text - - async def _execute_action( - self, - action_planner_info: ActionPlannerInfo, - chosen_action_plan_infos: List[ActionPlannerInfo], - thinking_id: str, - available_actions: Dict[str, ActionInfo], - cycle_timers: Dict[str, float], - ): - """执行单个动作的通用函数""" - try: - with Timer(f"动作{action_planner_info.action_type}", cycle_timers): - if action_planner_info.action_type == "complete_talk": - # 直接处理complete_talk逻辑,不再通过动作系统 - reason = action_planner_info.reasoning or "选择完成对话" - logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}") - - # 存储complete_talk信息到数据库 - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason, - thinking_id=thinking_id, - action_data={"reason": reason}, - action_name="complete_talk", - ) - return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""} - - elif action_planner_info.action_type == "reply": - try: - # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用) - unknown_words = None - if isinstance(action_planner_info.action_data, dict): - uw = action_planner_info.action_data.get("unknown_words") - if isinstance(uw, list): - cleaned_uw: List[str] = [] - for item in uw: - if isinstance(item, str): - if stripped_item := item.strip(): - cleaned_uw.append(stripped_item) - if cleaned_uw: - unknown_words = cleaned_uw - - success, llm_response = await generator_api.generate_reply( - chat_stream=self.chat_stream, - reply_message=action_planner_info.action_message, - available_actions=available_actions, - chosen_actions=chosen_action_plan_infos, - reply_reason=action_planner_info.reasoning or "", - unknown_words=unknown_words, - enable_tool=global_config.tool.enable_tool, - request_type="replyer", - from_plugin=False, - ) - - if not success or not llm_response or not llm_response.reply_set: - if action_planner_info.action_message: - logger.info( - f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败" - ) - else: - logger.info("回复生成失败") - return { - "action_type": "reply", - "success": False, - "reply_text": "", - "loop_info": None, - } - - except asyncio.CancelledError: - logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") - return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} - - response_set = llm_response.reply_set - selected_expressions = llm_response.selected_expressions - loop_info, reply_text, _ = await self._send_and_store_reply( - response_set=response_set, - action_message=action_planner_info.action_message, # type: ignore - cycle_timers=cycle_timers, - thinking_id=thinking_id, - actions=chosen_action_plan_infos, - selected_expressions=selected_expressions, - ) - # 标记这次循环已经成功进行了回复 - self._last_successful_reply = True - return { - "action_type": "reply", - "success": True, - "reply_text": reply_text, - "loop_info": loop_info, - } - - # 其他动作 - else: - # 内建 wait / listening:不通过插件系统,直接在这里处理 - if action_planner_info.action_type in ["wait", "listening"]: - reason = action_planner_info.reasoning or "" - action_data = action_planner_info.action_data or {} - - if action_planner_info.action_type == "wait": - # 获取等待时间(必填) - wait_seconds = action_data.get("wait_seconds") - if wait_seconds is None: - logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒") - wait_seconds = 5 - else: - try: - wait_seconds = float(wait_seconds) - if wait_seconds < 0: - logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒") - wait_seconds = 5 - except (ValueError, TypeError): - logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒") - wait_seconds = 5 - - logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒(可被新消息打断)") - - # 清除事件状态,准备等待新消息 - self._new_message_event.clear() - - # 记录动作信息 - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason or f"等待 {wait_seconds} 秒", - thinking_id=thinking_id, - action_data={"reason": reason, "wait_seconds": wait_seconds}, - action_name="wait", - ) - - # 等待指定时间,但可被新消息打断 - try: - await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds) - # 如果事件被触发,说明有新消息到达 - logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待") - except asyncio.TimeoutError: - # 超时正常完成 - pass - - logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考") - - # 这些动作本身不产生文本回复 - self._last_successful_reply = False - return { - "action_type": "wait", - "success": True, - "reply_text": "", - "command": "", - } - - # listening 已合并到 wait,如果遇到则转换为 wait(向后兼容) - elif action_planner_info.action_type == "listening": - logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换") - # 使用默认等待时间 - wait_seconds = 3 - - logger.info( - f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)" - ) - - # 清除事件状态,准备等待新消息 - self._new_message_event.clear() - - # 记录动作信息 - await database_api.store_action_info( - chat_stream=self.chat_stream, - display_prompt=reason or f"倾听并等待 {wait_seconds} 秒", - thinking_id=thinking_id, - action_data={"reason": reason, "wait_seconds": wait_seconds}, - action_name="listening", - ) - - # 等待指定时间,但可被新消息打断 - try: - await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds) - # 如果事件被触发,说明有新消息到达 - logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待") - except asyncio.TimeoutError: - # 超时正常完成 - pass - - logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考") - - # 这些动作本身不产生文本回复 - self._last_successful_reply = False - return { - "action_type": "listening", - "success": True, - "reply_text": "", - "command": "", - } - - # 其余动作:走原有插件 Action 体系 - with Timer("动作执行", cycle_timers): - success, reply_text, command = await self._handle_action( - action_planner_info.action_type, - action_planner_info.reasoning or "", - action_planner_info.action_data or {}, - cycle_timers, - thinking_id, - action_planner_info.action_message, - ) - # 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt - if success and action_planner_info.action_type != "reply": - self._last_successful_reply = False - - return { - "action_type": action_planner_info.action_type, - "success": success, - "reply_text": reply_text, - "command": command, - } - - except Exception as e: - logger.error(f"{self.log_prefix} 执行动作时出错: {e}") - logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") - return { - "action_type": action_planner_info.action_type, - "success": False, - "reply_text": "", - "loop_info": None, - "error": str(e), - } diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py deleted file mode 100644 index 709be8ee..00000000 --- a/src/chat/brain_chat/brain_planner.py +++ /dev/null @@ -1,620 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple - -import json -import random -import re -import time -import traceback - -from json_repair import repair_json -from rich.traceback import install - -from src.chat.logger.plan_reply_logger import PlanReplyLogger -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.utils.utils import get_chat_type_and_target_info -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.logger import get_logger -from src.common.utils.utils_action import ActionUtils -from src.config.config import global_config, model_config -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.llm_models.utils_model import LLMRequest -from src.plugin_runtime.component_query import component_query_service -from src.prompt.prompt_manager import prompt_manager -from src.services.message_service import ( - build_readable_messages_with_id, - get_actions_by_timestamp_with_chat, - get_messages_before_time_in_chat, -) - -if TYPE_CHECKING: - from src.common.data_models.info_data_model import TargetPersonInfo - from src.chat.message_receive.message import SessionMessage - -logger = get_logger("planner") - -install(extra_lines=3) - - -class BrainPlanner: - def __init__(self, chat_id: str, action_manager: ActionManager): - self.chat_id = chat_id - self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" - self.action_manager = action_manager - # LLM规划器配置 - self.planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="planner" - ) # 用于动作规划 - - self.last_obs_time_mark = 0.0 - - # 计划日志记录 - self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = [] - - def find_message_by_id( - self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]] - ) -> Optional["SessionMessage"]: - # sourcery skip: use-next - """ - 根据message_id从message_id_list中查找对应的原始消息 - - Args: - message_id: 要查找的消息ID - message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...] - - Returns: - 找到的原始消息字典,如果未找到则返回None - """ - for item in message_id_list: - if item[0] == message_id: - return item[1] - return None - - def _parse_single_action( - self, - action_json: dict, - message_id_list: List[Tuple[str, "SessionMessage"]], - current_available_actions: List[Tuple[str, ActionInfo]], - ) -> List[ActionPlannerInfo]: - """解析单个action JSON并返回ActionPlannerInfo列表""" - action_planner_infos = [] - - try: - action = action_json.get("action", "complete_talk") - logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}") - reasoning = action_json.get("reason", "未提供原因") - action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]} - # 非complete_talk动作需要target_message_id - target_message = None - - if target_message_id := action_json.get("target_message_id"): - # 根据target_message_id查找原始消息 - target_message = self.find_message_by_id(target_message_id, message_id_list) - if target_message is None: - logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息") - # 选择最新消息作为target_message - target_message = message_id_list[-1][1] - else: - target_message = message_id_list[-1][1] - logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message") - - # 验证action是否可用 - available_action_names = [action_name for action_name, _ in current_available_actions] - # 内部保留动作(不依赖插件系统) - # 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait - internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"] - - logger.debug( - f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}" - ) - - # 将 listening 转换为 wait(向后兼容) - if action == "listening": - logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换") - action = "wait" - - if action not in internal_action_names and action not in available_action_names: - logger.warning( - f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'" - ) - reasoning = ( - f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}" - ) - action = "complete_talk" - logger.warning(f"{self.log_prefix}动作已转换为 complete_talk") - - # 创建ActionPlannerInfo对象 - # 将列表转换为字典格式 - available_actions_dict = dict(current_available_actions) - action_planner_infos.append( - ActionPlannerInfo( - action_type=action, - reasoning=reasoning, - action_data=action_data, - action_message=target_message, - available_actions=available_actions_dict, - ) - ) - - except Exception as e: - logger.error(f"{self.log_prefix}解析单个action时出错: {e}") - # 将列表转换为字典格式 - available_actions_dict = dict(current_available_actions) - action_planner_infos.append( - ActionPlannerInfo( - action_type="complete_talk", - reasoning=f"解析单个action时出错: {e}", - action_data={}, - action_message=None, - available_actions=available_actions_dict, - ) - ) - - return action_planner_infos - - async def plan( - self, - available_actions: Dict[str, ActionInfo], - loop_start_time: float = 0.0, - ) -> List[ActionPlannerInfo]: - # sourcery skip: use-named-expression - """ - 规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。 - """ - plan_start = time.perf_counter() - - # 获取聊天上下文 - message_list_before_now = get_messages_before_time_in_chat( - chat_id=self.chat_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - message_id_list: list[Tuple[str, "SessionMessage"]] = [] - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :] - chat_content_block_short, message_id_list_short = build_readable_messages_with_id( - messages=message_list_before_now_short, - timestamp_mode="normal_no_YMD", - truncate=False, - show_actions=False, - ) - - self.last_obs_time_mark = time.time() - - # 获取必要信息 - is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info() - - # 提及/被@ 的处理由心流或统一判定模块驱动;Planner 不再做硬编码强制回复 - - # 应用激活类型过滤 - filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short) - - logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作") - - prompt_build_start = time.perf_counter() - # 构建包含所有动作的提示词:使用统一的 ReAct Prompt - prompt_key = "brain_planner" - # 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt - prompt, message_id_list = await self.build_planner_prompt( - chat_target_info=chat_target_info, - current_available_actions=filtered_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - prompt_key=prompt_key, - ) - prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000 - - # 调用LLM获取决策 - reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner( - prompt=prompt, - message_id_list=message_id_list, - filtered_actions=filtered_actions, - available_actions=available_actions, - loop_start_time=loop_start_time, - ) - - # 记录和展示计划日志 - logger.info( - f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}" - ) - self.add_plan_log(reasoning, actions) - - try: - PlanReplyLogger.log_plan( - chat_id=self.chat_id, - prompt=prompt, - reasoning=reasoning, - raw_output=llm_raw_output, - raw_reasoning=llm_reasoning, - actions=actions, - timing={ - "prompt_build_ms": round(prompt_build_ms, 2), - "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None, - "total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2), - "loop_start_time": loop_start_time, - }, - extra=None, - ) - except Exception: - logger.exception(f"{self.log_prefix}记录plan日志失败") - - return actions - - async def build_planner_prompt( - self, - chat_target_info: Optional["TargetPersonInfo"], - current_available_actions: Dict[str, ActionInfo], - message_id_list: List[Tuple[str, "SessionMessage"]], - chat_content_block: str = "", - interest: str = "", - prompt_key: str = "brain_planner", - ) -> tuple[str, List[Tuple[str, "SessionMessage"]]]: - """构建 Planner LLM 的提示词 (获取模板并填充数据)""" - try: - # 获取最近执行过的动作 - actions_before_now = get_actions_by_timestamp_with_chat( - chat_id=self.chat_id, - timestamp_start=time.time() - 600, - timestamp_end=time.time(), - limit=6, - ) - actions_before_now_block = ActionUtils.build_readable_action_records(actions_before_now) - if actions_before_now_block: - actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" - else: - actions_before_now_block = "" - - chat_context_description: str = "" - if chat_target_info: - # 构建聊天上下文描述 - chat_context_description = ( - f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中" - ) - - # 构建动作选项块 - action_options_block = await self._build_action_options_block(current_available_actions) - - # 其他信息 - moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" - time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - bot_name = global_config.bot.nickname - bot_nickname = ( - f",也可以叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else "" - ) - name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。" - - # 获取主规划器模板并填充 - planner_prompt_template = prompt_manager.get_prompt(prompt_key) - planner_prompt_template.add_context("time_block", time_block) - planner_prompt_template.add_context("chat_context_description", chat_context_description) - planner_prompt_template.add_context("chat_content_block", chat_content_block) - planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block) - planner_prompt_template.add_context("action_options_text", action_options_block) - planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block) - planner_prompt_template.add_context("name_block", name_block) - planner_prompt_template.add_context("interest", interest) - planner_prompt_template.add_context("plan_style", global_config.experimental.private_plan_style) - prompt = await prompt_manager.render_prompt(planner_prompt_template) - - return prompt, message_id_list - except Exception as e: - logger.error(f"构建 Planner 提示词时出错: {e}") - logger.error(traceback.format_exc()) - return "构建 Planner Prompt 时出错", [] - - def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]: - """ - 获取 Planner 需要的必要信息 - """ - is_group_chat = True - is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) - logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") - - current_available_actions_dict = self.action_manager.get_using_actions() - - # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore - ComponentType.ACTION - ) - current_available_actions = {} - for action_name in current_available_actions_dict: - if action_name in all_registered_actions: - current_available_actions[action_name] = all_registered_actions[action_name] - else: - logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到") - - return is_group_chat, chat_target_info, current_available_actions - - def _filter_actions_by_activation_type( - self, available_actions: Dict[str, ActionInfo], chat_content_block: str - ) -> Dict[str, ActionInfo]: - """根据激活类型过滤动作""" - filtered_actions = {} - - for action_name, action_info in available_actions.items(): - if action_info.activation_type == ActionActivationType.NEVER: - logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过") - continue - elif action_info.activation_type == ActionActivationType.ALWAYS: - filtered_actions[action_name] = action_info - elif action_info.activation_type == ActionActivationType.RANDOM: - if random.random() < action_info.random_activation_probability: - filtered_actions[action_name] = action_info - elif action_info.activation_type == ActionActivationType.KEYWORD: - if action_info.activation_keywords: - for keyword in action_info.activation_keywords: - if keyword in chat_content_block: - filtered_actions[action_name] = action_info - break - else: - logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理") - - return filtered_actions - - async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str: - # sourcery skip: use-join - """构建动作选项块""" - if not current_available_actions: - return "" - - action_options_block = "" - for action_name, action_info in current_available_actions.items(): - # 构建参数文本 - param_text = "" - if action_info.action_parameters: - param_text = "\n" - for param_name, param_description in action_info.action_parameters.items(): - param_text += f' "{param_name}":"{param_description}"\n' - param_text = param_text.rstrip("\n") - - # 构建要求文本 - require_text = "" - for require_item in action_info.action_require: - require_text += f"- {require_item}\n" - require_text = require_text.rstrip("\n") - - # 获取动作提示模板并填充 - using_action_prompt_template = prompt_manager.get_prompt("brain_action") - using_action_prompt_template.add_context("action_name", action_name) - using_action_prompt_template.add_context("action_description", action_info.description) - using_action_prompt_template.add_context("action_parameters", param_text) - using_action_prompt_template.add_context("action_require", require_text) - using_action_prompt = await prompt_manager.render_prompt(using_action_prompt_template) - - action_options_block += using_action_prompt - - return action_options_block - - async def _execute_main_planner( - self, - prompt: str, - message_id_list: List[Tuple[str, "SessionMessage"]], - filtered_actions: Dict[str, ActionInfo], - available_actions: Dict[str, ActionInfo], - loop_start_time: float, - ) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]: - """执行主规划器""" - llm_content = None - actions: List[ActionPlannerInfo] = [] - extracted_reasoning = "" - llm_reasoning = None - llm_duration_ms = None - - try: - # 调用LLM - llm_start = time.perf_counter() - llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) - llm_duration_ms = (time.perf_counter() - llm_start) * 1000 - llm_reasoning = reasoning_content - - logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") - logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") - - if global_config.debug.show_planner_prompt: - logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") - logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") - if reasoning_content: - logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") - else: - logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}") - logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}") - if reasoning_content: - logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}") - - except Exception as req_e: - logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") - extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}" - return ( - extracted_reasoning, - [ - ActionPlannerInfo( - action_type="complete_talk", - reasoning=extracted_reasoning, - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ], - llm_content, - llm_reasoning, - llm_duration_ms, - ) - - # 解析LLM响应 - if llm_content: - try: - json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content) - if json_objects: - logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") - for i, json_obj in enumerate(json_objects): - logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}") - filtered_actions_list = list(filtered_actions.items()) - for json_obj in json_objects: - parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list) - logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}") - actions.extend(parsed_actions) - else: - # 尝试解析为直接的JSON - logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}") - extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作" - actions = self._create_complete_talk(extracted_reasoning, available_actions) - - except Exception as json_e: - logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'") - extracted_reasoning = f"解析LLM响应JSON失败: {json_e}" - actions = self._create_complete_talk(extracted_reasoning, available_actions) - traceback.print_exc() - else: - extracted_reasoning = "规划器没有获得LLM响应" - actions = self._create_complete_talk(extracted_reasoning, available_actions) - - # 添加循环开始时间到所有动作 - for action in actions: - action.action_data = action.action_data or {} - action.action_data["loop_start_time"] = loop_start_time - - logger.debug( - f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}" - ) - - return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms - - def _create_complete_talk( - self, reasoning: str, available_actions: Dict[str, ActionInfo] - ) -> List[ActionPlannerInfo]: - """创建complete_talk""" - return [ - ActionPlannerInfo( - action_type="complete_talk", - reasoning=reasoning, - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ] - - def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]): - """添加计划日志""" - self.plan_log.append((reasoning, time.time(), actions)) - if len(self.plan_log) > 20: - self.plan_log.pop(0) - - def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]: - # sourcery skip: for-append-to-extend - """从Markdown格式的内容中提取JSON对象和推理内容""" - json_objects = [] - reasoning_content = "" - - # 使用正则表达式查找```json包裹的JSON内容 - json_pattern = r"```json\s*(.*?)\s*```" - markdown_matches = re.findall(json_pattern, content, re.DOTALL) - - # 提取JSON之前的内容作为推理文本 - first_json_pos = len(content) - if markdown_matches: - # 找到第一个```json的位置 - first_json_pos = content.find("```json") - if first_json_pos > 0: - reasoning_content = content[:first_json_pos].strip() - # 清理推理内容中的注释标记 - reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) - reasoning_content = reasoning_content.strip() - - # 处理```json包裹的JSON - for match in markdown_matches: - try: - # 清理可能的注释和格式问题 - json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释 - json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释 - if json_str := json_str.strip(): - # 先尝试将整个块作为一个JSON对象或数组(适用于多行JSON) - try: - json_obj = json.loads(repair_json(json_str)) - if isinstance(json_obj, dict): - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict): - json_objects.append(item) - except json.JSONDecodeError: - # 如果整个块解析失败,尝试按行分割(适用于多个单行JSON对象) - lines = [line.strip() for line in json_str.split("\n") if line.strip()] - for line in lines: - try: - # 尝试解析每一行作为独立的JSON对象 - json_obj = json.loads(repair_json(line)) - if isinstance(json_obj, dict): - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict): - json_objects.append(item) - except json.JSONDecodeError: - # 单行解析失败,继续下一行 - continue - except Exception as e: - logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...") - continue - - # 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```) - if not json_objects: - json_start_pos = content.find("```json") - if json_start_pos != -1: - # 找到```json之后的内容 - json_content_start = json_start_pos + 7 # ```json的长度 - # 提取从```json之后到内容结尾的所有内容 - incomplete_json_str = content[json_content_start:].strip() - - # 提取JSON之前的内容作为推理文本 - if json_start_pos > 0: - reasoning_content = content[:json_start_pos].strip() - reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) - reasoning_content = reasoning_content.strip() - - if incomplete_json_str: - try: - # 清理可能的注释和格式问题 - json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str) - json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) - json_str = json_str.strip() - - if json_str: - # 尝试按行分割,每行可能是一个JSON对象 - lines = [line.strip() for line in json_str.split("\n") if line.strip()] - for line in lines: - try: - json_obj = json.loads(repair_json(line)) - if isinstance(json_obj, dict): - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict): - json_objects.append(item) - except json.JSONDecodeError: - pass - - # 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组 - if not json_objects: - try: - json_obj = json.loads(repair_json(json_str)) - if isinstance(json_obj, dict): - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict): - json_objects.append(item) - except Exception as e: - logger.debug(f"尝试解析不完整的JSON代码块失败: {e}") - except Exception as e: - logger.debug(f"处理不完整的JSON代码块时出错: {e}") - - return json_objects, reasoning_content diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 1e8b5479..594f33b4 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,8 +1,9 @@ from datetime import datetime from pathlib import Path +from typing import Dict, List, Optional, Tuple + from rich.traceback import install from sqlmodel import select -from typing import Optional, Tuple, List import asyncio import hashlib @@ -17,8 +18,9 @@ from src.common.database.database_model import Images, ImageType from src.common.database.database import get_db_session, get_db_session_manual from src.common.utils.utils_image import ImageUtils from src.prompt.prompt_manager import prompt_manager -from src.config.config import config_manager, global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import config_manager, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions +from src.services.llm_service import LLMServiceClient logger = get_logger("emoji") @@ -38,8 +40,10 @@ def _ensure_directories() -> None: # TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法 -emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see") -emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") +emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see") +emoji_manager_emotion_judge_llm = LLMServiceClient( + task_name="utils", request_type="emoji" +) class EmojiManager: @@ -48,11 +52,13 @@ class EmojiManager: """ def __init__(self) -> None: + """初始化表情包管理器。""" _ensure_directories() self._emoji_num: int = 0 - self.emojis: list[MaiEmoji] = [] + self.emojis: List[MaiEmoji] = [] self._maintenance_wakeup_event: asyncio.Event = asyncio.Event() + self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {} self._reload_callback_registered: bool = False config_manager.register_reload_callback(self.reload_runtime_config) @@ -75,7 +81,11 @@ class EmojiManager: logger.info("[关闭] Emoji 模块已注销配置热重载回调") async def get_emoji_description( - self, *, emoji_bytes: Optional[bytes] = None, emoji_hash: Optional[str] = None + self, + *, + emoji_bytes: Optional[bytes] = None, + emoji_hash: Optional[str] = None, + wait_for_build: bool = True, ) -> Optional[Tuple[str, List[str]]]: """ 根据表情包哈希获取表情包描述和情感列表的封装方法 @@ -83,6 +93,7 @@ class EmojiManager: Args: emoji_bytes (Optional[bytes]): 表情包的字节数据,如果提供了字节数据但数据库中没有找到对应记录,则会尝试构建表情包描述 emoji_hash (Optional[str]): 表情包的哈希值,如果提供了哈希值则优先使用哈希值查找表情包描述 + wait_for_build (bool): 未命中缓存时是否同步等待描述构建完成 Returns: return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None Raises: @@ -110,27 +121,88 @@ class EmojiManager: # 如果提供了字节数据但数据库中没有找到,尝试构建 if not emoji_bytes: return None + if not wait_for_build: + self._schedule_description_build(emoji_hash, emoji_bytes) + return None # 找不到尝试构建 + return await self._build_and_cache_emoji_description(emoji_hash, emoji_bytes) + + def _schedule_description_build(self, emoji_hash: str, emoji_bytes: bytes) -> None: + """调度表情包描述后台构建任务。 + + Args: + emoji_hash: 表情包哈希值。 + emoji_bytes: 表情包字节数据。 + """ + if emoji_hash in self._pending_description_tasks: + return + + task = asyncio.create_task(self._build_description_in_background(emoji_hash, emoji_bytes)) + self._pending_description_tasks[emoji_hash] = task + task.add_done_callback(lambda finished_task: self._finalize_description_build(emoji_hash, finished_task)) + + async def _build_description_in_background(self, emoji_hash: str, emoji_bytes: bytes) -> None: + """在后台构建并缓存表情包描述。 + + Args: + emoji_hash: 表情包哈希值。 + emoji_bytes: 表情包字节数据。 + """ + try: + logger.info(f"表情包描述后台构建已开始,哈希值: {emoji_hash}") + await self._build_and_cache_emoji_description(emoji_hash, emoji_bytes) + logger.info(f"表情包描述后台构建完成,哈希值: {emoji_hash}") + except Exception as exc: + logger.warning(f"表情包描述后台构建失败,哈希值: {emoji_hash},错误: {exc}") + + def _finalize_description_build(self, emoji_hash: str, task: asyncio.Task[None]) -> None: + """回收表情包描述后台构建任务。 + + Args: + emoji_hash: 表情包哈希值。 + task: 已完成的后台任务。 + """ + self._pending_description_tasks.pop(emoji_hash, None) + try: + task.result() + except Exception as exc: + logger.debug(f"表情包描述后台任务结束时捕获异常,哈希值: {emoji_hash},错误: {exc}") + + async def _build_and_cache_emoji_description( + self, + emoji_hash: str, + emoji_bytes: bytes, + ) -> Optional[Tuple[str, List[str]]]: + """构建并缓存表情包描述与情感标签。 + + Args: + emoji_hash: 表情包哈希值。 + emoji_bytes: 表情包字节数据。 + + Returns: + Optional[Tuple[str, List[str]]]: 构建成功时返回描述和情感标签,否则返回 ``None``。 + """ logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述") full_path = EMOJI_DIR / f"{emoji_hash}.png" try: full_path.write_bytes(emoji_bytes) new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes) await new_emoji.calculate_hash_format() - except Exception as e: - logger.error(f"缓存表情包文件时出错: {e}") - raise e + except Exception as exc: + logger.error(f"缓存表情包文件时出错: {exc}") + raise exc + success_desc, new_emoji = await self.build_emoji_description(new_emoji) if not success_desc: logger.error("构建表情包描述失败") return None + success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji) if not success_emotion: logger.error("构建表情包情感标签失败") return None - # 缓存结果到数据库 with get_db_session() as session: try: image_record = new_emoji.to_db_instance() @@ -139,8 +211,8 @@ class EmojiManager: image_record.register_time = datetime.now() image_record.no_file_flag = True session.add(image_record) - except Exception as e: - logger.error(f"缓存表情包描述时出错: {e}") + except Exception as exc: + logger.error(f"缓存表情包描述时出错: {exc}") return new_emoji.description, new_emoji.emotion or [] def load_emojis_from_db(self) -> None: @@ -461,9 +533,11 @@ class EmojiManager: emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list)) emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template) - decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async( - emoji_replace_prompt, temperature=0.8, max_tokens=600 + decision_result = await emoji_manager_emotion_judge_llm.generate_response( + emoji_replace_prompt, + options=LLMGenerationOptions(temperature=0.8, max_tokens=600), ) + decision = decision_result.response logger.info(f"[决策] 结果: {decision}") # 解析决策结果 @@ -515,33 +589,56 @@ class EmojiManager: image_bytes = target_emoji.image_bytes or await asyncio.to_thread( target_emoji.read_image_bytes, target_emoji.full_path ) + image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) + try: + if image_format == "gif": + try: + image_bytes = await asyncio.to_thread(ImageUtils.gif_2_static_image, image_bytes) + except Exception as e: + logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}") + return False, target_emoji + prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答" + image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) + description_result = await emoji_manager_vlm.generate_response_for_image( + prompt, + image_base64, + "jpg", + options=LLMImageOptions(temperature=0.5), + ) + description = description_result.response + else: + prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答" + description_result = await emoji_manager_vlm.generate_response_for_image( + prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.5), + ) + description = description_result.response + except Exception as e: + logger.error(f"[构建描述] 调用视觉模型生成表情包描述时出错: {e}") + return False, target_emoji - if image_format == "gif": - try: - image_bytes = await asyncio.to_thread(ImageUtils.gif_2_static_image, image_bytes) - except Exception as e: - logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}") - return False, target_emoji - prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答" - image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) - description, _ = await emoji_manager_vlm.generate_response_for_image( - prompt, image_base64, "jpg", temperature=0.5 - ) - else: - prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答" - image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) - description, _ = await emoji_manager_vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.5 - ) + if not description: + logger.warning(f"[构建描述] 视觉模型返回空描述,跳过注册: {target_emoji.file_name}") + return False, target_emoji # 表情包审查 if global_config.emoji.content_filtration: - filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration") - filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt) - filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template) - llm_response, _ = await emoji_manager_vlm.generate_response_for_image( - filtration_prompt, image_base64, image_format, temperature=0.3 - ) + try: + filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration") + filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt) + filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template) + filtration_result = await emoji_manager_vlm.generate_response_for_image( + filtration_prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.3), + ) + llm_response = filtration_result.response + except Exception as e: + logger.error(f"[表情包审查] 调用视觉模型审查表情包时出错: {e}") + return False, target_emoji if "否" in llm_response: logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}") return False, target_emoji @@ -567,9 +664,19 @@ class EmojiManager: emotion_prompt_template.add_context("description", target_emoji.description) emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template) # 调用LLM生成情感标签 - emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async( - emotion_prompt, temperature=0.3, max_tokens=200 - ) + try: + emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response( + emotion_prompt, + options=LLMGenerationOptions(temperature=0.3, max_tokens=200), + ) + emotion_result = emotion_generation_result.response + except Exception as e: + logger.error(f"[构建情感标签] 调用模型生成情感标签时出错: {e}") + return False, target_emoji + + if not emotion_result: + logger.warning(f"[构建情感标签] 情感标签结果为空,跳过注册: {target_emoji.file_name}") + return False, target_emoji # 解析情感标签结果 emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] @@ -651,7 +758,12 @@ class EmojiManager: for emoji_file in EMOJI_DIR.iterdir(): if not emoji_file.is_file(): continue - if await self.register_emoji_by_filename(emoji_file): + try: + register_success = await self.register_emoji_by_filename(emoji_file) + except Exception as e: + logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}") + register_success = False + if register_success: break # 每次只注册一个表情包 try: emoji_file.unlink() diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 74d94773..2696a420 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -198,7 +198,6 @@ class HeartFChatting: """判定和生成回复""" asyncio.create_task(self._trigger_expression_learning(self.message_cache)) # TODO: 完成反思器之后的逻辑 - start_time = time.time() current_cycle_detail = self._start_cycle() logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") @@ -207,10 +206,7 @@ class HeartFChatting: # TODO: 动作执行逻辑 cycle_detail = self._end_cycle(current_cycle_detail) - if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0: - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0.1) # 最小等待时间,避免过快循环 + await asyncio.sleep(0.1) # 最小等待时间,避免过快循环 return True def _handle_loop_completion(self, task: asyncio.Task): diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 5b3ece9b..3bbc6ec3 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,47 +1,50 @@ -from typing import Dict - +import asyncio import traceback -from src.common.logger import get_logger +from typing import Dict + from src.chat.message_receive.chat_manager import chat_manager -from src.chat.heart_flow.heartFC_chat import HeartFChatting -# from src.chat.brain_chat.brain_chat import BrainChatting +from src.common.logger import get_logger +from src.maisaka.runtime import MaisakaHeartFlowChatting logger = get_logger("heartflow") -# TODO: 恢复PFC,现在暂时禁用 class HeartflowManager: - """主心流协调器,负责初始化并协调聊天,控制聊天属性""" + """管理 session 级别的 Maisaka 心流实例。""" - def __init__(self): - # self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {} - self.heartflow_chat_list: Dict[str, HeartFChatting] = {} + def __init__(self) -> None: + self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} + self._chat_create_locks: Dict[str, asyncio.Lock] = {} - async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]: - """获取或创建一个新的HeartFChatting实例""" + async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: + """获取或创建指定会话对应的 Maisaka runtime。""" try: if chat := self.heartflow_chat_list.get(session_id): return chat - chat_session = chat_manager.get_session_by_session_id(session_id) - if not chat_session: - raise ValueError(f"未找到 session_id={session_id} 的聊天流") - # new_chat = ( - # HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id) - # ) - new_chat = HeartFChatting(session_id=session_id) - await new_chat.start() - self.heartflow_chat_list[session_id] = new_chat - return new_chat - except Exception as e: - logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True) - traceback.print_exc() - raise e - def adjust_talk_frequency(self, session_id: str, frequency: float): - """调整指定聊天流的说话频率""" + create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) + async with create_lock: + if chat := self.heartflow_chat_list.get(session_id): + return chat + + chat_session = chat_manager.get_session_by_session_id(session_id) + if not chat_session: + raise ValueError(f"未找到 session_id={session_id} 对应的聊天流") + + new_chat = MaisakaHeartFlowChatting(session_id=session_id) + await new_chat.start() + self.heartflow_chat_list[session_id] = new_chat + return new_chat + except Exception as exc: + logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) + traceback.print_exc() + raise + + def adjust_talk_frequency(self, session_id: str, frequency: float) -> None: + """调整指定聊天流的说话频率。""" chat = self.heartflow_chat_list.get(session_id) - if chat and isinstance(chat, HeartFChatting): + if chat: chat.adjust_talk_frequency(frequency) logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}") else: diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index ec545894..492886d4 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -1,9 +1,11 @@ from datetime import datetime from pathlib import Path +from typing import Dict, Optional + from rich.traceback import install from sqlmodel import select -from typing import Optional +import asyncio import base64 import hashlib @@ -11,8 +13,9 @@ from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Images, ImageType from src.common.data_models.image_data_model import MaiImage -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMImageOptions +from src.services.llm_service import LLMServiceClient install(extra_lines=3) @@ -23,21 +26,30 @@ IMAGE_DIR = DATA_DIR / "images" logger = get_logger("image") -def _ensure_image_dir_exists(): +def _ensure_image_dir_exists() -> None: + """确保图片缓存目录存在。""" IMAGE_DIR.mkdir(parents=True, exist_ok=True) -vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") +vlm = LLMServiceClient(task_name="vlm", request_type="image") class ImageManager: - def __init__(self): + """图片描述管理器。""" + + def __init__(self) -> None: + """初始化图片管理器。""" _ensure_image_dir_exists() + self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {} logger.info("图片管理器初始化完成") async def get_image_description( - self, *, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None + self, + *, + image_hash: Optional[str] = None, + image_bytes: Optional[bytes] = None, + wait_for_build: bool = True, ) -> str: """ 获取图片描述的封装方法 @@ -49,6 +61,7 @@ class ImageManager: Args: image_hash (Optional[str]): 图片的哈希值,如果提供则优先使用该 image_bytes (Optional[bytes]): 图片的字节数据,如果提供则在数据库中找不到哈希值时使用该数据生成描述 + wait_for_build (bool): 未命中缓存时是否同步等待描述构建完成 Returns: return (str): 图片描述,如果发生错误或无法生成描述则返回空字符串 Raises: @@ -73,6 +86,9 @@ class ImageManager: if not image_bytes: logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述") return "" + if not wait_for_build: + self._schedule_description_build(hash_str, image_bytes) + return "" logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述") try: image = await self.save_image_and_process(image_bytes) @@ -81,6 +97,47 @@ class ImageManager: logger.error(f"生成图片描述时发生错误: {e}") return "" + def _schedule_description_build(self, image_hash: str, image_bytes: bytes) -> None: + """调度图片描述后台构建任务。 + + Args: + image_hash: 图片哈希值。 + image_bytes: 图片字节数据。 + """ + if image_hash in self._pending_description_tasks: + return + + task = asyncio.create_task(self._build_description_in_background(image_hash, image_bytes)) + self._pending_description_tasks[image_hash] = task + task.add_done_callback(lambda finished_task: self._finalize_description_build(image_hash, finished_task)) + + async def _build_description_in_background(self, image_hash: str, image_bytes: bytes) -> None: + """在后台构建并缓存图片描述。 + + Args: + image_hash: 图片哈希值。 + image_bytes: 图片字节数据。 + """ + try: + logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}") + await self.save_image_and_process(image_bytes) + logger.info(f"图片描述后台构建完成,哈希值: {image_hash}") + except Exception as exc: + logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}") + + def _finalize_description_build(self, image_hash: str, task: asyncio.Task[None]) -> None: + """回收图片描述后台构建任务。 + + Args: + image_hash: 图片哈希值。 + task: 已完成的后台任务。 + """ + self._pending_description_tasks.pop(image_hash, None) + try: + task.result() + except Exception as exc: + logger.debug(f"图片描述后台任务结束时捕获异常,哈希值: {image_hash},错误: {exc}") + def get_image_from_db(self, image_hash: str) -> Optional[MaiImage]: """ 从数据库中根据图片哈希值获取图片记录 @@ -260,7 +317,13 @@ class ImageManager: prompt = global_config.personality.visual_style image_base64 = base64.b64encode(image_bytes).decode("utf-8") - description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4) + generation_result = await vlm.generate_response_for_image( + prompt, + image_base64, + image_format, + options=LLMImageOptions(temperature=0.4), + ) + description = generation_result.response if not description: logger.warning("VLM未能生成图片描述") return description or "" diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 6d041af6..026c72ee 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -139,14 +139,14 @@ class EmbeddingStore: asyncio.set_event_loop(loop) try: - # 创建新的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config + # 创建新的服务层实例 + from src.services.llm_service import LLMServiceClient - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + llm = LLMServiceClient(task_name="embedding", request_type="embedding") # 使用新的事件循环运行异步方法 - embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + embedding_result = loop.run_until_complete(llm.embed_text(s)) + embedding = embedding_result.embedding if embedding and len(embedding) > 0: return embedding @@ -195,13 +195,12 @@ class EmbeddingStore: start_idx, chunk_strs = chunk_data chunk_results = [] - # 为每个线程创建独立的LLMRequest实例 - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config + # 为每个线程创建独立的服务层实例 + from src.services.llm_service import LLMServiceClient try: - # 创建线程专用的LLM实例 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + # 创建线程专用的服务层实例 + llm = LLMServiceClient(task_name="embedding", request_type="embedding") for i, s in enumerate(chunk_strs): try: @@ -209,7 +208,8 @@ class EmbeddingStore: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - embedding = loop.run_until_complete(llm.get_embedding(s)) + embedding_result = loop.run_until_complete(llm.embed_text(s)) + embedding = embedding_result.embedding finally: loop.close() diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index d7413bdc..91ba83dc 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,18 +1,27 @@ import asyncio import json import time -from typing import List, Union +from typing import Dict, List, Tuple, Union -from .global_logger import logger -from . import prompt_template -from . import INVALID_ENTITY -from src.llm_models.utils_model import LLMRequest from json_repair import repair_json +from src.services.llm_service import LLMServiceClient -def _extract_json_from_text(text: str): +from . import INVALID_ENTITY +from . import prompt_template +from .global_logger import logger + + +def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]: # sourcery skip: assign-if-exp, extract-method - """从文本中提取JSON数据的高容错方法""" + """从文本中提取 JSON 数据。 + + Args: + text: 原始模型输出文本。 + + Returns: + List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。 + """ if text is None: logger.error("输入文本为None") return [] @@ -46,20 +55,30 @@ def _extract_json_from_text(text: str): return [] -def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: +def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]: # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression - """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" + """对单段文本执行实体提取。 + + Args: + llm_req: LLM 服务门面实例。 + paragraph: 待提取实体的原始段落文本。 + + Returns: + List[str]: 提取出的实体列表。 + """ entity_extract_context = prompt_template.build_entity_extract_context(paragraph) # 使用 asyncio.run 来运行异步方法 try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop) - response, _ = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop) + generation_result = future.result() + response = generation_result.response except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context)) + generation_result = asyncio.run(llm_req.generate_response(entity_extract_context)) + response = generation_result.response # 添加调试日志 logger.debug(f"LLM返回的原始响应: {response}") @@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: - """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" +def _rdf_triple_extract( + llm_req: LLMServiceClient, + paragraph: str, + entities: List[str], +) -> List[List[str]]: + """对单段文本执行 RDF 三元组提取。 + + Args: + llm_req: LLM 服务门面实例。 + paragraph: 待提取的原始段落文本。 + entities: 已识别出的实体列表。 + + Returns: + List[List[str]]: 提取出的三元组列表。 + """ rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) @@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop) - response, _ = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop) + generation_result = future.result() + response = generation_result.response except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context)) + generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context)) + response = generation_result.response # 添加调试日志 logger.debug(f"RDF LLM返回的原始响应: {response}") @@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> def info_extract_from_str( - llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str -) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: + llm_client_for_ner: LLMServiceClient, + llm_client_for_rdf: LLMServiceClient, + paragraph: str, +) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: + """从文本中提取实体与三元组信息。 + + Args: + llm_client_for_ner: 实体提取使用的 LLM 服务门面。 + llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。 + paragraph: 原始段落文本。 + + Returns: + Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回 + ``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。 + """ try_count = 0 while True: try: @@ -176,17 +223,30 @@ def info_extract_from_str( class IEProcess: - """ - 信息抽取处理器类,提供更方便的批次处理接口。 - """ + """信息抽取处理器。""" - def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None): + def __init__( + self, + llm_ner: LLMServiceClient, + llm_rdf: LLMServiceClient | None = None, + ) -> None: + """初始化信息抽取处理器。 + + Args: + llm_ner: 实体提取使用的 LLM 服务门面。 + llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。 + """ self.llm_ner = llm_ner self.llm_rdf = llm_rdf or llm_ner - async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]: - """ - 异步处理多个段落。 + async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]: + """异步处理多个段落。 + + Args: + paragraphs: 待处理的段落列表。 + + Returns: + List[Dict[str, object]]: 每个成功段落对应的抽取结果。 """ from .utils.hash import get_sha256 diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 1fc4ef53..33a66ffc 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -1,4 +1,5 @@ from contextlib import suppress +from copy import deepcopy from typing import Any, Dict, Optional import os @@ -10,9 +11,9 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils +from src.config.config import global_config from src.platform_io.route_key_factory import RouteKeyFactory -# from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.core.announcement_manager import global_announcement_manager from src.plugin_runtime.component_query import component_query_service @@ -29,36 +30,20 @@ logger = get_logger("chat") class ChatBot: - def __init__(self): + def __init__(self) -> None: + """初始化聊天机器人入口。""" + self.bot = None # bot 实例引用 self._started = False - self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增 - # self.pfc_manager = PFCManager.get_instance() # PFC管理器 # TODO: PFC恢复 + self.heartflow_message_receiver = HeartFCMessageReceiver() - async def _ensure_started(self): - """确保所有任务已启动""" + async def _ensure_started(self) -> None: + """确保所有后台任务已启动。""" if not self._started: logger.debug("确保ChatBot所有任务已启动") self._started = True - async def _create_pfc_chat(self, message: SessionMessage): - """创建或获取PFC对话实例 - - Args: - message: 消息对象 - """ - try: - chat_id = message.session_id - private_name = str(message.message_info.user_info.user_nickname) - - logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}") - await self.pfc_manager.get_or_create_conversation(chat_id, private_name) - - except Exception as e: - logger.error(f"创建PFC聊天失败: {e}") - logger.error(traceback.format_exc()) - async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]: """使用统一组件注册表处理命令。 @@ -175,11 +160,12 @@ class ChatBot: recalled: Dict[str, Any] = {} recalled_id = None - if getattr(seg, "type", None) == "notify" and isinstance(getattr(seg, "data", None), dict): - sub_type = seg.data.get("sub_type") - scene = seg.data.get("scene") - msg_id = seg.data.get("message_id") - recalled = seg.data.get("recalled_user_info") or {} + seg_data = getattr(seg, "data", None) + if getattr(seg, "type", None) == "notify" and isinstance(seg_data, dict): + sub_type = seg_data.get("sub_type") + scene = seg_data.get("scene") + msg_id = seg_data.get("message_id") + recalled = seg_data.get("recalled_user_info") or {} if isinstance(recalled, dict): recalled_id = recalled.get("user_id") @@ -301,7 +287,13 @@ class ChatBot: # pass # 处理消息内容,识别表情包等二进制数据并转化为文本描述 - await message.process() + if global_config.maisaka.direct_image_input: + message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined] + # 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。 + await message.process( + enable_heavy_media_analysis=False, + enable_voice_transcription=False, + ) # 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配 @@ -335,14 +327,13 @@ class ChatBot: # message.update_chat_stream(chat) - # 命令处理 - 使用新插件系统检查并处理命令 - # 注意:命令返回的 response 当前只用于日志记录和流程判断, - # 不会在这里自动作为回复消息发送回会话。 - # is_command, cmd_result, continue_process = await self._process_commands(message) + # 命令处理 - 使用新插件系统检查并处理命令。 + # 命令处理器内部自行决定是否回复消息,这里只负责流程分发与拦截。 + is_command, cmd_result, continue_process = await self._process_commands(message) - # # 如果是命令且不需要继续处理,则直接返回 - # if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): - # return + # 如果是命令且不需要继续处理,则直接返回,避免落入 HeartFlow / MaiSaka。 + if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process): + return # continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message) # if not continue_flag: @@ -362,31 +353,12 @@ class ChatBot: # else: # template_group_name = None - # async def preprocess(): - # # 根据聊天类型路由消息 - # if group_info is None: - # # 私聊消息 -> PFC系统 - # logger.debug("[私聊]检测到私聊消息,路由到PFC系统") - # await MessageStorage.store_message(message, chat) - # await self._create_pfc_chat(message) - # else: - # # 群聊消息 -> HeartFlow系统 - # logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统") - # await self.heartflow_message_receiver.process_message(message) - - # if template_group_name: - # async with global_prompt_manager.async_message_scope(template_group_name): - # await preprocess() - # else: - # await preprocess() async def preprocess(): if group_info is None: - # logger.debug("[私聊]检测到私聊消息,路由到PFC系统") - # MessageUtils.store_message_to_db(message) # 存储消息到数据库 - # await self._create_pfc_chat(message) - logger.critical("暂时禁用私聊") + logger.debug("[私聊]检测到私聊消息,路由到 Maisaka") + await self.heartflow_message_receiver.process_message(message) else: - logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统") + logger.debug("[群聊]检测到群聊消息,路由到 Maisaka") await self.heartflow_message_receiver.process_message(message) await preprocess() diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index be2ef026..3cf5fdf5 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,7 +1,8 @@ from asyncio import Task +from typing import Dict, List, Sequence, Tuple + from rich.traceback import install from sqlmodel import select -from typing import List, Dict, Tuple, Sequence import asyncio @@ -27,14 +28,36 @@ logger = get_logger("chat_message") class MsgIDMapping: - def __init__(self): - self.mapping: Dict[str, Tuple[str | Task, UserInfo]] = {} + """回复消息内容缓存。""" + + def __init__(self) -> None: + """初始化消息 ID 到内容的映射缓存。""" + self.mapping: Dict[str, Tuple[str | Task[str], UserInfo]] = {} class SessionMessage(MaiMessage): - async def process(self): - """处理消息内容,识别消息内容并转化为文本(会修改消息组件属性)""" - tasks = [self.process_single_component(component, MsgIDMapping()) for component in self.raw_message.components] + async def process( + self, + *, + enable_heavy_media_analysis: bool = True, + enable_voice_transcription: bool = True, + ) -> None: + """处理消息内容并转化为纯文本。 + + Args: + enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。 + enable_voice_transcription: 是否同步执行语音转写。 + """ + id_content_map = MsgIDMapping() + tasks = [ + self.process_single_component( + component, + id_content_map, + enable_heavy_media_analysis=enable_heavy_media_analysis, + enable_voice_transcription=enable_voice_transcription, + ) + for component in self.raw_message.components + ] results = await asyncio.gather(*tasks, return_exceptions=True) processed_texts: List[str] = [] for result in results: @@ -45,50 +68,116 @@ class SessionMessage(MaiMessage): self.processed_plain_text = " ".join(processed_texts) async def process_single_component( - self, component: StandardMessageComponents, id_content_map: MsgIDMapping, recursion_depth: int = 0 + self, + component: StandardMessageComponents, + id_content_map: MsgIDMapping, + recursion_depth: int = 0, + *, + enable_heavy_media_analysis: bool = True, + enable_voice_transcription: bool = True, ) -> str: - """按照类型处理单个消息组件,返回处理后的文本内容(会修改消息组件属性)""" + """按类型处理单个消息组件。 + + Args: + component: 待处理的消息组件。 + id_content_map: 回复消息解析缓存。 + recursion_depth: 当前递归深度。 + enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。 + enable_voice_transcription: 是否同步执行语音转写。 + + Returns: + str: 组件对应的文本表示。 + """ if isinstance(component, TextComponent): return component.text elif isinstance(component, ImageComponent): - return await self.process_image_component(component) + return await self.process_image_component( + component, + enable_heavy_media_analysis=enable_heavy_media_analysis, + ) elif isinstance(component, EmojiComponent): - return await self.process_emoji_component(component) + return await self.process_emoji_component( + component, + enable_heavy_media_analysis=enable_heavy_media_analysis, + ) elif isinstance(component, AtComponent): return await self.process_at_component(component) elif isinstance(component, VoiceComponent): - return await self.process_voice_component(component) + return await self.process_voice_component( + component, + enable_voice_transcription=enable_voice_transcription, + ) elif isinstance(component, ReplyComponent): return await self.process_reply_component(component, id_content_map) elif isinstance(component, ForwardNodeComponent): - return await self.process_forward_component(component, id_content_map, recursion_depth=recursion_depth + 1) + return await self.process_forward_component( + component, + id_content_map, + recursion_depth=recursion_depth + 1, + enable_heavy_media_analysis=enable_heavy_media_analysis, + enable_voice_transcription=enable_voice_transcription, + ) else: raise NotImplementedError(f"暂时不支持的消息组件类型: {type(component)}") - async def process_image_component(self, component: ImageComponent) -> str: + async def process_image_component( + self, + component: ImageComponent, + *, + enable_heavy_media_analysis: bool = True, + ) -> str: + """处理图片组件。 + + Args: + component: 图片组件。 + enable_heavy_media_analysis: 是否同步执行图片描述生成。 + + Returns: + str: 图片组件对应的文本表示。 + """ if component.content: # 先检查是否处理过 return component.content from src.chat.image_system.image_manager import image_manager # 获取描述 try: - desc = await image_manager.get_image_description(image_bytes=component.binary_data) + desc = await image_manager.get_image_description( + image_bytes=component.binary_data, + wait_for_build=enable_heavy_media_analysis, + ) except Exception: desc = None # 失败置空 - content = f"[图片:{desc}]" if desc else "[一张图片,网卡了加载不出来]" + content = f"[图片:{desc}]" if desc else "[图片]" component.content = content component.binary_data = b"" # 处理完就丢掉二进制数据,节省内存 return content - async def process_emoji_component(self, component: EmojiComponent) -> str: + async def process_emoji_component( + self, + component: EmojiComponent, + *, + enable_heavy_media_analysis: bool = True, + ) -> str: + """处理表情包组件。 + + Args: + component: 表情包组件。 + enable_heavy_media_analysis: 是否同步执行表情包描述生成。 + + Returns: + str: 表情包组件对应的文本表示。 + """ if component.content: # 先检查是否处理过 return component.content from src.chat.emoji_system.emoji_manager import emoji_manager # 获取表情包描述 try: - tuple_content = await emoji_manager.get_emoji_description(emoji_bytes=component.binary_data) + tuple_content = await emoji_manager.get_emoji_description( + emoji_bytes=component.binary_data, + wait_for_build=enable_heavy_media_analysis, + ) except Exception: tuple_content = None # 失败置空 @@ -96,7 +185,7 @@ class SessionMessage(MaiMessage): desc, _ = tuple_content content = f"[表情包: {desc}]" else: - content = "[一个表情,网卡了加载不出来]" + content = "[表情包]" component.content = content component.binary_data = b"" # 处理完就丢掉二进制数据,节省内存 return content @@ -124,9 +213,26 @@ class SessionMessage(MaiMessage): else: # 最后使用用户ID return f"@{component.target_user_id}" - async def process_voice_component(self, component: VoiceComponent) -> str: + async def process_voice_component( + self, + component: VoiceComponent, + *, + enable_voice_transcription: bool = True, + ) -> str: + """处理语音组件。 + + Args: + component: 语音组件。 + enable_voice_transcription: 是否同步执行语音转写。 + + Returns: + str: 语音组件对应的文本表示。 + """ if component.content: # 先检查是否处理过 return component.content + if not enable_voice_transcription: + component.content = "[语音消息]" + return component.content from src.common.utils.utils_voice import get_voice_text text = await get_voice_text(component.binary_data) @@ -169,13 +275,37 @@ class SessionMessage(MaiMessage): return "[回复了一条消息,但原消息已无法访问]" async def process_forward_component( - self, component: ForwardNodeComponent, id_content_map: MsgIDMapping, recursion_depth: int = 0 + self, + component: ForwardNodeComponent, + id_content_map: MsgIDMapping, + recursion_depth: int = 0, + *, + enable_heavy_media_analysis: bool = True, + enable_voice_transcription: bool = True, ) -> str: + """处理合并转发组件。 + + Args: + component: 合并转发组件。 + id_content_map: 回复消息解析缓存。 + recursion_depth: 当前递归深度。 + enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。 + enable_voice_transcription: 是否同步执行语音转写。 + + Returns: + str: 合并转发组件对应的文本表示。 + """ task_list: List[Task] = [] node_user_info_list: List[UserInfo] = [] for node in component.forward_components: task = asyncio.create_task( - self._process_multiple_components(node.content, id_content_map, recursion_depth + 1) + self._process_multiple_components( + node.content, + id_content_map, + recursion_depth + 1, + enable_heavy_media_analysis=enable_heavy_media_analysis, + enable_voice_transcription=enable_voice_transcription, + ) ) node_user_info = UserInfo(node.user_id or "未知用户", node.user_nickname, node.user_cardname) # 传入ID缓存映射,方便Reply组件获取并等待处理结果 @@ -196,9 +326,36 @@ class SessionMessage(MaiMessage): return "【合并转发消息: \n" + "\n".join(forward_texts) + "\n】" async def _process_multiple_components( - self, components: Sequence[StandardMessageComponents], id_content_map: MsgIDMapping, recursion_depth: int = 0 + self, + components: Sequence[StandardMessageComponents], + id_content_map: MsgIDMapping, + recursion_depth: int = 0, + *, + enable_heavy_media_analysis: bool = True, + enable_voice_transcription: bool = True, ) -> str: - tasks = [self.process_single_component(component, id_content_map, recursion_depth) for component in components] + """并行处理多个消息组件。 + + Args: + components: 待处理的组件序列。 + id_content_map: 回复消息解析缓存。 + recursion_depth: 当前递归深度。 + enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。 + enable_voice_transcription: 是否同步执行语音转写。 + + Returns: + str: 多个组件拼接后的文本表示。 + """ + tasks = [ + self.process_single_component( + component, + id_content_map, + recursion_depth, + enable_heavy_media_analysis=enable_heavy_media_analysis, + enable_voice_transcription=enable_voice_transcription, + ) + for component in components + ] results = await asyncio.gather(*tasks, return_exceptions=True) # 并行处理多个组件 processed_texts: List[str] = [] for result in results: diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py deleted file mode 100644 index 8133ac18..00000000 --- a/src/chat/planner_actions/action_manager.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import Dict, Optional, Tuple - -from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.message import SessionMessage -from src.common.logger import get_logger -from src.core.types import ActionInfo -from src.plugin_runtime.component_query import ActionExecutor, component_query_service - -logger = get_logger("action_manager") - - -class ActionHandle: - """Action 执行句柄 - - 不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。 - brain_chat 调用 ``await handle.execute()`` 即可。 - """ - - def __init__(self, executor: ActionExecutor, **kwargs): - self._executor = executor - self._kwargs = kwargs - - async def execute(self) -> Tuple[bool, str]: - return await self._executor(**self._kwargs) - - -class ActionManager: - """ - 动作管理器,用于管理各种类型的动作 - - 使用插件运行时统一查询服务的 executor-based 模式。 - """ - - def __init__(self): - """初始化动作管理器""" - - # 当前正在使用的动作集合,默认加载默认动作 - self._using_actions: Dict[str, ActionInfo] = {} - - # 初始化时将默认动作加载到使用中的动作 - self._using_actions = component_query_service.get_default_actions() - - # === 执行Action方法 === - - def create_action( - self, - action_name: str, - action_data: dict, - action_reasoning: str, - cycle_timers: dict, - thinking_id: str, - chat_stream: BotChatSession, - log_prefix: str, - shutting_down: bool = False, - action_message: Optional[SessionMessage] = None, - ) -> Optional[ActionHandle]: - """ - 创建动作执行句柄 - - Args: - action_name: 动作名称 - action_data: 动作数据 - action_reasoning: 执行理由 - cycle_timers: 计时器字典 - thinking_id: 思考ID - chat_stream: 聊天流 - log_prefix: 日志前缀 - shutting_down: 是否正在关闭 - action_message: 动作消息记录 - - Returns: - Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None - """ - try: - executor = component_query_service.get_action_executor(action_name) - if not executor: - logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") - return None - - info = component_query_service.get_action_info(action_name) - if not info: - logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") - return None - - plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {} - - handle = ActionHandle( - executor, - action_data=action_data, - action_reasoning=action_reasoning, - cycle_timers=cycle_timers, - thinking_id=thinking_id, - chat_stream=chat_stream, - log_prefix=log_prefix, - shutting_down=shutting_down, - plugin_config=plugin_config, - action_message=action_message, - ) - - logger.debug(f"创建Action执行句柄成功: {action_name}") - return handle - - except Exception as e: - logger.error(f"创建Action执行句柄失败 {action_name}: {e}") - import traceback - - logger.error(traceback.format_exc()) - return None - - def get_using_actions(self) -> Dict[str, ActionInfo]: - """获取当前正在使用的动作集合""" - return self._using_actions.copy() - - # === Modify相关方法 === - def remove_action_from_using(self, action_name: str) -> bool: - """ - 从当前使用的动作集中移除指定动作 - - Args: - action_name: 动作名称 - - Returns: - bool: 移除是否成功 - """ - if action_name not in self._using_actions: - logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中") - return False - - del self._using_actions[action_name] - logger.debug(f"已从使用集中移除动作 {action_name}") - return True - - def restore_actions(self) -> None: - """恢复到默认动作集""" - actions_to_restore = list(self._using_actions.keys()) - self._using_actions = component_query_service.get_default_actions() - logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py deleted file mode 100644 index 0d81c18f..00000000 --- a/src/chat/planner_actions/action_modifier.py +++ /dev/null @@ -1,233 +0,0 @@ -import random -import time -from typing import List, Dict, Tuple - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager -from src.chat.planner_actions.action_manager import ActionManager -from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat -from src.core.types import ActionActivationType, ActionInfo -from src.core.announcement_manager import global_announcement_manager - -logger = get_logger("action_manager") - - -class ActionModifier: - """动作处理器 - - 用于处理Observation对象和根据激活类型处理actions。 - 集成了原有的modify_actions功能和新的激活类型处理功能。 - 支持并行判定和智能缓存优化。 - """ - - def __init__(self, action_manager: ActionManager, chat_id: str): - """初始化动作处理器""" - self.chat_id = chat_id - self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore - self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" - - self.action_manager = action_manager - - async def modify_actions( - self, - message_content: str = "", - ): # sourcery skip: use-named-expression - """ - 动作修改流程,整合传统观察处理和新的激活类型判定 - - 这个方法处理完整的动作管理流程: - 1. 基于观察的传统动作修改(循环历史分析、类型匹配等) - 2. 基于激活类型的智能动作判定,最终确定可用动作集 - - 处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用 - """ - logger.debug(f"{self.log_prefix}开始完整动作修改流程") - - removals_s1: List[Tuple[str, str]] = [] - removals_s2: List[Tuple[str, str]] = [] - # removals_s3: List[Tuple[str, str]] = [] - - self.action_manager.restore_actions() - all_actions = self.action_manager.get_using_actions() - - message_list_before_now_half = get_messages_before_time_in_chat( - chat_id=self.chat_stream.session_id, - timestamp=time.time(), - limit=min(int(global_config.chat.max_context_size * 0.33), 10), - filter_intercept_message_level=1, - ) - - chat_content = build_readable_messages( - message_list_before_now_half, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - show_actions=True, - ) - - if message_content: - chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}" - - # === 第一阶段:去除用户自行禁用的 === - disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id) - if disabled_actions: - for disabled_action_name in disabled_actions: - if disabled_action_name in all_actions: - removals_s1.append((disabled_action_name, "用户自行禁用")) - self.action_manager.remove_action_from_using(disabled_action_name) - logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") - - # === 第二阶段:检查动作的关联类型 === - chat_context = self.chat_stream.context - type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context) - - if type_mismatched_actions: - removals_s2.extend(type_mismatched_actions) - - # 应用第二阶段的移除 - for action_name, reason in removals_s2: - self.action_manager.remove_action_from_using(action_name) - logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") - - # === 第三阶段:激活类型判定 === - # if chat_content is not None: - # logger.debug(f"{self.log_prefix}开始激活类型判定阶段") - - # 获取当前使用的动作集(经过第一阶段处理) - # current_using_actions = self.action_manager.get_using_actions() - - # 获取因激活类型判定而需要移除的动作 - # removals_s3 = await self._get_deactivated_actions_by_type( - # current_using_actions, - # chat_content, - # ) - - # 应用第三阶段的移除 - # for action_name, reason in removals_s3: - # self.action_manager.remove_action_from_using(action_name) - # logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") - - # === 统一日志记录 === - all_removals = removals_s1 + removals_s2 - removals_summary: str = "" - if all_removals: - removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals]) - - available_actions = list(self.action_manager.get_using_actions().keys()) - available_actions_text = "、".join(available_actions) if available_actions else "无" - logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") - - def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession): - type_mismatched_actions: List[Tuple[str, str]] = [] - for action_name, action_info in all_actions.items(): - if action_info.associated_types and not chat_context.check_types(action_info.associated_types): - associated_types_str = ", ".join(action_info.associated_types) - reason = f"适配器不支持(需要: {associated_types_str})" - type_mismatched_actions.append((action_name, reason)) - logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}") - return type_mismatched_actions - - async def _get_deactivated_actions_by_type( - self, - actions_with_info: Dict[str, ActionInfo], - chat_content: str = "", - ) -> List[tuple[str, str]]: - """ - 根据激活类型过滤,返回需要停用的动作列表及原因 - - Args: - actions_with_info: 带完整信息的动作字典 - chat_content: 聊天内容 - - Returns: - List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表 - """ - deactivated_actions = [] - - actions_to_check = list(actions_with_info.items()) - random.shuffle(actions_to_check) - - for action_name, action_info in actions_to_check: - activation_type = action_info.activation_type or action_info.focus_activation_type - - if activation_type == ActionActivationType.ALWAYS: - continue # 总是激活,无需处理 - - elif activation_type == ActionActivationType.RANDOM: - probability = action_info.random_activation_probability - if random.random() >= probability: - reason = f"RANDOM类型未触发(概率{probability})" - deactivated_actions.append((action_name, reason)) - logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}") - - elif activation_type == ActionActivationType.KEYWORD: - if not self._check_keyword_activation(action_name, action_info, chat_content): - keywords = action_info.activation_keywords - reason = f"关键词未匹配(关键词: {keywords})" - deactivated_actions.append((action_name, reason)) - logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}") - - elif activation_type == ActionActivationType.NEVER: - reason = "激活类型为never" - deactivated_actions.append((action_name, reason)) - logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never") - - else: - logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理") - - return deactivated_actions - - def _check_keyword_activation( - self, - action_name: str, - action_info: ActionInfo, - chat_content: str = "", - ) -> bool: - """ - 检查是否匹配关键词触发条件 - - Args: - action_name: 动作名称 - action_info: 动作信息 - observed_messages_str: 观察到的聊天消息 - chat_context: 聊天上下文 - extra_context: 额外上下文 - - Returns: - bool: 是否应该激活此action - """ - - activation_keywords = action_info.activation_keywords - case_sensitive = action_info.keyword_case_sensitive - - if not activation_keywords: - logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") - return False - - # 构建检索文本 - search_text = "" - if chat_content: - search_text += chat_content - # if chat_context: - # search_text += f" {chat_context}" - # if extra_context: - # search_text += f" {extra_context}" - - # 如果不区分大小写,转换为小写 - if not case_sensitive: - search_text = search_text.lower() - - # 检查每个关键词 - matched_keywords = [] - for keyword in activation_keywords: - check_keyword = keyword if case_sensitive else keyword.lower() - if check_keyword in search_text: - matched_keywords.append(keyword) - - if matched_keywords: - logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}") - return True - else: - logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}") - return False diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py deleted file mode 100644 index b21efa6b..00000000 --- a/src/chat/planner_actions/planner.py +++ /dev/null @@ -1,933 +0,0 @@ -from collections import OrderedDict -from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union - -import contextlib -import json -import random -import re -import time -import traceback - -from json_repair import repair_json -from rich.traceback import install - -from src.chat.logger.plan_reply_logger import PlanReplyLogger -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.message import SessionMessage -from src.chat.planner_actions.action_manager import ActionManager -from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self -from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.core.types import ActionActivationType, ActionInfo, ComponentType -from src.llm_models.utils_model import LLMRequest -from src.person_info.person_info import Person -from src.plugin_runtime.component_query import component_query_service -from src.prompt.prompt_manager import prompt_manager -from src.services.message_service import ( - build_readable_messages_with_id, - get_messages_before_time_in_chat, - replace_user_references, - translate_pid_to_description, -) - -if TYPE_CHECKING: - from src.common.data_models.info_data_model import TargetPersonInfo - -logger = get_logger("planner") - -install(extra_lines=3) - - -class ActionPlanner: - def __init__(self, chat_id: str, action_manager: ActionManager): - self.chat_id = chat_id - self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" - self.action_manager = action_manager - # LLM规划器配置 - self.planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, request_type="planner" - ) # 用于动作规划 - - self.last_obs_time_mark = 0.0 - - self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] - - # 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个 - self.unknown_words_cache: OrderedDict[str, None] = OrderedDict() - self.unknown_words_cache_limit = 10 - - def find_message_by_id( - self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]] - ) -> Optional["SessionMessage"]: - # sourcery skip: use-next - """ - 根据message_id从message_id_list中查找对应的原始消息 - - Args: - message_id: 要查找的消息ID - message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...] - - Returns: - 找到的原始消息字典,如果未找到则返回None - """ - for item in message_id_list: - if item[0] == message_id: - return item[1] - return None - - def _replace_message_ids_with_text( - self, text: Optional[str], message_id_list: List[Tuple[str, "SessionMessage"]] - ) -> Optional[str]: - """将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号""" - if not text: - return text - - id_to_message = dict(message_id_list) - - # 匹配m后带2-4位数字,前后不是字母数字下划线 - pattern = r"(? str: - msg_id = match.group(0) - message = id_to_message.get(msg_id) - if not message: - logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样") - return msg_id - - msg_text = (message.processed_plain_text or "").strip() - if not msg_text: - logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样") - return msg_id - - # 替换 [picid:xxx] 为 [图片:描述] - pic_pattern = r"\[picid:([^\]]+)\]" - - def replace_pic_id(pic_match: re.Match) -> str: - pic_id = pic_match.group(1) - description = translate_pid_to_description(pic_id) - return f"[图片:{description}]" - - msg_text = re.sub(pic_pattern, replace_pic_id, msg_text) - - # 替换用户引用格式:回复 和 @ - platform = message.platform or "" - if not platform: - logger.warning( - f"{self.log_prefix}planner: message {message.message_id} has no platform set, bot-self detection will be skipped" - ) - msg_text = replace_user_references(msg_text, platform, replace_bot_name=True) - - # 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式) - # 匹配所有 格式,由于 replace_user_references 已经替换了回复<和@<格式, - # 这里匹配到的应该都是单独的格式 - user_ref_pattern = r"<([^:<>]+):([^:<>]+)>" - - def replace_user_ref(user_match: re.Match) -> str: - user_name = user_match.group(1) - user_id = user_match.group(2) - try: - # 检查是否是机器人自己 - if is_bot_self(platform, str(user_id)): - return f"{global_config.bot.nickname}(你)" - person = Person(platform=platform, user_id=user_id) - return person.person_name or user_name - except Exception: - # 如果解析失败,使用原始昵称 - return user_name - - msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text) - - preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..." - logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})") - return f"消息({msg_text})" - - return re.sub(pattern, _replace, text) - - def _parse_single_action( - self, - action_json: dict, - message_id_list: List[Tuple[str, "SessionMessage"]], - current_available_actions: List[Tuple[str, ActionInfo]], - extracted_reasoning: str = "", - ) -> List[ActionPlannerInfo]: - """解析单个action JSON并返回ActionPlannerInfo列表""" - action_planner_infos = [] - - try: - action = action_json.get("action", "no_reply") - # 使用 extracted_reasoning(整体推理文本)作为 reasoning - if extracted_reasoning: - reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) - if reasoning is None: - reasoning = extracted_reasoning - else: - reasoning = "未提供原因" - action_data = {key: value for key, value in action_json.items() if key not in ["action"]} - - # 非no_reply动作需要target_message_id - target_message = None - - target_message_id = action_json.get("target_message_id") - if target_message_id: - # 根据target_message_id查找原始消息 - target_message = self.find_message_by_id(target_message_id, message_id_list) - if target_message is None: - logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息") - # 选择最新消息作为target_message - target_message = message_id_list[-1][1] - else: - target_message = message_id_list[-1][1] - logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message") - - if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message): - logger.info( - f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply" - ) - reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}" - action = "no_reply" - target_message = None - - # 验证action是否可用 - available_action_names = [action_name for action_name, _ in current_available_actions] - internal_action_names = ["no_reply", "reply", "wait_time"] - - if action not in internal_action_names and action not in available_action_names: - logger.warning( - f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'" - ) - reasoning = ( - f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}" - ) - action = "no_reply" - - # 创建ActionPlannerInfo对象 - # 将列表转换为字典格式 - available_actions_dict = dict(current_available_actions) - action_planner_infos.append( - ActionPlannerInfo( - action_type=action, - reasoning=reasoning, - action_data=action_data, - action_message=target_message, - available_actions=available_actions_dict, - action_reasoning=extracted_reasoning or None, - ) - ) - - except Exception as e: - logger.error(f"{self.log_prefix}解析单个action时出错: {e}") - # 将列表转换为字典格式 - available_actions_dict = dict(current_available_actions) - action_planner_infos.append( - ActionPlannerInfo( - action_type="no_reply", - reasoning=f"解析单个action时出错: {e}", - action_data={}, - action_message=None, - available_actions=available_actions_dict, - action_reasoning=extracted_reasoning or None, - ) - ) - - return action_planner_infos - - def _is_message_from_self(self, message: "SessionMessage") -> bool: - """判断消息是否由机器人自身发送(支持多平台,包括 WebUI)""" - try: - return is_bot_self(message.platform or "", str(message.message_info.user_info.user_id)) - except AttributeError: - logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段") - return False - - def _update_unknown_words_cache(self, new_words: List[str]) -> None: - """ - 更新黑话缓存,将新的黑话加入缓存 - - Args: - new_words: 新提取的黑话列表 - """ - for word in new_words: - if not isinstance(word, str): - continue - word = word.strip() - if not word: - continue - - # 如果已存在,移到末尾(LRU) - if word in self.unknown_words_cache: - self.unknown_words_cache.move_to_end(word) - else: - # 添加新词 - self.unknown_words_cache[word] = None - # 如果超过限制,移除最老的 - if len(self.unknown_words_cache) > self.unknown_words_cache_limit: - self.unknown_words_cache.popitem(last=False) - logger.debug(f"{self.log_prefix}黑话缓存已满,移除最老的黑话") - - def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]: - """ - 合并新提取的黑话和缓存中的黑话 - - Args: - new_words: 新提取的黑话列表(可能为None) - - Returns: - 合并后的黑话列表(去重) - """ - # 清理新提取的黑话 - cleaned_new_words: List[str] = [] - if new_words: - for word in new_words: - if isinstance(word, str): - if word := word.strip(): - cleaned_new_words.append(word) - - # 获取缓存中的黑话列表 - cached_words = list(self.unknown_words_cache.keys()) - - # 合并并去重(保留顺序:新提取的在前,缓存的在后) - merged_words: List[str] = [] - seen = set() - - # 先添加新提取的 - for word in cleaned_new_words: - if word not in seen: - merged_words.append(word) - seen.add(word) - - # 再添加缓存的(如果不在新提取的列表中) - for word in cached_words: - if word not in seen: - merged_words.append(word) - seen.add(word) - - return merged_words - - def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None: - """ - 处理黑话缓存逻辑: - 1. 检查是否有 reply action 提取了 unknown_words - 2. 如果没有提取,移除最老的1个 - 3. 如果缓存数量大于5,移除最老的2个 - 4. 对于每个 reply action,合并缓存和新提取的黑话 - 5. 更新缓存 - - Args: - actions: 解析后的动作列表 - """ - # 先检查缓存数量,如果大于5,移除最老的2个 - if len(self.unknown_words_cache) > 5: - # 移除最老的2个 - removed_count = 0 - for _ in range(2): - if len(self.unknown_words_cache) > 0: - self.unknown_words_cache.popitem(last=False) - removed_count += 1 - if removed_count > 0: - logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存") - - # 检查是否有 reply action 提取了 unknown_words - has_extracted_unknown_words = False - for action in actions: - if action.action_type == "reply": - action_data = action.action_data or {} - unknown_words = action_data.get("unknown_words") - if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0: - has_extracted_unknown_words = True - break - - # 如果当前 plan 的 reply 没有提取,移除最老的1个 - if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0: - self.unknown_words_cache.popitem(last=False) - logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存") - - # 对于每个 reply action,合并缓存和新提取的黑话 - for action in actions: - if action.action_type == "reply": - action_data = action.action_data or {} - new_words = action_data.get("unknown_words") - - # 合并新提取的和缓存的黑话列表 - if merged_words := self._merge_unknown_words_with_cache(new_words): - action_data["unknown_words"] = merged_words - logger.debug( - f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个," - f"缓存 {len(self.unknown_words_cache)} 个,合并后 {len(merged_words)} 个" - ) - else: - # 如果没有合并后的黑话,移除 unknown_words 字段 - action_data.pop("unknown_words", None) - - # 更新缓存(将新提取的黑话加入缓存) - if new_words: - self._update_unknown_words_cache(new_words) - - async def plan( - self, - available_actions: Dict[str, ActionInfo], - loop_start_time: float = 0.0, - force_reply_message: Optional["SessionMessage"] = None, - ) -> List[ActionPlannerInfo]: - # sourcery skip: use-named-expression - """ - 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 - """ - plan_start = time.perf_counter() - - # 获取聊天上下文 - message_list_before_now = get_messages_before_time_in_chat( - chat_id=self.chat_id, - timestamp=time.time(), - limit=int(global_config.chat.max_context_size * 0.6), - filter_intercept_message_level=1, - ) - message_id_list: list[Tuple[str, "SessionMessage"]] = [] - chat_content_block, message_id_list = build_readable_messages_with_id( - messages=message_list_before_now, - timestamp_mode="normal_no_YMD", - read_mark=self.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - - message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :] - chat_content_block_short, message_id_list_short = build_readable_messages_with_id( - messages=message_list_before_now_short, - timestamp_mode="normal_no_YMD", - truncate=False, - show_actions=False, - ) - - self.last_obs_time_mark = time.time() - - # 获取必要信息 - is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info() - - # 应用激活类型过滤 - filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short) - - logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作") - - prompt_build_start = time.perf_counter() - # 构建包含所有动作的提示词 - prompt, message_id_list = await self.build_planner_prompt( - is_group_chat=is_group_chat, - chat_target_info=chat_target_info, - current_available_actions=filtered_actions, - chat_content_block=chat_content_block, - message_id_list=message_id_list, - ) - prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000 - - # 调用LLM获取决策 - reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner( - prompt=prompt, - message_id_list=message_id_list, - filtered_actions=filtered_actions, - available_actions=available_actions, - loop_start_time=loop_start_time, - ) - - # 如果有强制回复消息,确保回复该消息 - if force_reply_message: - # 检查是否已经有回复该消息的 action - has_reply_to_force_message = any( - action.action_type == "reply" - and action.action_message - and action.action_message.message_id == force_reply_message.message_id - for action in actions - ) - - # 如果没有回复该消息,强制添加回复 action - if not has_reply_to_force_message: - # 移除所有 no_reply action(如果有) - actions = [a for a in actions if a.action_type != "no_reply"] - - # 创建强制回复 action - available_actions_dict = dict(current_available_actions) - force_reply_action = ActionPlannerInfo( - action_type="reply", - reasoning="用户提及了我,必须回复该消息", - action_data={"loop_start_time": loop_start_time}, - action_message=force_reply_message, - available_actions=available_actions_dict, - action_reasoning=None, - ) - # 将强制回复 action 放在最前面 - actions.insert(0, force_reply_action) - logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作") - - logger.info( - f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}" - ) - - self.add_plan_log(reasoning, actions) - - try: - PlanReplyLogger.log_plan( - chat_id=self.chat_id, - prompt=prompt, - reasoning=reasoning, - raw_output=llm_raw_output, - raw_reasoning=llm_reasoning, - actions=actions, - timing={ - "prompt_build_ms": round(prompt_build_ms, 2), - "llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None, - "total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2), - "loop_start_time": loop_start_time, - }, - extra=None, - ) - except Exception: - logger.exception(f"{self.log_prefix}记录plan日志失败") - - return actions - - def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]): - self.plan_log.append((reasoning, time.time(), actions)) - if len(self.plan_log) > 20: - self.plan_log.pop(0) - - def add_plan_excute_log(self, result: str): - self.plan_log.append(("", time.time(), result)) - if len(self.plan_log) > 20: - self.plan_log.pop(0) - - def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str: - """ - 获取计划日志字符串 - - Args: - max_action_records: 显示多少条最新的action记录,默认2 - max_execution_records: 显示多少条最新执行结果记录,默认8 - - Returns: - 格式化的日志字符串 - """ - action_records = [] - execution_records = [] - - # 从后往前遍历,收集最新的记录 - for reasoning, timestamp, content in reversed(self.plan_log): - if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): - if len(action_records) < max_action_records: - action_records.append((reasoning, timestamp, content, "action")) - elif len(execution_records) < max_execution_records: - execution_records.append((reasoning, timestamp, content, "execution")) - - # 合并所有记录并按时间戳排序 - all_records = action_records + execution_records - all_records.sort(key=lambda x: x[1]) # 按时间戳排序 - - plan_log_str = "" - - # 按时间顺序添加所有记录 - for reasoning, timestamp, content, record_type in all_records: - time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S") - if record_type == "action": - # plan_log_str += f"{time_str}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n" - plan_log_str += f"{time_str}:{reasoning}\n" - else: - plan_log_str += f"{time_str}:你执行了action:{content}\n" - - return plan_log_str - - async def build_planner_prompt( - self, - is_group_chat: bool, - chat_target_info: Optional["TargetPersonInfo"], - current_available_actions: Dict[str, ActionInfo], - message_id_list: List[Tuple[str, "SessionMessage"]], - chat_content_block: str = "", - interest: str = "", - ) -> tuple[str, List[Tuple[str, "SessionMessage"]]]: - """构建 Planner LLM 的提示词 (获取模板并填充数据)""" - try: - actions_before_now_block = self.get_plan_log_str() - - # 构建聊天上下文描述 - chat_context_description = "你现在正在一个群聊中" - - # 构建动作选项块 - action_options_block = await self._build_action_options_block(current_available_actions) - - # 其他信息 - moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" - time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - bot_name = global_config.bot.nickname - bot_nickname = ( - f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else "" - ) - name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。" - - # 根据 think_mode 配置决定 reply action 的示例 JSON - # 在 JSON 中直接作为 action 参数携带 unknown_words - if global_config.chat.think_mode == "classic": - reply_action_example = "" - if global_config.chat.llm_quote: - reply_action_example += ( - "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" - ) - reply_action_example += ( - '{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]' - ) - if global_config.chat.llm_quote: - reply_action_example += ', "quote":"如果需要引用该message,设置为true"' - reply_action_example += "}" - else: - reply_action_example = ( - "5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n" - ) - if global_config.chat.llm_quote: - reply_action_example += ( - "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n" - ) - reply_action_example += ( - '{{"action":"reply", "think_level":数值等级(0或1), ' - '"target_message_id":"消息id(m+数字)", ' - '"unknown_words":["词语1","词语2"]' - ) - if global_config.chat.llm_quote: - reply_action_example += ', "quote":"如果需要引用该message,设置为true"' - reply_action_example += "}" - - planner_prompt_template = prompt_manager.get_prompt("planner") - planner_prompt_template.add_context("time_block", time_block) - planner_prompt_template.add_context("chat_context_description", chat_context_description) - planner_prompt_template.add_context("chat_content_block", chat_content_block) - planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block) - planner_prompt_template.add_context("action_options_text", action_options_block) - planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block) - planner_prompt_template.add_context("name_block", name_block) - planner_prompt_template.add_context("interest", interest) - planner_prompt_template.add_context("plan_style", global_config.personality.plan_style) - planner_prompt_template.add_context("reply_action_example", reply_action_example) - prompt = await prompt_manager.render_prompt(planner_prompt_template) - - return prompt, message_id_list - except Exception as e: - logger.error(f"构建 Planner 提示词时出错: {e}") - logger.error(traceback.format_exc()) - return "构建 Planner Prompt 时出错", [] - - def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]: - """ - 获取 Planner 需要的必要信息 - """ - is_group_chat = True - is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) - logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") - - current_available_actions_dict = self.action_manager.get_using_actions() - - # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore - ComponentType.ACTION - ) - current_available_actions = {} - for action_name in current_available_actions_dict: - if action_name in all_registered_actions: - current_available_actions[action_name] = all_registered_actions[action_name] - else: - logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到") - - return is_group_chat, chat_target_info, current_available_actions - - def _filter_actions_by_activation_type( - self, available_actions: Dict[str, ActionInfo], chat_content_block: str - ) -> Dict[str, ActionInfo]: - """根据激活类型过滤动作""" - filtered_actions = {} - - for action_name, action_info in available_actions.items(): - if action_info.activation_type == ActionActivationType.NEVER: - logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过") - continue - elif action_info.activation_type == ActionActivationType.ALWAYS: - filtered_actions[action_name] = action_info - elif action_info.activation_type == ActionActivationType.RANDOM: - if random.random() < action_info.random_activation_probability: - filtered_actions[action_name] = action_info - elif action_info.activation_type == ActionActivationType.KEYWORD: - if action_info.activation_keywords: - for keyword in action_info.activation_keywords: - if keyword in chat_content_block: - filtered_actions[action_name] = action_info - break - else: - logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理") - - return filtered_actions - - async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str: - """构建动作选项块""" - if not current_available_actions: - return "" - - action_options_block = "" - for action_name, action_info in current_available_actions.items(): - # 构建参数文本 - param_text = "" - if action_info.action_parameters: - param_text = "\n" - for param_name, param_description in action_info.action_parameters.items(): - param_text += f' "{param_name}":"{param_description}"\n' - param_text = param_text.rstrip("\n") - - # 构建要求文本 - require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require) - - parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)" - - # 获取动作提示模板并填充 - using_action_prompt = prompt_manager.get_prompt("action") - using_action_prompt.add_context("action_name", action_name) - using_action_prompt.add_context("action_description", action_info.description) - using_action_prompt.add_context("action_parameters", param_text) - using_action_prompt.add_context("action_require", require_text) - using_action_prompt.add_context("parallel_text", parallel_text) - using_action_rendered_prompt = await prompt_manager.render_prompt(using_action_prompt) - - action_options_block += using_action_rendered_prompt - - return action_options_block - - async def _execute_main_planner( - self, - prompt: str, - message_id_list: List[Tuple[str, "SessionMessage"]], - filtered_actions: Dict[str, ActionInfo], - available_actions: Dict[str, ActionInfo], - loop_start_time: float, - ) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]: - """执行主规划器""" - llm_content = None - actions: List[ActionPlannerInfo] = [] - llm_reasoning = None - llm_duration_ms = None - - try: - # 调用LLM - llm_start = time.perf_counter() - llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) - llm_duration_ms = (time.perf_counter() - llm_start) * 1000 - llm_reasoning = reasoning_content - - if global_config.debug.show_planner_prompt: - logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") - logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") - if reasoning_content: - logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") - else: - logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}") - logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}") - if reasoning_content: - logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}") - - except Exception as req_e: - logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") - return ( - f"LLM 请求失败,模型出现问题: {req_e}", - [ - ActionPlannerInfo( - action_type="no_reply", - reasoning=f"LLM 请求失败,模型出现问题: {req_e}", - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ], - llm_content, - llm_reasoning, - llm_duration_ms, - ) - - # 解析LLM响应 - extracted_reasoning = "" - if llm_content: - try: - json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content) - extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or "" - if json_objects: - logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") - filtered_actions_list = list(filtered_actions.items()) - for json_obj in json_objects: - actions.extend( - self._parse_single_action( - json_obj, message_id_list, filtered_actions_list, extracted_reasoning - ) - ) - else: - # 尝试解析为直接的JSON - logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}") - extracted_reasoning = "LLM没有返回可用动作" - actions = self._create_no_reply("LLM没有返回可用动作", available_actions) - - except Exception as json_e: - logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'") - extracted_reasoning = f"解析LLM响应JSON失败: {json_e}" - actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions) - traceback.print_exc() - else: - extracted_reasoning = "规划器没有获得LLM响应" - actions = self._create_no_reply("规划器没有获得LLM响应", available_actions) - - # 添加循环开始时间到所有非no_reply动作 - for action in actions: - action.action_data = action.action_data or {} - action.action_data["loop_start_time"] = loop_start_time - - # 去重:如果同一个动作被选择了多次,随机选择其中一个 - if actions: - shuffled = actions.copy() - random.shuffle(shuffled) - actions = list({a.action_type: a for a in shuffled}.values()) - - # 处理黑话缓存逻辑 - self._process_unknown_words_cache(actions) - - logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") - - return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms - - def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: - """创建no_reply""" - return [ - ActionPlannerInfo( - action_type="no_reply", - reasoning=reasoning, - action_data={}, - action_message=None, - available_actions=available_actions, - ) - ] - - def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]: - # sourcery skip: for-append-to-extend - """从Markdown格式的内容中提取JSON对象和推理内容""" - json_objects = [] - reasoning_content = "" - - # 使用正则表达式查找```json包裹的JSON内容 - json_pattern = r"```json\s*(.*?)\s*```" - markdown_matches = re.findall(json_pattern, content, re.DOTALL) - - # 提取JSON之前的内容作为推理文本 - first_json_pos = len(content) - if markdown_matches: - # 找到第一个```json的位置 - first_json_pos = content.find("```json") - if first_json_pos > 0: - reasoning_content = content[:first_json_pos].strip() - # 清理推理内容中的注释标记 - reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) - reasoning_content = reasoning_content.strip() - - # 处理```json包裹的JSON - for match in markdown_matches: - try: - # 清理可能的注释和格式问题 - json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释 - json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释 - if json_str := json_str.strip(): - # 尝试按行分割,每行可能是一个JSON对象 - lines = [line.strip() for line in json_str.split("\n") if line.strip()] - for line in lines: - with contextlib.suppress(json.JSONDecodeError): - json_obj = json.loads(repair_json(line)) - if isinstance(json_obj, dict): - if json_obj: - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict) and item: - json_objects.append(item) - - # 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组 - if not json_objects: - json_obj = json.loads(repair_json(json_str)) - if isinstance(json_obj, dict): - # 过滤掉空字典 - if json_obj: - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict) and item: - json_objects.append(item) - except Exception as e: - logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...") - continue - - # 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```) - if not json_objects: - json_start_pos = content.find("```json") - if json_start_pos != -1: - # 找到```json之后的内容 - json_content_start = json_start_pos + 7 # ```json的长度 - # 提取从```json之后到内容结尾的所有内容 - incomplete_json_str = content[json_content_start:].strip() - - # 提取JSON之前的内容作为推理文本 - if json_start_pos > 0: - reasoning_content = content[:json_start_pos].strip() - reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) - reasoning_content = reasoning_content.strip() - - if incomplete_json_str: - try: - # 清理可能的注释和格式问题 - json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str) - json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) - json_str = json_str.strip() - - if json_str: - # 尝试按行分割,每行可能是一个JSON对象 - lines = [line.strip() for line in json_str.split("\n") if line.strip()] - for line in lines: - try: - json_obj = json.loads(repair_json(line)) - if isinstance(json_obj, dict): - # 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况 - if json_obj: - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict) and item: - json_objects.append(item) - except json.JSONDecodeError: - pass - - # 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组 - if not json_objects: - try: - json_obj = json.loads(repair_json(json_str)) - if isinstance(json_obj, dict): - # 过滤掉空字典 - if json_obj: - json_objects.append(json_obj) - elif isinstance(json_obj, list): - for item in json_obj: - if isinstance(item, dict) and item: - json_objects.append(item) - except Exception as e: - logger.debug(f"尝试解析不完整的JSON代码块失败: {e}") - except Exception as e: - logger.debug(f"处理不完整的JSON代码块时出错: {e}") - - return json_objects, reasoning_content diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 4ffa14a7..10630ecc 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -10,8 +10,8 @@ from datetime import datetime from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo from src.common.data_models.mai_message_data_model import MaiMessage @@ -56,14 +56,12 @@ class DefaultReplyer: chat_stream: 当前绑定的聊天会话。 request_type: LLM 请求类型标识。 """ - self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) + self.express_model = LLMServiceClient( + task_name="replyer", request_type=request_type + ) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - from src.chat.tool_executor import ToolExecutor - - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) - async def generate_reply_with_context( self, extra_info: str = "", @@ -397,6 +395,11 @@ class DefaultReplyer: return f"{expression_habits_title}\n{expression_habits_block}", selected_ids async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: + del chat_history + del sender + del target + del enable_tool + return "" """构建工具信息块 Args: @@ -413,9 +416,7 @@ class DefaultReplyer: try: # 使用工具执行器获取信息 - tool_results, _, _ = await self.tool_executor.execute_from_chat_message( - sender=sender, target_message=target, chat_history=chat_history, return_details=False - ) + tool_results = [] if tool_results: tool_info_str = "以下是你通过工具获取到的实时信息:\n" @@ -576,29 +577,6 @@ class DefaultReplyer: duration = end_time - start_time return name, result, duration - async def _build_disabled_jargon_explanation(self) -> str: - """当关闭黑话解释时使用的占位协程,避免额外的LLM调用""" - return "" - - async def _build_unknown_words_jargon(self, unknown_words: Optional[List[str]], chat_id: str) -> str: - """针对 Planner 提供的未知词语列表执行黑话检索""" - if not unknown_words: - return "" - # 清洗未知词语列表,只保留非空字符串 - concepts: List[str] = [] - for item in unknown_words: - if isinstance(item, str): - s = item.strip() - if s: - concepts.append(s) - if not concepts: - return "" - try: - return await retrieve_concepts_with_jargon(concepts, chat_id) - except Exception as e: - logger.error(f"未知词语黑话检索失败: {e}") - return "" - async def _build_jargon_explanation( self, chat_id: str, @@ -608,19 +586,14 @@ class DefaultReplyer: ) -> str: """ 统一的黑话解释构建函数: - - 根据 enable_jargon_explanation / jargon_mode 决定具体策略 + - 根据 enable_jargon_explanation 决定是否启用 """ + del unknown_words enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True) if not enable_jargon_explanation: return "" - jargon_mode = getattr(global_config.expression, "jargon_mode", "context") - - # planner 模式:仅使用 Planner 的 unknown_words - if jargon_mode == "planner": - return await self._build_unknown_words_jargon(unknown_words, chat_id) - - # 默认 / context 模式:使用上下文自动匹配黑话 + # 使用上下文自动匹配黑话 try: return await explain_jargon_in_context(chat_id, messages_short, chat_talking_prompt_short) or "" except Exception as e: @@ -1158,9 +1131,11 @@ class DefaultReplyer: # else: # logger.debug(f"\nreplyer_Prompt:{prompt}\n") - content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( - prompt - ) + generation_result = await self.express_model.generate_response(prompt) + content = generation_result.response + reasoning_content = generation_result.reasoning + model_name = generation_result.model_name + tool_calls = generation_result.tool_calls # 移除 content 前后的换行符和空格 content = content.strip() @@ -1169,6 +1144,10 @@ class DefaultReplyer: return content, reasoning_content, model_name, tool_calls async def get_prompt_info(self, message: str, sender: str, target: str): + del message + del sender + del target + return "" related_info = "" start_time = time.time() try: @@ -1200,17 +1179,21 @@ class DefaultReplyer: template_prompt.add_context("sender", sender) template_prompt.add_context("target_message", target) prompt = await prompt_manager.render_prompt(template_prompt) - _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( - prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[search_knowledge_tool.get_tool_definition()], + generation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="utils", + request_type="replyer.lpmm_knowledge", + prompt=prompt, + tool_options=[search_knowledge_tool.get_tool_definition()], + ) ) + tool_calls = generation_result.completion.tool_calls # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") if tool_calls: - result = await self.tool_executor.execute_tool_call(tool_calls[0]) + result = None end_time = time.time() if not result or not result.get("content"): logger.debug("从LPMM知识库获取知识失败,返回空知识...") diff --git a/src/chat/replyer/maisaka_generator.py b/src/chat/replyer/maisaka_generator.py new file mode 100644 index 00000000..7b1a1043 --- /dev/null +++ b/src/chat/replyer/maisaka_generator.py @@ -0,0 +1,419 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional, Tuple + +import random +import time + +from sqlmodel import select + +from src.chat.message_receive.chat_manager import BotChatSession +from src.common.database.database import get_db_session +from src.common.database.database_model import Expression +from src.common.data_models.reply_generation_data_models import ( + GenerationMetrics, + LLMCompletionResult, + ReplyGenerationResult, +) +from src.common.logger import get_logger +from src.common.prompt_i18n import load_prompt +from src.config.config import global_config +from src.core.types import ActionInfo +from src.services.llm_service import LLMServiceClient + +from src.chat.message_receive.message import SessionMessage +from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage +from src.maisaka.message_adapter import parse_speaker_content + +logger = get_logger("replyer") + + +@dataclass +class MaisakaReplyContext: + """Maisaka replyer 使用的回复上下文。""" + + expression_habits: str = "" + selected_expression_ids: List[int] = field(default_factory=list) + + +@dataclass +class _ExpressionRecord: + """表达方式的轻量记录。""" + + expression_id: Optional[int] + situation: str + style: str + + +class MaisakaReplyGenerator: + """生成 Maisaka 的最终可见回复。""" + + def __init__( + self, + chat_stream: Optional[BotChatSession] = None, + request_type: str = "maisaka_replyer", + ) -> None: + self.chat_stream = chat_stream + self.request_type = request_type + self.express_model = LLMServiceClient( + task_name="replyer", + request_type=request_type, + ) + self._personality_prompt = self._build_personality_prompt() + + def _build_personality_prompt(self) -> str: + """构建 replyer 使用的人设描述。""" + try: + bot_name = global_config.bot.nickname + alias_names = global_config.bot.alias_names + bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else "" + + prompt_personality = global_config.personality.personality + if ( + hasattr(global_config.personality, "states") + and global_config.personality.states + and hasattr(global_config.personality, "state_probability") + and global_config.personality.state_probability > 0 + and random.random() < global_config.personality.state_probability + ): + prompt_personality = random.choice(global_config.personality.states) + + return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};" + except Exception as exc: + logger.warning(f"构建 Maisaka 人设提示词失败: {exc}") + return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。" + + @staticmethod + def _normalize_content(content: str, limit: int = 500) -> str: + normalized = " ".join((content or "").split()) + if len(normalized) > limit: + return normalized[:limit] + "..." + return normalized + + @staticmethod + def _format_message_time(message: LLMContextMessage) -> str: + return message.timestamp.strftime("%H:%M:%S") + + @staticmethod + def _extract_visible_assistant_reply(message: AssistantMessage) -> str: + del message + return "" + + def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str: + speaker_name, body = parse_speaker_content(message.processed_plain_text.strip()) + bot_nickname = global_config.bot.nickname.strip() or "Bot" + if speaker_name == bot_nickname: + return self._normalize_content(body.strip()) + return "" + + @staticmethod + def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]: + """按说话人拆分用户消息。""" + segments: List[tuple[Optional[str], str]] = [] + current_speaker: Optional[str] = None + current_lines: List[str] = [] + + for raw_line in raw_content.splitlines(): + speaker_name, content_body = parse_speaker_content(raw_line) + if speaker_name is not None: + if current_lines: + segments.append((current_speaker, "\n".join(current_lines))) + current_speaker = speaker_name + current_lines = [content_body] + continue + + current_lines.append(raw_line) + + if current_lines: + segments.append((current_speaker, "\n".join(current_lines))) + + return segments + + def _format_chat_history(self, messages: List[LLMContextMessage]) -> str: + """格式化 replyer 使用的可见聊天记录。""" + bot_nickname = global_config.bot.nickname.strip() or "Bot" + parts: List[str] = [] + + for message in messages: + timestamp = self._format_message_time(message) + + if isinstance(message, (ReferenceMessage, ToolResultMessage)): + continue + + if isinstance(message, SessionBackedMessage): + guided_reply = self._extract_guided_bot_reply(message) + if guided_reply: + parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}") + continue + + raw_content = message.processed_plain_text + for speaker_name, content_body in self._split_user_message_segments(raw_content): + content = self._normalize_content(content_body) + if not content: + continue + visible_speaker = speaker_name or global_config.maisaka.user_name.strip() or "User" + parts.append(f"{timestamp} {visible_speaker}: {content}") + continue + + if isinstance(message, AssistantMessage): + visible_reply = self._extract_visible_assistant_reply(message) + if visible_reply: + parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}") + + return "\n".join(parts) + + def _build_prompt( + self, + chat_history: List[LLMContextMessage], + reply_reason: str, + expression_habits: str = "", + ) -> str: + """构建 Maisaka replyer 提示词。""" + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + formatted_history = self._format_chat_history(chat_history) + + try: + system_prompt = load_prompt( + "maidairy_replyer", + bot_name=global_config.bot.nickname, + time_block=f"当前时间:{current_time}", + identity=self._personality_prompt, + reply_style=global_config.personality.reply_style, + ) + except Exception: + system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。" + + extra_sections: List[str] = [] + if expression_habits.strip(): + extra_sections.append(expression_habits.strip()) + + user_sections = [ + f"当前时间:{current_time}", + f"【聊天记录】\n{formatted_history}", + ] + if extra_sections: + user_sections.append("\n\n".join(extra_sections)) + user_sections.append(f"【你的想法】\n{reply_reason}") + user_sections.append("现在,你说:") + + user_prompt = "\n\n".join(user_sections) + return f"System: {system_prompt}\n\nUser: {user_prompt}" + + def _resolve_session_id(self, stream_id: Optional[str]) -> str: + """解析当前回复使用的会话 ID。""" + if stream_id: + return stream_id + if self.chat_stream is not None: + return self.chat_stream.session_id + return "" + + async def _build_reply_context( + self, + chat_history: List[LLMContextMessage], + reply_message: Optional[SessionMessage], + reply_reason: str, + stream_id: Optional[str], + ) -> MaisakaReplyContext: + """在 replyer 内部构建表达习惯和黑话解释。""" + session_id = self._resolve_session_id(stream_id) + if not session_id: + logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识") + return MaisakaReplyContext() + + expression_habits, selected_expression_ids = self._build_expression_habits( + session_id=session_id, + chat_history=chat_history, + reply_message=reply_message, + reply_reason=reply_reason, + ) + return MaisakaReplyContext( + expression_habits=expression_habits, + selected_expression_ids=selected_expression_ids, + ) + + def _build_expression_habits( + self, + session_id: str, + chat_history: List[LLMContextMessage], + reply_message: Optional[SessionMessage], + reply_reason: str, + ) -> tuple[str, List[int]]: + """查询并格式化适合当前会话的表达习惯。""" + del chat_history + del reply_message + del reply_reason + + expression_records = self._load_expression_records(session_id) + if not expression_records: + return "", [] + + lines: List[str] = [] + selected_ids: List[int] = [] + for expression in expression_records: + if expression.expression_id is not None: + selected_ids.append(expression.expression_id) + lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。") + + block = "【表达习惯参考】\n" + "\n".join(lines) + logger.info( + f"已构建 Maisaka 表达习惯: 会话标识={session_id} " + f"数量={len(selected_ids)} 表达编号={selected_ids!r}" + ) + return block, selected_ids + + def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]: + """提取表达方式静态数据,避免 detached ORM 对象。""" + with get_db_session(auto_commit=False) as session: + query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined] + if global_config.expression.expression_checked_only: + query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined] + + query = query.where( + (Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] + ).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined] + + expressions = session.exec(query.limit(5)).all() + return [ + _ExpressionRecord( + expression_id=expression.id, + situation=expression.situation, + style=expression.style, + ) + for expression in expressions + ] + + async def generate_reply_with_context( + self, + extra_info: str = "", + reply_reason: str = "", + available_actions: Optional[Dict[str, ActionInfo]] = None, + chosen_actions: Optional[List[object]] = None, + enable_tool: bool = True, + from_plugin: bool = True, + stream_id: Optional[str] = None, + reply_message: Optional[SessionMessage] = None, + reply_time_point: Optional[float] = None, + think_level: int = 1, + unknown_words: Optional[List[str]] = None, + log_reply: bool = True, + chat_history: Optional[List[LLMContextMessage]] = None, + expression_habits: str = "", + selected_expression_ids: Optional[List[int]] = None, + ) -> Tuple[bool, ReplyGenerationResult]: + """结合上下文生成 Maisaka 的最终可见回复。""" + del available_actions + del chosen_actions + del enable_tool + del extra_info + del from_plugin + del log_reply + del reply_time_point + del think_level + del unknown_words + + result = ReplyGenerationResult() + if chat_history is None: + result.error_message = "聊天历史为空" + return False, result + + logger.info( + f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} " + f"历史消息数={len(chat_history)} 目标消息编号=" + f"{reply_message.message_id if reply_message else None}" + ) + + filtered_history = [ + message + for message in chat_history + if not isinstance(message, (ReferenceMessage, ToolResultMessage)) + ] + + logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}") + + # Validate that express_model is properly initialized + if self.express_model is None: + logger.error("Maisaka 回复器的回复模型未初始化") + result.error_message = "回复模型尚未初始化" + return False, result + + try: + reply_context = await self._build_reply_context( + chat_history=filtered_history, + reply_message=reply_message, + reply_reason=reply_reason or "", + stream_id=stream_id, + ) + except Exception as exc: + import traceback + logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}") + result.error_message = f"构建回复上下文失败: {exc}" + return False, result + + merged_expression_habits = expression_habits.strip() or reply_context.expression_habits + result.selected_expression_ids = ( + list(selected_expression_ids) + if selected_expression_ids is not None + else list(reply_context.selected_expression_ids) + ) + + logger.info( + f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} " + f"已选表达编号={result.selected_expression_ids!r}" + ) + + try: + prompt = self._build_prompt( + chat_history=filtered_history, + reply_reason=reply_reason or "", + expression_habits=merged_expression_habits, + ) + except Exception as exc: + import traceback + logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}") + result.error_message = f"构建提示词失败: {exc}" + return False, result + + result.completion.request_prompt = prompt + + if global_config.debug.show_replyer_prompt: + logger.info(f"\nMaisaka 回复器提示词:\n{prompt}\n") + + started_at = time.perf_counter() + try: + generation_result = await self.express_model.generate_response(prompt) + except Exception as exc: + logger.exception("Maisaka 回复器调用失败") + result.error_message = str(exc) + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - started_at) * 1000, 2), + ) + return False, result + + response_text = (generation_result.response or "").strip() + result.success = bool(response_text) + result.completion = LLMCompletionResult( + request_prompt=prompt, + response_text=response_text, + reasoning_text=generation_result.reasoning or "", + model_name=generation_result.model_name or "", + tool_calls=generation_result.tool_calls or [], + ) + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - started_at) * 1000, 2), + ) + + if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text: + logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") + + if not result.success: + result.error_message = "回复器返回了空内容" + logger.warning("Maisaka 回复器返回了空内容") + return False, result + + logger.info( + f"Maisaka 回复器生成成功: 回复文本={response_text!r} " + f"总耗时毫秒={result.metrics.overall_ms} " + f"已选表达编号={result.selected_expression_ids!r}" + ) + result.text_fragments = [response_text] + return True, result diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index c125a42f..bd1c7bbc 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -9,8 +9,8 @@ from datetime import datetime from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo from src.common.data_models.mai_message_data_model import MaiMessage @@ -52,15 +52,13 @@ class PrivateReplyer: chat_stream: 当前绑定的聊天会话。 request_type: LLM 请求类型标识。 """ - self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) + self.express_model = LLMServiceClient( + task_name="replyer", request_type=request_type + ) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) # self.memory_activator = MemoryActivator() - from src.chat.tool_executor import ToolExecutor - - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) - async def generate_reply_with_context( self, extra_info: str = "", @@ -290,6 +288,11 @@ class PrivateReplyer: return f"{expression_habits_title}\n{expression_habits_block}", selected_ids async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: + del chat_history + del sender + del target + del enable_tool + return "" """构建工具信息块 Args: @@ -306,9 +309,7 @@ class PrivateReplyer: try: # 使用工具执行器获取信息 - tool_results, _, _ = await self.tool_executor.execute_from_chat_message( - sender=sender, target_message=target, chat_history=chat_history, return_details=False - ) + tool_results = [] if tool_results: tool_info_str = "以下是你通过工具获取到的实时信息:\n" @@ -997,9 +998,11 @@ class PrivateReplyer: else: logger.debug(f"\n{prompt}\n") - content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( - prompt - ) + generation_result = await self.express_model.generate_response(prompt) + content = generation_result.response + reasoning_content = generation_result.reasoning + model_name = generation_result.model_name + tool_calls = generation_result.tool_calls content = content.strip() diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index eb430585..6ba9ce02 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,65 +1,82 @@ -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional -from src.common.logger import get_logger from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager -from src.chat.replyer.group_generator import DefaultReplyer -from src.chat.replyer.private_generator import PrivateReplyer +from src.common.logger import get_logger + +if TYPE_CHECKING: + from src.chat.replyer.group_generator import DefaultReplyer + from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator + from src.chat.replyer.private_generator import PrivateReplyer logger = get_logger("ReplyerManager") class ReplyerManager: - def __init__(self): - self._repliers: Dict[str, DefaultReplyer | PrivateReplyer] = {} + """统一管理不同类型的回复生成器。""" + + def __init__(self) -> None: + self._repliers: Dict[str, Any] = {} def get_replyer( self, chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, request_type: str = "replyer", - ) -> Optional[DefaultReplyer | PrivateReplyer]: - """ - 获取或创建回复器实例。 - - model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。 - 后续调用将返回已缓存的实例,忽略 model_configs 参数。 - """ + replyer_type: str = "default", + ) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]: + """按会话和 replyer 类型获取实例。""" stream_id = chat_stream.session_id if chat_stream else chat_id if not stream_id: - logger.warning("[ReplyerManager] 缺少 stream_id,无法获取回复器。") + logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer") return None - # 如果已有缓存实例,直接返回 - if stream_id in self._repliers: - logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。") - return self._repliers[stream_id] + cache_key = f"{replyer_type}:{stream_id}" + if cache_key in self._repliers: + logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}") + return self._repliers[cache_key] - # 如果没有缓存,则创建新实例(首次初始化) - logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。") - - target_stream = chat_stream + target_stream = chat_stream or _chat_manager.get_session_by_session_id(stream_id) if not target_stream: - target_stream = _chat_manager.get_session_by_session_id(stream_id) - - if not target_stream: - logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。") + logger.warning(f"[ReplyerManager] 未找到会话,stream_id={stream_id}") return None - # model_configs 只在此时(初始化时)生效 - if target_stream.is_group_session: - replyer = DefaultReplyer( - chat_stream=target_stream, - request_type=request_type, - ) - else: - replyer = PrivateReplyer( - chat_stream=target_stream, - request_type=request_type, - ) + logger.info( + f"[ReplyerManager] 开始创建 replyer: cache_key={cache_key}, " + f"replyer_type={replyer_type}, is_group_session={target_stream.is_group_session}" + ) - self._repliers[stream_id] = replyer + try: + if replyer_type == "maisaka": + logger.info("[ReplyerManager] importing MaisakaReplyGenerator") + from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator + + replyer = MaisakaReplyGenerator( + chat_stream=target_stream, + request_type=request_type, + ) + elif target_stream.is_group_session: + logger.info("[ReplyerManager] importing DefaultReplyer") + from src.chat.replyer.group_generator import DefaultReplyer + + replyer = DefaultReplyer( + chat_stream=target_stream, + request_type=request_type, + ) + else: + logger.info("[ReplyerManager] importing PrivateReplyer") + from src.chat.replyer.private_generator import PrivateReplyer + + replyer = PrivateReplyer( + chat_stream=target_stream, + request_type=request_type, + ) + except Exception: + logger.exception(f"[ReplyerManager] 创建 replyer 失败: cache_key={cache_key}") + raise + + self._repliers[cache_key] = replyer + logger.info(f"[ReplyerManager] replyer 创建完成: cache_key={cache_key}") return replyer -# 创建一个全局实例 replyer_manager = ReplyerManager() diff --git a/src/chat/tool_executor.py b/src/chat/tool_executor.py deleted file mode 100644 index aa99fce8..00000000 --- a/src/chat/tool_executor.py +++ /dev/null @@ -1,248 +0,0 @@ -"""工具执行器。 - -独立的工具执行组件,可以直接输入聊天消息内容, -自动判断并执行相应的工具,返回结构化的工具执行结果。 -""" - -from typing import Any, Dict, List, Optional, Tuple - -import hashlib -import time - -from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.core.announcement_manager import global_announcement_manager -from src.llm_models.payload_content import ToolCall -from src.llm_models.utils_model import LLMRequest -from src.plugin_runtime.component_query import component_query_service -from src.prompt.prompt_manager import prompt_manager - -logger = get_logger("tool_use") - - -class ToolExecutor: - """独立的工具执行器组件 - - 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 - """ - - def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): - from src.chat.message_receive.chat_manager import chat_manager as _chat_manager - - self.chat_id = chat_id - self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id) - self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" - - self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") - - self.enable_cache = enable_cache - self.cache_ttl = cache_ttl - self.tool_cache: Dict[str, dict] = {} - - logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") - - async def execute_from_chat_message( - self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict[str, Any]], List[str], str]: - """从聊天消息执行工具""" - - cache_key = self._generate_cache_key(target_message, chat_history, sender) - if cached_result := self._get_from_cache(cache_key): - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") - if not return_details: - return cached_result, [], "" - used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "" - - tools = self._get_tool_definitions() - if not tools: - logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容") - return [], [], "" - - prompt_template = prompt_manager.get_prompt("tool_executor") - prompt_template.add_context("target_message", target_message) - prompt_template.add_context("chat_history", chat_history) - prompt_template.add_context("sender", sender) - prompt_template.add_context("bot_name", global_config.bot.nickname) - prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) - prompt = await prompt_manager.render_prompt(prompt_template) - - logger.debug(f"{self.log_prefix}开始LLM工具调用分析") - - response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( - prompt=prompt, tools=tools, raise_when_empty=False - ) - - tool_results, used_tools = await self.execute_tool_calls(tool_calls) - - if tool_results: - self._set_cache(cache_key, tool_results) - - if used_tools: - logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}") - - if return_details: - return tool_results, used_tools, prompt - return tool_results, [], "" - - def _get_tool_definitions(self) -> List[Dict[str, Any]]: - """获取 LLM 可用的工具定义列表""" - all_tools = component_query_service.get_llm_available_tools() - user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools] - - async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: - """执行工具调用列表""" - tool_results: List[Dict[str, Any]] = [] - used_tools: List[str] = [] - - if not tool_calls: - logger.debug(f"{self.log_prefix}无需执行工具") - return [], [] - - func_names = [call.func_name for call in tool_calls if call.func_name] - logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}") - - for tool_call in tool_calls: - tool_name = tool_call.func_name - try: - logger.debug(f"{self.log_prefix}执行工具: {tool_name}") - result = await self.execute_tool_call(tool_call) - - if result: - tool_info = { - "type": result.get("type", "unknown_type"), - "id": result.get("id", f"tool_exec_{time.time()}"), - "content": result.get("content", ""), - "tool_name": tool_name, - "timestamp": time.time(), - } - content = tool_info["content"] - if not isinstance(content, (str, list, tuple)): - tool_info["content"] = str(content) - content_check = tool_info["content"] - if (isinstance(content_check, str) and not content_check.strip()) or ( - isinstance(content_check, (list, tuple)) and len(content_check) == 0 - ): - logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示") - continue - - tool_results.append(tool_info) - used_tools.append(tool_name) - preview = str(content)[:200] - logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") - except Exception as e: - logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") - error_info = { - "type": "tool_error", - "id": f"tool_error_{time.time()}", - "content": f"工具{tool_name}执行失败: {str(e)}", - "tool_name": tool_name, - "timestamp": time.time(), - } - tool_results.append(error_info) - - return tool_results, used_tools - - async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]: - """执行单个工具调用""" - function_name = tool_call.func_name - function_args = tool_call.args or {} - function_args["llm_called"] = True - - executor = component_query_service.get_tool_executor(function_name) - if not executor: - logger.warning(f"未知工具名称: {function_name}") - return None - - result = await executor(function_args) - if result: - return { - "tool_call_id": tool_call.call_id, - "role": "tool", - "name": function_name, - "type": "function", - "content": result["content"], - } - return None - - async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: - """直接执行指定工具""" - try: - tool_call = ToolCall( - call_id=f"direct_tool_{time.time()}", - func_name=tool_name, - args=tool_args, - ) - - logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - result = await self.execute_tool_call(tool_call) - - if result: - tool_info = { - "type": result.get("type", "unknown_type"), - "id": result.get("id", f"direct_tool_{time.time()}"), - "content": result.get("content", ""), - "tool_name": tool_name, - "timestamp": time.time(), - } - logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}") - return tool_info - - except Exception as e: - logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}") - - return None - - # === 缓存方法 === - - def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: - content = f"{target_message}_{chat_history}_{sender}" - return hashlib.md5(content.encode()).hexdigest() - - def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]: - if not self.enable_cache or cache_key not in self.tool_cache: - return None - cache_item = self.tool_cache[cache_key] - if cache_item["ttl"] <= 0: - del self.tool_cache[cache_key] - return None - cache_item["ttl"] -= 1 - return cache_item["result"] - - def _set_cache(self, cache_key: str, result: List[Dict]): - if not self.enable_cache: - return - self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()} - - def _cleanup_expired_cache(self): - if not self.enable_cache: - return - expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0] - for key in expired: - del self.tool_cache[key] - - def clear_cache(self): - if self.enable_cache: - self.tool_cache.clear() - - def get_cache_status(self) -> Dict: - if not self.enable_cache: - return {"enabled": False, "cache_count": 0} - self._cleanup_expired_cache() - ttl_distribution: Dict[int, int] = {} - for item in self.tool_cache.values(): - ttl = item["ttl"] - ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1 - return { - "enabled": True, - "cache_count": len(self.tool_cache), - "cache_ttl": self.cache_ttl, - "ttl_distribution": ttl_distribution, - } - - def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): - if enable_cache is not None: - self.enable_cache = enable_cache - if cache_ttl > 0: - self.cache_ttl = cache_ttl diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 51e5e643..25add4bf 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -12,7 +12,7 @@ from sqlmodel import col, select from src.common.logger import get_logger from src.common.database.database import get_db_session -from src.common.database.database_model import OnlineTime, ModelUsage, Messages, ActionRecord +from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage from src.config.config import global_config @@ -648,7 +648,7 @@ class StatisticOutputTask(AsyncTask): def _collect_message_count_for_period( self, collect_period: list[tuple[str, datetime]], - ) -> StatPeriodMapping: + ) -> dict[str, dict[str, object]]: """ 收集指定时间段的消息统计数据 @@ -659,8 +659,13 @@ class StatisticOutputTask(AsyncTask): collect_period.sort(key=lambda x: x[1], reverse=True) - stats: StatPeriodMapping = { - period_key: StatisticOutputTask._build_stat_period_data() for period_key, _ in collect_period + stats: dict[str, dict[str, object]] = { + period_key: { + TOTAL_MSG_CNT: 0, + MSG_CNT_BY_CHAT: defaultdict(int), + TOTAL_REPLY_CNT: 0, + } + for period_key, _ in collect_period } query_start_timestamp = collect_period[-1][1] @@ -710,24 +715,24 @@ class StatisticOutputTask(AsyncTask): StatisticOutputTask._add_defaultdict_int(stats[period_key], MSG_CNT_BY_CHAT, chat_id, 1) break - # 使用 ActionRecords 中的 reply 动作次数作为回复数基准 + # 使用 ToolRecord 中的 reply 工具次数作为回复数基准 try: - action_query_start_timestamp = collect_period[-1][1] + tool_query_start_timestamp = collect_period[-1][1] with get_db_session(auto_commit=False) as session: - statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp) - actions = session.exec(statement).all() - for action in actions: - if action.action_name != "reply": + statement = select(ToolRecord).where(col(ToolRecord.timestamp) >= tool_query_start_timestamp) + tool_records = session.exec(statement).all() + for tool_record in tool_records: + if tool_record.tool_name != "reply": continue - action_time_ts = action.timestamp.timestamp() + action_time_ts = tool_record.timestamp.timestamp() for idx, (_, period_start_dt) in enumerate(collect_period): if action_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: StatisticOutputTask._add_int_stat(stats[period_key], TOTAL_REPLY_CNT, 1) break except Exception as e: - logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}") + logger.warning(f"统计 reply 工具次数失败,将回复数视为 0,错误信息:{e}") return stats diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 07cec0b4..aa14e790 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,8 +13,8 @@ import jieba from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient from src.person_info.person_info import Person from .typo_generator import ChineseTypoGenerator @@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" - # 每次都创建新的LLMRequest实例以避免事件循环冲突 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) + # 每次都创建新的服务层实例以避免事件循环冲突 + llm = LLMServiceClient(task_name="embedding", request_type=request_type) try: - embedding, _ = await llm.get_embedding(text) + embedding_result = await llm.embed_text(text) + embedding = embedding_result.embedding except Exception as e: logger.error(f"获取embedding失败: {str(e)}") embedding = None diff --git a/src/cli/__init__.py b/src/cli/__init__.py new file mode 100644 index 00000000..28eb1a3b --- /dev/null +++ b/src/cli/__init__.py @@ -0,0 +1,3 @@ +""" +CLI startup and interaction package. +""" diff --git a/src/cli/console.py b/src/cli/console.py new file mode 100644 index 00000000..3dbfde36 --- /dev/null +++ b/src/cli/console.py @@ -0,0 +1,17 @@ +"""MaiSaka terminal console helpers.""" + +from rich.console import Console +from rich.theme import Theme + +custom_theme = Theme( + { + "info": "cyan", + "success": "green", + "warning": "yellow", + "error": "bold red", + "muted": "dim", + "accent": "bold magenta", + } +) + +console = Console(theme=custom_theme) diff --git a/src/maisaka/input_reader.py b/src/cli/input_reader.py similarity index 93% rename from src/maisaka/input_reader.py rename to src/cli/input_reader.py index eff2525c..f1ac6b44 100644 --- a/src/maisaka/input_reader.py +++ b/src/cli/input_reader.py @@ -1,12 +1,12 @@ """ -MaiSaka - 异步输入读取器 -将阻塞的标准输入读取放到后台线程中,供 asyncio 循环安全消费。 +MaiSaka asynchronous stdin reader for CLI interaction. """ +from typing import Optional + import asyncio import sys import threading -from typing import Optional class InputReader: diff --git a/src/cli/maisaka_cli.py b/src/cli/maisaka_cli.py new file mode 100644 index 00000000..1174ea67 --- /dev/null +++ b/src/cli/maisaka_cli.py @@ -0,0 +1,463 @@ +""" +MaiSaka CLI and conversation loop. +""" + +from datetime import datetime +from typing import Optional + +import asyncio +import os +import time + +from rich import box +from rich.markdown import Markdown +from rich.panel import Panel +from rich.text import Text + +from src.know_u.knowledge import KnowledgeLearner, retrieve_relevant_knowledge +from src.know_u.knowledge_store import get_knowledge_store +from src.chat.message_receive.message import SessionMessage +from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator +from src.config.config import config_manager, global_config +from src.mcp_module import MCPManager +from src.mcp_module.host_llm_bridge import MCPHostLLMBridge + +from src.maisaka.chat_loop_service import MaisakaChatLoopService +from src.maisaka.context_messages import ( + AssistantMessage, + LLMContextMessage, + SessionBackedMessage, + ToolResultMessage, +) +from src.maisaka.message_adapter import format_speaker_content +from src.maisaka.tool_handlers import ( + ToolHandlerContext, + handle_mcp_tool, + handle_stop, + handle_unknown_tool, + handle_wait, +) + +from .console import console +from .input_reader import InputReader + + +class BufferCLI: + """Maisaka 命令行交互入口。""" + + def __init__(self) -> None: + self._chat_loop_service: Optional[MaisakaChatLoopService] = None + self._reply_generator = MaisakaReplyGenerator() + self._reader = InputReader() + self._chat_history: Optional[list[LLMContextMessage]] = None + self._knowledge_store = get_knowledge_store() + self._knowledge_learner = KnowledgeLearner("maisaka_cli") + self._knowledge_min_messages_for_extraction = 10 + self._knowledge_min_extraction_interval = 30 + self._last_knowledge_extraction_time = 0.0 + + knowledge_stats = self._knowledge_store.get_stats() + if knowledge_stats["total_items"] > 0: + console.print(f"[success]知识库中已有 {knowledge_stats['total_items']} 条数据[/success]") + else: + console.print("[muted]知识库已初始化,当前没有数据[/muted]") + + self._chat_start_time: Optional[datetime] = None + self._last_user_input_time: Optional[datetime] = None + self._last_assistant_response_time: Optional[datetime] = None + self._user_input_times: list[datetime] = [] + self._mcp_manager: Optional[MCPManager] = None + self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None + self._init_llm() + + def _init_llm(self) -> None: + """初始化 Maisaka 使用的聊天服务。""" + thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower() + enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None + + _ = enable_thinking + self._chat_loop_service = MaisakaChatLoopService() + + model_name = self._get_current_model_name() + console.print(f"[success]大模型服务已初始化[/success] [muted](模型: {model_name})[/muted]") + + @staticmethod + def _get_current_model_name() -> str: + """读取当前 planner 模型名。""" + try: + model_task_config = config_manager.get_model_config().model_task_config + if model_task_config.planner.model_list: + return model_task_config.planner.model_list[0] + except Exception: + pass + return "未配置" + + def _build_tool_context(self) -> ToolHandlerContext: + """构建工具处理的共享上下文。""" + tool_context = ToolHandlerContext( + reader=self._reader, + user_input_times=self._user_input_times, + ) + tool_context.last_user_input_time = self._last_user_input_time + return tool_context + + def _show_banner(self) -> None: + """渲染启动横幅。""" + banner = Text() + banner.append("MaiSaka", style="bold cyan") + banner.append(" v2.0\n", style="muted") + banner.append("输入内容开始对话 | Ctrl+C 退出", style="muted") + + console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2))) + console.print() + + async def _start_chat(self, user_text: str) -> None: + """追加用户输入并继续内部循环。""" + if self._chat_loop_service is None: + console.print("[warning]大模型服务尚未初始化,已跳过本次对话。[/warning]") + return + + now = datetime.now() + self._last_user_input_time = now + self._user_input_times.append(now) + + if self._chat_history is None: + self._chat_start_time = now + self._last_assistant_response_time = None + self._chat_history = self._chat_loop_service.build_chat_context(user_text) + self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)]) + else: + self._chat_history.append( + self._build_cli_context_message( + user_text=user_text, + timestamp=now, + source_kind="user", + ) + ) + self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)]) + + await self._run_llm_loop(self._chat_history) + + @staticmethod + def _build_cli_context_message( + user_text: str, + timestamp: datetime, + source_kind: str = "user", + speaker_name: Optional[str] = None, + ) -> SessionBackedMessage: + """为 CLI 构造新的上下文消息。""" + resolved_speaker_name = speaker_name or global_config.maisaka.user_name.strip() or "用户" + visible_text = format_speaker_content( + resolved_speaker_name, + user_text, + timestamp, + ) + planner_prefix = ( + f"[时间]{timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{resolved_speaker_name}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" + ) + from src.common.data_models.message_component_data_model import MessageSequence, TextComponent + + return SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]), + visible_text=visible_text, + timestamp=timestamp, + source_kind=source_kind, + ) + + @staticmethod + def _build_cli_session_message(user_text: str, timestamp: datetime) -> SessionMessage: + """为 CLI 的知识学习构造兼容 SessionMessage。""" + from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo + from src.common.data_models.message_component_data_model import MessageSequence + + message = SessionMessage(message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}", timestamp=timestamp, platform="maisaka") + message.message_info = MessageInfo( + user_info=UserInfo( + user_id="maisaka_user", + user_nickname=global_config.maisaka.user_name.strip() or "用户", + user_cardname=None, + ), + group_info=None, + additional_config={}, + ) + message.session_id = "maisaka_cli" + message.raw_message = MessageSequence([]) + visible_text = format_speaker_content( + global_config.maisaka.user_name.strip() or "用户", + user_text, + timestamp, + ) + message.raw_message.text(visible_text) + message.processed_plain_text = visible_text + message.display_message = visible_text + message.initialized = True + return message + + def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None: + """在 CLI 会话中按批次触发 knowledge 学习。""" + if not global_config.maisaka.enable_knowledge_module: + return + + self._knowledge_learner.add_messages(messages) + + elapsed = time.monotonic() - self._last_knowledge_extraction_time + if elapsed < self._knowledge_min_extraction_interval: + return + + cache_size = self._knowledge_learner.get_cache_size() + if cache_size < self._knowledge_min_messages_for_extraction: + return + + self._last_knowledge_extraction_time = time.monotonic() + asyncio.create_task(self._run_knowledge_learning()) + + async def _run_knowledge_learning(self) -> None: + """后台执行 knowledge 学习,避免阻塞主对话。""" + try: + added_count = await self._knowledge_learner.learn() + if added_count > 0 and global_config.maisaka.show_thinking: + console.print(f"[muted]知识学习已完成,新增 {added_count} 条数据。[/muted]") + except Exception as exc: + console.print(f"[warning]知识学习失败:{exc}[/warning]") + + async def _run_llm_loop(self, chat_history: list[LLMContextMessage]) -> None: + """ + Main inner loop for the Maisaka planner. + + Each round may produce internal thoughts and optionally call tools: + - reply(msg_id): generate a visible reply for the current round + - no_reply(): skip visible output and continue the loop + - wait(seconds): wait for new user input + - stop(): stop the current inner loop and return to idle + """ + if self._chat_loop_service is None: + return + + consecutive_errors = 0 + last_had_tool_calls = True + + while True: + if last_had_tool_calls: + tasks = [] + status_text_parts = [] + + if global_config.maisaka.enable_knowledge_module: + tasks.append(("knowledge", retrieve_relevant_knowledge(self._chat_loop_service, chat_history))) + status_text_parts.append("知识库") + + with console.status( + f"[info]{' + '.join(status_text_parts)} 分析中...[/info]", + spinner="dots", + ): + results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True) + + knowledge_analysis = "" + if global_config.maisaka.enable_knowledge_module: + knowledge_result = results[0] if results else None + if isinstance(knowledge_result, Exception): + console.print(f"[warning]知识分析失败:{knowledge_result}[/warning]") + elif isinstance(knowledge_result, str) and knowledge_result.strip(): + knowledge_analysis = knowledge_result + if global_config.maisaka.show_thinking: + console.print( + Panel( + Markdown(knowledge_analysis), + title="知识", + border_style="bright_magenta", + padding=(0, 1), + style="dim", + ) + ) + + if chat_history and isinstance(chat_history[-1], AssistantMessage) and chat_history[-1].source == "perception": + chat_history.pop() + + perception_parts = [] + if knowledge_analysis: + perception_parts.append(f"知识库\n{knowledge_analysis}") + + if perception_parts: + chat_history.append( + AssistantMessage( + content="\n\n".join(perception_parts), + timestamp=datetime.now(), + source_kind="perception", + ) + ) + elif global_config.maisaka.show_thinking: + console.print("[muted]上一轮没有使用工具,本轮跳过模块分析。[/muted]") + + with console.status("[info]正在思考...[/info]", spinner="dots"): + try: + response = await self._chat_loop_service.chat_loop_step(chat_history) + consecutive_errors = 0 + except Exception as exc: + consecutive_errors += 1 + console.print(f"[error]大模型调用失败:{exc}[/error]") + if consecutive_errors >= 3: + console.print("[error]连续失败次数过多,结束对话。[/error]\n") + break + continue + + chat_history.append(response.raw_message) + self._last_assistant_response_time = datetime.now() + + if global_config.maisaka.show_thinking and response.content: + console.print( + Panel( + Markdown(response.content), + title="思考", + border_style="dim", + padding=(1, 2), + style="dim", + ) + ) + + if response.content and not response.tool_calls: + last_had_tool_calls = False + continue + + if not response.tool_calls: + last_had_tool_calls = False + continue + + should_stop = False + tool_context = self._build_tool_context() + + for tool_call in response.tool_calls: + if tool_call.func_name == "stop": + await handle_stop(tool_call, chat_history) + should_stop = True + + elif tool_call.func_name == "reply": + reply = await self._generate_visible_reply(chat_history, response.content or "") + chat_history.append( + ToolResultMessage( + content="已生成并记录可见回复。", + timestamp=datetime.now(), + tool_call_id=tool_call.call_id, + tool_name=tool_call.func_name, + ) + ) + chat_history.append( + self._build_cli_context_message( + user_text=reply, + timestamp=datetime.now(), + source_kind="guided_reply", + speaker_name=global_config.bot.nickname.strip() or "MaiSaka", + ) + ) + + elif tool_call.func_name == "no_reply": + if global_config.maisaka.show_thinking: + console.print("[muted]本轮未发送可见回复。[/muted]") + chat_history.append( + ToolResultMessage( + content="本轮未发送可见回复。", + timestamp=datetime.now(), + tool_call_id=tool_call.call_id, + tool_name=tool_call.func_name, + ) + ) + + elif tool_call.func_name == "wait": + tool_result = await handle_wait(tool_call, chat_history, tool_context) + if tool_context.last_user_input_time != self._last_user_input_time: + self._last_user_input_time = tool_context.last_user_input_time + if tool_result.startswith("[[QUIT]]"): + should_stop = True + + elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tool_call.func_name): + await handle_mcp_tool(tool_call, chat_history, self._mcp_manager) + + else: + await handle_unknown_tool(tool_call, chat_history) + + if should_stop: + console.print("[muted]对话已暂停,等待新的输入...[/muted]\n") + break + + last_had_tool_calls = True + + async def _init_mcp(self) -> None: + """初始化 MCP 服务并注册暴露的工具。""" + self._mcp_host_bridge = MCPHostLLMBridge( + sampling_task_name=global_config.mcp.client.sampling.task_name, + ) + self._mcp_manager = await MCPManager.from_app_config( + global_config.mcp, + host_callbacks=self._mcp_host_bridge.build_callbacks(), + ) + + if self._mcp_manager and self._chat_loop_service: + mcp_tools = self._mcp_manager.get_openai_tools() + if mcp_tools: + self._chat_loop_service.set_extra_tools(mcp_tools) + summary = self._mcp_manager.get_feature_summary() + console.print( + Panel( + f"已加载 {len(mcp_tools)} 个 MCP 工具。\n{summary}", + title="MCP 能力", + border_style="green", + padding=(0, 1), + ) + ) + + async def _generate_visible_reply(self, chat_history: list[LLMContextMessage], latest_thought: str) -> str: + """根据最新思考生成并输出可见回复。""" + if not latest_thought: + return "" + + with console.status("[info]正在生成可见回复...[/info]", spinner="dots"): + success, result = await self._reply_generator.generate_reply_with_context( + reply_reason=latest_thought, + chat_history=chat_history, + ) + if success and result.text_fragments: + reply = result.text_fragments[0] + else: + reply = "..." + + console.print( + Panel( + Markdown(reply), + title="MaiSaka", + border_style="magenta", + padding=(1, 2), + ) + ) + + return reply + + async def run(self) -> None: + """主交互循环。""" + if global_config.mcp.enable: + await self._init_mcp() + else: + console.print("[muted]MCP 已禁用(mcp.enable=false)[/muted]") + + self._reader.start(asyncio.get_event_loop()) + self._show_banner() + + try: + while True: + console.print("[bold cyan]> [/bold cyan]", end="") + raw_input = await self._reader.get_line() + + if raw_input is None: + console.print("\n[muted]再见![/muted]") + break + + raw_input = raw_input.strip() + if not raw_input: + continue + + await self._start_chat(raw_input) + finally: + if self._mcp_manager: + await self._mcp_manager.close() + self._mcp_host_bridge = None diff --git a/src/common/data_models/llm_service_data_models.py b/src/common/data_models/llm_service_data_models.py new file mode 100644 index 00000000..15b530ca --- /dev/null +++ b/src/common/data_models/llm_service_data_models.py @@ -0,0 +1,187 @@ +"""LLM 服务层与编排层共享数据模型。 + +该模块集中定义 LLM 服务层与底层编排器共同使用的请求、选项与结果对象, +用于替代散落在各层之间的复杂元组返回值。 +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeAlias + +import asyncio + +from src.common.data_models import BaseDataModel +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput + +if TYPE_CHECKING: + from src.llm_models.model_client.base_client import BaseClient + from src.llm_models.payload_content.message import Message + + +PromptMessage: TypeAlias = Dict[str, Any] +"""统一的原始提示消息结构。""" + +PromptInput: TypeAlias = str | List[PromptMessage] +"""统一的提示输入类型。""" + +MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]] +"""统一的消息工厂类型。""" + + +@dataclass(slots=True) +class LLMServiceRequest(BaseDataModel): + """LLM 服务层统一请求对象。""" + + task_name: str + request_type: str + prompt: PromptInput | None = None + message_factory: MessageFactory | None = None + tool_options: List[ToolDefinitionInput] | None = None + temperature: float | None = None + max_tokens: int | None = None + response_format: RespFormat | None = None + interrupt_flag: asyncio.Event | None = None + + def __post_init__(self) -> None: + """校验请求对象的必要字段。 + + Raises: + ValueError: 当 `task_name` 为空,或 `prompt` 与 `message_factory` + 的组合非法时抛出。 + """ + self.task_name = self.task_name.strip() + if not self.task_name: + raise ValueError("`task_name` 不能为空") + has_prompt = self.prompt is not None + has_message_factory = self.message_factory is not None + if has_prompt == has_message_factory: + raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个") + + +@dataclass(slots=True) +class LLMResponseResult(BaseDataModel): + """单次 LLM 响应结果。""" + + response: str = field(default_factory=str) + reasoning: str = field(default_factory=str) + model_name: str = field(default_factory=str) + tool_calls: List[ToolCall] | None = None + + +@dataclass(slots=True) +class LLMServiceResult(BaseDataModel): + """LLM 服务层统一响应对象。""" + + success: bool = False + completion: LLMResponseResult = field(default_factory=LLMResponseResult) + error: str | None = None + + @classmethod + def from_response_result(cls, completion: LLMResponseResult) -> "LLMServiceResult": + """从单次 LLM 响应结果构建服务响应。 + + Args: + completion: 单次 LLM 响应结果。 + + Returns: + LLMServiceResult: 标记为成功的服务响应对象。 + """ + return cls( + success=True, + completion=completion, + error=None, + ) + + @classmethod + def from_error(cls, error_message: str, error_detail: str | None = None) -> "LLMServiceResult": + """构建失败的服务响应对象。 + + Args: + error_message: 对上层展示的错误消息。 + error_detail: 底层错误详情。 + + Returns: + LLMServiceResult: 标记为失败的服务响应对象。 + """ + return cls( + success=False, + completion=LLMResponseResult(response=error_message), + error=error_detail or error_message, + ) + + def to_capability_payload(self) -> Dict[str, Any]: + """转换为插件能力层可直接返回的结构。 + + Returns: + Dict[str, Any]: 标准化后的能力返回值。 + """ + payload: Dict[str, Any] = { + "success": self.success, + "response": self.completion.response, + "reasoning": self.completion.reasoning, + "model_name": self.completion.model_name, + } + if self.completion.tool_calls is not None: + payload["tool_calls"] = [ + { + "id": tool_call.call_id, + "function": { + "name": tool_call.func_name, + "arguments": tool_call.args or {}, + }, + } + for tool_call in self.completion.tool_calls + ] + if self.error: + payload["error"] = self.error + return payload + + +@dataclass(slots=True) +class LLMGenerationOptions(BaseDataModel): + """LLM 文本生成选项。""" + + temperature: float | None = None + max_tokens: int | None = None + tool_options: List[ToolDefinitionInput] | None = None + response_format: RespFormat | None = None + interrupt_flag: asyncio.Event | None = None + raise_when_empty: bool = True + + +@dataclass(slots=True) +class LLMImageOptions(BaseDataModel): + """LLM 图像理解选项。""" + + temperature: float | None = None + max_tokens: int | None = None + interrupt_flag: asyncio.Event | None = None + + +@dataclass(slots=True) +class LLMAudioTranscriptionResult(BaseDataModel): + """LLM 音频转写结果。""" + + text: str | None = None + + +@dataclass(slots=True) +class LLMEmbeddingResult(BaseDataModel): + """LLM 向量生成结果。""" + + embedding: List[float] = field(default_factory=list) + model_name: str = field(default_factory=str) + + +__all__ = [ + "LLMAudioTranscriptionResult", + "LLMEmbeddingResult", + "LLMGenerationOptions", + "LLMImageOptions", + "LLMResponseResult", + "LLMServiceRequest", + "LLMServiceResult", + "MessageFactory", + "PromptInput", + "PromptMessage", +] diff --git a/src/common/data_models/mai_message_data_model.py b/src/common/data_models/mai_message_data_model.py index 4396201a..814f642b 100644 --- a/src/common/data_models/mai_message_data_model.py +++ b/src/common/data_models/mai_message_data_model.py @@ -1,15 +1,17 @@ +import json from dataclasses import dataclass, field -from maim_message import ( - MessageBase, - UserInfo as MaimUserInfo, - GroupInfo as MaimGroupInfo, - BaseMessageInfo as MaimBaseMessageInfo, - Seg, -) +from datetime import datetime from typing import Optional -import json -from datetime import datetime +from maim_message import ( + BaseMessageInfo as MaimBaseMessageInfo, + GroupInfo as MaimGroupInfo, + MessageBase, + ReceiverInfo as MaimReceiverInfo, + Seg, + SenderInfo as MaimSenderInfo, + UserInfo as MaimUserInfo, +) from src.common.database.database_model import Messages from src.common.data_models.message_component_data_model import MessageSequence @@ -41,34 +43,24 @@ class MessageInfo: class MaiMessage(BaseDatabaseDataModel[Messages]): def __init__(self, message_id: str, timestamp: datetime, platform: str): self.message_id: str = message_id - self.timestamp: datetime = timestamp # 时间戳 - self.initialized = False # 用于标记是否已初始化其他属性 + self.timestamp: datetime = timestamp + self.initialized = False self.platform: str = platform - # 定义其他属性 - self.message_info: MessageInfo # 初始化后赋值 + self.message_info: MessageInfo self.is_mentioned: bool = False - """机器人被提及标记,若被at,则提及也被标记""" self.is_at: bool = False - """机器人被at标记""" self.is_emoji: bool = False - """消息为纯表情包,在计算打字时长时候会被特殊处理""" self.is_picture: bool = False - """消息为纯图片,在计算打字时长时候会被特殊处理""" self.is_command: bool = False - """消息为命令消息,打字时长必定为0""" self.is_notify: bool = False - """消息为通知消息""" self.session_id: str self.reply_to: Optional[str] = None self.processed_plain_text: Optional[str] = None - """处理过后的纯文本内容""" self.display_message: Optional[str] = None - """最后显示给大模型的消息内容""" self.raw_message: MessageSequence - """原始消息数据""" @classmethod def from_db_instance(cls, db_record: "Messages"): @@ -79,12 +71,12 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): group_info = GroupInfo(db_record.group_id, db_record.group_name) else: group_info = None + obj.message_info = MessageInfo( user_info=user_info, group_info=group_info, additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {}, ) - obj.is_mentioned = db_record.is_mentioned obj.is_at = db_record.is_at obj.is_emoji = db_record.is_emoji @@ -127,18 +119,22 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): @classmethod def from_maim_message(cls, message: MessageBase): - """从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息""" + """从 maim_message.MessageBase 创建 MaiMessage。""" msg_info = message.message_info assert msg_info, "MessageBase 的 message_info 不能为空" + platform = msg_info.platform assert isinstance(platform, str) + msg_id = str(msg_info.message_id) timestamp = msg_info.time assert isinstance(msg_id, str) assert msg_id assert timestamp + obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp), platform=platform) obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message) + usr_info = msg_info.user_info assert usr_info assert isinstance(usr_info.user_id, str) @@ -148,40 +144,69 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): user_nickname=usr_info.user_nickname, user_cardname=usr_info.user_cardname, ) - if grp_info := msg_info.group_info: + + if msg_info.group_info: + grp_info = msg_info.group_info assert isinstance(grp_info.group_id, str) assert isinstance(grp_info.group_name, str) group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name) else: group_info = None + add_cfg = msg_info.additional_config or {} obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg) return obj async def to_maim_message(self) -> MessageBase: - """ - 从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息 - """ - maim_user_info = MaimUserInfo( + """将 MaiMessage 转换为 maim_message.MessageBase。""" + sender_user_info = MaimUserInfo( user_id=self.message_info.user_info.user_id, user_nickname=self.message_info.user_info.user_nickname, user_cardname=self.message_info.user_info.user_cardname, platform=self.platform, ) - maim_group_info = None + + sender_group_info = None if self.message_info.group_info: - maim_group_info = MaimGroupInfo( + sender_group_info = MaimGroupInfo( group_id=self.message_info.group_info.group_id, group_name=self.message_info.group_info.group_name, platform=self.platform, ) + + sender_info = MaimSenderInfo( + group_info=sender_group_info, + user_info=sender_user_info, + ) + + receiver_group_info = sender_group_info + receiver_user_info = None + additional_config = self.message_info.additional_config or {} + target_user_id = str(additional_config.get("platform_io_target_user_id") or "").strip() + if receiver_group_info is None and target_user_id: + receiver_user_info = MaimUserInfo( + user_id=target_user_id, + user_nickname=None, + user_cardname=None, + platform=self.platform, + ) + + receiver_info = None + if receiver_group_info or receiver_user_info: + receiver_info = MaimReceiverInfo( + group_info=receiver_group_info, + user_info=receiver_user_info, + ) + maim_msg_info = MaimBaseMessageInfo( platform=self.platform, message_id=self.message_id, time=self.timestamp.timestamp(), - group_info=maim_group_info, - user_info=maim_user_info, + group_info=receiver_group_info, + user_info=sender_user_info, additional_config=self.message_info.additional_config, + sender_info=sender_info, + receiver_info=receiver_info, ) msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message) return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments)) diff --git a/src/common/data_models/message_component_data_model.py b/src/common/data_models/message_component_data_model.py index 995e54ce..d766cfcf 100644 --- a/src/common/data_models/message_component_data_model.py +++ b/src/common/data_models/message_component_data_model.py @@ -348,17 +348,11 @@ class MessageSequence: if isinstance(item, TextComponent): return {"type": "text", "data": item.text} elif isinstance(item, ImageComponent): - if not item.content: - raise RuntimeError("ImageComponent content 未初始化") - return {"type": "image", "data": item.content, "hash": item.binary_hash} + return {"type": "image", "data": self._ensure_binary_component_content(item, "[图片]"), "hash": item.binary_hash} elif isinstance(item, EmojiComponent): - if not item.content: - raise RuntimeError("EmojiComponent content 未初始化") - return {"type": "emoji", "data": item.content, "hash": item.binary_hash} + return {"type": "emoji", "data": self._ensure_binary_component_content(item, "[表情包]"), "hash": item.binary_hash} elif isinstance(item, VoiceComponent): - if not item.content: - raise RuntimeError("VoiceComponent content 未初始化") - return {"type": "voice", "data": item.content, "hash": item.binary_hash} + return {"type": "voice", "data": self._ensure_binary_component_content(item, "[语音消息]"), "hash": item.binary_hash} elif isinstance(item, AtComponent): return { "type": "at", @@ -388,6 +382,14 @@ class MessageSequence: logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent") return {"type": "dict", "data": item.data} + @staticmethod + def _ensure_binary_component_content(item: ByteComponent, fallback_text: str) -> str: + """确保二进制组件在序列化时带有稳定的文本占位。""" + if item.content: + return item.content + item.content = fallback_text + return item.content + @classmethod def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents: """内部方法:将单个消息组件的字典格式转换回组件对象""" diff --git a/src/common/data_models/tool_record_data_model.py b/src/common/data_models/tool_record_data_model.py new file mode 100644 index 00000000..90b594d5 --- /dev/null +++ b/src/common/data_models/tool_record_data_model.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import Dict, Optional + +import json + +from src.common.database.database_model import ToolRecord + +from . import BaseDatabaseDataModel + + +class MaiToolRecord(BaseDatabaseDataModel[ToolRecord]): + """工具调用记录数据模型。""" + + def __init__( + self, + tool_id: str, + timestamp: datetime, + session_id: str, + tool_name: str, + tool_reasoning: Optional[str] = None, + tool_data: Optional[Dict] = None, + tool_builtin_prompt: Optional[str] = None, + tool_display_prompt: Optional[str] = None, + ): + self.tool_id = tool_id + self.timestamp = timestamp + self.session_id = session_id + self.tool_name = tool_name + self.tool_reasoning = tool_reasoning + self.tool_data = tool_data or {} + self.tool_builtin_prompt = tool_builtin_prompt + self.tool_display_prompt = tool_display_prompt + + @classmethod + def from_db_instance(cls, db_record: ToolRecord): + """从数据库实例创建数据模型对象。""" + return cls( + tool_id=db_record.tool_id, + timestamp=db_record.timestamp, + session_id=db_record.session_id, + tool_name=db_record.tool_name, + tool_reasoning=db_record.tool_reasoning, + tool_data=json.loads(db_record.tool_data) if db_record.tool_data else None, + tool_builtin_prompt=db_record.tool_builtin_prompt, + tool_display_prompt=db_record.tool_display_prompt, + ) + + def to_db_instance(self): + """将数据模型对象转换为数据库实例。""" + return ToolRecord( + tool_id=self.tool_id, + timestamp=self.timestamp, + session_id=self.session_id, + tool_name=self.tool_name, + tool_reasoning=self.tool_reasoning, + tool_data=json.dumps(self.tool_data) if self.tool_data else None, + tool_builtin_prompt=self.tool_builtin_prompt, + tool_display_prompt=self.tool_display_prompt, + ) diff --git a/src/common/database/database.py b/src/common/database/database.py index e88be9ec..2b22475a 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,18 +1,23 @@ -from rich.traceback import install from contextlib import contextmanager from pathlib import Path -from typing import Generator, TYPE_CHECKING +from typing import ContextManager, Generator, TYPE_CHECKING -from sqlalchemy import event +from rich.traceback import install +from sqlalchemy import event, text from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, Session, create_engine +from src.common.database.migrations import create_database_migration_bootstrapper +from src.common.logger import get_logger + if TYPE_CHECKING: from sqlite3 import Connection as SQLite3Connection install(extra_lines=3) +logger = get_logger("database") + # 定义数据库文件路径 ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve() @@ -53,18 +58,70 @@ SessionLocal = sessionmaker( bind=engine, class_=Session, ) +_migration_bootstrapper = create_database_migration_bootstrapper(engine) _db_initialized = False +def _migrate_action_records_to_tool_records() -> None: + """将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。""" + migration_sql = text( + """ + INSERT INTO tool_records ( + tool_id, + timestamp, + session_id, + tool_name, + tool_reasoning, + tool_data, + tool_builtin_prompt, + tool_display_prompt + ) + SELECT + action_id, + timestamp, + session_id, + action_name, + action_reasoning, + action_data, + action_builtin_prompt, + action_display_prompt + FROM action_records + WHERE NOT EXISTS ( + SELECT 1 + FROM tool_records + WHERE tool_records.tool_id = action_records.action_id + ) + """ + ) + with engine.begin() as connection: + connection.execute(migration_sql) + + def initialize_database() -> None: + """初始化数据库连接、结构与启动期迁移。 + + 当前初始化流程遵循以下顺序: + 1. 确保数据库目录存在; + 2. 加载 SQLModel 模型定义; + 3. 执行已注册的启动期迁移; + 4. 兜底执行 ``create_all`` 确保当前模型定义已建表; + 5. 执行项目现有的轻量数据补迁移逻辑。 + """ global _db_initialized if _db_initialized: return _DB_DIR.mkdir(parents=True, exist_ok=True) import src.common.database.database_model # noqa: F401 + migration_state = _migration_bootstrapper.prepare_database() + logger.info( + "数据库迁移准备完成," + f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}" + ) SQLModel.metadata.create_all(engine) + _migrate_action_records_to_tool_records() + _migration_bootstrapper.finalize_database(migration_state) _db_initialized = True @@ -114,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: session.close() -def get_db_session_manual(): - """获取数据库会话的上下文管理器 (手动提交模式)。""" +def get_db_session_manual() -> ContextManager[Session]: + """获取数据库会话的上下文管理器 (手动提交模式)。 + + Returns: + ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。 + """ return get_db_session(auto_commit=False) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 5b274c43..9874af67 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float +from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float, Text from sqlmodel import Field, LargeBinary, SQLModel @@ -17,8 +17,8 @@ class ImageType(str, Enum): class ModifiedBy(str, Enum): - AI = "ai" - USER = "user" + AI = "AI" + USER = "USER" class Messages(SQLModel, table=True): @@ -134,6 +134,27 @@ class ActionRecord(SQLModel, table=True): action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容 +class ToolRecord(SQLModel, table=True): + """存储工具调用记录""" + + __tablename__ = "tool_records" # type: ignore + + id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 + + # 元信息 + tool_id: str = Field(index=True, max_length=255) # 工具调用ID + timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳 + session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id + + # 调用信息 + tool_name: str = Field(index=True, max_length=255) # 工具名称 + tool_reasoning: Optional[str] = Field(default=None) # 工具调用推理过程 + tool_data: Optional[str] = Field(default=None) # 工具数据,JSON格式存储 + + tool_builtin_prompt: Optional[str] = Field(default=None) # 内置工具提示 + tool_display_prompt: Optional[str] = Field(default=None) # 最终输入到 Prompt 的内容 + + class CommandRecord(SQLModel, table=True): """记录命令执行情况""" @@ -202,18 +223,40 @@ class Jargon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 content: str = Field(index=True, max_length=255) # 黑话内容 - raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str] + raw_content: Optional[str] = Field( + default=None, sa_column=Column(Text, nullable=True) + ) # 原始内容,未处理的黑话内容,为List[str] - meaning: str # 黑话含义 - session_id_dict: str = Field(default=r"{}") # 会话ID列表,格式为{"session_id": session_count, ...} + meaning: str = Field(sa_column=Column(Text, nullable=False)) # 黑话含义 + session_id_dict: str = Field( + default=r"{}", sa_column=Column(Text, nullable=False) + ) # 会话ID列表,格式为{"session_id": session_count, ...} count: int = Field(default=0) # 使用次数 is_jargon: Optional[bool] = Field(default=True) # 是否为黑话,False表示为白话 is_complete: bool = Field(default=False) # 是否为已经完成全部推断(count > 100后不再推断) is_global: bool = Field(default=False) # 是否为全局黑话(独立于session_id_dict) last_inference_count: int = Field(default=0) # 上一次进行推断时的count值,用于判断是否需要重新推断 - inference_with_context: Optional[str] = Field(default=None, nullable=True) # 带上下文的推断结果,JSON格式 - inference_with_content_only: Optional[str] = Field(default=None, nullable=True) # 只基于词条的推断结果,JSON格式 + inference_with_context: Optional[str] = Field( + default=None, sa_column=Column(Text, nullable=True) + ) # 带上下文的推断结果,JSON格式 + inference_with_content_only: Optional[str] = Field( + default=None, sa_column=Column(Text, nullable=True) + ) # 只基于词条的推断结果,JSON格式 + + +class MaiKnowledge(SQLModel, table=True): + """存储 Maisaka 的用户画像知识。""" + + __tablename__ = "mai_knowledge" # type: ignore + + id: Optional[int] = Field(default=None, primary_key=True) + knowledge_id: str = Field(index=True, max_length=255) + category_id: str = Field(index=True, max_length=32) + content: str + normalized_content: str = Field(index=True) + metadata_json: Optional[str] = Field(default=None, nullable=True) + created_at: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) class ChatHistory(SQLModel, table=True): diff --git a/src/common/database/migrations/__init__.py b/src/common/database/migrations/__init__.py new file mode 100644 index 00000000..e9a69bd1 --- /dev/null +++ b/src/common/database/migrations/__init__.py @@ -0,0 +1,79 @@ +"""数据库迁移基础设施导出模块。""" + +from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper +from .builtin import ( + EMPTY_SCHEMA_VERSION, + LATEST_SCHEMA_VERSION, + LEGACY_V1_SCHEMA_VERSION, + build_default_migration_registry, + build_default_schema_version_resolver, +) +from .exceptions import ( + DatabaseMigrationConfigurationError, + DatabaseMigrationError, + DatabaseMigrationExecutionError, + DatabaseMigrationPlanningError, + DatabaseMigrationVersionError, + MissingMigrationStepError, + UnrecognizedDatabaseSchemaError, + UnsupportedMigrationDirectionError, +) +from .manager import DatabaseMigrationManager +from .models import ( + ColumnSchema, + DatabaseMigrationState, + DatabaseSchemaSnapshot, + MigrationExecutionContext, + MigrationPlan, + MigrationStep, + ResolvedSchemaVersion, + SchemaVersionSource, + TableSchema, +) +from .planner import MigrationPlanner +from .progress import ( + BaseMigrationProgressReporter, + RichMigrationProgressReporter, + create_rich_migration_progress_reporter, +) +from .registry import MigrationRegistry +from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver +from .schema import SQLiteSchemaInspector +from .version_store import SQLiteUserVersionStore + +__all__ = [ + "BaseSchemaVersionDetector", + "BaseMigrationProgressReporter", + "build_default_migration_registry", + "build_default_schema_version_resolver", + "ColumnSchema", + "create_database_migration_bootstrapper", + "create_rich_migration_progress_reporter", + "DatabaseMigrationConfigurationError", + "DatabaseMigrationError", + "DatabaseMigrationBootstrapper", + "DatabaseMigrationExecutionError", + "DatabaseMigrationManager", + "DatabaseMigrationPlanningError", + "DatabaseMigrationState", + "DatabaseMigrationVersionError", + "DatabaseSchemaSnapshot", + "EMPTY_SCHEMA_VERSION", + "LATEST_SCHEMA_VERSION", + "LEGACY_V1_SCHEMA_VERSION", + "MigrationExecutionContext", + "MigrationPlan", + "MigrationPlanner", + "MigrationRegistry", + "MigrationStep", + "MissingMigrationStepError", + "ResolvedSchemaVersion", + "RichMigrationProgressReporter", + "SchemaVersionResolver", + "SchemaVersionSource", + "SQLiteSchemaInspector", + "SQLiteUserVersionStore", + "TableSchema", + "UnrecognizedDatabaseSchemaError", + "UnsupportedMigrationDirectionError", +] diff --git a/src/common/database/migrations/bootstrap.py b/src/common/database/migrations/bootstrap.py new file mode 100644 index 00000000..a7a0a779 --- /dev/null +++ b/src/common/database/migrations/bootstrap.py @@ -0,0 +1,171 @@ +"""数据库迁移启动桥接层。""" + +from typing import Optional + +from sqlalchemy.engine import Engine + +from src.common.logger import get_logger + +from .builtin import ( + LATEST_SCHEMA_VERSION, + build_default_migration_registry, + build_default_schema_version_resolver, +) +from .exceptions import DatabaseMigrationExecutionError +from .manager import DatabaseMigrationManager +from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource +from .registry import MigrationRegistry +from .resolver import SchemaVersionResolver +from .version_store import SQLiteUserVersionStore + +logger = get_logger("database_migration") + + +class DatabaseMigrationBootstrapper: + """数据库迁移启动桥接器。 + + 该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束: + 1. 若数据库为空,则直接交给当前模型定义建出最新结构; + 2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动; + 3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移; + 4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。 + """ + + def __init__( + self, + manager: DatabaseMigrationManager, + latest_schema_version: int = LATEST_SCHEMA_VERSION, + ) -> None: + """初始化数据库迁移启动桥接器。 + + Args: + manager: 数据库迁移编排器。 + latest_schema_version: 当前代码支持的最新 schema 版本号。 + """ + self.manager = manager + self.latest_schema_version = latest_schema_version + + def prepare_database(self) -> DatabaseMigrationState: + """为数据库初始化阶段准备迁移状态。 + + Returns: + DatabaseMigrationState: 迁移准备完成后的数据库状态。 + + Raises: + DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。 + """ + with self.manager.engine.connect() as connection: + resolved_version = self.manager.resolver.resolve(connection) + + if resolved_version.version > self.latest_schema_version: + raise DatabaseMigrationExecutionError( + "当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。" + f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}" + ) + + if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE: + logger.info( + "检测到空数据库,将直接根据当前模型创建最新结构。" + f" 目标版本={self.latest_schema_version}" + ) + return self._build_noop_state( + current_version=resolved_version.version, + target_version=self.latest_schema_version, + resolved_state=resolved_version, + ) + + migration_state = self.manager.describe_state(target_version=self.latest_schema_version) + if not migration_state.requires_migration(): + logger.info( + f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}" + ) + return migration_state + + logger.info( + "检测到数据库需要迁移," + f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}" + ) + self.manager.migrate(target_version=self.latest_schema_version) + return self.manager.describe_state(target_version=self.latest_schema_version) + + def finalize_database(self, migration_state: DatabaseMigrationState) -> None: + """在数据库初始化末尾补写最终 schema 版本号。 + + 该方法主要负责两类场景: + 1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本; + 2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。 + + Args: + migration_state: 初始化前解析得到的迁移状态。 + """ + if migration_state.requires_migration(): + return + if migration_state.target_version <= 0: + return + if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA: + return + + with self.manager.engine.begin() as connection: + self.manager.version_store.write_version(connection, migration_state.target_version) + + logger.info( + "数据库 schema 版本写入完成。" + f" 来源={migration_state.resolved_version.source.value}," + f" 写入版本={migration_state.target_version}" + ) + + def _build_noop_state( + self, + current_version: int, + target_version: int, + resolved_state: ResolvedSchemaVersion, + ) -> DatabaseMigrationState: + """构建无迁移动作的数据库状态对象。 + + Args: + current_version: 当前数据库版本号。 + target_version: 当前初始化流程期望达到的目标版本号。 + resolved_state: 已解析的数据库版本状态。 + + Returns: + DatabaseMigrationState: 不包含迁移步骤的状态对象。 + """ + return DatabaseMigrationState( + resolved_version=resolved_state, + target_version=target_version, + plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]), + ) + + +def create_database_migration_bootstrapper( + engine: Engine, + registry: Optional[MigrationRegistry] = None, + resolver: Optional[SchemaVersionResolver] = None, + version_store: Optional[SQLiteUserVersionStore] = None, + latest_schema_version: int = LATEST_SCHEMA_VERSION, +) -> DatabaseMigrationBootstrapper: + """创建数据库迁移启动桥接器。 + + Args: + engine: 目标数据库引擎。 + registry: 迁移步骤注册表;未提供时使用默认注册表。 + resolver: 数据库版本解析器;未提供时使用默认解析器。 + version_store: 版本存储器;未提供时使用默认存储器。 + latest_schema_version: 当前代码支持的最新 schema 版本号。 + + Returns: + DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。 + """ + migration_registry = registry or build_default_migration_registry() + migration_resolver = resolver or build_default_schema_version_resolver() + migration_version_store = version_store or SQLiteUserVersionStore() + migration_manager = DatabaseMigrationManager( + engine=engine, + registry=migration_registry, + resolver=migration_resolver, + version_store=migration_version_store, + ) + return DatabaseMigrationBootstrapper( + manager=migration_manager, + latest_schema_version=latest_schema_version, + ) diff --git a/src/common/database/migrations/builtin.py b/src/common/database/migrations/builtin.py new file mode 100644 index 00000000..5b16780b --- /dev/null +++ b/src/common/database/migrations/builtin.py @@ -0,0 +1,159 @@ +"""数据库迁移内置版本与默认注册表。""" + +from typing import List, Optional + +from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2 +from .models import DatabaseSchemaSnapshot, MigrationStep +from .registry import MigrationRegistry +from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver +from .version_store import SQLiteUserVersionStore +from .schema import SQLiteSchemaInspector + +EMPTY_SCHEMA_VERSION = 0 +LEGACY_V1_SCHEMA_VERSION = 1 +LATEST_SCHEMA_VERSION = 2 + +_LEGACY_V1_EXCLUSIVE_TABLES = ( + "chat_streams", + "emoji", + "emoji_description_cache", + "expression", + "group_info", + "image_descriptions", + "jargon", + "messages", + "thinking_back", +) + + +class LatestSchemaVersionDetector(BaseSchemaVersionDetector): + """当前最新 schema 结构探测器。""" + + @property + def name(self) -> str: + """返回探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + return "latest_schema_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """检测数据库是否已经是当前最新结构。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。 + """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): + return None + + latest_marker_tables = ( + "mai_messages", + "chat_sessions", + "expressions", + "jargons", + "thinking_questions", + "tool_records", + ) + if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables): + return None + if not snapshot.has_column("images", "image_hash"): + return None + if not snapshot.has_column("images", "full_path"): + return None + if not snapshot.has_column("images", "image_type"): + return None + if not snapshot.has_column("action_records", "session_id"): + return None + if not snapshot.has_column("chat_history", "session_id"): + return None + if not snapshot.has_column("person_info", "user_nickname"): + return None + return LATEST_SCHEMA_VERSION + + +class LegacyV1SchemaDetector(BaseSchemaVersionDetector): + """旧版 ``0.x`` schema 结构探测器。""" + + @property + def name(self) -> str: + """返回探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + return "legacy_v1_schema_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """检测数据库是否为旧版 ``0.x`` 结构。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。 + """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): + return LEGACY_V1_SCHEMA_VERSION + + legacy_shared_markers = ( + ("action_records", ("chat_id", "time")), + ("chat_history", ("chat_id", "original_text")), + ("images", ("emoji_hash", "path", "type")), + ("llm_usage", ("model_api_provider", "status")), + ("online_time", ("duration",)), + ("person_info", ("nickname", "group_nick_name")), + ) + for table_name, required_columns in legacy_shared_markers: + if snapshot.has_table(table_name) and all( + snapshot.has_column(table_name, column_name) for column_name in required_columns + ): + return LEGACY_V1_SCHEMA_VERSION + return None + + +def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]: + """构建默认 schema 版本探测器链。 + + Returns: + List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。 + """ + return [ + LatestSchemaVersionDetector(), + LegacyV1SchemaDetector(), + ] + + +def build_default_schema_version_resolver() -> SchemaVersionResolver: + """构建默认 schema 版本解析器。 + + Returns: + SchemaVersionResolver: 配置完成的 schema 版本解析器。 + """ + return SchemaVersionResolver( + version_store=SQLiteUserVersionStore(), + schema_inspector=SQLiteSchemaInspector(), + detectors=build_default_schema_version_detectors(), + ) + + +def build_default_migration_registry() -> MigrationRegistry: + """构建默认迁移步骤注册表。 + + Returns: + MigrationRegistry: 含默认迁移步骤的注册表实例。 + """ + return MigrationRegistry( + steps=[ + MigrationStep( + version_from=LEGACY_V1_SCHEMA_VERSION, + version_to=LATEST_SCHEMA_VERSION, + name="legacy_v1_to_latest_v2", + description="将旧版 0.x 数据库整体迁移到当前最新 schema。", + handler=migrate_legacy_v1_to_v2, + ) + ] + ) diff --git a/src/common/database/migrations/exceptions.py b/src/common/database/migrations/exceptions.py new file mode 100644 index 00000000..7f0a667d --- /dev/null +++ b/src/common/database/migrations/exceptions.py @@ -0,0 +1,33 @@ +"""数据库迁移基础设施异常定义。""" + + +class DatabaseMigrationError(Exception): + """数据库迁移基础异常。""" + + +class DatabaseMigrationConfigurationError(DatabaseMigrationError): + """数据库迁移配置不合法。""" + + +class DatabaseMigrationPlanningError(DatabaseMigrationError): + """数据库迁移计划生成失败。""" + + +class DatabaseMigrationExecutionError(DatabaseMigrationError): + """数据库迁移执行失败。""" + + +class DatabaseMigrationVersionError(DatabaseMigrationError): + """数据库版本读写或校验失败。""" + + +class MissingMigrationStepError(DatabaseMigrationPlanningError): + """缺少某个版本区间所需的迁移步骤。""" + + +class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError): + """当前迁移方向不被支持。""" + + +class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError): + """无法识别未标记版本数据库的结构。""" diff --git a/src/common/database/migrations/legacy_v1_to_v2.py b/src/common/database/migrations/legacy_v1_to_v2.py new file mode 100644 index 00000000..c1f88dd0 --- /dev/null +++ b/src/common/database/migrations/legacy_v1_to_v2.py @@ -0,0 +1,1489 @@ +"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast + +from sqlalchemy import text +from sqlalchemy.engine import Connection + +import json +import msgpack + +from src.common.logger import get_logger + +from .exceptions import DatabaseMigrationExecutionError +from .models import DatabaseSchemaSnapshot, MigrationExecutionContext +from .schema import SQLiteSchemaInspector + +logger = get_logger("database_migration") + +_LEGACY_V1_BACKUP_PREFIX = "__legacy_v1_" +_LEGACY_V1_TABLE_NAMES = ( + "action_records", + "chat_history", + "chat_streams", + "emoji", + "emoji_description_cache", + "expression", + "group_info", + "image_descriptions", + "images", + "jargon", + "llm_usage", + "messages", + "online_time", + "person_info", + "thinking_back", +) +_EMPTY_MESSAGE_SEQUENCE_BYTES = msgpack.packb([], use_bin_type=True) + + +@dataclass(frozen=True) +class LegacyTableData: + """旧版表数据快照。""" + + source_table_name: str + columns: Set[str] + rows: List[Dict[str, Any]] + + +def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None: + """执行旧版 ``0.x`` 数据库到最新 schema 的迁移。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + from sqlmodel import SQLModel + + import src.common.database.database_model # noqa: F401 + + schema_inspector = SQLiteSchemaInspector() + snapshot = schema_inspector.inspect(context.connection) + _rename_legacy_v1_tables(context.connection, snapshot) + SQLModel.metadata.create_all(context.connection) + + table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [ + ("chat_sessions", _migrate_chat_sessions), + ("llm_usage", _migrate_model_usage), + ("images", _migrate_images), + ("mai_messages", _migrate_messages), + ("action_records", _migrate_action_records), + ("tool_records", _migrate_tool_records), + ("online_time", _migrate_online_time), + ("person_info", _migrate_person_info), + ("expressions", _migrate_expressions), + ("jargons", _migrate_jargons), + ("chat_history", _migrate_chat_history), + ("thinking_questions", _migrate_thinking_questions), + ] + migrated_counts: Dict[str, int] = {} + total_record_count = _estimate_total_record_count(context.connection) + context.start_progress( + total_tables=len(table_migration_jobs), + total_records=total_record_count, + description="总迁移进度", + table_unit_name="表", + record_unit_name="记录", + ) + for table_name, migration_handler in table_migration_jobs: + migrated_counts[table_name] = migration_handler(context) + + summary_text = ", ".join(f"{table_name}={count}" for table_name, count in migrated_counts.items()) + logger.info(f"旧版数据库迁移完成: {summary_text}") + + +def _legacy_backup_table_name(table_name: str) -> str: + """构建旧版表的备份表名。 + + Args: + table_name: 旧版原始表名。 + + Returns: + str: 带前缀的备份表名。 + """ + return f"{_LEGACY_V1_BACKUP_PREFIX}{table_name}" + + +def _quote_identifier(identifier: str) -> str: + """为 SQLite 标识符添加安全引号。 + + Args: + identifier: 待引用的标识符。 + + Returns: + str: 可安全拼接到 SQL 中的标识符。 + """ + escaped_identifier = identifier.replace('"', '""') + return f'"{escaped_identifier}"' + + +def _rename_legacy_v1_tables(connection: Connection, snapshot: DatabaseSchemaSnapshot) -> None: + """将旧版表统一改名为带备份前缀的表名。 + + Args: + connection: 当前数据库连接。 + snapshot: 当前数据库结构快照。 + + Raises: + DatabaseMigrationExecutionError: 当发现同名旧表与备份表同时存在时抛出。 + """ + for table_name in _LEGACY_V1_TABLE_NAMES: + if not snapshot.has_table(table_name): + continue + backup_table_name = _legacy_backup_table_name(table_name) + if snapshot.has_table(backup_table_name): + raise DatabaseMigrationExecutionError( + "检测到旧版表与迁移备份表同时存在,无法安全继续迁移。" + f" 冲突表={table_name},备份表={backup_table_name}" + ) + connection.execute( + text( + f"ALTER TABLE {_quote_identifier(table_name)} " + f"RENAME TO {_quote_identifier(backup_table_name)}" + ) + ) + + +def _load_legacy_table_data(connection: Connection, original_table_name: str) -> Optional[LegacyTableData]: + """加载单张旧版备份表的数据快照。 + + Args: + connection: 当前数据库连接。 + original_table_name: 旧版原始表名。 + + Returns: + Optional[LegacyTableData]: 若备份表存在则返回其数据快照,否则返回 ``None``。 + """ + backup_table_name = _legacy_backup_table_name(original_table_name) + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, backup_table_name): + return None + + table_schema = schema_inspector.get_table_schema(connection, backup_table_name) + rows = connection.execute(text(f"SELECT * FROM {_quote_identifier(backup_table_name)}")).mappings().all() + return LegacyTableData( + source_table_name=backup_table_name, + columns=set(table_schema.columns), + rows=[dict(row) for row in rows], + ) + + +def _normalize_optional_text(value: Any) -> Optional[str]: + """将任意值标准化为可空字符串。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 标准化后的文本;若值为空则返回 ``None``。 + """ + if value is None: + return None + text_value = str(value).strip() + return text_value or None + + +def _normalize_required_text(value: Any, default: str = "") -> str: + """将任意值标准化为非空字符串。 + + Args: + value: 待标准化的原始值。 + default: 为空时使用的默认值。 + + Returns: + str: 标准化后的字符串。 + """ + normalized_value = _normalize_optional_text(value) + if normalized_value is None: + return default + return normalized_value + + +def _normalize_int(value: Any, default: int = 0) -> int: + """将任意值标准化为整数。 + + Args: + value: 待标准化的原始值。 + default: 转换失败时的默认值。 + + Returns: + int: 标准化后的整数。 + """ + if value is None or value == "": + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _normalize_float(value: Any, default: float = 0.0) -> float: + """将任意值标准化为浮点数。 + + Args: + value: 待标准化的原始值。 + default: 转换失败时的默认值。 + + Returns: + float: 标准化后的浮点数。 + """ + if value is None or value == "": + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_optional_bool(value: Any) -> Optional[bool]: + """将任意值标准化为可空布尔值。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[bool]: 标准化后的布尔值;若无法确定则返回 ``None``。 + """ + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(int(value)) + + normalized_text = str(value).strip().lower() + if normalized_text in {"", "null", "none"}: + return None + if normalized_text in {"1", "true", "t", "yes", "y"}: + return True + if normalized_text in {"0", "false", "f", "no", "n"}: + return False + return None + + +def _normalize_bool(value: Any, default: bool = False) -> bool: + """将任意值标准化为布尔值。 + + Args: + value: 待标准化的原始值。 + default: 无法识别时的默认值。 + + Returns: + bool: 标准化后的布尔值。 + """ + parsed_value = _normalize_optional_bool(value) + return default if parsed_value is None else parsed_value + + +def _coerce_datetime(value: Any, fallback_now: bool = False) -> Optional[datetime]: + """将旧版时间字段标准化为 ``datetime``。 + + Args: + value: 待转换的原始值。 + fallback_now: 转换失败时是否回退到当前时间。 + + Returns: + Optional[datetime]: 转换后的时间对象。 + """ + if value is None or value == "": + return datetime.now() if fallback_now else None + if isinstance(value, datetime): + return value + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value)) + except (OSError, OverflowError, ValueError): + return datetime.now() if fallback_now else None + + normalized_text = str(value).strip() + if not normalized_text: + return datetime.now() if fallback_now else None + try: + return datetime.fromtimestamp(float(normalized_text)) + except (TypeError, ValueError, OSError, OverflowError): + pass + try: + return datetime.fromisoformat(normalized_text.replace("Z", "+00:00")) + except ValueError: + return datetime.now() if fallback_now else None + + +def _normalize_string_list(value: Any) -> List[str]: + """将旧版文本或 JSON 字段规范化为字符串列表。 + + Args: + value: 待标准化的原始值。 + + Returns: + List[str]: 规范化后的字符串列表。 + """ + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + + normalized_text = str(value).strip() + if not normalized_text: + return [] + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return [normalized_text] + + if isinstance(parsed_value, list): + return [str(item).strip() for item in parsed_value if str(item).strip()] + if isinstance(parsed_value, str): + parsed_text = parsed_value.strip() + return [parsed_text] if parsed_text else [] + if parsed_value is None: + return [] + return [str(parsed_value).strip()] + + +def _normalize_json_dict_text(value: Any) -> Optional[str]: + """将旧版附加配置标准化为 JSON 字典字符串。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 合法的 JSON 字典字符串;若无内容则返回 ``None``。 + """ + if value is None: + return None + if isinstance(value, dict): + return json.dumps(value, ensure_ascii=False) + + normalized_text = str(value).strip() + if not normalized_text: + return None + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return json.dumps({"_legacy_additional_config_raw": normalized_text}, ensure_ascii=False) + + if isinstance(parsed_value, dict): + return json.dumps(parsed_value, ensure_ascii=False) + return json.dumps({"_legacy_additional_config_raw": parsed_value}, ensure_ascii=False) + + +def _normalize_group_cardname_json(value: Any) -> Optional[str]: + """将旧版群昵称字段转换为当前使用的 JSON 结构。 + + Args: + value: 旧版 ``group_nick_name`` 字段值。 + + Returns: + Optional[str]: 新版 ``group_cardname`` JSON 字符串。 + """ + if value is None: + return None + + normalized_text = str(value).strip() + if not normalized_text: + return None + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return None + + if not isinstance(parsed_value, list): + return None + + normalized_items: List[Dict[str, str]] = [] + for item in parsed_value: + if not isinstance(item, Mapping): + continue + group_id = _normalize_required_text(item.get("group_id")) + group_cardname = _normalize_required_text(item.get("group_cardname") or item.get("group_nick_name")) + if not group_id or not group_cardname: + continue + normalized_items.append( + { + "group_id": group_id, + "group_cardname": group_cardname, + } + ) + if not normalized_items: + return None + return json.dumps(normalized_items, ensure_ascii=False) + + +def _normalize_modified_by(value: Any) -> Optional[str]: + """将旧版审核来源字段标准化为当前枚举名称。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 若能识别则返回 ``AI`` / ``USER``,否则返回 ``None``。 + """ + normalized_text = _normalize_required_text(value).lower() + if normalized_text in {"", "null", "none"}: + return None + if normalized_text in {"ai"}: + return "AI" + if normalized_text in {"user"}: + return "USER" + return None + + +def _build_session_id_dict(value: Any, fallback_count: int) -> str: + """将旧版 ``chat_id`` 字段转换为新版 ``session_id_dict``。 + + Args: + value: 旧版 ``chat_id`` 字段值。 + fallback_count: 默认引用次数。 + + Returns: + str: 新版 ``session_id_dict`` JSON 字符串。 + """ + if value is None: + return json.dumps({}, ensure_ascii=False) + + normalized_text = str(value).strip() + if not normalized_text: + return json.dumps({}, ensure_ascii=False) + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return json.dumps({normalized_text: max(fallback_count, 1)}, ensure_ascii=False) + + if isinstance(parsed_value, str): + parsed_text = parsed_value.strip() + if not parsed_text: + return json.dumps({}, ensure_ascii=False) + return json.dumps({parsed_text: max(fallback_count, 1)}, ensure_ascii=False) + if not isinstance(parsed_value, list): + return json.dumps({}, ensure_ascii=False) + + session_counts: Dict[str, int] = {} + for item in parsed_value: + if not isinstance(item, list) or not item: + continue + session_id = _normalize_required_text(item[0]) + if not session_id: + continue + session_count = fallback_count + if len(item) > 1: + session_count = _normalize_int(item[1], default=fallback_count) + session_counts[session_id] = max(session_count, 1) + return json.dumps(session_counts, ensure_ascii=False) + + +def _build_legacy_message_additional_config(row: Mapping[str, Any]) -> Optional[str]: + """构建新版消息表使用的附加配置 JSON。 + + Args: + row: 旧版消息表行数据。 + + Returns: + Optional[str]: 新版消息表 ``additional_config`` 字段内容。 + """ + additional_config_text = _normalize_json_dict_text(row.get("additional_config")) + if additional_config_text: + merged_config = json.loads(additional_config_text) + else: + merged_config = {} + + legacy_fields = { + "intercept_message_level": row.get("intercept_message_level"), + "interest_value": row.get("interest_value"), + "key_words": row.get("key_words"), + "key_words_lite": row.get("key_words_lite"), + "priority_info": row.get("priority_info"), + "priority_mode": row.get("priority_mode"), + "selected_expressions": row.get("selected_expressions"), + } + for field_name, field_value in legacy_fields.items(): + if field_value is None: + continue + merged_config[field_name] = field_value + + if not merged_config: + return None + return json.dumps(merged_config, ensure_ascii=False) + + +def _build_message_raw_content(processed_plain_text: Optional[str], display_message: Optional[str]) -> bytes: + """为旧版消息构造一个可被当前代码读取的占位 ``raw_content``。 + + Args: + processed_plain_text: 旧版消息的处理后文本。 + display_message: 旧版消息的展示文本。 + + Returns: + bytes: 可被当前消息模型安全反序列化的 msgpack 字节串。 + """ + message_text = _normalize_optional_text(display_message) or _normalize_optional_text(processed_plain_text) + if not message_text: + return cast(bytes, _EMPTY_MESSAGE_SEQUENCE_BYTES) + serialized_payload = [{"type": "text", "data": message_text}] + return cast(bytes, msgpack.packb(serialized_payload, use_bin_type=True)) + + +def _deduce_image_type_name(value: Any) -> str: + """将旧版图片类型转换为当前枚举名称。 + + Args: + value: 旧版图片类型字段值。 + + Returns: + str: 当前 ``ImageType`` 枚举在数据库中的文本值。 + """ + normalized_text = _normalize_required_text(value, default="image").lower() + if normalized_text == "emoji": + return "EMOJI" + return "IMAGE" + + +def _count_legacy_table_rows(connection: Connection, original_table_name: str) -> int: + """统计单张旧版备份表中的记录总数。 + + Args: + connection: 当前数据库连接。 + original_table_name: 旧版原始表名。 + + Returns: + int: 备份表中的记录数;若表不存在则返回 ``0``。 + """ + backup_table_name = _legacy_backup_table_name(original_table_name) + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, backup_table_name): + return 0 + row = connection.execute( + text(f"SELECT COUNT(*) FROM {_quote_identifier(backup_table_name)}") + ).first() + if row is None: + return 0 + return _normalize_int(row[0], default=0) + + +def _estimate_total_record_count(connection: Connection) -> int: + """估算旧版迁移步骤需要处理的总记录数。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 本次迁移预计处理的总记录数。 + """ + return ( + _count_legacy_table_rows(connection, "chat_streams") + + _count_legacy_table_rows(connection, "llm_usage") + + _count_legacy_table_rows(connection, "emoji") + + _count_legacy_table_rows(connection, "images") + + _count_legacy_table_rows(connection, "messages") + + _count_legacy_table_rows(connection, "action_records") + + _count_legacy_table_rows(connection, "action_records") + + _count_legacy_table_rows(connection, "online_time") + + _count_legacy_table_rows(connection, "person_info") + + _count_legacy_table_rows(connection, "expression") + + _count_legacy_table_rows(connection, "jargon") + + _count_legacy_table_rows(connection, "chat_history") + + _count_legacy_table_rows(connection, "thinking_back") + ) + + +def _complete_table_progress(context: MigrationExecutionContext, table_name: str) -> None: + """标记单张表的迁移已经完成。 + + Args: + context: 当前迁移步骤执行上下文。 + table_name: 已完成迁移的表名。 + """ + context.advance_progress(completed_tables=1, item_name=table_name) + + +def _migrate_chat_sessions(context: MigrationExecutionContext) -> int: + """迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "chat_streams") + if legacy_table is None: + _complete_table_progress(context, "chat_sessions") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO chat_sessions ( + session_id, + created_timestamp, + last_active_timestamp, + user_id, + group_id, + platform + ) VALUES ( + :session_id, + :created_timestamp, + :last_active_timestamp, + :user_id, + :group_id, + :platform + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_required_text(row.get("stream_id")) + if session_id: + connection.execute( + insert_sql, + { + "session_id": session_id, + "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), + "last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True), + "user_id": _normalize_optional_text(row.get("user_id")), + "group_id": _normalize_optional_text(row.get("group_id")), + "platform": _normalize_required_text(row.get("platform"), default="unknown"), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "chat_sessions") + return migrated_count + + +def _migrate_model_usage(context: MigrationExecutionContext) -> int: + """迁移旧版 ``llm_usage`` 到新版 ``llm_usage``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "llm_usage") + if legacy_table is None: + _complete_table_progress(context, "llm_usage") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO llm_usage ( + id, + model_name, + model_assign_name, + model_api_provider_name, + endpoint, + user_type, + request_type, + time_cost, + timestamp, + prompt_tokens, + completion_tokens, + total_tokens, + cost + ) VALUES ( + :id, + :model_name, + :model_assign_name, + :model_api_provider_name, + :endpoint, + :user_type, + :request_type, + :time_cost, + :timestamp, + :prompt_tokens, + :completion_tokens, + :total_tokens, + :cost + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "model_name": _normalize_required_text(row.get("model_name"), default="unknown"), + "model_assign_name": _normalize_optional_text(row.get("model_assign_name")), + "model_api_provider_name": _normalize_required_text(row.get("model_api_provider"), default="unknown"), + "endpoint": _normalize_optional_text(row.get("endpoint")), + "user_type": "SYSTEM", + "request_type": _normalize_required_text(row.get("request_type"), default="unknown"), + "time_cost": _normalize_float(row.get("time_cost"), default=0.0), + "timestamp": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "prompt_tokens": _normalize_int(row.get("prompt_tokens"), default=0), + "completion_tokens": _normalize_int(row.get("completion_tokens"), default=0), + "total_tokens": _normalize_int(row.get("total_tokens"), default=0), + "cost": _normalize_float(row.get("cost"), default=0.0), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "llm_usage") + return migrated_count + + +def _migrate_images(context: MigrationExecutionContext) -> int: + """迁移旧版 ``emoji`` 与 ``images`` 到新版 ``images``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + migrated_count = 0 + existing_keys: Set[Tuple[str, str, str]] = set() + existing_rows = connection.execute( + text("SELECT full_path, image_hash, image_type FROM images") + ).mappings().all() + for row in existing_rows: + existing_keys.add( + ( + _normalize_required_text(row.get("full_path")), + _normalize_required_text(row.get("image_hash")), + _normalize_required_text(row.get("image_type")), + ) + ) + insert_sql = text( + """ + INSERT INTO images ( + image_hash, + description, + full_path, + image_type, + emotion, + query_count, + is_registered, + is_banned, + no_file_flag, + record_time, + register_time, + last_used_time, + vlm_processed + ) VALUES ( + :image_hash, + :description, + :full_path, + :image_type, + :emotion, + :query_count, + :is_registered, + :is_banned, + :no_file_flag, + :record_time, + :register_time, + :last_used_time, + :vlm_processed + ) + """ + ) + + legacy_emoji_table = _load_legacy_table_data(connection, "emoji") + if legacy_emoji_table is not None: + for row in legacy_emoji_table.rows: + full_path = _normalize_required_text(row.get("full_path")) + image_hash = _normalize_required_text(row.get("emoji_hash")) + dedupe_key = (full_path, image_hash, "EMOJI") + if full_path and dedupe_key not in existing_keys: + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": "EMOJI", + "emotion": _normalize_optional_text(row.get("emotion")), + "query_count": _normalize_int(row.get("query_count"), default=0), + "is_registered": _normalize_bool(row.get("is_registered"), default=False), + "is_banned": _normalize_bool(row.get("is_banned"), default=False), + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("record_time"), fallback_now=True), + "register_time": _coerce_datetime(row.get("register_time")), + "last_used_time": _coerce_datetime(row.get("last_used_time")), + "vlm_processed": False, + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + context.advance_progress(records=1) + + legacy_images_table = _load_legacy_table_data(connection, "images") + if legacy_images_table is not None: + for row in legacy_images_table.rows: + full_path = _normalize_required_text(row.get("path")) + image_hash = _normalize_required_text(row.get("emoji_hash")) + image_type = _deduce_image_type_name(row.get("type")) + dedupe_key = (full_path, image_hash, image_type) + if full_path and dedupe_key not in existing_keys: + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": image_type, + "emotion": None, + "query_count": _normalize_int(row.get("count"), default=0), + "is_registered": False, + "is_banned": False, + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "register_time": None, + "last_used_time": None, + "vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False), + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + context.advance_progress(records=1) + + _complete_table_progress(context, "images") + return migrated_count + + +def _migrate_messages(context: MigrationExecutionContext) -> int: + """迁移旧版 ``messages`` 到新版 ``mai_messages``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "messages") + if legacy_table is None: + _complete_table_progress(context, "mai_messages") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO mai_messages ( + id, + message_id, + timestamp, + platform, + user_id, + user_nickname, + user_cardname, + group_id, + group_name, + is_mentioned, + is_at, + session_id, + reply_to, + is_emoji, + is_picture, + is_command, + is_notify, + raw_content, + processed_plain_text, + display_message, + additional_config + ) VALUES ( + :id, + :message_id, + :timestamp, + :platform, + :user_id, + :user_nickname, + :user_cardname, + :group_id, + :group_name, + :is_mentioned, + :is_at, + :session_id, + :reply_to, + :is_emoji, + :is_picture, + :is_command, + :is_notify, + :raw_content, + :processed_plain_text, + :display_message, + :additional_config + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if session_id: + processed_plain_text = _normalize_optional_text(row.get("processed_plain_text")) + display_message = _normalize_optional_text(row.get("display_message")) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "message_id": _normalize_required_text(row.get("message_id"), default=""), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "platform": _normalize_required_text( + row.get("chat_info_platform") or row.get("user_platform"), + default="unknown", + ), + "user_id": _normalize_required_text( + row.get("user_id") or row.get("chat_info_user_id"), + default="", + ), + "user_nickname": _normalize_required_text( + row.get("user_nickname") or row.get("chat_info_user_nickname"), + default="", + ), + "user_cardname": _normalize_optional_text( + row.get("user_cardname") or row.get("chat_info_user_cardname") + ), + "group_id": _normalize_optional_text(row.get("chat_info_group_id")), + "group_name": _normalize_optional_text(row.get("chat_info_group_name")), + "is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False), + "is_at": _normalize_bool(row.get("is_at"), default=False), + "session_id": session_id, + "reply_to": _normalize_optional_text(row.get("reply_to")), + "is_emoji": _normalize_bool(row.get("is_emoji"), default=False), + "is_picture": _normalize_bool(row.get("is_picid"), default=False), + "is_command": _normalize_bool(row.get("is_command"), default=False), + "is_notify": _normalize_bool(row.get("is_notify"), default=False), + "raw_content": _build_message_raw_content(processed_plain_text, display_message), + "processed_plain_text": processed_plain_text, + "display_message": display_message, + "additional_config": _build_legacy_message_additional_config(row), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "mai_messages") + return migrated_count + + +def _migrate_action_records(context: MigrationExecutionContext) -> int: + """迁移旧版 ``action_records`` 到新版 ``action_records``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "action_records") + if legacy_table is None: + _complete_table_progress(context, "action_records") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO action_records ( + id, + action_id, + timestamp, + session_id, + action_name, + action_reasoning, + action_data, + action_builtin_prompt, + action_display_prompt + ) VALUES ( + :id, + :action_id, + :timestamp, + :session_id, + :action_name, + :action_reasoning, + :action_data, + :action_builtin_prompt, + :action_display_prompt + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "action_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "action_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "action_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "action_data": _normalize_optional_text(row.get("action_data")), + "action_builtin_prompt": None, + "action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "action_records") + return migrated_count + + +def _migrate_tool_records(context: MigrationExecutionContext) -> int: + """迁移旧版 ``action_records`` 到新版 ``tool_records``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "action_records") + if legacy_table is None: + _complete_table_progress(context, "tool_records") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO tool_records ( + id, + tool_id, + timestamp, + session_id, + tool_name, + tool_reasoning, + tool_data, + tool_builtin_prompt, + tool_display_prompt + ) VALUES ( + :id, + :tool_id, + :timestamp, + :session_id, + :tool_name, + :tool_reasoning, + :tool_data, + :tool_builtin_prompt, + :tool_display_prompt + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "tool_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "tool_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "tool_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "tool_data": _normalize_optional_text(row.get("action_data")), + "tool_builtin_prompt": None, + "tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "tool_records") + return migrated_count + + +def _migrate_online_time(context: MigrationExecutionContext) -> int: + """迁移旧版 ``online_time`` 到新版 ``online_time``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "online_time") + if legacy_table is None: + _complete_table_progress(context, "online_time") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO online_time ( + id, + timestamp, + duration_minutes, + start_timestamp, + end_timestamp + ) VALUES ( + :id, + :timestamp, + :duration_minutes, + :start_timestamp, + :end_timestamp + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "timestamp": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "duration_minutes": _normalize_int(row.get("duration"), default=0), + "start_timestamp": _coerce_datetime(row.get("start_timestamp"), fallback_now=True), + "end_timestamp": _coerce_datetime(row.get("end_timestamp"), fallback_now=True), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "online_time") + return migrated_count + + +def _migrate_person_info(context: MigrationExecutionContext) -> int: + """迁移旧版 ``person_info`` 到新版 ``person_info``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "person_info") + if legacy_table is None: + _complete_table_progress(context, "person_info") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO person_info ( + id, + is_known, + person_id, + person_name, + name_reason, + platform, + user_id, + user_nickname, + group_cardname, + memory_points, + know_counts, + first_known_time, + last_known_time + ) VALUES ( + :id, + :is_known, + :person_id, + :person_name, + :name_reason, + :platform, + :user_id, + :user_nickname, + :group_cardname, + :memory_points, + :know_counts, + :first_known_time, + :last_known_time + ) + """ + ) + for row in legacy_table.rows: + first_known_time = _coerce_datetime(row.get("know_times")) or _coerce_datetime(row.get("know_since")) + last_known_time = _coerce_datetime(row.get("last_know")) or _coerce_datetime(row.get("know_since")) + memory_points = _normalize_string_list(row.get("memory_points")) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "is_known": _normalize_bool(row.get("is_known"), default=False), + "person_id": _normalize_required_text(row.get("person_id")), + "person_name": _normalize_optional_text(row.get("person_name")), + "name_reason": _normalize_optional_text(row.get("name_reason")), + "platform": _normalize_required_text(row.get("platform"), default="unknown"), + "user_id": _normalize_required_text(row.get("user_id"), default=""), + "user_nickname": _normalize_required_text(row.get("nickname"), default=""), + "group_cardname": _normalize_group_cardname_json(row.get("group_nick_name")), + "memory_points": json.dumps(memory_points, ensure_ascii=False) if memory_points else None, + "know_counts": 1 if _normalize_bool(row.get("is_known"), default=False) else 0, + "first_known_time": first_known_time, + "last_known_time": last_known_time, + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "person_info") + return migrated_count + + +def _migrate_expressions(context: MigrationExecutionContext) -> int: + """迁移旧版 ``expression`` 到新版 ``expressions``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "expression") + if legacy_table is None: + _complete_table_progress(context, "expressions") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO expressions ( + id, + situation, + style, + content_list, + count, + last_active_time, + create_time, + session_id, + checked, + rejected, + modified_by + ) VALUES ( + :id, + :situation, + :style, + :content_list, + :count, + :last_active_time, + :create_time, + :session_id, + :checked, + :rejected, + :modified_by + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "situation": _normalize_required_text(row.get("situation"), default=""), + "style": _normalize_required_text(row.get("style"), default=""), + "content_list": json.dumps(_normalize_string_list(row.get("content_list")), ensure_ascii=False), + "count": _normalize_int(row.get("count"), default=1), + "last_active_time": _coerce_datetime(row.get("last_active_time"), fallback_now=True), + "create_time": _coerce_datetime(row.get("create_date"), fallback_now=True), + "session_id": _normalize_optional_text(row.get("chat_id")), + "checked": _normalize_bool(row.get("checked"), default=False), + "rejected": _normalize_bool(row.get("rejected"), default=False), + "modified_by": _normalize_modified_by(row.get("modified_by")), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "expressions") + return migrated_count + + +def _migrate_jargons(context: MigrationExecutionContext) -> int: + """迁移旧版 ``jargon`` 到新版 ``jargons``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "jargon") + if legacy_table is None: + _complete_table_progress(context, "jargons") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO jargons ( + id, + content, + raw_content, + meaning, + session_id_dict, + count, + is_jargon, + is_complete, + is_global, + last_inference_count, + inference_with_context, + inference_with_content_only + ) VALUES ( + :id, + :content, + :raw_content, + :meaning, + :session_id_dict, + :count, + :is_jargon, + :is_complete, + :is_global, + :last_inference_count, + :inference_with_context, + :inference_with_content_only + ) + """ + ) + for row in legacy_table.rows: + count = _normalize_int(row.get("count"), default=0) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "content": _normalize_required_text(row.get("content"), default=""), + "raw_content": json.dumps(_normalize_string_list(row.get("raw_content")), ensure_ascii=False) + if row.get("raw_content") is not None + else None, + "meaning": _normalize_required_text(row.get("meaning")), + "session_id_dict": _build_session_id_dict(row.get("chat_id"), fallback_count=max(count, 1)), + "count": count, + "is_jargon": _normalize_optional_bool(row.get("is_jargon")), + "is_complete": _normalize_bool(row.get("is_complete"), default=False), + "is_global": _normalize_bool(row.get("is_global"), default=False), + "last_inference_count": _normalize_int(row.get("last_inference_count"), default=0), + "inference_with_context": _normalize_optional_text(row.get("inference_with_context")), + "inference_with_content_only": _normalize_optional_text( + row.get("inference_content_only") or row.get("inference_with_content_only") + ), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "jargons") + return migrated_count + + +def _migrate_chat_history(context: MigrationExecutionContext) -> int: + """迁移旧版 ``chat_history`` 到新版 ``chat_history``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "chat_history") + if legacy_table is None: + _complete_table_progress(context, "chat_history") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO chat_history ( + id, + session_id, + start_timestamp, + end_timestamp, + query_count, + query_forget_count, + original_messages, + participants, + theme, + keywords, + summary + ) VALUES ( + :id, + :session_id, + :start_timestamp, + :end_timestamp, + :query_count, + :query_forget_count, + :original_messages, + :participants, + :theme, + :keywords, + :summary + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_required_text(row.get("chat_id")) + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "session_id": session_id, + "start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True), + "end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True), + "query_count": _normalize_int(row.get("count"), default=0), + "query_forget_count": _normalize_int(row.get("forget_times"), default=0), + "original_messages": _normalize_required_text(row.get("original_text")), + "participants": _normalize_required_text(row.get("participants"), default="[]"), + "theme": _normalize_required_text(row.get("theme"), default=""), + "keywords": _normalize_required_text(row.get("keywords"), default="[]"), + "summary": _normalize_required_text(row.get("summary"), default=""), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "chat_history") + return migrated_count + + +def _migrate_thinking_questions(context: MigrationExecutionContext) -> int: + """迁移旧版 ``thinking_back`` 到新版 ``thinking_questions``。 + + Args: + context: 当前迁移步骤执行上下文。 + + Returns: + int: 迁移成功的记录数。 + """ + connection = context.connection + legacy_table = _load_legacy_table_data(connection, "thinking_back") + if legacy_table is None: + _complete_table_progress(context, "thinking_questions") + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO thinking_questions ( + id, + question, + context, + found_answer, + answer, + thinking_steps, + created_timestamp, + updated_timestamp + ) VALUES ( + :id, + :question, + :context, + :found_answer, + :answer, + :thinking_steps, + :created_timestamp, + :updated_timestamp + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "question": _normalize_required_text(row.get("question"), default=""), + "context": _normalize_optional_text(row.get("context")), + "found_answer": _normalize_bool(row.get("found_answer"), default=False), + "answer": _normalize_optional_text(row.get("answer")), + "thinking_steps": _normalize_optional_text(row.get("thinking_steps")), + "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), + "updated_timestamp": _coerce_datetime(row.get("update_time"), fallback_now=True), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "thinking_questions") + return migrated_count diff --git a/src/common/database/migrations/manager.py b/src/common/database/migrations/manager.py new file mode 100644 index 00000000..d33e6926 --- /dev/null +++ b/src/common/database/migrations/manager.py @@ -0,0 +1,205 @@ +"""数据库迁移编排器。""" + +from typing import Callable, Optional + +from sqlalchemy.engine import Connection, Engine + +from src.common.logger import get_logger + +from .exceptions import DatabaseMigrationExecutionError +from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan +from .planner import MigrationPlanner +from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter +from .registry import MigrationRegistry +from .resolver import SchemaVersionResolver +from .version_store import SQLiteUserVersionStore + +logger = get_logger("database_migration") + + +class DatabaseMigrationManager: + """数据库迁移编排器。 + + 该类只负责基础设施层面的编排工作,包括: + 1. 解析当前数据库版本; + 2. 生成迁移计划; + 3. 顺序执行已注册迁移步骤; + 4. 在每一步成功后更新 ``user_version``。 + + 当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。 + """ + + def __init__( + self, + engine: Engine, + registry: Optional[MigrationRegistry] = None, + planner: Optional[MigrationPlanner] = None, + resolver: Optional[SchemaVersionResolver] = None, + version_store: Optional[SQLiteUserVersionStore] = None, + progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None, + ) -> None: + """初始化数据库迁移编排器。 + + Args: + engine: 目标数据库引擎。 + registry: 迁移步骤注册表。 + planner: 迁移计划生成器。 + resolver: 数据库版本解析器。 + version_store: 版本存储器。 + progress_reporter_factory: 迁移进度上报器工厂。 + """ + self.engine = engine + self.registry = registry or MigrationRegistry() + self.planner = planner or MigrationPlanner() + self.resolver = resolver or SchemaVersionResolver() + self.version_store = version_store or SQLiteUserVersionStore() + self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter + + def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState: + """描述当前数据库的迁移状态。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + DatabaseMigrationState: 当前数据库迁移状态。 + """ + with self.engine.connect() as connection: + resolved_version = self.resolver.resolve(connection) + + effective_target_version = self._resolve_target_version(target_version) + migration_plan = self.planner.plan( + current_version=resolved_version.version, + target_version=effective_target_version, + registry=self.registry, + ) + return DatabaseMigrationState( + resolved_version=resolved_version, + target_version=effective_target_version, + plan=migration_plan, + ) + + def plan(self, target_version: Optional[int] = None) -> MigrationPlan: + """生成当前数据库的迁移计划。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + MigrationPlan: 当前数据库对应的迁移计划。 + """ + return self.describe_state(target_version=target_version).plan + + def migrate(self, target_version: Optional[int] = None) -> MigrationPlan: + """执行迁移计划。 + + 注意: + 若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该 + 版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式 + 副作用。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + MigrationPlan: 已执行的迁移计划。 + + Raises: + DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。 + """ + migration_state = self.describe_state(target_version=target_version) + migration_plan = migration_state.plan + if migration_plan.is_empty(): + logger.info("数据库迁移计划为空,跳过执行。") + return migration_plan + + current_version = migration_state.resolved_version.version + total_steps = migration_plan.step_count() + for step_index, step in enumerate(migration_plan.steps, start=1): + logger.info( + f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})" + ) + try: + with self.progress_reporter_factory() as progress_reporter: + if step.transactional: + with self.engine.begin() as connection: + execution_context = self._build_execution_context( + connection=connection, + current_version=current_version, + migration_plan=migration_plan, + step_index=step_index, + step_name=step.name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) + step.run(execution_context) + self.version_store.write_version(connection, step.version_to) + else: + with self.engine.connect() as connection: + execution_context = self._build_execution_context( + connection=connection, + current_version=current_version, + migration_plan=migration_plan, + step_index=step_index, + step_name=step.name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) + step.run(execution_context) + self.version_store.write_version(connection, step.version_to) + connection.commit() + except Exception as exc: + raise DatabaseMigrationExecutionError( + f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。" + ) from exc + current_version = step.version_to + logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}") + + return migration_plan + + def _resolve_target_version(self, target_version: Optional[int]) -> int: + """解析最终目标版本号。 + + Args: + target_version: 调用方显式指定的目标版本。 + + Returns: + int: 最终用于规划和执行的目标版本号。 + """ + if target_version is not None: + return target_version + return self.registry.latest_version() + + def _build_execution_context( + self, + connection: Connection, + current_version: int, + migration_plan: MigrationPlan, + step_index: int, + step_name: str, + total_steps: int, + progress_reporter: BaseMigrationProgressReporter, + ) -> MigrationExecutionContext: + """构建单个迁移步骤的执行上下文。 + + Args: + connection: 当前迁移步骤使用的数据库连接。 + current_version: 当前数据库版本。 + migration_plan: 当前迁移计划。 + step_index: 当前步骤序号,从 ``1`` 开始。 + step_name: 当前步骤名称。 + total_steps: 计划总步骤数。 + progress_reporter: 当前步骤使用的进度上报器。 + + Returns: + MigrationExecutionContext: 当前步骤的执行上下文对象。 + """ + return MigrationExecutionContext( + connection=connection, + current_version=current_version, + target_version=migration_plan.target_version, + step_index=step_index, + step_name=step_name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) diff --git a/src/common/database/migrations/models.py b/src/common/database/migrations/models.py new file mode 100644 index 00000000..1bf39346 --- /dev/null +++ b/src/common/database/migrations/models.py @@ -0,0 +1,305 @@ +"""数据库迁移基础设施核心数据模型。""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Callable, Dict, List, Optional, TYPE_CHECKING + +from sqlalchemy.engine import Connection + +if TYPE_CHECKING: + from .progress import BaseMigrationProgressReporter + + +def _utc_now() -> datetime: + """返回当前 UTC 时间。 + + Returns: + datetime: 当前 UTC 时间。 + """ + return datetime.now(timezone.utc) + + +class SchemaVersionSource(str, Enum): + """数据库版本来源。""" + + PRAGMA = "pragma" + DETECTOR = "detector" + EMPTY_DATABASE = "empty_database" + + +@dataclass(frozen=True) +class ColumnSchema: + """数据库列结构快照。""" + + name: str + declared_type: str + default_value: Optional[str] + is_not_null: bool + primary_key_position: int + + +@dataclass(frozen=True) +class TableSchema: + """数据库表结构快照。""" + + name: str + columns: Dict[str, ColumnSchema] + + def has_column(self, column_name: str) -> bool: + """判断表中是否存在指定列。 + + Args: + column_name: 待检查的列名。 + + Returns: + bool: 若列存在则返回 ``True``,否则返回 ``False``。 + """ + return column_name in self.columns + + def get_column(self, column_name: str) -> Optional[ColumnSchema]: + """获取指定列的结构信息。 + + Args: + column_name: 待获取的列名。 + + Returns: + Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。 + """ + return self.columns.get(column_name) + + def column_names(self) -> List[str]: + """返回当前表中全部列名。 + + Returns: + List[str]: 按字母顺序排列的列名列表。 + """ + return sorted(self.columns) + + +@dataclass(frozen=True) +class DatabaseSchemaSnapshot: + """数据库结构快照。""" + + tables: Dict[str, TableSchema] + + def is_empty(self) -> bool: + """判断数据库是否没有任何用户表。 + + Returns: + bool: 若数据库中没有用户表则返回 ``True``。 + """ + return not self.tables + + def has_table(self, table_name: str) -> bool: + """判断数据库是否存在指定表。 + + Args: + table_name: 待检查的表名。 + + Returns: + bool: 若表存在则返回 ``True``,否则返回 ``False``。 + """ + return table_name in self.tables + + def has_column(self, table_name: str, column_name: str) -> bool: + """判断数据库指定表中是否存在指定列。 + + Args: + table_name: 待检查的表名。 + column_name: 待检查的列名。 + + Returns: + bool: 若表和列均存在则返回 ``True``。 + """ + table_schema = self.get_table(table_name) + if table_schema is None: + return False + return table_schema.has_column(column_name) + + def get_table(self, table_name: str) -> Optional[TableSchema]: + """获取指定表的结构信息。 + + Args: + table_name: 待获取的表名。 + + Returns: + Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。 + """ + return self.tables.get(table_name) + + def table_names(self) -> List[str]: + """返回当前数据库中的全部用户表名。 + + Returns: + List[str]: 按字母顺序排列的表名列表。 + """ + return sorted(self.tables) + + +@dataclass(frozen=True) +class ResolvedSchemaVersion: + """解析后的数据库版本信息。""" + + version: int + source: SchemaVersionSource + detector_name: Optional[str] = None + snapshot: Optional[DatabaseSchemaSnapshot] = None + + +@dataclass(frozen=True) +class MigrationExecutionContext: + """单个迁移步骤的执行上下文。""" + + connection: Connection + current_version: int + target_version: int + step_index: int + step_name: str + total_steps: int + started_at: datetime = field(default_factory=_utc_now) + progress_reporter: Optional["BaseMigrationProgressReporter"] = None + + def is_last_step(self) -> bool: + """判断当前步骤是否为最后一步。 + + Returns: + bool: 若当前步骤已是计划中的最后一步则返回 ``True``。 + """ + return self.step_index >= self.total_steps + + def start_progress( + self, + total_tables: int, + total_records: int, + description: str = "总迁移进度", + table_unit_name: str = "表", + record_unit_name: str = "记录", + ) -> None: + """启动当前迁移步骤的进度展示。 + + Args: + total_tables: 当前步骤需要处理的总表数。 + total_records: 当前步骤需要处理的总记录数。 + description: 进度描述文本。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 + """ + if self.progress_reporter is None: + return + self.progress_reporter.start( + total_records=total_records, + total_tables=total_tables, + description=description, + table_unit_name=table_unit_name, + record_unit_name=record_unit_name, + ) + + def advance_progress( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: + """推进当前迁移步骤的进度展示。 + + Args: + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 + item_name: 当前完成的项目名称。 + """ + if self.progress_reporter is None: + return + self.progress_reporter.advance( + records=records, + completed_tables=completed_tables, + item_name=item_name, + ) + + +MigrationHandler = Callable[[MigrationExecutionContext], None] + + +@dataclass(frozen=True) +class MigrationStep: + """单个数据库迁移步骤定义。""" + + version_from: int + version_to: int + name: str + description: str + handler: MigrationHandler + transactional: bool = True + + def __post_init__(self) -> None: + """校验迁移步骤定义是否合法。 + + Raises: + ValueError: 当版本号不合法或迁移方向错误时抛出。 + """ + if self.version_from < 0: + raise ValueError("迁移起始版本不能小于 0。") + if self.version_to <= self.version_from: + raise ValueError("迁移目标版本必须大于起始版本。") + + def run(self, context: MigrationExecutionContext) -> None: + """执行当前迁移步骤。 + + Args: + context: 当前迁移步骤的执行上下文。 + """ + self.handler(context) + + +@dataclass(frozen=True) +class MigrationPlan: + """数据库迁移执行计划。""" + + current_version: int + target_version: int + steps: List[MigrationStep] + + def is_empty(self) -> bool: + """判断迁移计划是否为空。 + + Returns: + bool: 若无需执行任何迁移步骤则返回 ``True``。 + """ + return not self.steps + + def step_count(self) -> int: + """返回迁移计划中的步骤数量。 + + Returns: + int: 当前计划中的迁移步骤数。 + """ + return len(self.steps) + + def latest_reachable_version(self) -> int: + """返回该计划执行后的最终版本。 + + Returns: + int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。 + """ + if self.is_empty(): + return self.current_version + return self.steps[-1].version_to + + +@dataclass(frozen=True) +class DatabaseMigrationState: + """数据库迁移状态描述。""" + + resolved_version: ResolvedSchemaVersion + target_version: int + plan: MigrationPlan + + def requires_migration(self) -> bool: + """判断当前状态是否需要执行迁移。 + + Returns: + bool: 若计划中存在待执行迁移步骤则返回 ``True``。 + """ + return not self.plan.is_empty() diff --git a/src/common/database/migrations/planner.py b/src/common/database/migrations/planner.py new file mode 100644 index 00000000..eca98c27 --- /dev/null +++ b/src/common/database/migrations/planner.py @@ -0,0 +1,108 @@ +"""数据库迁移计划生成器。""" + +from typing import List + +from .exceptions import ( + DatabaseMigrationPlanningError, + MissingMigrationStepError, + UnsupportedMigrationDirectionError, +) +from .models import MigrationPlan, MigrationStep +from .registry import MigrationRegistry + + +class MigrationPlanner: + """数据库迁移计划生成器。""" + + def plan( + self, + current_version: int, + target_version: int, + registry: MigrationRegistry, + ) -> MigrationPlan: + """根据当前版本与目标版本生成迁移计划。 + + Args: + current_version: 当前数据库版本。 + target_version: 目标数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + MigrationPlan: 按顺序执行的迁移计划。 + + Raises: + DatabaseMigrationPlanningError: 当版本号非法时抛出。 + MissingMigrationStepError: 当所需迁移步骤缺失时抛出。 + UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。 + """ + self._validate_version(current_version, "current_version") + self._validate_version(target_version, "target_version") + + if target_version < current_version: + raise UnsupportedMigrationDirectionError( + f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}。" + ) + if target_version == current_version: + return MigrationPlan(current_version=current_version, target_version=target_version, steps=[]) + + steps = self._build_steps(current_version, target_version, registry) + return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps) + + def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan: + """生成迁移到注册表最新版本的执行计划。 + + Args: + current_version: 当前数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + MigrationPlan: 指向最新版本的迁移计划。 + """ + target_version = registry.latest_version() + return self.plan(current_version=current_version, target_version=target_version, registry=registry) + + def _build_steps( + self, + current_version: int, + target_version: int, + registry: MigrationRegistry, + ) -> List[MigrationStep]: + """按顺序拼装迁移步骤链。 + + Args: + current_version: 当前数据库版本。 + target_version: 目标数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + List[MigrationStep]: 按顺序执行的迁移步骤列表。 + + Raises: + MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。 + """ + planned_steps: List[MigrationStep] = [] + next_version = current_version + + while next_version < target_version: + step = registry.get_step(next_version) + if step is None: + raise MissingMigrationStepError( + f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。" + ) + planned_steps.append(step) + next_version = step.version_to + + return planned_steps + + def _validate_version(self, version: int, field_name: str) -> None: + """校验版本号是否合法。 + + Args: + version: 待校验的版本号。 + field_name: 当前版本号对应的字段名。 + + Raises: + DatabaseMigrationPlanningError: 当版本号非法时抛出。 + """ + if version < 0: + raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}") diff --git a/src/common/database/migrations/progress.py b/src/common/database/migrations/progress.py new file mode 100644 index 00000000..4aff8d38 --- /dev/null +++ b/src/common/database/migrations/progress.py @@ -0,0 +1,319 @@ +"""数据库迁移进度展示工具。""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Optional + +from rich.console import Console +from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID +from rich.text import Text + + +def _format_duration(total_seconds: Optional[float]) -> str: + """将秒数格式化为适合展示的耗时文本。 + + Args: + total_seconds: 总秒数;为空时表示暂不可用。 + + Returns: + str: 格式化后的耗时文本。 + """ + if total_seconds is None: + return "--:--:--" + safe_seconds = max(total_seconds, 0.0) + return str(timedelta(seconds=int(safe_seconds))) + + +class MigrationSummaryColumn(ProgressColumn): + """渲染数据库迁移总进度摘要列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的总进度摘要。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的摘要文本。 + """ + completed_tables = int(task.fields.get("completed_tables", 0)) + display_table_total = task.fields.get("display_table_total") + total_text = "?" if display_table_total is None else str(int(display_table_total)) + completed_text = str(completed_tables) + return Text(f"总迁移进度({completed_text}/{total_text})") + + +class MigrationSpeedColumn(ProgressColumn): + """渲染数据库迁移速度列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的速度信息。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的速度文本。 + """ + unit_name = str(task.fields.get("progress_unit_name", "项")) + if task.speed is None or task.speed <= 0: + return Text(f"-- {unit_name}/s") + return Text(f"{task.speed:.2f} {unit_name}/s") + + +class MigrationElapsedColumn(ProgressColumn): + """渲染数据库迁移已用时间列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的已用时间。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的已用时间文本。 + """ + return Text(f"已用时间 {_format_duration(task.elapsed)}") + + +class MigrationRemainingColumn(ProgressColumn): + """渲染数据库迁移预估剩余时间列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的预估剩余时间。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的预估剩余时间文本。 + """ + return Text(f"预估时间 {_format_duration(task.time_remaining)}") + + +class BaseMigrationProgressReporter(ABC): + """数据库迁移进度上报器基类。""" + + def __enter__(self) -> "BaseMigrationProgressReporter": + """进入进度上报上下文。 + + Returns: + BaseMigrationProgressReporter: 当前上报器实例。 + """ + self.open() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """退出进度上报上下文。 + + Args: + exc_type: 异常类型。 + exc_value: 异常实例。 + traceback: 异常追踪对象。 + """ + del exc_type, exc_value, traceback + self.close() + + @abstractmethod + def open(self) -> None: + """打开进度上报资源。""" + + @abstractmethod + def close(self) -> None: + """关闭进度上报资源。""" + + @abstractmethod + def start( + self, + total_records: int, + total_tables: int, + description: str = "总迁移进度", + table_unit_name: str = "表", + record_unit_name: str = "记录", + ) -> None: + """启动一个新的迁移进度任务。 + + Args: + total_records: 任务记录总数。 + total_tables: 任务表总数。 + description: 任务描述。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 + """ + + @abstractmethod + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: + """推进当前迁移进度任务。 + + Args: + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 + item_name: 当前完成的项目名称。 + """ + + +class NullMigrationProgressReporter(BaseMigrationProgressReporter): + """不输出任何内容的空进度上报器。""" + + def open(self) -> None: + """打开空进度上报器。""" + + def close(self) -> None: + """关闭空进度上报器。""" + + def start( + self, + total_records: int, + total_tables: int, + description: str = "总迁移进度", + table_unit_name: str = "表", + record_unit_name: str = "记录", + ) -> None: + """启动空进度任务。 + + Args: + total_records: 任务记录总数。 + total_tables: 任务表总数。 + description: 任务描述。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 + """ + del total_records, total_tables, description, table_unit_name, record_unit_name + + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: + """推进空进度任务。 + + Args: + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 + item_name: 当前完成的项目名称。 + """ + del records, completed_tables, item_name + + +class RichMigrationProgressReporter(BaseMigrationProgressReporter): + """基于 ``rich`` 的数据库迁移进度上报器。""" + + def __init__( + self, + console: Optional[Console] = None, + disable: Optional[bool] = None, + refresh_per_second: int = 10, + ) -> None: + """初始化 ``rich`` 迁移进度上报器。 + + Args: + console: 输出使用的 ``rich`` 控制台。 + disable: 是否禁用进度条;为空时根据终端能力自动判断。 + refresh_per_second: 每秒刷新次数。 + """ + self.console = console or Console() + self.disable = disable + self.refresh_per_second = refresh_per_second + self._progress: Optional[Progress] = None + self._task_id: Optional[TaskID] = None + + def open(self) -> None: + """打开 ``rich`` 进度条资源。""" + effective_disable = not self.console.is_terminal if self.disable is None else self.disable + self._progress = Progress( + MigrationSummaryColumn(), + BarColumn(), + MigrationSpeedColumn(), + MigrationElapsedColumn(), + MigrationRemainingColumn(), + console=self.console, + transient=False, + disable=effective_disable, + refresh_per_second=self.refresh_per_second, + expand=True, + ) + self._progress.start() + + def close(self) -> None: + """关闭 ``rich`` 进度条资源。""" + if self._progress is None: + return + self._progress.stop() + self._progress = None + self._task_id = None + + def start( + self, + total_records: int, + total_tables: int, + description: str = "总迁移进度", + table_unit_name: str = "表", + record_unit_name: str = "记录", + ) -> None: + """启动一个新的 ``rich`` 迁移进度任务。 + + Args: + total_records: 任务记录总数。 + total_tables: 任务表总数。 + description: 任务描述。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 + """ + if self._progress is None: + self.open() + assert self._progress is not None + use_record_progress = total_records > 0 + effective_total = total_records if use_record_progress else total_tables + effective_total = max(effective_total, 1) + progress_unit_name = record_unit_name if use_record_progress else table_unit_name + self._task_id = self._progress.add_task( + description, + total=effective_total, + completed_tables=0, + display_table_total=total_tables, + progress_unit_name=progress_unit_name, + use_record_progress=use_record_progress, + ) + + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: + """推进当前 ``rich`` 迁移进度任务。 + + Args: + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 + item_name: 当前完成的项目名称。 + """ + del item_name + if self._progress is None or self._task_id is None: + return + task = self._progress.tasks[self._task_id] + use_record_progress = bool(task.fields.get("use_record_progress", False)) + progress_advance = records if use_record_progress else completed_tables + updated_completed_tables = int(task.fields.get("completed_tables", 0)) + completed_tables + self._progress.update( + self._task_id, + advance=progress_advance, + completed_tables=updated_completed_tables, + ) + + +def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter: + """创建默认的 ``rich`` 迁移进度上报器。 + + Returns: + BaseMigrationProgressReporter: 默认迁移进度上报器实例。 + """ + return RichMigrationProgressReporter() diff --git a/src/common/database/migrations/registry.py b/src/common/database/migrations/registry.py new file mode 100644 index 00000000..fb9d893b --- /dev/null +++ b/src/common/database/migrations/registry.py @@ -0,0 +1,98 @@ +"""数据库迁移步骤注册表。""" + +from typing import Dict, List, Optional + +from .exceptions import DatabaseMigrationConfigurationError +from .models import MigrationStep + + +class MigrationRegistry: + """数据库迁移步骤注册表。""" + + def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None: + """初始化迁移步骤注册表。 + + Args: + steps: 初始化时要注册的迁移步骤列表。 + """ + self._steps_by_from_version: Dict[int, MigrationStep] = {} + if steps: + self.register_many(steps) + + def register(self, step: MigrationStep) -> None: + """注册单个迁移步骤。 + + 当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、 + 易于回放,也便于后续生产问题排查。 + + Args: + step: 待注册的迁移步骤定义。 + + Raises: + DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。 + """ + if step.version_to != step.version_from + 1: + raise DatabaseMigrationConfigurationError( + "迁移步骤必须使用相邻版本号定义,例如 2 -> 3。" + ) + if step.version_from in self._steps_by_from_version: + existing_step = self._steps_by_from_version[step.version_from] + raise DatabaseMigrationConfigurationError( + f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}" + ) + for registered_step in self._steps_by_from_version.values(): + if registered_step.version_to == step.version_to: + raise DatabaseMigrationConfigurationError( + f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。" + ) + self._steps_by_from_version[step.version_from] = step + + def register_many(self, steps: List[MigrationStep]) -> None: + """批量注册多个迁移步骤。 + + Args: + steps: 待注册的迁移步骤列表。 + """ + for step in steps: + self.register(step) + + def get_step(self, version_from: int) -> Optional[MigrationStep]: + """获取指定起始版本的迁移步骤。 + + Args: + version_from: 迁移步骤的起始版本号。 + + Returns: + Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。 + """ + return self._steps_by_from_version.get(version_from) + + def has_step(self, version_from: int) -> bool: + """判断指定起始版本是否已注册迁移步骤。 + + Args: + version_from: 待检查的起始版本号。 + + Returns: + bool: 若已注册对应步骤则返回 ``True``。 + """ + return version_from in self._steps_by_from_version + + def latest_version(self) -> int: + """返回当前注册表支持到的最新 schema 版本。 + + Returns: + int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。 + """ + if not self._steps_by_from_version: + return 0 + return max(step.version_to for step in self._steps_by_from_version.values()) + + def list_steps(self) -> List[MigrationStep]: + """按起始版本顺序返回全部迁移步骤。 + + Returns: + List[MigrationStep]: 已注册迁移步骤列表。 + """ + ordered_versions = sorted(self._steps_by_from_version) + return [self._steps_by_from_version[version] for version in ordered_versions] diff --git a/src/common/database/migrations/resolver.py b/src/common/database/migrations/resolver.py new file mode 100644 index 00000000..fb66a57d --- /dev/null +++ b/src/common/database/migrations/resolver.py @@ -0,0 +1,135 @@ +"""数据库版本解析器。""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +from sqlalchemy.engine import Connection + +from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError +from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource +from .schema import SQLiteSchemaInspector +from .version_store import SQLiteUserVersionStore + + +class BaseSchemaVersionDetector(ABC): + """未标记版本数据库的结构探测器基类。""" + + @property + @abstractmethod + def name(self) -> str: + """返回当前探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + + @abstractmethod + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """根据数据库结构快照推断版本号。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。 + """ + + +class SchemaVersionResolver: + """数据库版本解析器。""" + + def __init__( + self, + version_store: Optional[SQLiteUserVersionStore] = None, + schema_inspector: Optional[SQLiteSchemaInspector] = None, + detectors: Optional[List[BaseSchemaVersionDetector]] = None, + ) -> None: + """初始化数据库版本解析器。 + + Args: + version_store: 版本存储器;未提供时将使用默认实现。 + schema_inspector: 结构探测器;未提供时将使用默认实现。 + detectors: 未标记版本数据库的探测器列表。 + """ + self.version_store = version_store or SQLiteUserVersionStore() + self.schema_inspector = schema_inspector or SQLiteSchemaInspector() + self.detectors: List[BaseSchemaVersionDetector] = list(detectors or []) + + def add_detector(self, detector: BaseSchemaVersionDetector) -> None: + """注册一个未标记版本数据库探测器。 + + Args: + detector: 待注册的探测器实例。 + """ + self.detectors.append(detector) + + def list_detectors(self) -> List[BaseSchemaVersionDetector]: + """返回当前已注册的全部探测器。 + + Returns: + List[BaseSchemaVersionDetector]: 已注册探测器列表副本。 + """ + return list(self.detectors) + + def resolve(self, connection: Connection) -> ResolvedSchemaVersion: + """解析当前数据库的 schema 版本信息。 + + 解析顺序如下: + 1. 优先读取 ``PRAGMA user_version``。 + 2. 若其值为 0,则对数据库结构做快照。 + 3. 若数据库为空,则返回空库版本。 + 4. 若数据库非空,则交给探测器链进行识别。 + + Args: + connection: 当前数据库连接。 + + Returns: + ResolvedSchemaVersion: 解析后的数据库版本信息。 + + Raises: + DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。 + UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。 + """ + recorded_version = self.version_store.read_version(connection) + if recorded_version > 0: + return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA) + + snapshot = self.schema_inspector.inspect(connection) + if snapshot.is_empty(): + return ResolvedSchemaVersion( + version=0, + source=SchemaVersionSource.EMPTY_DATABASE, + snapshot=snapshot, + ) + + return self._detect_unversioned_database(snapshot) + + def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion: + """识别未标记版本的历史数据库。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + ResolvedSchemaVersion: 探测器识别出的版本信息。 + + Raises: + DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。 + UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。 + """ + for detector in self.detectors: + detected_version = detector.detect_version(snapshot) + if detected_version is None: + continue + if detected_version < 0: + raise DatabaseMigrationVersionError( + f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}" + ) + return ResolvedSchemaVersion( + version=detected_version, + source=SchemaVersionSource.DETECTOR, + detector_name=detector.name, + snapshot=snapshot, + ) + + raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。") diff --git a/src/common/database/migrations/schema.py b/src/common/database/migrations/schema.py new file mode 100644 index 00000000..150b8cb7 --- /dev/null +++ b/src/common/database/migrations/schema.py @@ -0,0 +1,98 @@ +"""SQLite 数据库结构探测工具。""" + +from typing import Dict, List + +from sqlalchemy import text +from sqlalchemy.engine import Connection + +from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema + + +class SQLiteSchemaInspector: + """SQLite 数据库结构探测器。""" + + def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot: + """提取数据库中的全部用户表结构快照。 + + Args: + connection: 当前数据库连接。 + + Returns: + DatabaseSchemaSnapshot: 当前数据库结构快照。 + """ + tables: Dict[str, TableSchema] = {} + for table_name in self.list_user_tables(connection): + table_schema = self.get_table_schema(connection, table_name) + tables[table_name] = table_schema + return DatabaseSchemaSnapshot(tables=tables) + + def list_user_tables(self, connection: Connection) -> List[str]: + """列出数据库中的全部用户表。 + + Args: + connection: 当前数据库连接。 + + Returns: + List[str]: 按字母顺序排列的用户表名列表。 + """ + statement = text( + """ + SELECT name + FROM sqlite_master + WHERE type = 'table' + AND name NOT LIKE 'sqlite_%' + ORDER BY name + """ + ) + rows = connection.execute(statement).fetchall() + return [str(row[0]) for row in rows] + + def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema: + """获取指定表的结构信息。 + + Args: + connection: 当前数据库连接。 + table_name: 待读取结构的表名。 + + Returns: + TableSchema: 指定表的结构快照。 + """ + quoted_table_name = self._quote_identifier(table_name) + rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all() + + columns: Dict[str, ColumnSchema] = {} + for row in rows: + column_schema = ColumnSchema( + name=str(row["name"]), + declared_type=str(row["type"] or ""), + default_value=None if row["dflt_value"] is None else str(row["dflt_value"]), + is_not_null=bool(row["notnull"]), + primary_key_position=int(row["pk"]), + ) + columns[column_schema.name] = column_schema + + return TableSchema(name=table_name, columns=columns) + + def table_exists(self, connection: Connection, table_name: str) -> bool: + """判断数据库中是否存在指定表。 + + Args: + connection: 当前数据库连接。 + table_name: 待检查的表名。 + + Returns: + bool: 若表存在则返回 ``True``。 + """ + return table_name in self.list_user_tables(connection) + + def _quote_identifier(self, identifier: str) -> str: + """为 SQLite 标识符添加安全引号。 + + Args: + identifier: 待引用的 SQLite 标识符。 + + Returns: + str: 可直接拼接到 PRAGMA 语句中的安全标识符。 + """ + escaped_identifier = identifier.replace('"', '""') + return f'"{escaped_identifier}"' diff --git a/src/common/database/migrations/version_store.py b/src/common/database/migrations/version_store.py new file mode 100644 index 00000000..ea1e5077 --- /dev/null +++ b/src/common/database/migrations/version_store.py @@ -0,0 +1,57 @@ +"""SQLite 数据库版本存储实现。""" + +from sqlalchemy.engine import Connection + +from .exceptions import DatabaseMigrationVersionError + + +class SQLiteUserVersionStore: + """基于 ``PRAGMA user_version`` 的 SQLite 版本存储器。""" + + def read_version(self, connection: Connection) -> int: + """读取当前数据库的 schema 版本号。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 数据库记录的 schema 版本号。 + + Raises: + DatabaseMigrationVersionError: 当读取结果异常或版本号非法时抛出。 + """ + row = connection.exec_driver_sql("PRAGMA user_version").first() + if row is None or len(row) == 0: + raise DatabaseMigrationVersionError("读取 SQLite user_version 失败,返回结果为空。") + + version = row[0] + if not isinstance(version, int): + raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不是整数: {version!r}") + if version < 0: + raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不能为负数: {version}") + return version + + def write_version(self, connection: Connection, version: int) -> None: + """写入新的 schema 版本号。 + + Args: + connection: 当前数据库连接。 + version: 待写入的 schema 版本号。 + + Raises: + DatabaseMigrationVersionError: 当版本号非法时抛出。 + """ + self._validate_version(version) + connection.exec_driver_sql(f"PRAGMA user_version = {version}") + + def _validate_version(self, version: int) -> None: + """校验版本号是否合法。 + + Args: + version: 待校验的版本号。 + + Raises: + DatabaseMigrationVersionError: 当版本号非法时抛出。 + """ + if version < 0: + raise DatabaseMigrationVersionError(f"SQLite user_version 不能小于 0: {version}") diff --git a/src/common/i18n/manager.py b/src/common/i18n/manager.py index 4d6a1cc8..938d7ef6 100644 --- a/src/common/i18n/manager.py +++ b/src/common/i18n/manager.py @@ -46,9 +46,7 @@ class I18nManager: self._log_once( ("invalid_env_locale", "env", env_locale), logging.WARNING, - "检测到非法 MAIBOT_LOCALE=%s,已回退到默认 locale %s", - env_locale, - self._default_locale, + f"检测到非法 MAIBOT_LOCALE={env_locale},已回退到默认 locale {self._default_locale}", ) return self._default_locale @@ -84,15 +82,14 @@ class I18nManager: self._log_once( ("non_plural_key", translation_locale, key), logging.WARNING, - "翻译 key '%s' 不是 plural 节点,已回退到普通 t()", - key, + f"翻译 key '{key}' 不是 plural 节点,已回退到普通 t()", ) return self.t(key, locale=translation_locale, count=count, **kwargs) try: plural_category = select_plural_category(translation_locale, count) except Exception as exc: - logger.warning("为 key '%s' 选择 plural category 失败: %s,已回退到 other", key, exc) + logger.warning(f"为 key '{key}' 选择 plural category 失败: {exc},已回退到 other") plural_category = "other" template = translation_value.get(plural_category) or translation_value.get("other") @@ -100,8 +97,7 @@ class I18nManager: self._log_once( ("plural_missing_template", translation_locale, key), logging.WARNING, - "翻译 key '%s' 缺少 plural 模板,已回退到 key 本身", - key, + f"翻译 key '{key}' 缺少 plural 模板,已回退到 key 本身", ) return key @@ -125,8 +121,7 @@ class I18nManager: self._log_once( ("plural_missing_other", translation_locale, key), logging.WARNING, - "翻译 key '%s' 缺少 other plural category,已回退到 key 本身", - key, + f"翻译 key '{key}' 缺少 other plural category,已回退到 key 本身", ) return template @@ -134,7 +129,7 @@ class I18nManager: try: return format_template(template, **kwargs) except Exception as exc: - logger.error("翻译 key '%s' 格式化失败: %s", key, exc) + logger.error(f"翻译 key '{key}' 格式化失败: {exc}") return template def _get_translation_value(self, key: str, locale: str | None) -> tuple[TranslationValue | None, str]: @@ -149,20 +144,15 @@ class I18nManager: self._log_once( ("missing_key_fallback", target_locale, key), logging.WARNING, - "翻译 key '%s' 在 locale '%s' 中缺失,已回退到默认 locale '%s'", - key, - target_locale, - self._default_locale, + f"翻译 key '{key}' 在 locale '{target_locale}' 中缺失," + f"已回退到默认 locale '{self._default_locale}'", ) return default_catalog[key], self._default_locale self._log_once( ("missing_key", target_locale, key), logging.WARNING, - "翻译 key '%s' 缺失,locale='%s',默认 locale='%s'", - key, - target_locale, - self._default_locale, + f"翻译 key '{key}' 缺失,locale='{target_locale}',默认 locale='{self._default_locale}'", ) return None, target_locale @@ -177,9 +167,7 @@ class I18nManager: self._log_once( ("invalid_locale", "explicit", locale), logging.WARNING, - "检测到非法 locale='%s',已回退到当前默认 locale %s", - locale, - current_locale, + f"检测到非法 locale='{locale}',已回退到当前默认 locale {current_locale}", ) return current_locale @@ -195,9 +183,7 @@ class I18nManager: self._log_once( ("load_failed", normalized_locale, exc.__class__.__name__), logging.WARNING, - "加载 locale '%s' 失败: %s", - normalized_locale, - exc, + f"加载 locale '{normalized_locale}' 失败: {exc}", ) return {} diff --git a/src/common/prompt_i18n.py b/src/common/prompt_i18n.py index 46b6d70b..358833d1 100644 --- a/src/common/prompt_i18n.py +++ b/src/common/prompt_i18n.py @@ -170,7 +170,7 @@ def _format_prompt_template(name: str, template: str, **kwargs: object) -> str: error = KeyError(t("prompt.missing_placeholder", name=name, placeholder=missing_placeholder)) if is_strict_prompt_i18n_mode(): raise error from exc - logger.error("%s", error) + logger.error(f"{error}") return template except Exception as exc: logger.error(t("prompt.format_failed", name=name, error=exc)) diff --git a/src/common/utils/utils_action.py b/src/common/utils/utils_action.py index c1fe7c28..382957c8 100644 --- a/src/common/utils/utils_action.py +++ b/src/common/utils/utils_action.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING, List from src.common.utils.math_utils import translate_timestamp_to_human_readable, TimestampMode if TYPE_CHECKING: - from src.common.data_models.action_record_data_model import MaiActionRecord + from src.common.data_models.tool_record_data_model import MaiToolRecord class ActionUtils: @staticmethod - def build_readable_action_records(action_records: List["MaiActionRecord"], timestamp_mode: str | TimestampMode): + def build_readable_action_records(action_records: List["MaiToolRecord"], timestamp_mode: str | TimestampMode): """ 将动作列表转换为可读的文本格式。 @@ -27,6 +27,6 @@ class ActionUtils: output_lines = [] for record in action_records: timestamp_str = translate_timestamp_to_human_readable(record.timestamp.timestamp(), mode=timestamp_mode) - line = f"在{timestamp_str},你使用了{record.action_name},具体内容是:{record.action_display_prompt}" + line = f"在{timestamp_str},你使用了{record.tool_name},具体内容是:{record.tool_display_prompt}" output_lines.append(line) return "\n".join(output_lines) diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 6b3b5f4e..e1db1d29 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -579,26 +579,26 @@ class MessageUtils: List[Tuple[float, str]]: 按时间排序的动作文本列表,每个元素为 (timestamp, action_text) """ from src.common.database.database import get_db_session - from src.common.database.database_model import ActionRecord + from src.common.database.database_model import ToolRecord # 获取这个时间范围内的动作记录,并匹配session_id try: with get_db_session() as session: actions_in_range = session.exec( - select(ActionRecord) - .where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time)) - .where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time)) - .where(col(ActionRecord.session_id) == session_id) - .order_by(col(ActionRecord.timestamp)) + select(ToolRecord) + .where(col(ToolRecord.timestamp) >= datetime.fromtimestamp(min_time)) + .where(col(ToolRecord.timestamp) <= datetime.fromtimestamp(max_time)) + .where(col(ToolRecord.session_id) == session_id) + .order_by(col(ToolRecord.timestamp)) ).all() # 获取最新消息之后的第一个动作记录 with get_db_session() as session: action_after_latest = session.exec( - select(ActionRecord) - .where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time)) - .where(col(ActionRecord.session_id) == session_id) - .order_by(col(ActionRecord.timestamp)) + select(ToolRecord) + .where(col(ToolRecord.timestamp) > datetime.fromtimestamp(max_time)) + .where(col(ToolRecord.session_id) == session_id) + .order_by(col(ToolRecord.timestamp)) .limit(1) ).all() except Exception as e: @@ -611,7 +611,7 @@ class MessageUtils: # 构建动作文本列表 action_messages: List[Tuple[float, str]] = [] for action in actions: - if action_display_prompt := action.action_display_prompt or "": + if action_display_prompt := action.tool_display_prompt or "": action_time = action.timestamp.timestamp() action_messages.append((action_time, action_display_prompt)) diff --git a/src/common/utils/utils_voice.py b/src/common/utils/utils_voice.py index 651febf0..cef30119 100644 --- a/src/common/utils/utils_voice.py +++ b/src/common/utils/utils_voice.py @@ -4,16 +4,15 @@ from typing import Optional import base64 from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient install(extra_lines=3) logger = get_logger("voice_utils") -# TODO: 在LLMRequest重构后修改这里 -asr_model = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio") +asr_model = LLMServiceClient(task_name="voice", request_type="audio") async def get_voice_text(voice_bytes: bytes) -> Optional[str]: @@ -30,7 +29,8 @@ async def get_voice_text(voice_bytes: bytes) -> Optional[str]: return None try: voice_base64 = base64.b64encode(voice_bytes).decode("utf-8") - text = await asr_model.generate_response_for_voice(voice_base64) + transcription_result = await asr_model.transcribe_audio(voice_base64) + text = transcription_result.text if not text: logger.warning("语音转文字结果为空") diff --git a/src/config/config.py b/src/config/config.py index bce391c3..318c987f 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,6 +1,6 @@ from datetime import datetime from pathlib import Path -from typing import Any, Callable, Mapping, Sequence, TypeVar +from typing import Any, Callable, Mapping, Sequence, TypeVar, cast import asyncio import copy @@ -27,6 +27,7 @@ from .official_configs import ( LPMMKnowledgeConfig, MaiSakaConfig, MaimMessageConfig, + MCPConfig, PluginRuntimeConfig, MemoryConfig, MessageReceiveConfig, @@ -56,8 +57,8 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.1.4" -MODEL_CONFIG_VERSION: str = "1.12.0" +CONFIG_VERSION: str = "8.2.0" +MODEL_CONFIG_VERSION: str = "1.13.1" logger = get_logger("config") @@ -134,6 +135,9 @@ class Config(ConfigBase): maisaka: MaiSakaConfig = Field(default_factory=MaiSakaConfig) """MaiSaka对话系统配置类""" + mcp: MCPConfig = Field(default_factory=MCPConfig) + """MCP 配置类""" + plugin_runtime: PluginRuntimeConfig = Field(default_factory=PluginRuntimeConfig) """插件运行时配置类""" @@ -332,7 +336,12 @@ class ConfigManager: changed_scopes: 本次热重载命中的配置范围。 """ - result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback() + if self._callback_accepts_scopes(callback): + callback_with_scopes = cast(Callable[[Sequence[str]], object], callback) + result = callback_with_scopes(changed_scopes) + else: + callback_without_scopes = cast(Callable[[], object], callback) + result = callback_without_scopes() if asyncio.iscoroutine(result): await result diff --git a/src/config/model_configs.py b/src/config/model_configs.py index 6f10ff83..a501be66 100644 --- a/src/config/model_configs.py +++ b/src/config/model_configs.py @@ -1,7 +1,35 @@ +from enum import Enum from typing import Any -from .config_base import ConfigBase, Field from src.common.i18n import t +from .config_base import ConfigBase, Field + + +class OpenAICompatibleAuthType(str, Enum): + """OpenAI 兼容接口的鉴权方式。""" + + BEARER = "bearer" + HEADER = "header" + QUERY = "query" + NONE = "none" + + +class ReasoningParseMode(str, Enum): + """推理内容解析策略。""" + + AUTO = "auto" + NATIVE = "native" + THINK_TAG = "think_tag" + NONE = "none" + + +class ToolArgumentParseMode(str, Enum): + """工具调用参数的解析策略。""" + + AUTO = "auto" + STRICT = "strict" + REPAIR = "repair" + DOUBLE_DECODE = "double_decode" class APIProvider(ConfigBase): @@ -33,7 +61,7 @@ class APIProvider(ConfigBase): "x-icon": "key", }, ) - """API密钥""" + """API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。""" client_type: str = Field( default="openai", @@ -44,6 +72,105 @@ class APIProvider(ConfigBase): ) """客户端类型 (可选: openai/google, 默认为openai)""" + auth_type: str = Field( + default=OpenAICompatibleAuthType.BEARER.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "shield", + }, + ) + """OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。""" + + auth_header_name: str = Field( + default="Authorization", + json_schema_extra={ + "x-widget": "input", + "x-icon": "header", + }, + ) + """当 `auth_type` 为 `header` 时使用的请求头名称。""" + + auth_header_prefix: str = Field( + default="Bearer", + json_schema_extra={ + "x-widget": "input", + "x-icon": "shield-check", + }, + ) + """当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。""" + + auth_query_name: str = Field( + default="api_key", + json_schema_extra={ + "x-widget": "input", + "x-icon": "link", + }, + ) + """当 `auth_type` 为 `query` 时使用的查询参数名称。""" + + default_headers: dict[str, str] = Field( + default_factory=dict, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "header", + }, + ) + """所有请求默认附带的 HTTP Header。""" + + default_query: dict[str, str] = Field( + default_factory=dict, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "list-filter", + }, + ) + """所有请求默认附带的查询参数。""" + + organization: str | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "building-2", + }, + ) + """OpenAI 官方接口可选的 `organization`。""" + + project: str | None = Field( + default=None, + json_schema_extra={ + "x-widget": "input", + "x-icon": "folder-kanban", + }, + ) + """OpenAI 官方接口可选的 `project`。""" + + model_list_endpoint: str = Field( + default="/models", + json_schema_extra={ + "x-widget": "input", + "x-icon": "list", + }, + ) + """模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。""" + + reasoning_parse_mode: str = Field( + default=ReasoningParseMode.AUTO.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "brain", + }, + ) + """推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。""" + + tool_argument_parse_mode: str = Field( + default=ToolArgumentParseMode.AUTO.value, + json_schema_extra={ + "x-widget": "select", + "x-icon": "braces", + }, + ) + """工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。""" + max_retry: int = Field( default=2, ge=0, @@ -76,15 +203,26 @@ class APIProvider(ConfigBase): ) """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" - def model_post_init(self, context: Any = None): - """确保api_key在repr中不被显示""" - if not self.api_key: + def model_post_init(self, context: Any = None) -> None: + """执行 API 提供商配置的后置校验。 + + Args: + context: Pydantic 传入的上下文对象。 + + Raises: + ValueError: 当配置项缺失或组合不合法时抛出。 + """ + if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key: raise ValueError(t("config.api_key_empty")) if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url raise ValueError(t("config.api_base_url_empty")) if not self.name: raise ValueError(t("config.api_provider_name_empty")) - return super().model_post_init(context) + if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip(): + raise ValueError("当 auth_type=header 时,auth_header_name 不能为空") + if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip(): + raise ValueError("当 auth_type=query 时,auth_query_name 不能为空") + super().model_post_init(context) class ModelInfo(ConfigBase): @@ -264,6 +402,15 @@ class ModelTaskConfig(ConfigBase): }, ) """首要回复模型配置, 还用于表达器和表达方式学习""" + + planner: TaskConfig = Field( + default_factory=TaskConfig, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "map", + }, + ) + """规划模型配置""" vlm: TaskConfig = Field( default_factory=TaskConfig, @@ -283,24 +430,6 @@ class ModelTaskConfig(ConfigBase): ) """语音识别模型配置""" - tool_use: TaskConfig = Field( - default_factory=TaskConfig, - json_schema_extra={ - "x-widget": "custom", - "x-icon": "tools", - }, - ) - """工具使用模型配置, 需要使用支持工具调用的模型""" - - planner: TaskConfig = Field( - default_factory=TaskConfig, - json_schema_extra={ - "x-widget": "custom", - "x-icon": "map", - }, - ) - """规划模型配置""" - embedding: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ @@ -308,22 +437,4 @@ class ModelTaskConfig(ConfigBase): "x-icon": "database", }, ) - """嵌入模型配置""" - - lpmm_entity_extract: TaskConfig = Field( - default_factory=TaskConfig, - json_schema_extra={ - "x-widget": "custom", - "x-icon": "filter", - }, - ) - """LPMM实体提取模型配置""" - - lpmm_rdf_build: TaskConfig = Field( - default_factory=TaskConfig, - json_schema_extra={ - "x-widget": "custom", - "x-icon": "network", - }, - ) - """LPMM RDF构建模型配置""" + """嵌入模型配置""" \ No newline at end of file diff --git a/src/config/official_configs.py b/src/config/official_configs.py index a7470fb3..e72abf49 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,6 +1,8 @@ -from .config_base import ConfigBase, Field +from typing import Literal, Optional + import re -from typing import Optional, Literal + +from .config_base import ConfigBase, Field """ 须知: @@ -234,17 +236,6 @@ class ChatConfig(ConfigBase): ) """上下文长度""" - planner_smooth: float = Field( - default=3, - ge=0, - json_schema_extra={ - "x-widget": "slider", - "x-icon": "gauge", - "step": 0.5, - }, - ) - """规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0""" - think_mode: Literal["classic", "deep", "dynamic"] = Field( default="dynamic", json_schema_extra={ @@ -677,21 +668,6 @@ class ExpressionConfig(ConfigBase): ) """是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)""" - jargon_mode: Literal["context", "planner"] = Field( - default="planner", - json_schema_extra={ - "x-widget": "select", - "x-icon": "settings", - }, - ) - """ - 黑话解释来源模式 - - 可选: - - "context":使用上下文自动匹配黑话 - - "planner":仅使用Planner在reply动作中给出的unknown_words列表 - """ - class ToolConfig(ConfigBase): """工具配置类""" @@ -1528,33 +1504,6 @@ class MaiSakaConfig(ConfigBase): __ui_icon__ = "message-circle" __ui_parent__ = "experimental" - enable_emotion_module: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "heart", - }, - ) - """启用情绪感知模块""" - - enable_cognition_module: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "brain", - }, - ) - """启用认知分析模块""" - - enable_timing_module: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "clock", - }, - ) - """启用时间感知模块(含自我反思功能)""" - enable_knowledge_module: bool = Field( default=True, json_schema_extra={ @@ -1564,42 +1513,6 @@ class MaiSakaConfig(ConfigBase): ) """启用知识库模块""" - enable_mcp: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "zap", - }, - ) - """启用 MCP (Model Context Protocol) 支持""" - - enable_write_file: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "file-plus", - }, - ) - """启用文件写入工具""" - - enable_read_file: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "file-text", - }, - ) - """启用文件读取工具""" - - enable_list_files: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "list", - }, - ) - """启用文件列表工具""" - show_analyze_cognition_prompt: bool = Field( default=False, json_schema_extra={ @@ -1609,15 +1522,6 @@ class MaiSakaConfig(ConfigBase): ) """是否在 CLI 中显示 analyze_cognition 的 Prompt""" - show_analyze_timing_prompt: bool = Field( - default=False, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "terminal", - }, - ) - """是否在 CLI 中显示 analyze_timing 的 Prompt""" - show_thinking: bool = Field( default=True, json_schema_extra={ @@ -1625,7 +1529,7 @@ class MaiSakaConfig(ConfigBase): "x-icon": "brain", }, ) - """鏄惁鍦?CLI 涓樉绀哄唴蹇冩€濊€冨拰瀹屾暣 Prompt""" + """是否显示MaiSaka思考过程""" user_name: str = Field( default="用户", @@ -1634,7 +1538,465 @@ class MaiSakaConfig(ConfigBase): "x-icon": "user", }, ) - """MaiSaka 涓敤鎴风殑鏄剧ず鍚嶇О""" + """MaiSaka 使用的用户名称""" + + direct_image_input: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "image", + }, + ) + """是否直接输入图片""" + + merge_user_messages: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "merge", + }, + ) + """是否将新接收的用户发言合并为单个用户消息""" + + max_internal_rounds: int = Field( + default=6, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "repeat", + }, + ) + """每个入站消息的最大内部规划轮数""" + + tool_filter_task_name: str = Field( + default="utils", + json_schema_extra={ + "x-widget": "input", + "x-icon": "sparkles", + }, + ) + """工具筛选预判使用的模型任务名""" + + tool_filter_threshold: int = Field( + default=20, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "filter", + }, + ) + """当可用工具总数超过该阈值时,先进行一轮工具筛选""" + + tool_filter_max_keep: int = Field( + default=5, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "list-filter", + }, + ) + """工具筛选阶段最多保留的非内置工具数量""" + + terminal_image_preview: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "image", + }, + ) + """是否渲染低分辨率终端预览图片""" + + terminal_image_preview_width: int = Field( + default=24, + ge=8, + json_schema_extra={ + "x-widget": "input", + "x-icon": "columns", + }, + ) + """Maisaka终端图片预览的字符宽度""" + + +class MCPAuthorizationConfig(ConfigBase): + """MCP HTTP 认证配置。""" + + mode: Literal["none", "bearer"] = Field( + default="none", + json_schema_extra={ + "x-widget": "select", + "x-icon": "shield", + }, + ) + """认证模式,当前支持无认证和静态 Bearer Token""" + + bearer_token: str = Field( + default="", + json_schema_extra={ + "x-widget": "password", + "x-icon": "key", + }, + ) + """静态 Bearer Token,仅在 `mode=\"bearer\"` 时使用""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 MCP 认证配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.mode == "bearer" and not self.bearer_token.strip(): + raise ValueError("MCP 使用 bearer 认证时必须填写 bearer_token") + return super().model_post_init(context) + + +class MCPRootItemConfig(ConfigBase): + """单个 MCP Root 配置。""" + + enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "power", + }, + ) + """是否启用当前 Root""" + + uri: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "folder", + }, + ) + """Root URI,通常为 `file://` 路径 URI""" + + name: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "tag", + }, + ) + """Root 的显示名称""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证单个 Root 配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.enabled and not self.uri.strip(): + raise ValueError("启用的 MCP Root 必须填写 uri") + return super().model_post_init(context) + + +class MCPRootsConfig(ConfigBase): + """MCP Roots 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "folder-tree", + }, + ) + """是否向 MCP 服务器暴露 Roots 能力""" + + items: list[MCPRootItemConfig] = Field( + default_factory=lambda: [], + json_schema_extra={ + "x-widget": "custom", + "x-icon": "folder", + }, + ) + """Roots 列表""" + + +class MCPSamplingConfig(ConfigBase): + """MCP Sampling 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "brain", + }, + ) + """是否启用 Sampling 能力声明""" + + task_name: str = Field( + default="planner", + json_schema_extra={ + "x-widget": "input", + "x-icon": "sparkles", + }, + ) + """执行 Sampling 请求时使用的主程序模型任务名""" + + include_context_support: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "layers", + }, + ) + """是否声明支持 `includeContext` 非 `none` 语义""" + + tool_support: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "wrench", + }, + ) + """是否声明支持在 Sampling 中继续使用工具""" + + +class MCPElicitationConfig(ConfigBase): + """MCP Elicitation 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "message-circle-question", + }, + ) + """是否启用 Elicitation 能力声明""" + + allow_form: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "form-input", + }, + ) + """是否允许表单模式 Elicitation""" + + allow_url: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "link", + }, + ) + """是否允许 URL 模式 Elicitation""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 Elicitation 配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.enable and not (self.allow_form or self.allow_url): + raise ValueError("启用 MCP Elicitation 时至少需要允许一种模式") + return super().model_post_init(context) + + +class MCPClientConfig(ConfigBase): + """MCP 客户端宿主能力配置。""" + + client_name: str = Field( + default="MaiBot", + json_schema_extra={ + "x-widget": "input", + "x-icon": "bot", + }, + ) + """MCP 客户端实现名称""" + + client_version: str = Field( + default="1.0.0", + json_schema_extra={ + "x-widget": "input", + "x-icon": "info", + }, + ) + """MCP 客户端实现版本""" + + roots: MCPRootsConfig = Field(default_factory=MCPRootsConfig) + """Roots 能力配置""" + + sampling: MCPSamplingConfig = Field(default_factory=MCPSamplingConfig) + """Sampling 能力配置""" + + elicitation: MCPElicitationConfig = Field(default_factory=MCPElicitationConfig) + """Elicitation 能力配置""" + + +class MCPServerItemConfig(ConfigBase): + """单个 MCP 服务器配置。""" + + name: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "tag", + }, + ) + """服务器名称,必须唯一""" + + enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "power", + }, + ) + """是否启用当前 MCP 服务器""" + + transport: Literal["stdio", "streamable_http"] = Field( + default="stdio", + json_schema_extra={ + "x-widget": "select", + "x-icon": "shuffle", + }, + ) + """传输方式,可选 `stdio` 或 `streamable_http`""" + + command: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "terminal", + }, + ) + """stdio 模式下启动服务器的命令""" + + args: list[str] = Field( + default_factory=lambda: [], + json_schema_extra={ + "x-widget": "custom", + "x-icon": "list", + }, + ) + """stdio 模式下的命令参数列表""" + + env: dict[str, str] = Field( + default_factory=lambda: {}, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "variable", + }, + ) + """stdio 模式下附加的环境变量""" + + url: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "link", + }, + ) + """`streamable_http` 模式下的 MCP 端点地址""" + + headers: dict[str, str] = Field( + default_factory=lambda: {}, + json_schema_extra={ + "x-widget": "custom", + "x-icon": "file-json", + }, + ) + """HTTP 模式下附加的请求头""" + + http_timeout_seconds: float = Field( + default=30.0, + gt=0, + json_schema_extra={ + "x-widget": "number", + "x-icon": "clock-3", + }, + ) + """HTTP 请求超时时间,单位秒""" + + read_timeout_seconds: float = Field( + default=300.0, + gt=0, + json_schema_extra={ + "x-widget": "number", + "x-icon": "timer", + }, + ) + """会话读取超时时间,单位秒""" + + authorization: MCPAuthorizationConfig = Field(default_factory=MCPAuthorizationConfig) + """HTTP 认证配置""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 MCP 服务器配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if not self.name.strip(): + raise ValueError("MCPServerItemConfig.name 不能为空") + + if self.transport == "stdio" and not self.command.strip(): + raise ValueError(f"MCP 服务器 {self.name} 使用 stdio 时必须填写 command") + + if self.transport == "streamable_http" and not self.url.strip(): + raise ValueError(f"MCP 服务器 {self.name} 使用 streamable_http 时必须填写 url") + + return super().model_post_init(context) + + +class MCPConfig(ConfigBase): + """MCP 总配置。""" + + __ui_parent__ = "maisaka" + + enable: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "zap", + }, + ) + """是否启用 MCP(Model Context Protocol)""" + + client: MCPClientConfig = Field(default_factory=MCPClientConfig) + """MCP 客户端宿主能力配置""" + + servers: list[MCPServerItemConfig] = Field( + default_factory=lambda: [], + json_schema_extra={ + "x-widget": "custom", + "x-icon": "server", + }, + ) + """_wrap_MCP 服务器配置列表""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 MCP 总配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + server_names = [server.name.strip() for server in self.servers if server.name.strip()] + if len(server_names) != len(set(server_names)): + raise ValueError("MCP 配置中的服务器名称不能重复") + return super().model_post_init(context) + class PluginRuntimeConfig(ConfigBase): """插件运行时配置类""" diff --git a/src/core/tooling.py b/src/core/tooling.py new file mode 100644 index 00000000..f9c6ec62 --- /dev/null +++ b/src/core/tooling.py @@ -0,0 +1,404 @@ +"""统一工具抽象。 + +该模块定义主程序内部统一使用的工具声明、调用与执行结果模型, +用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。 +""" + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field +import json +from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable + +from src.common.logger import get_logger +from src.llm_models.payload_content.tool_option import ToolDefinitionInput + +logger = get_logger("core.tooling") + + +def _normalize_schema_type(raw_type: Any) -> str: + """将原始 Schema 类型值规范化为可读字符串。 + + Args: + raw_type: 原始类型值。 + + Returns: + str: 规范化后的类型名称。 + """ + + normalized_type = str(raw_type or "").strip().lower() + if not normalized_type: + return "string" + if normalized_type == "number": + return "number" + if normalized_type == "integer": + return "integer" + if normalized_type == "boolean": + return "boolean" + if normalized_type == "array": + return "array" + if normalized_type == "object": + return "object" + return normalized_type + + +def build_tool_detailed_description( + parameters_schema: Optional[Dict[str, Any]], + fallback_description: str = "", +) -> str: + """根据参数 Schema 构建工具详细描述。 + + Args: + parameters_schema: 工具参数对象级 Schema。 + fallback_description: 无法从 Schema 解析时使用的兜底说明。 + + Returns: + str: 生成后的详细描述文本。 + """ + + if not parameters_schema: + return fallback_description.strip() + + properties = parameters_schema.get("properties") + if not isinstance(properties, dict) or not properties: + return fallback_description.strip() + + required_names = { + str(name).strip() + for name in parameters_schema.get("required", []) + if str(name).strip() + } + + lines = ["参数说明:"] + for parameter_name, parameter_schema in properties.items(): + if not isinstance(parameter_schema, dict): + continue + + normalized_name = str(parameter_name).strip() + parameter_type = _normalize_schema_type(parameter_schema.get("type")) + required_text = "必填" if normalized_name in required_names else "可选" + parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明" + line = f"- {normalized_name}:{parameter_type},{required_text}。{parameter_description}" + + if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]: + enum_values = "、".join(str(item) for item in parameter_schema["enum"]) + line += f" 可选值:{enum_values}。" + + if "default" in parameter_schema: + line += f" 默认值:{parameter_schema['default']}。" + + lines.append(line) + + if len(lines) == 1: + return fallback_description.strip() + + if fallback_description.strip(): + lines.append("") + lines.append(fallback_description.strip()) + return "\n".join(lines).strip() + + +@dataclass(slots=True) +class ToolIcon: + """统一工具图标信息。""" + + src: str + mime_type: str = "" + sizes: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class ToolAnnotation: + """统一工具注解信息。""" + + audience: list[str] = field(default_factory=list) + priority: float | None = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class ToolContentItem: + """统一工具内容项。""" + + content_type: Literal["text", "image", "audio", "resource_link", "resource", "binary", "unknown"] + text: str = "" + data: str = "" + mime_type: str = "" + uri: str = "" + name: str = "" + description: str = "" + annotation: ToolAnnotation | None = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def build_history_text(self) -> str: + """生成适合写入历史消息的文本摘要。 + + Returns: + str: 当前内容项对应的历史摘要文本。 + """ + + if self.content_type == "text" and self.text.strip(): + return self.text.strip() + if self.content_type == "image": + return f"[图片内容 {self.mime_type or 'unknown'}]" + if self.content_type == "audio": + return f"[音频内容 {self.mime_type or 'unknown'}]" + if self.content_type == "resource_link": + label = self.name or self.uri or "资源链接" + return f"[资源链接] {label}" + if self.content_type == "resource": + if self.text.strip(): + return self.text.strip() + label = self.name or self.uri or "嵌入资源" + return f"[嵌入资源] {label}" + if self.content_type == "binary": + return f"[二进制内容 {self.mime_type or 'unknown'}]" + return f"[{self.content_type} 内容]" + + +@dataclass(slots=True) +class ToolSpec: + """统一工具声明。""" + + name: str + brief_description: str + detailed_description: str = "" + title: str = "" + parameters_schema: Dict[str, Any] | None = None + output_schema: Dict[str, Any] | None = None + provider_name: str = "" + provider_type: str = "" + enabled: bool = True + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def build_llm_description(self) -> str: + """构建供 LLM 使用的描述文本。 + + Returns: + str: 合并后的单段工具描述。 + """ + + parts = [self.brief_description.strip()] + if self.detailed_description.strip(): + parts.append(self.detailed_description.strip()) + return "\n\n".join(part for part in parts if part).strip() + + def to_llm_definition(self) -> ToolDefinitionInput: + """转换为统一的 LLM 工具定义。 + + Returns: + ToolDefinitionInput: 可直接交给模型层的工具定义。 + """ + + definition: Dict[str, Any] = { + "name": self.name, + "description": self.build_llm_description(), + } + if self.parameters_schema is not None: + definition["parameters_schema"] = deepcopy(self.parameters_schema) + return definition + + +@dataclass(slots=True) +class ToolInvocation: + """统一工具调用请求。""" + + tool_name: str + arguments: Dict[str, Any] = field(default_factory=dict) + call_id: str = "" + session_id: str = "" + stream_id: str = "" + reasoning: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class ToolExecutionContext: + """统一工具执行上下文。""" + + session_id: str = "" + stream_id: str = "" + reasoning: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class ToolExecutionResult: + """统一工具执行结果。""" + + tool_name: str + success: bool + content: str = "" + error_message: str = "" + structured_content: Any = None + content_items: list[ToolContentItem] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def get_history_content(self) -> str: + """获取适合写入对话历史的结果文本。 + + Returns: + str: 优先使用文本内容,其次使用错误信息。 + """ + + if self.content.strip(): + return self.content.strip() + if self.content_items: + parts = [item.build_history_text() for item in self.content_items if item.build_history_text().strip()] + if parts: + return "\n".join(parts).strip() + if self.structured_content is not None: + if isinstance(self.structured_content, str): + return self.structured_content.strip() + try: + return json.dumps(self.structured_content, ensure_ascii=False) + except (TypeError, ValueError): + return str(self.structured_content).strip() + return self.error_message.strip() + + +@runtime_checkable +class ToolProvider(Protocol): + """统一工具提供者协议。""" + + provider_name: str + provider_type: str + + async def list_tools(self) -> list[ToolSpec]: + """列出当前 Provider 暴露的全部工具。""" + ... + + async def invoke( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行指定工具调用。""" + ... + + async def close(self) -> None: + """释放 Provider 资源。""" + ... + + +class ToolRegistry: + """统一工具注册表。""" + + def __init__(self) -> None: + """初始化统一工具注册表。""" + + self._providers: list[ToolProvider] = [] + + def register_provider(self, provider: ToolProvider) -> None: + """注册一个工具提供者。 + + Args: + provider: 待注册的工具提供者。 + """ + + self._providers = [item for item in self._providers if item.provider_name != provider.provider_name] + self._providers.append(provider) + + def unregister_provider(self, provider_name: str) -> None: + """注销指定名称的工具提供者。 + + Args: + provider_name: 待移除的 Provider 名称。 + """ + + self._providers = [item for item in self._providers if item.provider_name != provider_name] + + async def list_tools(self) -> list[ToolSpec]: + """按 Provider 顺序列出全部去重后的工具。 + + Returns: + list[ToolSpec]: 去重后的工具列表。 + """ + + collected_specs: list[ToolSpec] = [] + seen_names: set[str] = set() + + for provider in self._providers: + provider_specs = await provider.list_tools() + for spec in provider_specs: + if not spec.enabled: + continue + if spec.name in seen_names: + logger.warning( + f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}" + ) + continue + seen_names.add(spec.name) + collected_specs.append(spec) + return collected_specs + + async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]: + """查询指定工具声明。 + + Args: + tool_name: 工具名称。 + + Returns: + Optional[ToolSpec]: 匹配到的工具声明。 + """ + + for spec in await self.list_tools(): + if spec.name == tool_name: + return spec + return None + + async def has_tool(self, tool_name: str) -> bool: + """判断指定工具是否存在。 + + Args: + tool_name: 工具名称。 + + Returns: + bool: 是否存在。 + """ + + return await self.get_tool_spec(tool_name) is not None + + async def get_llm_definitions(self) -> list[ToolDefinitionInput]: + """获取供 LLM 使用的工具定义列表。 + + Returns: + list[ToolDefinitionInput]: 统一工具定义列表。 + """ + + return [spec.to_llm_definition() for spec in await self.list_tools()] + + async def invoke( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行一次工具调用。 + + Args: + invocation: 工具调用请求。 + context: 执行上下文。 + + Returns: + ToolExecutionResult: 工具执行结果。 + """ + + for provider in self._providers: + provider_specs = await provider.list_tools() + if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs): + return await provider.invoke(invocation, context) + + return ToolExecutionResult( + tool_name=invocation.tool_name, + success=False, + error_message=f"未找到工具:{invocation.tool_name}", + ) + + async def close(self) -> None: + """关闭全部 Provider。""" + + for provider in self._providers: + await provider.close() diff --git a/src/core/types.py b/src/core/types.py index 535352f3..aff857a3 100644 --- a/src/core/types.py +++ b/src/core/types.py @@ -1,12 +1,13 @@ -import copy -import warnings from dataclasses import dataclass, field, fields from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional + +import copy +import warnings + from maim_message import Seg -from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType -from src.llm_models.payload_content.tool_option import ToolCall as ToolCall +from src.llm_models.payload_content.tool_option import ToolCall # from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType # from src.common.data_models.message_data_model import ReplyContent as ReplyContent # from src.common.data_models.message_data_model import ForwardNode as ForwardNode @@ -15,49 +16,42 @@ from src.llm_models.payload_content.tool_option import ToolCall as ToolCall # 组件类型枚举 class ComponentType(Enum): - """组件类型枚举""" + """Host 内部使用的组件类型枚举。""" ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 - TOOL = "tool" # 服务组件(预留) - SCHEDULER = "scheduler" # 定时任务组件(预留) - EVENT_HANDLER = "event_handler" # 事件处理组件(预留) + TOOL = "tool" # 工具组件 def __str__(self) -> str: + """返回枚举值字符串。 + + Returns: + str: 当前组件类型对应的字符串值。 + """ return self.value # 动作激活类型枚举 class ActionActivationType(Enum): - """动作激活类型枚举""" + """动作激活类型枚举。""" NEVER = "never" # 从不激活(默认关闭) ALWAYS = "always" # 默认参与到planner RANDOM = "random" # 随机启用action到planner KEYWORD = "keyword" # 关键词触发启用action到planner - def __str__(self): - return self.value + def __str__(self) -> str: + """返回枚举值字符串。 - -# 聊天模式枚举 -class ChatMode(Enum): - """聊天模式枚举""" - - FOCUS = "focus" # Focus聊天模式 - NORMAL = "normal" # Normal聊天模式 - PRIORITY = "priority" # 优先级聊天模式 - ALL = "all" # 所有聊天模式 - - def __str__(self): + Returns: + str: 当前激活类型对应的字符串值。 + """ return self.value # 事件类型枚举 class EventType(Enum): - """ - 事件类型枚举类 - """ + """事件类型枚举。""" ON_START = "on_start" # 启动事件,用于调用按时任务 ON_STOP = "on_stop" # 停止事件,用于调用按时任务 @@ -72,185 +66,96 @@ class EventType(Enum): UNKNOWN = "unknown" # 未知事件类型 def __str__(self) -> str: + """返回枚举值字符串。 + + Returns: + str: 当前事件类型对应的字符串值。 + """ return self.value -@dataclass -class PythonDependency: - """Python包依赖信息""" - - package_name: str # 包名称 - version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本 - optional: bool = False # 是否为可选依赖 - description: str = "" # 依赖描述 - install_name: str = "" # 安装时的包名(如果与import名不同) - - def __post_init__(self): - if not self.install_name: - self.install_name = self.package_name - - def get_pip_requirement(self) -> str: - """获取pip安装格式的依赖字符串""" - if self.version: - return f"{self.install_name}{self.version}" - return self.install_name - - -@dataclass +@dataclass(slots=True) class ComponentInfo: - """组件信息""" + """Host 内部使用的组件信息快照。""" - name: str # 组件名称 - component_type: ComponentType # 组件类型 - description: str = "" # 组件描述 - enabled: bool = True # 是否启用 - plugin_name: str = "" # 所属插件名称 - is_built_in: bool = False # 是否为内置组件 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 + name: str + """组件名称。""" - def __post_init__(self): - if self.metadata is None: - self.metadata = {} + description: str = "" + """组件描述。""" + + enabled: bool = True + """组件是否启用。""" + + plugin_name: str = "" + """所属插件 ID。""" + + component_type: ComponentType = field(init=False) + """组件类型。""" -@dataclass +@dataclass(slots=True) class ActionInfo(ComponentInfo): - """动作组件信息""" + """供 Planner 与回复链使用的动作信息快照。""" action_parameters: Dict[str, str] = field( default_factory=dict ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} action_require: List[str] = field(default_factory=list) # 动作需求说明 associated_types: List[str] = field(default_factory=list) # 关联的消息类型 - # 激活类型相关 - focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用 - normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用 activation_type: ActionActivationType = ActionActivationType.ALWAYS random_activation_probability: float = 0.0 activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表 keyword_case_sensitive: bool = False - # 模式和并行设置 parallel_action: bool = False + component_type: ComponentType = field(init=False, default=ComponentType.ACTION) + """组件类型。""" - def __post_init__(self): - super().__post_init__() - if self.activation_keywords is None: - self.activation_keywords = [] - if self.action_parameters is None: - self.action_parameters = {} - if self.action_require is None: - self.action_require = [] - if self.associated_types is None: - self.associated_types = [] - self.component_type = ComponentType.ACTION + def __post_init__(self) -> None: + """归一化动作快照中的集合字段。""" + self.action_parameters = dict(self.action_parameters or {}) + self.action_require = list(self.action_require or []) + self.associated_types = list(self.associated_types or []) + self.activation_keywords = list(self.activation_keywords or []) -@dataclass +@dataclass(slots=True) class CommandInfo(ComponentInfo): - """命令组件信息""" + """供命令处理链使用的命令信息快照。""" - command_pattern: str = "" # 命令匹配模式(正则表达式) - - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.COMMAND + component_type: ComponentType = field(init=False, default=ComponentType.COMMAND) + """组件类型。""" -@dataclass +@dataclass(slots=True) class ToolInfo(ComponentInfo): - """工具组件信息""" + """供工具执行链使用的工具信息快照。""" - tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field( - default_factory=list - ) # 工具参数定义 - tool_description: str = "" # 工具描述 + parameters_schema: Dict[str, Any] | None = None + """对象级工具参数 Schema。""" - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.TOOL + component_type: ComponentType = field(init=False, default=ComponentType.TOOL) + """组件类型。""" - def get_llm_definition(self) -> dict: - """生成 LLM function-calling 所需的工具定义""" - return { + def get_llm_definition(self) -> Dict[str, Any]: + """生成供 LLM 使用的规范化工具定义。 + + Returns: + Dict[str, Any]: 统一工具定义字典。 + """ + definition: Dict[str, Any] = { "name": self.name, - "description": self.tool_description, - "parameters": self.tool_parameters, + "description": self.description, } + if self.parameters_schema is not None: + definition["parameters_schema"] = copy.deepcopy(self.parameters_schema) + return definition -@dataclass -class EventHandlerInfo(ComponentInfo): - """事件处理器组件信息""" - - event_type: EventType | str = EventType.ON_MESSAGE # 监听事件类型 - intercept_message: bool = False # 是否拦截消息处理(默认不拦截) - weight: int = 0 # 事件处理器权重,决定执行顺序 - - def __post_init__(self): - super().__post_init__() - self.component_type = ComponentType.EVENT_HANDLER - - -@dataclass -class PluginInfo: - """插件信息""" - - display_name: str # 插件显示名称 - name: str # 插件名称 - description: str # 插件描述 - version: str = "1.0.0" # 插件版本 - author: str = "" # 插件作者 - enabled: bool = True # 是否启用 - is_built_in: bool = False # 是否为内置插件 - components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表 - dependencies: List[str] = field(default_factory=list) # 依赖的其他插件 - python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖 - config_file: str = "" # 配置文件路径 - metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据 - # 新增:manifest相关信息 - manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据 - license: str = "" # 插件许可证 - homepage_url: str = "" # 插件主页 - repository_url: str = "" # 插件仓库地址 - keywords: List[str] = field(default_factory=list) # 插件关键词 - categories: List[str] = field(default_factory=list) # 插件分类 - min_host_version: str = "" # 最低主机版本要求 - max_host_version: str = "" # 最高主机版本要求 - - def __post_init__(self): - if self.components is None: - self.components = [] - if self.dependencies is None: - self.dependencies = [] - if self.python_dependencies is None: - self.python_dependencies = [] - if self.metadata is None: - self.metadata = {} - if self.manifest_data is None: - self.manifest_data = {} - if self.keywords is None: - self.keywords = [] - if self.categories is None: - self.categories = [] - - def get_missing_packages(self) -> List[PythonDependency]: - """检查缺失的Python包""" - missing = [] - for dep in self.python_dependencies: - try: - __import__(dep.package_name) - except ImportError: - if not dep.optional: - missing.append(dep) - return missing - - def get_pip_requirements(self) -> List[str]: - """获取所有pip安装格式的依赖""" - return [dep.get_pip_requirement() for dep in self.python_dependencies] - - -@dataclass +@dataclass(slots=True) class ModifyFlag: + """消息修改标记集合。""" + modify_message_segments: bool = False modify_plain_text: bool = False modify_llm_prompt: bool = False @@ -258,9 +163,9 @@ class ModifyFlag: modify_llm_response_reasoning: bool = False -@dataclass +@dataclass(slots=True) class MaiMessages: - """MaiM插件消息""" + """核心事件系统使用的统一消息模型。""" message_segments: List[Seg] = field(default_factory=list) """消息段列表,支持多段消息""" @@ -306,11 +211,17 @@ class MaiMessages: _modify_flags: ModifyFlag = field(default_factory=ModifyFlag) - def __post_init__(self): + def __post_init__(self) -> None: + """归一化消息段列表。""" if self.message_segments is None: self.message_segments = [] - def deepcopy(self): + def deepcopy(self) -> "MaiMessages": + """深拷贝当前消息对象。 + + Returns: + MaiMessages: 深拷贝后的消息对象。 + """ return copy.deepcopy(self) def to_transport_dict(self) -> Dict[str, Any]: @@ -347,6 +258,14 @@ class MaiMessages: @staticmethod def _serialize_transport_value(value: Any) -> Any: + """递归序列化字段值为可传输结构。 + + Args: + value: 任意字段值。 + + Returns: + Any: 可用于 IPC 传输的纯 Python 值。 + """ if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, Enum): @@ -367,13 +286,22 @@ class MaiMessages: @staticmethod def _deserialize_transport_field(field_name: str, value: Any) -> Any: + """反序列化特定字段的传输值。 + + Args: + field_name: 字段名称。 + value: 传输层返回的字段值。 + + Returns: + Any: 反序列化后的字段值。 + """ if field_name == "message_segments" and isinstance(value, list): deserialized_segments: List[Seg] = [] for segment in value: if isinstance(segment, Seg): deserialized_segments.append(segment) elif isinstance(segment, dict) and "type" in segment: - deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data"))) + deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data", ""))) return deserialized_segments if field_name == "llm_response_tool_call" and isinstance(value, list): @@ -393,15 +321,15 @@ class MaiMessages: return value - def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False): - """ - 修改消息段列表 + def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False) -> None: + """修改消息段列表。 Warning: - 在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致 + 在生成了 ``plain_text`` 的情况下调用此方法,可能会导致文本与消息段不一致。 Args: - new_segments (List[Seg]): 新的消息段列表 + new_segments: 新的消息段列表。 + suppress_warning: 是否抑制潜在不一致警告。 """ if self.plain_text and not suppress_warning: warnings.warn( @@ -412,15 +340,15 @@ class MaiMessages: self.message_segments = new_segments self._modify_flags.modify_message_segments = True - def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False): - """ - 修改LLM提示词 + def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False) -> None: + """修改 LLM 提示词。 Warning: - 在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效 + 在没有生成 ``llm_prompt`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_prompt (str): 新的提示词内容 + new_prompt: 新的提示词内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if self.llm_prompt is None and not suppress_warning: warnings.warn( @@ -431,15 +359,15 @@ class MaiMessages: self.llm_prompt = new_prompt self._modify_flags.modify_llm_prompt = True - def modify_plain_text(self, new_text: str, suppress_warning: bool = False): - """ - 修改生成的plain_text内容 + def modify_plain_text(self, new_text: str, suppress_warning: bool = False) -> None: + """修改生成的纯文本内容。 Warning: - 在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效 + 在未生成 ``plain_text`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_text (str): 新的纯文本内容 + new_text: 新的纯文本内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.plain_text and not suppress_warning: warnings.warn( @@ -450,15 +378,15 @@ class MaiMessages: self.plain_text = new_text self._modify_flags.modify_plain_text = True - def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False): - """ - 修改生成的llm_response_content内容 + def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False) -> None: + """修改生成的 LLM 响应正文。 Warning: - 在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效 + 在未生成 ``llm_response_content`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_content (str): 新的LLM响应内容 + new_content: 新的 LLM 响应内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.llm_response_content and not suppress_warning: warnings.warn( @@ -469,15 +397,15 @@ class MaiMessages: self.llm_response_content = new_content self._modify_flags.modify_llm_response_content = True - def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False): - """ - 修改生成的llm_response_reasoning内容 + def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False) -> None: + """修改生成的 LLM 推理内容。 Warning: - 在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效 + 在未生成 ``llm_response_reasoning`` 的情况下调用此方法,可能会导致修改无效。 Args: - new_reasoning (str): 新的LLM响应推理内容 + new_reasoning: 新的 LLM 推理内容。 + suppress_warning: 是否抑制潜在无效修改警告。 """ if not self.llm_response_reasoning and not suppress_warning: warnings.warn( @@ -487,10 +415,3 @@ class MaiMessages: ) self.llm_response_reasoning = new_reasoning self._modify_flags.modify_llm_response_reasoning = True - - -@dataclass -class CustomEventHandlerResult: - message: str = "" - timestamp: float = 0.0 - extra_info: Optional[Dict] = None diff --git a/src/know_u/__init__.py b/src/know_u/__init__.py new file mode 100644 index 00000000..9945120b --- /dev/null +++ b/src/know_u/__init__.py @@ -0,0 +1,3 @@ +""" +Knowledge utilities package for Maisaka. +""" diff --git a/src/know_u/knowledge.py b/src/know_u/knowledge.py new file mode 100644 index 00000000..05573008 --- /dev/null +++ b/src/know_u/knowledge.py @@ -0,0 +1,365 @@ +""" +Maisaka knowledge retrieval and learning helpers. +""" + +from typing import Any, Dict, List + +import asyncio +import json + +from src.chat.message_receive.message import SessionMessage +from src.chat.utils.utils import is_bot_self +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.common.logger import get_logger +from src.know_u.knowledge_store import KNOWLEDGE_CATEGORIES, get_knowledge_store +from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage, ToolResultMessage +from src.maisaka.message_adapter import parse_speaker_content +from src.person_info.person_info import Person +from src.services.llm_service import LLMServiceClient + +logger = get_logger("maisaka_knowledge") + +NO_RESULT_KEYWORDS = [ + "无", + "没有", + "不适用", + "无需", + "无相关", +] + + +def extract_category_ids_from_result(result: str) -> List[str]: + """Extract valid category ids from an LLM result string.""" + if not result: + return [] + + normalized = result.strip() + if not normalized: + return [] + + lowered = normalized.lower() + if any(keyword in lowered for keyword in ["none", "no relevant", "no_need", "no need"]): + return [] + if any(keyword in normalized for keyword in NO_RESULT_KEYWORDS): + return [] + + category_ids: List[str] = [] + for part in normalized.replace(",", " ").replace(",", " ").replace("\n", " ").split(): + candidate = part.strip() + if candidate in KNOWLEDGE_CATEGORIES and candidate not in category_ids: + category_ids.append(candidate) + + return category_ids + + +async def retrieve_relevant_knowledge( + knowledge_analyzer: Any, + chat_history: List[LLMContextMessage], +) -> str: + """Retrieve formatted knowledge snippets relevant to the current chat history.""" + store = get_knowledge_store() + categories_summary = store.get_categories_summary() + + try: + category_ids = await knowledge_analyzer.analyze_knowledge_need(chat_history, categories_summary) + if not category_ids: + return "" + return store.get_formatted_knowledge(category_ids) + except Exception: + logger.exception("检索相关知识失败") + return "" + + +class KnowledgeLearner: + """ + 从最近对话中提取用户画像类知识并写入知识库。 + """ + + def __init__(self, session_id: str) -> None: + self._session_id = session_id + self._store = get_knowledge_store() + self._llm = LLMServiceClient(task_name="utils", request_type="maisaka.knowledge.learn") + self._learning_lock = asyncio.Lock() + self._messages_cache: List[SessionMessage] = [] + + def add_messages(self, messages: List[SessionMessage]) -> None: + """缓存待学习的消息。""" + self._messages_cache.extend(messages) + + def get_cache_size(self) -> int: + """获取缓存消息数量。""" + return len(self._messages_cache) + + async def learn(self) -> int: + """ + 从缓存消息中提取知识并落库。 + + Returns: + 新增入库的知识条数 + """ + if not self._messages_cache: + return 0 + + async with self._learning_lock: + chat_excerpt = self._build_chat_excerpt() + if not chat_excerpt: + return 0 + + prompt = self._build_learning_prompt(chat_excerpt) + try: + result = await self._llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions( + temperature=0.1, + max_tokens=512, + ), + ) + except Exception: + logger.exception("知识学习模型调用失败") + return 0 + + knowledge_items = self._parse_learning_result(result.response or "") + if not knowledge_items: + logger.debug("知识学习已完成,但未提取到有效条目") + return 0 + + added_count = 0 + for item in knowledge_items: + category_id = str(item.get("category_id", "")).strip() + content = str(item.get("content", "")).strip() + if not category_id or not content: + continue + + metadata = { + "session_id": self._session_id, + "source": "maisaka_learning", + } + for field_name in ("platform", "user_id", "user_nickname", "person_name"): + field_value = str(item.get(field_name, "")).strip() + if field_value: + metadata[field_name] = field_value + + if self._store.add_knowledge( + category_id=category_id, + content=content, + metadata=metadata, + ): + added_count += 1 + + if added_count > 0: + logger.info( + f"Maisaka 知识学习已完成: 会话标识={self._session_id} 新增条数={added_count}" + ) + else: + logger.debug( + f"Maisaka 知识学习已完成,但没有新增条目: 会话标识={self._session_id}" + ) + + return added_count + + def _build_chat_excerpt(self) -> str: + """ + 构建适合画像提取的对话片段,只保留用户可见文本。 + """ + lines: List[str] = [] + for message in self._messages_cache[-30:]: + if isinstance(message, (AssistantMessage, ToolResultMessage)): + continue + if isinstance(message, SessionBackedMessage): + if message.original_message and is_bot_self( + message.original_message.platform, + message.original_message.message_info.user_info.user_id, + ): + continue + raw_text = message.processed_plain_text.strip() + fallback_speaker = ( + message.original_message.message_info.user_info.user_nickname + if message.original_message is not None + else "用户" + ) + else: + if is_bot_self(message.platform, message.message_info.user_info.user_id): + continue + raw_text = message.processed_plain_text.strip() + fallback_speaker = message.message_info.user_info.user_nickname or "用户" + + if not raw_text: + continue + + speaker_name, body = parse_speaker_content(raw_text) + visible_text = (body or raw_text).strip() + if not visible_text: + continue + + speaker = speaker_name or fallback_speaker + user_metadata = self._extract_message_user_metadata(message) + metadata_parts = [ + f"platform={user_metadata['platform'] or 'unknown'}", + f"user_id={user_metadata['user_id'] or 'unknown'}", + f"user_nickname={user_metadata['user_nickname'] or speaker}", + f"person_name={user_metadata['person_name'] or ''}", + ] + lines.append( + f"[用户信息] {'; '.join(metadata_parts)}\n" + f"[发言] {speaker}: {visible_text}" + ) + + return "\n".join(lines) + + @staticmethod + def _extract_message_user_metadata(message: SessionMessage) -> Dict[str, str]: + """提取消息对应的用户元信息。""" + source_message = message.original_message if isinstance(message, SessionBackedMessage) else message + platform = str(getattr(source_message, "platform", "") or "").strip() + user_info = getattr(getattr(source_message, "message_info", None), "user_info", None) + user_id = str(getattr(user_info, "user_id", "") or "").strip() + user_nickname = str(getattr(user_info, "user_nickname", "") or "").strip() + + person_name = "" + if platform and user_id: + try: + person = Person(platform=platform, user_id=user_id) + if person.is_known and person.person_name: + person_name = str(person.person_name).strip() + except Exception: + person_name = "" + + return { + "platform": platform, + "user_id": user_id, + "user_nickname": user_nickname, + "person_name": person_name, + } + + def _build_learning_prompt(self, chat_excerpt: str) -> str: + """构建知识提取提示词。""" + categories_text = "\n".join( + f"{category_id}. {category_name}" for category_id, category_name in KNOWLEDGE_CATEGORIES.items() + ) + return ( + "你是一个用户画像知识提取器,需要从聊天记录里提取稳定、可复用的用户事实。\n" + "只提取用户明确表达或高置信度可归纳的信息,不要猜测,不要提取一次性情绪,不要重复表述。\n" + "如果没有可提取内容,返回空数组 []。\n" + "输出必须是 JSON 数组,每项格式为 " + '{"category_id":"分类编号","content":"简洁中文陈述"}。\n' + "分类如下:\n" + f"{categories_text}\n\n" + "聊天记录:\n" + f"{chat_excerpt}" + ) + + def _parse_learning_result(self, result: str) -> List[Dict[str, str]]: + """解析模型返回的知识条目。""" + normalized = result.strip() + if not normalized: + return [] + + if "```" in normalized: + normalized = normalized.replace("```json", "").replace("```JSON", "").replace("```", "").strip() + + try: + parsed = json.loads(normalized) + except json.JSONDecodeError: + logger.warning("知识学习结果不是有效的 JSON") + return [] + + if not isinstance(parsed, list): + return [] + + normalized_items: List[Dict[str, str]] = [] + seen_pairs: set[tuple[str, str]] = set() + for item in parsed: + if not isinstance(item, dict): + continue + + category_id = str(item.get("category_id", "")).strip() + content = " ".join(str(item.get("content", "")).strip().split()) + if category_id not in KNOWLEDGE_CATEGORIES: + continue + if not content: + continue + + pair = (category_id, content) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + normalized_items.append( + { + "category_id": category_id, + "content": content, + } + ) + + return normalized_items + + def _build_learning_prompt(self, chat_excerpt: str) -> str: + """构建知识提取提示词。""" + categories_text = "\n".join( + f"{category_id}. {category_name}" for category_id, category_name in KNOWLEDGE_CATEGORIES.items() + ) + return ( + "你是一个用户画像知识提取器,需要从聊天记录里提取稳定、可复用的用户事实。\n" + "聊天记录每条发言前都带有用户元信息,你必须明确判断这些特征属于哪个用户。\n" + "只提取用户明确表达或高置信度可归纳的信息,不要猜测,不要提取一次性情绪,不要重复表达。\n" + "如果没有可提取内容,返回空数组[]。\n" + "输出必须是 JSON 数组,每项格式为 " + '{"category_id":"分类编号","content":"简洁中文陈述","platform":"平台","user_id":"用户ID","user_nickname":"用户昵称","person_name":"人物名或空字符串"}。\n' + "其中 platform 和 user_id 必填;user_nickname 尽量填写;person_name 仅在用户信息中明确给出时填写,否则填空字符串。\n" + "同一条知识只能归属到一个用户,不要混合不同人的信息。\n" + "分类如下:\n" + f"{categories_text}\n\n" + "聊天记录:\n" + f"{chat_excerpt}" + ) + + def _parse_learning_result(self, result: str) -> List[Dict[str, str]]: + """解析模型返回的知识条目。""" + normalized = result.strip() + if not normalized: + return [] + + if "```" in normalized: + normalized = normalized.replace("```json", "").replace("```JSON", "").replace("```", "").strip() + + try: + parsed = json.loads(normalized) + except json.JSONDecodeError: + logger.warning("知识学习结果不是有效的 JSON") + return [] + + if not isinstance(parsed, list): + return [] + + normalized_items: List[Dict[str, str]] = [] + seen_pairs: set[tuple[str, str, str, str]] = set() + for item in parsed: + if not isinstance(item, dict): + continue + + category_id = str(item.get("category_id", "")).strip() + content = " ".join(str(item.get("content", "")).strip().split()) + platform = str(item.get("platform", "")).strip() + user_id = str(item.get("user_id", "")).strip() + user_nickname = str(item.get("user_nickname", "")).strip() + person_name = str(item.get("person_name", "")).strip() + if category_id not in KNOWLEDGE_CATEGORIES: + continue + if not content or not platform or not user_id: + continue + + pair = (category_id, content, platform, user_id) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + normalized_items.append( + { + "category_id": category_id, + "content": content, + "platform": platform, + "user_id": user_id, + "user_nickname": user_nickname, + "person_name": person_name, + } + ) + + return normalized_items diff --git a/src/know_u/knowledge_store.py b/src/know_u/knowledge_store.py new file mode 100644 index 00000000..4ca56814 --- /dev/null +++ b/src/know_u/knowledge_store.py @@ -0,0 +1,370 @@ +""" +MaiSaka knowledge store. +""" + +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +import json + +from sqlmodel import col, select + +from src.common.database.database import DATABASE_URL, get_db_session +from src.common.database.database_model import MaiKnowledge + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge" +KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json" + + +KNOWLEDGE_CATEGORIES = { + "1": "性别", + "2": "性格", + "3": "饮食口味", + "4": "交友偏好", + "5": "情绪/理性倾向", + "6": "兴趣爱好", + "7": "职业/专业", + "8": "生活习惯", + "9": "价值观", + "10": "沟通风格", + "11": "学习方式", + "12": "压力应对方式", +} + + +class KnowledgeStore: + """存储 Maisaka 的用户画像知识。""" + + def __init__(self) -> None: + """初始化知识存储,并在需要时迁移旧版 JSON 数据。""" + self._ensure_legacy_data_dir() + self._migrate_legacy_file_if_needed() + + def _ensure_legacy_data_dir(self) -> None: + """确保旧版知识目录存在,便于兼容历史数据。""" + KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _normalize_content(content: str) -> str: + """标准化知识内容,便于去重。""" + return " ".join(str(content).strip().split()) + + @staticmethod + def _serialize_metadata(metadata: Optional[Dict[str, Any]]) -> Optional[str]: + """将元数据序列化为 JSON 文本。""" + if not metadata: + return None + return json.dumps(metadata, ensure_ascii=False, sort_keys=True) + + @staticmethod + def _deserialize_metadata(raw_text: Optional[str]) -> Dict[str, Any]: + """将 JSON 文本反序列化为元数据字典。""" + if not raw_text: + return {} + try: + parsed = json.loads(raw_text) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + + @staticmethod + def _parse_created_at(raw_value: Any) -> datetime: + """解析旧版数据中的创建时间。""" + if isinstance(raw_value, datetime): + return raw_value + if isinstance(raw_value, str): + raw_text = raw_value.strip() + if raw_text: + try: + return datetime.fromisoformat(raw_text) + except ValueError: + pass + return datetime.now() + + @classmethod + def _build_item_dict(cls, record: MaiKnowledge) -> Dict[str, Any]: + """将数据库记录转换为兼容旧接口的字典。""" + return { + "id": record.knowledge_id, + "content": record.content, + "metadata": cls._deserialize_metadata(record.metadata_json), + "created_at": record.created_at.isoformat(), + } + + def _load_legacy_knowledge_file(self) -> Dict[str, List[Dict[str, Any]]]: + """读取旧版 JSON 知识文件。""" + if not KNOWLEDGE_FILE.exists(): + return {} + + try: + with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file: + loaded = json.load(file) + except Exception: + return {} + + if not isinstance(loaded, dict): + return {} + + normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {} + for category_id in KNOWLEDGE_CATEGORIES: + category_items = loaded.get(category_id, []) + if isinstance(category_items, list): + normalized_knowledge[category_id] = [ + item for item in category_items if isinstance(item, dict) + ] + return normalized_knowledge + + def _migrate_legacy_file_if_needed(self) -> None: + """在数据库为空时,将旧版 JSON 中的知识导入数据库。""" + legacy_knowledge = self._load_legacy_knowledge_file() + if not legacy_knowledge: + return + + with get_db_session(auto_commit=False) as session: + existing_record = session.exec(select(MaiKnowledge.id).limit(1)).first() + if existing_record is not None: + return + + for category_id, items in legacy_knowledge.items(): + if category_id not in KNOWLEDGE_CATEGORIES: + continue + + for item in items: + content = self._normalize_content(str(item.get("content", ""))) + if not content: + continue + + metadata = item.get("metadata") + session.add( + MaiKnowledge( + knowledge_id=str(item.get("id") or f"know_{category_id}_{datetime.now().timestamp()}"), + category_id=category_id, + content=content, + normalized_content=content, + metadata_json=self._serialize_metadata(metadata if isinstance(metadata, dict) else None), + created_at=self._parse_created_at(item.get("created_at")), + ) + ) + + session.commit() + + def add_knowledge( + self, + category_id: str, + content: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> bool: + """添加一条知识信息。""" + if category_id not in KNOWLEDGE_CATEGORIES: + return False + + normalized_content = self._normalize_content(content) + if not normalized_content: + return False + + user_platform = str((metadata or {}).get("platform", "")).strip() + user_id = str((metadata or {}).get("user_id", "")).strip() + with get_db_session(auto_commit=False) as session: + existing_records = session.exec( + select(MaiKnowledge).where( + MaiKnowledge.category_id == category_id, + MaiKnowledge.normalized_content == normalized_content, + ) + ).all() + for existing_record in existing_records: + existing_metadata = self._deserialize_metadata(existing_record.metadata_json) + existing_platform = str(existing_metadata.get("platform", "")).strip() + existing_user_id = str(existing_metadata.get("user_id", "")).strip() + if user_platform and user_id: + if existing_platform == user_platform and existing_user_id == user_id: + return False + continue + if not existing_platform and not existing_user_id: + return False + + session.add( + MaiKnowledge( + knowledge_id=f"know_{category_id}_{datetime.now().timestamp()}", + category_id=category_id, + content=normalized_content, + normalized_content=normalized_content, + metadata_json=self._serialize_metadata(metadata), + created_at=datetime.now(), + ) + ) + session.commit() + return True + + def search_knowledge( + self, + keyword: str, + limit: int = 10, + ) -> List[Dict[str, Any]]: + """按关键词搜索知识内容。""" + normalized_keyword = self._normalize_content(keyword) + if not normalized_keyword: + return [] + + limit_value = max(1, int(limit)) + with get_db_session() as session: + records = session.exec( + select(MaiKnowledge) + .where( + col(MaiKnowledge.content).contains(normalized_keyword) + | col(MaiKnowledge.normalized_content).contains(normalized_keyword) + ) + .order_by(MaiKnowledge.created_at.desc(), MaiKnowledge.id.desc()) + .limit(limit_value) + ).all() + + results: List[Dict[str, Any]] = [] + for record in records: + item = self._build_item_dict(record) + item["category_id"] = record.category_id + item["category_name"] = self.get_category_name(record.category_id) + results.append(item) + return results + + def get_knowledge_by_user( + self, + *, + platform: str = "", + user_id: str = "", + user_nickname: str = "", + person_name: str = "", + limit: int = 10, + ) -> List[Dict[str, Any]]: + """按用户元信息筛选知识条目。""" + platform = str(platform).strip() + user_id = str(user_id).strip() + user_nickname = str(user_nickname).strip() + person_name = str(person_name).strip() + if not any((platform, user_id, user_nickname, person_name)): + return [] + + limit_value = max(1, int(limit)) + with get_db_session() as session: + records = session.exec( + select(MaiKnowledge).order_by(MaiKnowledge.created_at.desc(), MaiKnowledge.id.desc()) + ).all() + + results: List[Dict[str, Any]] = [] + for record in records: + metadata = self._deserialize_metadata(record.metadata_json) + if user_id and str(metadata.get("user_id", "")).strip() != user_id: + continue + if platform and str(metadata.get("platform", "")).strip() != platform: + continue + if user_nickname and str(metadata.get("user_nickname", "")).strip() != user_nickname: + continue + if person_name and str(metadata.get("person_name", "")).strip() != person_name: + continue + + item = self._build_item_dict(record) + item["category_id"] = record.category_id + item["category_name"] = self.get_category_name(record.category_id) + results.append(item) + if len(results) >= limit_value: + break + + return results + + def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]: + """获取某个分类下的所有知识。""" + if category_id not in KNOWLEDGE_CATEGORIES: + return [] + + with get_db_session() as session: + records = session.exec( + select(MaiKnowledge) + .where(MaiKnowledge.category_id == category_id) + .order_by(MaiKnowledge.created_at.asc(), MaiKnowledge.id.asc()) + ).all() + return [self._build_item_dict(record) for record in records] + + def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]: + """获取全部知识。""" + all_knowledge: Dict[str, List[Dict[str, Any]]] = { + category_id: [] for category_id in KNOWLEDGE_CATEGORIES + } + with get_db_session() as session: + records = session.exec( + select(MaiKnowledge).order_by( + MaiKnowledge.category_id.asc(), + MaiKnowledge.created_at.asc(), + MaiKnowledge.id.asc(), + ) + ).all() + + for record in records: + all_knowledge.setdefault(record.category_id, []).append(self._build_item_dict(record)) + return all_knowledge + + def get_category_name(self, category_id: str) -> str: + """获取分类名称。""" + return KNOWLEDGE_CATEGORIES.get(category_id, "未知分类") + + def get_categories_summary(self) -> str: + """获取分类摘要,供模型判断是否需要检索。""" + counts: Dict[str, int] = {category_id: 0 for category_id in KNOWLEDGE_CATEGORIES} + with get_db_session() as session: + records = session.exec(select(MaiKnowledge.category_id)).all() + + for category_id in records: + if category_id in counts: + counts[category_id] += 1 + + lines: List[str] = [] + for category_id, category_name in KNOWLEDGE_CATEGORIES.items(): + count = counts.get(category_id, 0) + count_text = f"{count}条" if count > 0 else "无数据" + lines.append(f"{category_id}. {category_name} ({count_text})") + return "\n".join(lines) + + def get_formatted_knowledge(self, category_ids: List[str], limit_per_category: int = 5) -> str: + """获取指定分类的格式化知识内容。""" + parts: List[str] = [] + for category_id in category_ids: + items = self.get_category_knowledge(category_id) + if not items: + continue + + category_name = self.get_category_name(category_id) + parts.append(f"【{category_name}】") + + recent_items = items[-limit_per_category:] + for item in recent_items: + content = str(item.get("content", "")).strip() + if content: + parts.append(f"- {content}") + + return "\n".join(parts) + + def get_stats(self) -> Dict[str, Any]: + """获取知识数据统计。""" + with get_db_session() as session: + total_items = len(session.exec(select(MaiKnowledge.id)).all()) + + return { + "total_categories": len(KNOWLEDGE_CATEGORIES), + "total_items": total_items, + "data_file": DATABASE_URL, + "data_exists": True, + "data_size_kb": 0, + "legacy_data_file": str(KNOWLEDGE_FILE), + "legacy_data_exists": KNOWLEDGE_FILE.exists(), + "storage_type": "database", + } + + +_knowledge_store_instance: Optional[KnowledgeStore] = None + + +def get_knowledge_store() -> KnowledgeStore: + """获取知识存储单例。""" + global _knowledge_store_instance + if _knowledge_store_instance is None: + _knowledge_store_instance = KnowledgeStore() + return _knowledge_store_instance diff --git a/src/learners/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py index e5af1057..44141118 100644 --- a/src/learners/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -20,8 +20,8 @@ from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.logger import get_logger from src.config.config import global_config -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.manager.async_task_manager import AsyncTask logger = get_logger("expression_auto_check_task") @@ -76,7 +76,7 @@ def create_evaluation_prompt(situation: str, style: str) -> str: return prompt -judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") +judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check") async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]: @@ -94,10 +94,11 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str prompt = create_evaluation_prompt(situation, style) logger.debug(f"正在评估表达方式: situation={situation}, style={style}") - response, (reasoning, model_name, _) = await judge_llm.generate_response_async( - prompt=prompt, temperature=0.6, max_tokens=1024 + generation_result = await judge_llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions(temperature=0.6, max_tokens=1024), ) - + response = generation_result.response logger.debug(f"LLM响应: {response}") # 解析JSON响应 diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index b82ae1fa..579fb5ea 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -7,8 +7,9 @@ import difflib import json import re -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.common.logger import get_logger from src.common.database.database_model import Expression @@ -26,10 +27,11 @@ if TYPE_CHECKING: logger = get_logger("expressor") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner") -summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary") -check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check") +express_learn_model = LLMServiceClient( + task_name="utils", request_type="expression.learner" +) +summary_model = LLMServiceClient(task_name="utils", request_type="expression.summary") +check_model = LLMServiceClient(task_name="utils", request_type="expression.check") class ExpressionLearner: @@ -74,7 +76,10 @@ class ExpressionLearner: # 调用 LLM 学习表达方式 try: - response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3) + generation_result = await express_learn_model.generate_response( + prompt, options=LLMGenerationOptions(temperature=0.3) + ) + response = generation_result.response except Exception as e: logger.error(f"学习表达方式失败,模型生成出错:{e}") return @@ -413,7 +418,10 @@ class ExpressionLearner: "只输出概括内容。" ) try: - summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2) + summary_result = await summary_model.generate_response( + prompt, options=LLMGenerationOptions(temperature=0.2) + ) + summary = summary_result.response if summary := summary.strip(): return summary except Exception as e: diff --git a/src/learners/expression_selector.py b/src/learners/expression_selector.py index c96e84cf..30e2f154 100644 --- a/src/learners/expression_selector.py +++ b/src/learners/expression_selector.py @@ -4,10 +4,11 @@ import time from typing import List, Dict, Optional, Any, Tuple from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.common.logger import get_logger from src.common.database.database_model import Expression +from src.common.utils.utils_session import SessionUtils from src.prompt.prompt_manager import prompt_manager from src.learners.learner_utils_old import weighted_sample from src.chat.utils.common_utils import TempMethodsExpression @@ -17,8 +18,8 @@ logger = get_logger("expression_selector") class ExpressionSelector: def __init__(self): - self.llm_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.selector" + self.llm_model = LLMServiceClient( + task_name="utils", request_type="expression.selector" ) def can_use_expression_for_chat(self, chat_id: str) -> bool: @@ -383,8 +384,8 @@ class ExpressionSelector: prompt = await prompt_manager.render_prompt(prompt_template) # 4. 调用LLM - content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) - + generation_result = await self.llm_model.generate_response(prompt=prompt) + content = generation_result.response # print(prompt) # print(content) diff --git a/src/learners/expression_utils.py b/src/learners/expression_utils.py index 88237e57..23c41c39 100644 --- a/src/learners/expression_utils.py +++ b/src/learners/expression_utils.py @@ -1,19 +1,40 @@ from json_repair import repair_json -from typing import Tuple, Optional, List +from typing import Any, List, Optional, Tuple import json import re -from src.config.config import model_config from src.config.config import global_config -from src.llm_models.utils_model import LLMRequest +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.prompt.prompt_manager import prompt_manager from src.common.logger import get_logger logger = get_logger("expression_utils") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") +judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check") + + +def _normalize_repair_json_result(repaired_result: Any) -> str: + """将 repair_json 的返回值规范化为 JSON 字符串。 + + Args: + repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。 + + Returns: + str: 可供 `json.loads` 继续解析的 JSON 字符串。 + + Raises: + TypeError: 当返回值无法规范化为字符串时抛出。 + """ + if isinstance(repaired_result, str): + return repaired_result + if isinstance(repaired_result, tuple) and repaired_result: + first_item = repaired_result[0] + if isinstance(first_item, str): + return first_item + return json.dumps(first_item, ensure_ascii=False) + raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}") async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]: @@ -51,7 +72,11 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool logger.info(f"正在评估表达方式: situation={situation}, style={style}") - response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024) + generation_result = await judge_llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions(temperature=0.6, max_tokens=1024), + ) + response = generation_result.response logger.debug(f"评估结果: {response}") @@ -59,7 +84,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool evaluation = json.loads(response) except json.JSONDecodeError: try: - response_repaired = repair_json(response) + response_repaired = _normalize_repair_json_result(repair_json(response)) evaluation = json.loads(response_repaired) except Exception as e: raise ValueError(f"无法解析LLM响应为JSON: {response}") from e @@ -74,7 +99,7 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool return False, f"评估结果格式错误: {e}", str(e) -def fix_chinese_quotes_in_json(text): +def fix_chinese_quotes_in_json(text: str) -> str: """使用状态机修复 JSON 字符串值中的中文引号""" result = [] i = 0 @@ -201,12 +226,12 @@ def is_single_char_jargon(content: str) -> bool: ) -def _try_parse(text): +def _try_parse(text: str) -> Any: try: return json.loads(text) except Exception: try: - repaired = repair_json(text) + repaired = _normalize_repair_json_result(repair_json(text)) return json.loads(repaired) except Exception: return None diff --git a/src/learners/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py index 0cfafa82..876b4539 100644 --- a/src/learners/jargon_explainer_old.py +++ b/src/learners/jargon_explainer_old.py @@ -4,8 +4,9 @@ from typing import List, Dict, Optional, Any from src.common.logger import get_logger from src.common.database.database_model import Jargon -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.learners.jargon_miner_old import search_jargon from src.learners.learner_utils_old import ( @@ -23,8 +24,8 @@ class JargonExplainer: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id - self.llm = LLMRequest( - model_set=model_config.model_task_config.tool_use, + self.llm = LLMServiceClient( + task_name="utils", request_type="jargon.explain", ) @@ -206,7 +207,10 @@ class JargonExplainer: prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text) summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize) - summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3) + summary_result = await self.llm.generate_response( + summarize_prompt, options=LLMGenerationOptions(temperature=0.3) + ) + summary = summary_result.response if not summary: # 如果LLM概括失败,直接返回原始解释 return f"上下文中的黑话解释:\n{explanations_text}" diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 32926894..b0b4ae33 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -12,17 +12,17 @@ from src.common.data_models.jargon_data_model import MaiJargon from src.common.database.database import get_db_session from src.common.database.database_model import Jargon from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.prompt.prompt_manager import prompt_manager from .expression_utils import is_single_char_jargon logger = get_logger("jargon") -# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 -llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract") -llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference") +llm_extract = LLMServiceClient(task_name="utils", request_type="jargon.extract") +llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference") class JargonEntry(TypedDict): @@ -100,7 +100,10 @@ class JargonMiner: prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) prompt1 = await prompt_manager.render_prompt(prompt1_template) - llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3) + generation_result_1 = await llm_inference.generate_response( + prompt1, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_1 = generation_result_1.response if not llm_response_1: logger.warning(f"jargon {content} 推断1失败:无响应") return @@ -129,7 +132,10 @@ class JargonMiner: prompt2_template.add_context("content", content) prompt2 = await prompt_manager.render_prompt(prompt2_template) - llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3) + generation_result_2 = await llm_inference.generate_response( + prompt2, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_2 = generation_result_2.response if not llm_response_2: logger.warning(f"jargon {content} 推断2失败:无响应") return @@ -153,7 +159,10 @@ class JargonMiner: if global_config.debug.show_jargon_prompt: logger.info(f"jargon {content} 比较提示词: {prompt3}") - llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3) + generation_result_3 = await llm_inference.generate_response( + prompt3, options=LLMGenerationOptions(temperature=0.3) + ) + llm_response_3 = generation_result_3.response if not llm_response_3: logger.warning(f"jargon {content} 比较失败:无响应") return diff --git a/src/llm_models/model_client/adapter_base.py b/src/llm_models/model_client/adapter_base.py new file mode 100644 index 00000000..660a286d --- /dev/null +++ b/src/llm_models/model_client/adapter_base.py @@ -0,0 +1,265 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast + +import asyncio + +from src.common.logger import get_logger +from src.config.model_configs import ModelInfo + +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + EmbeddingRequest, + ResponseRequest, + UsageRecord, + UsageTuple, +) + +RawStreamT = TypeVar("RawStreamT") +"""流式原始响应类型变量。""" + +RawResponseT = TypeVar("RawResponseT") +"""非流式原始响应类型变量。""" + +TaskResultT = TypeVar("TaskResultT") +"""异步任务返回值类型变量。""" + +ProviderStreamResponseHandler = Callable[ + [RawStreamT, asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]], +] +"""Provider 专用流式响应处理函数类型。""" + +ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]] +"""Provider 专用非流式响应解析函数类型。""" + +logger = get_logger("llm_adapter_base") + + +async def await_task_with_interrupt( + task: asyncio.Task[TaskResultT], + interrupt_flag: asyncio.Event | None, + *, + interval_seconds: float = 0.02, +) -> TaskResultT: + """在支持外部中断的前提下等待异步任务完成。 + + Args: + task: 待等待的异步任务。 + interrupt_flag: 外部中断标记。 + interval_seconds: 轮询检查间隔,单位秒。 + + Returns: + TaskResultT: 任务执行结果。 + + Raises: + ReqAbortException: 等待期间收到外部中断信号时抛出。 + """ + from src.llm_models.exceptions import ReqAbortException + + started_at = asyncio.get_running_loop().time() + while not task.done(): + if interrupt_flag and interrupt_flag.is_set(): + elapsed = asyncio.get_running_loop().time() - started_at + logger.info(f"LLM 请求检测到中断信号,准备取消底层任务,elapsed={elapsed:.3f}s") + task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(interval_seconds) + return await task + + +class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]): + """提供统一请求执行骨架的 Provider 适配基类。""" + + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取对话响应。 + + Args: + request: 统一响应请求对象。 + + Returns: + APIResponse: 解析完成的统一响应对象。 + """ + stream_response_handler = self._resolve_stream_response_handler(request) + response_parser = self._resolve_response_parser(request) + response, usage_record = await self._execute_response_request( + request, + stream_response_handler, + response_parser, + ) + return self._attach_usage_record(response, request.model_info, usage_record) + + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取文本嵌入。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + APIResponse: 解析完成的统一嵌入响应。 + """ + response, usage_record = await self._execute_embedding_request(request) + return self._attach_usage_record(response, request.model_info, usage_record) + + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取音频转录。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + APIResponse: 解析完成的统一音频转录响应。 + """ + response, usage_record = await self._execute_audio_transcription_request(request) + return self._attach_usage_record(response, request.model_info, usage_record) + + def _resolve_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[RawStreamT]: + """解析实际使用的流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。 + """ + if request.stream_response_handler is not None: + return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler) + return self._build_default_stream_response_handler(request) + + def _resolve_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[RawResponseT]: + """解析实际使用的非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[RawResponseT]: 非流式响应解析器。 + """ + if request.async_response_parser is not None: + return cast(ProviderResponseParser[RawResponseT], request.async_response_parser) + return self._build_default_response_parser(request) + + @staticmethod + def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord: + """根据统一使用量三元组构建 `UsageRecord`。 + + Args: + model_info: 模型信息。 + usage_record: 使用量三元组。 + + Returns: + UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。 + """ + return UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + def _attach_usage_record( + self, + response: APIResponse, + model_info: ModelInfo, + usage_record: UsageTuple | None, + ) -> APIResponse: + """在响应对象上附加统一使用量信息。 + + Args: + response: 已解析的统一响应对象。 + model_info: 模型信息。 + usage_record: 可选的使用量三元组。 + + Returns: + APIResponse: 附加使用量后的响应对象。 + """ + if usage_record is not None: + response.usage = self._build_usage_record(model_info, usage_record) + return response + + @abstractmethod + def _build_default_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[RawStreamT]: + """构建默认流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。 + """ + raise NotImplementedError + + @abstractmethod + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[RawResponseT]: + """构建默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[RawResponseT]: 默认非流式解析器。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[RawStreamT], + response_parser: ProviderResponseParser[RawResponseT], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_embedding_request( + self, + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError + + @abstractmethod + async def _execute_audio_transcription_request( + self, + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Provider 的音频转录请求。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + raise NotImplementedError diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 226c725f..fc03ac02 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,14 +1,15 @@ -import asyncio -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Callable, Any, Optional +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type + +import asyncio from src.common.logger import get_logger from src.config.config import config_manager -from src.config.model_configs import ModelInfo, APIProvider -from ..payload_content.message import Message -from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption, ToolCall +from src.config.model_configs import APIProvider, ModelInfo +from src.llm_models.payload_content.message import Message +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption logger = get_logger("model_client_registry") @@ -47,10 +48,10 @@ class APIResponse: reasoning_content: str | None = None """推理内容""" - tool_calls: list[ToolCall] | None = None + tool_calls: List[ToolCall] | None = None """工具调用 [(工具名称, 工具参数), ...]""" - embedding: list[float] | None = None + embedding: List[float] | None = None """嵌入向量""" usage: UsageRecord | None = None @@ -60,6 +61,82 @@ class APIResponse: """响应原始数据""" +UsageTuple = Tuple[int, int, int] +"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。""" + +StreamResponseHandler = Callable[ + [Any, asyncio.Event | None], + Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]], +] +"""统一的流式响应处理函数类型。""" + +ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]] +"""统一的非流式响应解析函数类型。""" + + +@dataclass(slots=True) +class ResponseRequest: + """统一的文本/多模态响应请求。""" + + model_info: ModelInfo + message_list: List[Message] + tool_options: List[ToolOption] | None = None + max_tokens: int | None = None + temperature: float | None = None + response_format: RespFormat | None = None + stream_response_handler: StreamResponseHandler | None = None + async_response_parser: ResponseParser | None = None + interrupt_flag: asyncio.Event | None = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + def copy_with(self, **changes: Any) -> "ResponseRequest": + """基于当前请求创建一个带局部变更的新请求。 + + Args: + **changes: 需要覆盖的字段值。 + + Returns: + ResponseRequest: 复制后的请求对象。 + """ + payload = { + "model_info": self.model_info, + "message_list": list(self.message_list), + "tool_options": None if self.tool_options is None else list(self.tool_options), + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "response_format": self.response_format, + "stream_response_handler": self.stream_response_handler, + "async_response_parser": self.async_response_parser, + "interrupt_flag": self.interrupt_flag, + "extra_params": dict(self.extra_params), + } + payload.update(changes) + return ResponseRequest(**payload) + + +@dataclass(slots=True) +class EmbeddingRequest: + """统一的嵌入请求。""" + + model_info: ModelInfo + embedding_input: str + extra_params: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class AudioTranscriptionRequest: + """统一的音频转录请求。""" + + model_info: ModelInfo + audio_base64: str + max_tokens: int | None = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + +ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest +"""统一客户端请求类型。""" + + class BaseClient(ABC): """ 基础客户端 @@ -67,97 +144,82 @@ class BaseClient(ABC): api_provider: APIProvider - def __init__(self, api_provider: APIProvider): + def __init__(self, api_provider: APIProvider) -> None: + """初始化基础客户端。 + + Args: + api_provider: API 提供商配置。 + """ self.api_provider = api_provider @abstractmethod - async def get_response( - self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] - ] = None, - async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 - :param model_info: 模型信息 - :param message_list: 对话体 - :param tool_options: 工具选项(可选,默认为None) - :param max_tokens: 最大token数(可选,默认为1024) - :param temperature: 温度(可选,默认为0.7) - :param response_format: 响应格式(可选,默认为 NotGiven ) - :param stream_response_handler: 流式响应处理函数(可选) - :param async_response_parser: 响应解析函数(可选) - :param interrupt_flag: 中断信号量(可选,默认为None) - :return: (响应文本, 推理文本, 工具调用, 其他数据) + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取对话响应。 + + Args: + request: 统一响应请求对象。 + + Returns: + APIResponse: 统一响应对象。 """ raise NotImplementedError("'get_response' method should be overridden in subclasses") @abstractmethod - async def get_embedding( - self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取文本嵌入。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + APIResponse: 嵌入响应。 """ raise NotImplementedError("'get_embedding' method should be overridden in subclasses") @abstractmethod - async def get_audio_transcriptions( - self, - model_info: ModelInfo, - audio_base64: str, - max_tokens: Optional[int] = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: base64编码的音频数据 - :extra_params: 附加的请求参数 - :return: 音频转录响应 + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取音频转录。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + APIResponse: 音频转录响应。 """ raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses") @abstractmethod - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + def get_support_image_formats(self) -> List[str]: + """获取支持的图片格式。 + + Returns: + List[str]: 支持的图片格式列表。 """ raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses") class ClientRegistry: + """客户端注册表。""" + def __init__(self) -> None: - self.client_registry: dict[str, type[BaseClient]] = {} + """初始化注册表并绑定配置重载回调。""" + self.client_registry: Dict[str, Type[BaseClient]] = {} """APIProvider.type -> BaseClient的映射表""" - self.client_instance_cache: dict[str, BaseClient] = {} + self.client_instance_cache: Dict[str, BaseClient] = {} """APIProvider.name -> BaseClient的映射表""" config_manager.register_reload_callback(self.clear_client_instance_cache) - def register_client_class(self, client_type: str): - """ - 注册API客户端类 + def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]: + """注册 API 客户端类。 + Args: - client_class: API客户端类 + client_type: 客户端类型标识。 + + Returns: + Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。 """ - def decorator(cls: type[BaseClient]) -> type[BaseClient]: + def decorator(cls: Type[BaseClient]) -> Type[BaseClient]: if not issubclass(cls, BaseClient): raise TypeError(f"{cls.__name__} is not a subclass of BaseClient") self.client_registry[client_type] = cls @@ -165,14 +227,15 @@ class ClientRegistry: return decorator - def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: - """ - 获取注册的API客户端实例 + def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient: + """获取注册的 API 客户端实例。 + Args: - api_provider: APIProvider实例 - force_new: 是否强制创建新实例(用于解决事件循环问题) + api_provider: APIProvider 实例。 + force_new: 是否强制创建新实例。 + Returns: - BaseClient: 注册的API客户端实例 + BaseClient: 注册的 API 客户端实例。 """ from . import ensure_client_type_loaded @@ -194,6 +257,7 @@ class ClientRegistry: return self.client_instance_cache[api_provider.name] def clear_client_instance_cache(self) -> None: + """清空客户端实例缓存。""" self.client_instance_cache.clear() logger.info("检测到配置重载,已清空LLM客户端实例缓存") diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index f63707d9..0cad4fb4 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,771 +1,973 @@ +from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast + import asyncio -import io import base64 -from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict +import io +import json from google import genai -from google.genai.types import ( - Content, - Part, - FunctionDeclaration, - GenerateContentResponse, - ContentListUnion, - ContentUnion, - ThinkingConfig, - Tool, - GoogleSearch, - GenerateContentConfig, - EmbedContentResponse, - EmbedContentConfig, - SafetySetting, - HttpOptions, - HarmCategory, - HarmBlockThreshold, -) from google.genai.errors import ( ClientError, + FunctionInvocationError, ServerError, UnknownFunctionCallArgumentError, UnsupportedFunctionError, - FunctionInvocationError, +) +from google.genai.types import ( + Candidate, + Content, + ContentListUnion, + ContentUnion, + EmbedContentConfig, + EmbedContentResponse, + FunctionDeclaration, + GenerateContentConfig, + GenerateContentResponse, + GoogleSearch, + HarmBlockThreshold, + HarmCategory, + HttpOptions, + Part, + SafetySetting, + ThinkingConfig, + Tool, ) -from src.config.model_configs import ModelInfo, APIProvider from src.common.logger import get_logger - -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry -from ..exceptions import ( - RespParseException, - NetworkConnectionError, - RespNotOkException, - ReqAbortException, +from src.config.model_configs import APIProvider +from src.llm_models.exceptions import ( EmptyResponseException, + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart +from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption + +from .adapter_base import ( + AdapterClient, + ProviderResponseParser, + ProviderStreamResponseHandler, + await_task_with_interrupt, +) +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + EmbeddingRequest, + ResponseRequest, + UsageTuple, + client_registry, ) -from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") -# gemini_thinking参数(默认范围) -# 不同模型的思考预算范围配置 -THINKING_BUDGET_LIMITS = { +GeminiStreamResponseHandler = Callable[ + [AsyncIterator[GenerateContentResponse], asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, Optional[UsageTuple]]], +] +"""Gemini 流式响应处理函数类型。""" + +GeminiResponseParser = Callable[[GenerateContentResponse], Tuple[APIResponse, Optional[UsageTuple]]] +"""Gemini 非流式响应解析函数类型。""" + +THINKING_BUDGET_LIMITS: Dict[str, Dict[str, int | bool]] = { "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, } -# 思维预算特殊值 -THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 -THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) +"""不同 Gemini 模型允许的思考预算范围。""" -gemini_safe_settings = [ +THINKING_BUDGET_AUTO = -1 +"""自动思考预算模式,由模型自行决定。""" + +THINKING_BUDGET_DISABLED = 0 +"""禁用思考预算模式。仅部分模型支持。""" + +GEMINI_SAFE_SETTINGS: List[SafetySetting] = [ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), ] +"""默认安全策略,避免 Gemini 在部分内容上返回空响应。""" + +GENERATE_CONFIG_RESERVED_EXTRA_PARAMS = { + "thinking_budget", + "include_thoughts", + "enable_google_search", + "transcription_prompt", + "audio_mime_type", +} +"""由当前客户端自行处理、不再直接透传给 `GenerateContentConfig` 的额外参数。""" + +EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = { + "task_type", + "title", + "output_dimensionality", + "mime_type", + "auto_truncate", +} +"""可透传给 `EmbedContentConfig` 的额外参数字段。""" -def _convert_messages( - messages: list[Message], -) -> tuple[ContentListUnion, list[str] | None]: +def _normalize_image_mime_type(image_format: str) -> str: + """将图片格式名称转换为标准 MIME 类型。 + + Args: + image_format: 图片格式名,例如 `png`、`jpg`。 + + Returns: + str: 规范化后的图片 MIME 类型。 """ - 转换消息格式 - 将消息转换为Gemini API所需的格式 - :param messages: 消息列表 - :return: 转换后的消息列表(和可能存在的system消息) + normalized_image_format = image_format.lower() + if normalized_image_format in {"jpg", "jpeg"}: + return "image/jpeg" + return f"image/{normalized_image_format}" + + +def _build_non_tool_parts(message: Message) -> List[Part]: + """将消息中的文本与图片片段转换为 Gemini `Part` 列表。 + + Args: + message: 内部统一消息对象。 + + Returns: + List[Part]: Gemini 所需的内容片段列表。 """ + converted_parts: List[Part] = [] + for message_part in message.parts: + if isinstance(message_part, TextMessagePart): + converted_parts.append(Part.from_text(text=message_part.text)) + continue + if isinstance(message_part, ImageMessagePart): + converted_parts.append( + Part.from_bytes( + data=base64.b64decode(message_part.image_base64), + mime_type=_normalize_image_mime_type(message_part.normalized_image_format), + ) + ) + return converted_parts - def _convert_message_item(message: Message) -> Content: - """ - 转换单个消息格式,除了system和tool类型的消息 - :param message: 消息对象 - :return: 转换后的消息字典 - """ - # 将openai格式的角色重命名为gemini格式的角色 - if message.role == RoleType.Assistant: - role = "model" - elif message.role == RoleType.User: - role = "user" - else: - raise ValueError(f"Unsupported role: {message.role}") +def _normalize_function_response_payload(message: Message) -> Dict[str, Any]: + """将内部工具结果消息转换为 Gemini 函数响应负载。 - # 添加Content - if isinstance(message.content, str): - content = [Part.from_text(text=message.content)] - elif isinstance(message.content, list): - content: List[Part] = [] - for item in message.content: - if isinstance(item, tuple): - image_format = item[0].lower() - # 规范 JPEG MIME 类型后缀,统一使用 image/jpeg - if image_format in ("jpg", "jpeg"): - image_format = "jpeg" - content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")) - elif isinstance(item, str): - content.append(Part.from_text(text=item)) - else: - raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + Args: + message: 工具结果消息。 - return Content(role=role, parts=content) + Returns: + Dict[str, Any]: 可用于 `Part.from_function_response()` 的响应对象。 + """ + content = message.content + if isinstance(content, str): + stripped_content = content.strip() + if not stripped_content: + return {} + try: + parsed_content = json.loads(stripped_content) + except json.JSONDecodeError: + return {"result": content} + if isinstance(parsed_content, dict): + return parsed_content + return {"result": parsed_content} + + return {"result": content} + + +def _get_candidates(response: GenerateContentResponse) -> List[Candidate]: + """安全获取 Gemini 响应中的候选列表。 + + Args: + response: Gemini 响应对象。 + + Returns: + List[Candidate]: 非空时返回原候选列表,否则返回空列表。 + """ + return response.candidates or [] + + +def _extract_response_json_schema(response_format: RespFormat) -> Dict[str, object] | None: + """从内部响应格式中提取可供 Gemini 使用的 JSON Schema。 + + Args: + response_format: 输出格式定义。 + + Returns: + Dict[str, object] | None: 可直接传给 `response_json_schema` 的 JSON Schema。 + """ + schema_payload = response_format.get_schema_object() + if schema_payload is None: + return None + return cast(Dict[str, object], schema_payload) + + +def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str | None]: + """将内部统一消息列表转换为 Gemini 内容结构。 + + Args: + messages: 内部统一消息列表。 + + Returns: + Tuple[ContentListUnion, str | None]: `contents` 与可选的 `system_instruction`。 + + Raises: + ValueError: 当消息结构无法映射到 Gemini 内容模型时抛出。 + """ + contents: List[ContentUnion] = [] + system_instruction_chunks: List[str] = [] + tool_name_by_call_id: Dict[str, str] = {} - temp_list: list[ContentUnion] = [] - system_instructions: list[str] = [] for message in messages: if message.role == RoleType.System: - if isinstance(message.content, str): - system_instructions.append(message.content) - else: - raise ValueError("你tm怎么往system里面塞图片base64?") - elif message.role == RoleType.Tool: + system_text = message.get_text_content().strip() + if not system_text: + raise ValueError("Gemini 的 system message 必须为非空文本") + system_instruction_chunks.append(system_text) + continue + + if message.role == RoleType.User: + contents.append(Content(role="user", parts=_build_non_tool_parts(message))) + continue + + if message.role == RoleType.Assistant: + assistant_parts = _build_non_tool_parts(message) + if message.tool_calls: + for tool_call in message.tool_calls: + assistant_parts.append( + Part.from_function_call( + name=tool_call.func_name, + args=tool_call.args or {}, + ) + ) + tool_name_by_call_id[tool_call.call_id] = tool_call.func_name + contents.append(Content(role="model", parts=assistant_parts)) + continue + + if message.role == RoleType.Tool: if not message.tool_call_id: - raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") - else: - temp_list.append(_convert_message_item(message)) - if system_instructions: - # 如果有system消息,就把它加上去 - ret: tuple = (temp_list, system_instructions) - else: - # 如果没有system消息,就直接返回 - ret: tuple = (temp_list, None) + raise ValueError("Gemini 工具结果消息缺少 tool_call_id") + tool_name = tool_name_by_call_id.get(message.tool_call_id) + if not tool_name: + raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称") + function_response_part = Part.from_function_response( + name=tool_name, + response=_normalize_function_response_payload(message), + ) + contents.append(Content(role="tool", parts=[function_response_part])) + continue - return ret + raise ValueError(f"不支持的消息角色: {message.role}") + + system_instruction = "\n\n".join(chunk for chunk in system_instruction_chunks if chunk.strip()) or None + return contents, system_instruction -def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]: +def _build_tools(tool_options: List[ToolOption]) -> List[Tool]: + """将内部工具定义转换为 Gemini `Tool` 列表。 + + Args: + tool_options: 内部统一工具定义列表。 + + Returns: + List[Tool]: Gemini 所需工具列表。 """ - 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式 - :param tool_options: 工具选项列表 - :return: 转换后的工具对象列表 - """ - - def _convert_tool_param(tool_option_param: ToolParam) -> dict: - """ - 转换单个工具参数格式 - :param tool_option_param: 工具参数对象 - :return: 转换后的工具参数字典 - """ - # JSON Schema 类型名称修正: - # - 布尔类型使用 "boolean" 而不是 "bool" - # - 浮点数使用 "number" 而不是 "float" - param_type_value = tool_option_param.param_type.value - if param_type_value == "bool": - param_type_value = "boolean" - elif param_type_value == "float": - param_type_value = "number" - - return_dict: dict[str, Any] = { - "type": param_type_value, - "description": tool_option_param.description, - } - if tool_option_param.enum_values: - return_dict["enum"] = tool_option_param.enum_values - return return_dict - - def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration: - """ - 转换单个工具项格式 - :param tool_option: 工具选项对象 - :return: 转换后的Gemini工具选项对象 - """ - ret: dict[str, Any] = { + function_declarations: List[FunctionDeclaration] = [] + for tool_option in tool_options: + payload: Dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } - if tool_option.params: - ret["parameters"] = { - "type": "object", - "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, - "required": [param.name for param in tool_option.params if param.required], - } - ret1 = FunctionDeclaration(**ret) - return ret1 - - return [_convert_tool_option_item(tool_option) for tool_option in tool_options] + if tool_option.parameters_schema is not None: + payload["parameters_json_schema"] = tool_option.parameters_schema + function_declarations.append(FunctionDeclaration(**payload)) + return [Tool(function_declarations=function_declarations)] if function_declarations else [] -def _process_delta( - delta: GenerateContentResponse, - fc_delta_buffer: io.StringIO, - tool_calls_buffer: list[tuple[str, str, dict[str, Any]]], - resp: APIResponse | None = None, -): - if not hasattr(delta, "candidates") or not delta.candidates: - raise RespParseException(delta, "响应解析失败,缺失candidates字段") +def _extract_usage_record(response: GenerateContentResponse) -> Optional[UsageTuple]: + """从 Gemini 响应中提取使用量信息。 - # 处理 thought(Gemini 的特殊字段) - for c in getattr(delta, "candidates", []): - if c.content and getattr(c.content, "parts", None): - for p in c.content.parts: - if getattr(p, "thought", False) and getattr(p, "text", None): - # 保存到 reasoning_content - if resp is not None: - resp.reasoning_content = (resp.reasoning_content or "") + p.text - elif getattr(p, "text", None): - # 正常输出写入 buffer - fc_delta_buffer.write(p.text) + Args: + response: Gemini 响应对象。 - if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 - for call in delta.function_calls: - try: - if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了 - raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型") - if not call.id or not call.name: - raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段") - tool_calls_buffer.append( - ( - call.id, - call.name, - call.args or {}, # 如果args是None,则转换为一个空字典 - ) - ) - except Exception as e: - raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e + Returns: + Optional[UsageTuple]: 统一的使用量三元组;缺失时返回 `None`。 + """ + usage_metadata = getattr(response, "usage_metadata", None) + if usage_metadata is None: + return None + prompt_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0 + completion_tokens = ( + (getattr(usage_metadata, "candidates_token_count", 0) or 0) + + (getattr(usage_metadata, "thoughts_token_count", 0) or 0) + ) + total_tokens = getattr(usage_metadata, "total_token_count", 0) or 0 + return prompt_tokens, completion_tokens, total_tokens -def _build_stream_api_resp( - _fc_delta_buffer: io.StringIO, - _tool_calls_buffer: list[tuple[str, str, dict]], - last_resp: GenerateContentResponse | None = None, # 传入 last_resp - resp: APIResponse | None = None, -) -> APIResponse: - # sourcery skip: simplify-len-comparison, use-assigned-variable - if resp is None: - resp = APIResponse() +def _extract_finish_reason(response: GenerateContentResponse | None) -> str | None: + """提取 Gemini 响应的结束原因。 - if _fc_delta_buffer.tell() > 0: - # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 - resp.content = _fc_delta_buffer.getvalue() - _fc_delta_buffer.close() - if len(_tool_calls_buffer) > 0: - # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 - resp.tool_calls = [] - for call_id, function_name, arguments_buffer in _tool_calls_buffer: - if arguments_buffer is not None: - arguments = arguments_buffer - if not isinstance(arguments, dict): - raise RespParseException( - None, - f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}", - ) + Args: + response: Gemini 响应对象。 + + Returns: + str | None: 结束原因字符串;获取失败时返回 `None`。 + """ + if response is None: + return None + candidates = _get_candidates(response) + if not candidates: + return None + for candidate in candidates: + finish_reason = getattr(candidate, "finish_reason", None) or getattr(candidate, "finishReason", None) + if finish_reason: + return str(finish_reason) + return None + + +def _warn_if_max_tokens_truncated( + response: GenerateContentResponse | None, + content: str | None, + tool_calls: List[ToolCall] | None, +) -> None: + """在 Gemini 因 token 限制截断时输出警告。 + + Args: + response: Gemini 响应对象。 + content: 已解析的可见文本内容。 + tool_calls: 已解析的工具调用列表。 + """ + finish_reason = _extract_finish_reason(response) + if finish_reason is None or "MAX_TOKENS" not in finish_reason: + return + has_visible_output = bool((content and content.strip()) or tool_calls) + if has_visible_output: + logger.warning( + "Gemini 响应因达到 max_tokens 限制被部分截断,可能影响回复完整性,建议调整模型 max_tokens 配置。" + ) + return + logger.warning("Gemini 响应因达到 max_tokens 限制被截断,且未返回可见输出,请检查模型 max_tokens 配置。") + + +def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]: + """从 Gemini 响应中提取工具调用列表。 + + Args: + response: Gemini 响应对象。 + + Returns: + List[ToolCall]: 规范化后的工具调用列表。 + + Raises: + RespParseException: 当函数调用结构不合法时抛出。 + """ + raw_function_calls = getattr(response, "function_calls", None) + candidates = _get_candidates(response) + if not raw_function_calls and candidates: + raw_function_calls = [] + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + function_call = getattr(part, "function_call", None) + if function_call is not None: + raw_function_calls.append(function_call) + + if not raw_function_calls: + return [] + + tool_calls: List[ToolCall] = [] + for index, function_call in enumerate(raw_function_calls, start=1): + call_name = getattr(function_call, "name", None) + call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}" + call_args = getattr(function_call, "args", None) or {} + if not isinstance(call_name, str) or not call_name: + raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段") + if not isinstance(call_args, dict): + raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典") + tool_calls.append(ToolCall(call_id=call_id, func_name=call_name, args=call_args)) + return tool_calls + + +def _process_stream_chunk( + chunk: GenerateContentResponse, + content_buffer: io.StringIO, + tool_calls_buffer: List[ToolCall], + response: APIResponse, +) -> None: + """处理单个 Gemini 流式响应块。 + + Args: + chunk: 当前流式响应块。 + content_buffer: 正文缓冲区。 + tool_calls_buffer: 工具调用缓冲区。 + response: 当前累积的统一响应对象。 + """ + candidates = _get_candidates(chunk) + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + part_text = getattr(part, "text", None) + if not part_text: + continue + if getattr(part, "thought", False): + response.reasoning_content = (response.reasoning_content or "") + part_text else: - arguments = None + content_buffer.write(part_text) - resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + tool_calls_buffer.extend(_collect_function_calls(chunk)) - # 检查是否因为 max_tokens 截断 - reason = None - if last_resp and getattr(last_resp, "candidates", None): - for c in last_resp.candidates: - fr = getattr(c, "finish_reason", None) or getattr(c, "finishReason", None) - if fr: - reason = str(fr) - break - if str(reason).endswith("MAX_TOKENS"): - has_visible_output = bool(resp.content and resp.content.strip()) - if has_visible_output: - logger.warning( - "⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n" - " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" - ) - else: - logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") +def _build_stream_api_response( + content_buffer: io.StringIO, + tool_calls_buffer: List[ToolCall], + last_response: GenerateContentResponse | None, + response: APIResponse, +) -> APIResponse: + """根据流式缓冲区内容构建统一响应对象。 - if not resp.content and not resp.tool_calls: - if not getattr(resp, "reasoning_content", None): - raise EmptyResponseException() + Args: + content_buffer: 正文缓冲区。 + tool_calls_buffer: 工具调用缓冲区。 + last_response: 最后一个 Gemini 响应块。 + response: 已累积的响应对象。 - return resp + Returns: + APIResponse: 构建完成的统一响应对象。 + + Raises: + EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 + """ + if content_buffer.tell() > 0: + response.content = content_buffer.getvalue() + content_buffer.close() + + if tool_calls_buffer: + response.tool_calls = list(tool_calls_buffer) + response.raw_data = last_response + + _warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls) + if not response.content and not response.tool_calls and not response.reasoning_content: + raise EmptyResponseException() + return response async def _default_stream_response_handler( - resp_stream: AsyncIterator[GenerateContentResponse], + response_stream: AsyncIterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: +) -> Tuple[APIResponse, Optional[UsageTuple]]: + """处理 Gemini 流式响应。 + + Args: + response_stream: Gemini 异步流式响应迭代器。 + interrupt_flag: 外部中断标记。 + + Returns: + Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 """ - 流式响应处理函数 - 处理Gemini API的流式响应 - :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 - :return: APIResponse对象 - """ - _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 - _usage_record = None # 使用情况记录 - last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk - resp = APIResponse() - - def _insure_buffer_closed(): - if _fc_delta_buffer and not _fc_delta_buffer.closed: - _fc_delta_buffer.close() - - async for chunk in resp_stream: - last_resp = chunk # 保存最后一个响应 - # 检查是否有中断量 - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量被设置,则抛出ReqAbortException - raise ReqAbortException("请求被外部信号中断") - - _process_delta( - chunk, - _fc_delta_buffer, - _tool_calls_buffer, - resp=resp, - ) - - if chunk.usage_metadata: - # 如果有使用情况,则将其存储在APIResponse对象中 - _usage_record = ( - chunk.usage_metadata.prompt_token_count or 0, - (chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0), - chunk.usage_metadata.total_token_count or 0, - ) + content_buffer = io.StringIO() + tool_calls_buffer: List[ToolCall] = [] + api_response = APIResponse() + usage_record: Optional[UsageTuple] = None + last_response: GenerateContentResponse | None = None try: - return _build_stream_api_resp( - _fc_delta_buffer, - _tool_calls_buffer, - last_resp=last_resp, - resp=resp, - ), _usage_record + async for chunk in response_stream: + last_response = chunk + if interrupt_flag and interrupt_flag.is_set(): + raise ReqAbortException("请求被外部信号中断") + _process_stream_chunk(chunk, content_buffer, tool_calls_buffer, api_response) + usage_record = _extract_usage_record(chunk) or usage_record + return _build_stream_api_response(content_buffer, tool_calls_buffer, last_response, api_response), usage_record except Exception: - # 确保缓冲区被关闭 - _insure_buffer_closed() + if not content_buffer.closed: + content_buffer.close() raise def _default_normal_response_parser( - resp: GenerateContentResponse, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + response: GenerateContentResponse, +) -> Tuple[APIResponse, Optional[UsageTuple]]: + """解析 Gemini 非流式响应。 + + Args: + response: Gemini 响应对象。 + + Returns: + Tuple[APIResponse, Optional[UsageTuple]]: 统一响应对象与可选的使用量信息。 + + Raises: + EmptyResponseException: 响应中既无正文也无工具调用且无思考内容时抛出。 """ - 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 - :param resp: 响应对象 - :return: APIResponse对象 - """ - api_response = APIResponse() + api_response = APIResponse(raw_data=response) + visible_parts: List[str] = [] - # 解析思考内容 - try: - if candidates := resp.candidates: - if candidates[0].content and candidates[0].content.parts: - for part in candidates[0].content.parts: - if not part.text: - continue - if part.thought: - api_response.reasoning_content = ( - api_response.reasoning_content + part.text if api_response.reasoning_content else part.text - ) - except Exception as e: - logger.warning(f"解析思考内容时发生错误: {e},跳过解析") + for candidate in _get_candidates(response): + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + part_text = getattr(part, "text", None) + if not part_text: + continue + if getattr(part, "thought", False): + api_response.reasoning_content = (api_response.reasoning_content or "") + part_text + else: + visible_parts.append(part_text) - # 解析响应内容 - api_response.content = resp.text + api_response.content = "".join(visible_parts).strip() or getattr(response, "text", None) - # 解析工具调用 - if function_calls := resp.function_calls: - api_response.tool_calls = [] - for call in function_calls: - try: - if not isinstance(call.args, dict): - raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") - if not call.name: - raise RespParseException(resp, "响应解析失败,工具调用缺失name字段") - api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {})) - except Exception as e: - raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e + tool_calls = _collect_function_calls(response) + if tool_calls: + api_response.tool_calls = tool_calls - # 解析使用情况 - if usage_metadata := resp.usage_metadata: - _usage_record = ( - usage_metadata.prompt_token_count or 0, - (usage_metadata.candidates_token_count or 0) + (usage_metadata.thoughts_token_count or 0), - usage_metadata.total_token_count or 0, - ) - else: - _usage_record = None - - api_response.raw_data = resp - - # 检查是否因为 max_tokens 截断 - try: - if resp.candidates: - c0 = resp.candidates[0] - reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None) - if reason and "MAX_TOKENS" in str(reason): - # 检查第二个及之后的 parts 是否有内容 - has_real_output = False - if getattr(c0, "content", None) and getattr(c0.content, "parts", None): - for p in c0.content.parts[1:]: # 跳过第一个 thought - if getattr(p, "text", None) and p.text.strip(): - has_real_output = True - break - - if not has_real_output and getattr(resp, "text", None): - has_real_output = True - - if has_real_output: - logger.warning( - "⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n" - " 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!" - ) - else: - logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!") - - return api_response, _usage_record - except Exception as e: - logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") - - # 最终的、唯一的空响应检查 - if not api_response.content and not api_response.tool_calls: + usage_record = _extract_usage_record(response) + _warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls) + if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content: raise EmptyResponseException("响应中既无文本内容也无工具调用") + return api_response, usage_record - return api_response, _usage_record + +def _build_http_options(api_provider: APIProvider) -> HttpOptions: + """根据 Provider 配置构建 Gemini SDK 的 `HttpOptions`。 + + Args: + api_provider: API 提供商配置。 + + Returns: + HttpOptions: Gemini SDK HTTP 选项对象。 + """ + http_options_payload: Dict[str, Any] = {} + if api_provider.timeout is not None: + http_options_payload["timeout"] = int(api_provider.timeout * 1000) + + base_url = api_provider.base_url.strip() + if base_url: + normalized_base_url = base_url.rstrip("/") + version_candidate = normalized_base_url.rsplit("/", 1) + if len(version_candidate) == 2 and version_candidate[1].startswith("v"): + http_options_payload["base_url"] = f"{version_candidate[0]}/" + http_options_payload["api_version"] = version_candidate[1] + else: + http_options_payload["base_url"] = f"{normalized_base_url}/" + + return HttpOptions(**http_options_payload) + + +def _filter_generate_content_extra_params(extra_params: Dict[str, Any]) -> Dict[str, Any]: + """筛选可透传给 `GenerateContentConfig` 的额外参数。 + + Args: + extra_params: 模型级额外参数。 + + Returns: + Dict[str, Any]: 可直接透传到 `GenerateContentConfig` 的字段字典。 + """ + filtered_params: Dict[str, Any] = {} + for key, value in extra_params.items(): + if key in GENERATE_CONFIG_RESERVED_EXTRA_PARAMS: + continue + if key in GenerateContentConfig.model_fields: + filtered_params[key] = value + return filtered_params + + +def _build_embed_content_config(extra_params: Dict[str, Any]) -> EmbedContentConfig: + """构建 Gemini 嵌入配置。 + + Args: + extra_params: 模型级额外参数。 + + Returns: + EmbedContentConfig: Gemini 嵌入配置对象。 + """ + config_payload: Dict[str, Any] = {"task_type": extra_params.get("task_type", "SEMANTIC_SIMILARITY")} + for key in EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS: + if key == "task_type": + continue + if key in extra_params: + config_payload[key] = extra_params[key] + return EmbedContentConfig(**config_payload) @client_registry.register_client_class("gemini") -class GeminiClient(BaseClient): +class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], GenerateContentResponse]): + """Gemini 官方 SDK 客户端适配器。""" + client: genai.Client - def __init__(self, api_provider: APIProvider): + def __init__(self, api_provider: APIProvider) -> None: + """初始化 Gemini 客户端。 + + Args: + api_provider: API 提供商配置。 + """ super().__init__(api_provider) - - # 增加传入参数处理 - http_options_kwargs: Dict[str, Any] = {} - - # 秒转换为毫秒传入 - if api_provider.timeout is not None: - http_options_kwargs["timeout"] = int(api_provider.timeout * 1000) - - # 传入并处理地址和版本(必须为Gemini格式) - if api_provider.base_url: - parts = api_provider.base_url.rstrip("/").rsplit("/", 1) - if len(parts) == 2 and parts[1].startswith("v"): - http_options_kwargs["base_url"] = f"{parts[0]}/" - http_options_kwargs["api_version"] = parts[1] - else: - http_options_kwargs["base_url"] = api_provider.base_url - http_options_kwargs["api_version"] = None self.client = genai.Client( - http_options=HttpOptions(**http_options_kwargs), api_key=api_provider.api_key, - ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + http_options=_build_http_options(api_provider), + ) @staticmethod - def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int: - """ - 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) - """ - limits = None + def clamp_thinking_budget(extra_params: Dict[str, Any] | None, model_id: str) -> int: + """将思考预算裁剪到模型允许的范围内。 - # 参数传入处理 - tb = THINKING_BUDGET_AUTO + Args: + extra_params: 请求额外参数。 + model_id: 当前模型标识。 + + Returns: + int: 裁剪后的思考预算值。 + """ + thinking_budget = THINKING_BUDGET_AUTO if extra_params and "thinking_budget" in extra_params: try: - tb = int(extra_params["thinking_budget"]) - except (ValueError, TypeError): - logger.warning( - f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" - ) + thinking_budget = int(extra_params["thinking_budget"]) + except (TypeError, ValueError): + logger.warning(f"无效的 thinking_budget={extra_params['thinking_budget']},已回退为自动模式") - # 优先尝试精确匹配 + limits: Dict[str, int | bool] | None = None if model_id in THINKING_BUDGET_LIMITS: limits = THINKING_BUDGET_LIMITS[model_id] else: - # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 - sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) - for key in sorted_keys: - # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) - if model_id == key or model_id.startswith(f"{key}-"): - limits = THINKING_BUDGET_LIMITS[key] + for candidate_prefix in sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True): + if model_id == candidate_prefix or model_id.startswith(f"{candidate_prefix}-"): + limits = THINKING_BUDGET_LIMITS[candidate_prefix] break - # 预算值处理 - if tb == THINKING_BUDGET_AUTO: + if thinking_budget == THINKING_BUDGET_AUTO: return THINKING_BUDGET_AUTO - if tb == THINKING_BUDGET_DISABLED: - if limits and limits.get("can_disable", False): + + if thinking_budget == THINKING_BUDGET_DISABLED: + if limits and bool(limits.get("can_disable", False)): return THINKING_BUDGET_DISABLED if limits: - logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}") - return limits["min"] + minimum_value = int(limits["min"]) + logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退为最小值 {minimum_value}") + return minimum_value return THINKING_BUDGET_AUTO - # 已知模型范围裁剪 + 提示 - if limits: - if tb < limits["min"]: - logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}") - return limits["min"] - if tb > limits["max"]: - logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}") - return limits["max"] - return tb + if limits is None: + logger.warning(f"模型 {model_id} 未配置思考预算范围,已回退为自动模式") + return THINKING_BUDGET_AUTO - # 未知模型 → 默认自动模式 - logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容") - return THINKING_BUDGET_AUTO + minimum_value = int(limits["min"]) + maximum_value = int(limits["max"]) + if thinking_budget < minimum_value: + logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过小,已调整为 {minimum_value}") + return minimum_value + if thinking_budget > maximum_value: + logger.warning(f"模型 {model_id} 的 thinking_budget={thinking_budget} 过大,已调整为 {maximum_value}") + return maximum_value + return thinking_budget + + @staticmethod + def _resolve_model_identifier(model_identifier: str, extra_params: Dict[str, Any]) -> Tuple[str, bool]: + """解析请求实际使用的 Gemini 模型标识。 - async def get_response( - self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = 1024, - temperature: Optional[float] = 0.4, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncIterator[GenerateContentResponse], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 Args: - model_info: 模型信息 - message_list: 对话体 - tool_options: 工具选项(可选,默认为None) - max_tokens: 最大token数(可选,默认为1024) - temperature: 温度(可选,默认为0.7) - response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入) - stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) - async_response_parser: 响应解析函数(可选,默认为default_response_parser) - interrupt_flag: 中断信号量(可选,默认为None) + model_identifier: 原始模型标识。 + extra_params: 模型级额外参数。 + Returns: - APIResponse对象,包含响应内容、推理内容、工具调用等信息 + Tuple[str, bool]: `(实际模型标识, 是否启用 Google Search)`。 """ - if stream_response_handler is None: - stream_response_handler = _default_stream_response_handler - - if async_response_parser is None: - async_response_parser = _default_normal_response_parser - - # 将messages构造为Gemini API所需的格式 - messages = _convert_messages(message_list) - # 将tool_options转换为Gemini API所需的格式 - tools = _convert_tool_options(tool_options) if tool_options else None - # 解析并裁剪 thinking_budget - tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) - # 检测是否为带 -search 的模型 - enable_google_search = False - model_identifier = model_info.model_identifier - if model_identifier.endswith("-search"): + enable_google_search = bool(extra_params.get("enable_google_search", False)) + resolved_model_identifier = model_identifier + if resolved_model_identifier.endswith("-search"): + resolved_model_identifier = resolved_model_identifier.removesuffix("-search") enable_google_search = True - # 去掉后缀并更新模型ID - model_identifier = model_identifier.removesuffix("-search") - model_info.model_identifier = model_identifier - logger.info(f"模型已启用 GoogleSearch 功能:{model_identifier}") + return resolved_model_identifier, enable_google_search - # 将response_format转换为Gemini API所需的格式 - generation_config_dict = { - "max_output_tokens": max_tokens, - "temperature": temperature, - "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=tb, - ), - "safety_settings": gemini_safe_settings, # 防止空回复问题 - } - if tools: - generation_config_dict["tools"] = Tool(function_declarations=tools) - if messages[1]: - # 如果有system消息,则将其添加到配置中 - generation_config_dict["system_instructions"] = messages[1] - if response_format and response_format.format_type == RespFormatType.TEXT: - generation_config_dict["response_mime_type"] = "text/plain" - elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): - generation_config_dict["response_mime_type"] = "application/json" - generation_config_dict["response_schema"] = response_format.to_dict() - # 自动启用 GoogleSearch grounding_tool + def _build_generation_config( + self, + *, + model_identifier: str, + system_instruction: str | None, + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + max_tokens: int | None, + temperature: float | None, + extra_params: Dict[str, Any], + enable_google_search: bool, + ) -> GenerateContentConfig: + """构建 Gemini 生成配置。 + + Args: + model_identifier: 当前请求实际使用的模型标识。 + system_instruction: 系统指令文本。 + tool_options: 内部工具定义列表。 + response_format: 输出格式定义。 + max_tokens: 最大输出 token 数。 + temperature: 温度参数。 + extra_params: 模型级额外参数。 + enable_google_search: 是否自动追加 Google Search 工具。 + + Returns: + GenerateContentConfig: Gemini 生成配置对象。 + """ + config_payload = _filter_generate_content_extra_params(extra_params) + + if max_tokens is not None and "max_output_tokens" not in config_payload: + config_payload["max_output_tokens"] = max_tokens + if temperature is not None and "temperature" not in config_payload: + config_payload["temperature"] = temperature + if system_instruction and "system_instruction" not in config_payload: + config_payload["system_instruction"] = system_instruction + if "response_modalities" not in config_payload: + config_payload["response_modalities"] = ["TEXT"] + if "safety_settings" not in config_payload: + config_payload["safety_settings"] = GEMINI_SAFE_SETTINGS + if "thinking_config" not in config_payload: + config_payload["thinking_config"] = ThinkingConfig( + include_thoughts=bool(extra_params.get("include_thoughts", True)), + thinking_budget=self.clamp_thinking_budget(extra_params, model_identifier), + ) + + tools = _build_tools(tool_options) if tool_options else [] if enable_google_search: - grounding_tool = Tool(google_search=GoogleSearch()) - if "tools" in generation_config_dict: - existing = generation_config_dict["tools"] - if isinstance(existing, list): - existing.append(grounding_tool) + tools.append(Tool(google_search=GoogleSearch())) + if tools: + if "tools" in config_payload: + existing_tools = config_payload["tools"] + if isinstance(existing_tools, list): + config_payload["tools"] = [*existing_tools, *tools] else: - generation_config_dict["tools"] = [existing, grounding_tool] + config_payload["tools"] = [existing_tools, *tools] else: - generation_config_dict["tools"] = [grounding_tool] + config_payload["tools"] = tools - generation_config = GenerateContentConfig(**generation_config_dict) + if response_format is not None: + if response_format.format_type == RespFormatType.TEXT: + config_payload.setdefault("response_mime_type", "text/plain") + elif response_format.format_type == RespFormatType.JSON_OBJ: + config_payload.setdefault("response_mime_type", "application/json") + elif response_format.format_type == RespFormatType.JSON_SCHEMA: + config_payload.setdefault("response_mime_type", "application/json") + response_json_schema = _extract_response_json_schema(response_format) + if ( + response_json_schema is not None + and "response_json_schema" not in config_payload + and "response_schema" not in config_payload + ): + config_payload["response_json_schema"] = response_json_schema + + return GenerateContentConfig(**config_payload) + + def _build_default_stream_response_handler( + self, + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: + """构建 Gemini 默认流式响应处理器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]]: 默认流式处理器。 + """ + del request + return _default_stream_response_handler + + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[GenerateContentResponse]: + """构建 Gemini 默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[GenerateContentResponse]: 默认非流式解析器。 + """ + del request + return _default_normal_response_parser + + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[AsyncIterator[GenerateContentResponse]], + response_parser: ProviderResponseParser[GenerateContentResponse], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + contents, system_instruction = _convert_messages(request.message_list) + model_identifier, enable_google_search = self._resolve_model_identifier( + model_info.model_identifier, + request.extra_params, + ) + generation_config = self._build_generation_config( + model_identifier=model_identifier, + system_instruction=system_instruction, + tool_options=request.tool_options, + response_format=request.response_format, + max_tokens=request.max_tokens, + temperature=request.temperature, + extra_params=request.extra_params, + enable_google_search=enable_google_search, + ) try: if model_info.force_stream_mode: - req_task = asyncio.create_task( + stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task( self.client.aio.models.generate_content_stream( - model=model_info.model_identifier, - contents=messages[0], + model=model_identifier, + contents=contents, config=generation_config, ) ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) - else: - req_task = asyncio.create_task( - self.client.aio.models.generate_content( - model=model_info.model_identifier, - contents=messages[0], - config=generation_config, - ) + raw_response_stream = cast( + AsyncIterator[GenerateContentResponse], + await await_task_with_interrupt(stream_task, request.interrupt_flag), ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + return await stream_response_handler(raw_response_stream, request.interrupt_flag) - resp, usage_record = async_response_parser(req_task.result()) - except (ClientError, ServerError) as e: - # 重封装 ClientError 和 ServerError 为 RespNotOkException - raise RespNotOkException(e.code, e.message) from None - except ( - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, - ) as e: - # 工具调用相关错误 - raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None - except EmptyResponseException as e: - # 保持原始异常,便于区分“空响应”和网络异常 - raise e - except Exception as e: - # 其他未预料的错误,才归为网络连接类 - raise NetworkConnectionError() from e - - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], + completion_task: asyncio.Task[GenerateContentResponse] = asyncio.create_task( + self.client.aio.models.generate_content( + model=model_identifier, + contents=contents, + config=generation_config, + ) ) + raw_response = cast( + GenerateContentResponse, + await await_task_with_interrupt(completion_task, request.interrupt_flag), + ) + return response_parser(raw_response) + except ReqAbortException: + raise + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc: + raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc + except EmptyResponseException: + raise + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - return resp - - async def get_embedding( + async def _execute_embedding_request( self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 文本嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + embedding_input = request.embedding_input + extra_params = request.extra_params + embed_config = _build_embed_content_config(extra_params) + try: raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( model=model_info.model_identifier, contents=embedding_input, - config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + config=embed_config, ) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code) from None - except Exception as e: - raise NetworkConnectionError() from e + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - response = APIResponse() - - # 解析嵌入响应和使用情况 - if hasattr(raw_response, "embeddings") and raw_response.embeddings: + response = APIResponse(raw_data=raw_response) + if raw_response.embeddings: response.embedding = raw_response.embeddings[0].values else: - raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") + raise RespParseException(raw_response, "响应解析失败,缺失 embeddings 字段") - response.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=len(embedding_input), - completion_tokens=0, - total_tokens=len(embedding_input), + billable_character_count = 0 + if raw_response.metadata is not None: + billable_character_count = getattr(raw_response.metadata, "billable_character_count", 0) or 0 + usage_record: UsageTuple = ( + billable_character_count or len(embedding_input), + 0, + billable_character_count or len(embedding_input), ) + return response, usage_record - return response - - async def get_audio_transcriptions( + async def _execute_audio_transcription_request( self, - model_info: ModelInfo, - audio_base64: str, - max_tokens: Optional[int] = 2048, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: 音频文件的Base64编码字符串 - :param max_tokens: 最大输出token数(默认2048) - :param extra_params: 额外参数(可选) - :return: 转录响应 - """ - # 解析并裁剪 thinking_budget - tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier) + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 Gemini 音频转录请求。 - # 构造 prompt + 音频输入 - prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." - contents = [ + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + audio_base64 = request.audio_base64 + max_tokens = request.max_tokens + extra_params = request.extra_params + model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params) + + transcription_prompt = str( + extra_params.get( + "transcription_prompt", + "Generate a transcript of the speech. The language of the transcript should match the speech.", + ) + ) + audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav")) + contents: List[ContentUnion] = [ Content( role="user", parts=[ - Part.from_text(text=prompt), - Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + Part.from_text(text=transcription_prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type), ], ) ] - - generation_config_dict = { - "max_output_tokens": max_tokens, - "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=tb, - ), - "safety_settings": gemini_safe_settings, - } - generate_content_config = GenerateContentConfig(**generation_config_dict) + generation_config = self._build_generation_config( + model_identifier=model_identifier, + system_instruction=None, + tool_options=None, + response_format=None, + max_tokens=max_tokens, + temperature=None, + extra_params=extra_params, + enable_google_search=False, + ) try: raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( - model=model_info.model_identifier, + model=model_identifier, contents=contents, - config=generate_content_config, + config=generation_config, ) - resp, usage_record = _default_normal_response_parser(raw_response) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code) from None - except Exception as e: - raise NetworkConnectionError() from e + response, usage_record = _default_normal_response_parser(raw_response) + except (ClientError, ServerError) as exc: + status_code = int(getattr(exc, "code", 500) or 500) + raise RespNotOkException(status_code, str(exc)) from exc + except Exception as exc: + raise NetworkConnectionError(str(exc)) from exc - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], - ) + return response, usage_record - return resp + def get_support_image_formats(self) -> List[str]: + """获取 Gemini 当前支持的图片格式列表。 - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + Returns: + List[str]: 当前客户端支持的图片格式列表。 """ return ["png", "jpg", "jpeg", "webp", "heic", "heif"] diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 99efe8d9..44e085eb 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -1,742 +1,1000 @@ +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast + import asyncio +import base64 import io import json import re -import base64 -from collections.abc import Iterable -from typing import Callable, Any, Coroutine, Optional -from json_repair import repair_json -from openai import ( - AsyncOpenAI, - APIConnectionError, - APIStatusError, - NOT_GIVEN, - AsyncStream, -) +from json_repair import repair_json +from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream +from openai._types import FileTypes, Omit, omit from openai.types.chat import ( ChatCompletion, + ChatCompletionAssistantMessageParam, ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageFunctionToolCallParam, ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, ChatCompletionToolParam, + ChatCompletionUserMessageParam, ) +from openai.types.shared_params.function_definition import FunctionDefinition from openai.types.chat.chat_completion_chunk import ChoiceDelta -from src.config.model_configs import ModelInfo, APIProvider from src.common.logger import get_logger -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry -from ..exceptions import ( - RespParseException, - NetworkConnectionError, - RespNotOkException, - ReqAbortException, +from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode +from src.llm_models.exceptions import ( EmptyResponseException, + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from src.llm_models.openai_compat import ( + build_openai_compatible_client_config, + split_openai_request_overrides, +) +from src.llm_models.payload_content.message import ImageMessagePart, Message, RoleType, TextMessagePart +from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType +from src.llm_models.payload_content.tool_option import ToolCall, ToolOption + +from .adapter_base import ( + AdapterClient, + ProviderResponseParser, + ProviderStreamResponseHandler, + await_task_with_interrupt, +) +from .base_client import ( + APIResponse, + AudioTranscriptionRequest, + EmbeddingRequest, + ResponseRequest, + UsageTuple, + client_registry, ) -from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("llm_models") +THINK_CONTENT_PATTERN = re.compile( + r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", + re.DOTALL, +) +"""用于解析 `` 推理块的正则表达式。""" + +CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = { + "max_tokens", + "messages", + "model", + "response_format", + "stream", + "temperature", + "tools", +} +"""由当前客户端显式承载、不应再落入 `extra_body` 的字段集合。""" + +OpenAIStreamResponseHandler = Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]], +] +"""OpenAI 流式响应处理函数类型。""" + +OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]] +"""OpenAI 非流式响应解析函数类型。""" + + +def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: + """将配置中的推理解析模式收敛为枚举值。 + + Args: + parse_mode: 原始解析模式配置。 + + Returns: + ReasoningParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 + """ + if isinstance(parse_mode, ReasoningParseMode): + return parse_mode + try: + return ReasoningParseMode(parse_mode) + except ValueError: + logger.warning(f"未识别的推理解析模式 {parse_mode},已回退为 auto") + return ReasoningParseMode.AUTO + + +def _normalize_tool_argument_parse_mode(parse_mode: str | ToolArgumentParseMode) -> ToolArgumentParseMode: + """将配置中的工具参数解析模式收敛为枚举值。 + + Args: + parse_mode: 原始解析模式配置。 + + Returns: + ToolArgumentParseMode: 规范化后的解析模式;未知值会回退为 `AUTO`。 + """ + if isinstance(parse_mode, ToolArgumentParseMode): + return parse_mode + try: + return ToolArgumentParseMode(parse_mode) + except ValueError: + logger.warning(f"未识别的工具参数解析模式 {parse_mode},已回退为 auto") + return ToolArgumentParseMode.AUTO + + +def _build_text_content_part(text: str) -> ChatCompletionContentPartTextParam: + """构建文本内容片段。 + + Args: + text: 文本内容。 + + Returns: + ChatCompletionContentPartTextParam: OpenAI 兼容的文本片段。 + """ + return { + "type": "text", + "text": text, + } + + +def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPartImageParam: + """构建图片内容片段。 + + Args: + part: 内部图片片段。 + + Returns: + ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。 + """ + return { + "type": "image_url", + "image_url": { + "url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}", + }, + } + def _convert_response_format(response_format: RespFormat | None) -> Any: - """ - 转换响应格式 - 将内部RespFormat转换为OpenAI API所需格式 - """ - if response_format is None: - return NOT_GIVEN + """将内部响应格式转换为 OpenAI 兼容结构。 - if response_format.format_type == RespFormatType.TEXT: - return NOT_GIVEN + Args: + response_format: 内部响应格式定义。 + Returns: + Any: OpenAI SDK 可接受的响应格式参数;未指定时返回 `omit`。 + """ + if response_format is None or response_format.format_type == RespFormatType.TEXT: + return omit if response_format.format_type == RespFormatType.JSON_OBJ: return {"type": "json_object"} - if response_format.format_type == RespFormatType.JSON_SCHEMA: return { "type": "json_schema", "json_schema": response_format.schema, } - - return NOT_GIVEN + return omit -def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: +def _convert_text_only_message_content( + message: Message, +) -> str | List[ChatCompletionContentPartTextParam]: + """将仅允许文本的消息转换为 OpenAI 兼容内容。 + + Args: + message: 内部统一消息对象。 + + Returns: + str | List[ChatCompletionContentPartTextParam]: 文本内容结构。 + + Raises: + ValueError: 当消息中包含非文本片段时抛出。 """ - 转换消息格式 - 将消息转换为OpenAI API所需的格式 - :param messages: 消息列表 - :return: 转换后的消息列表 + if not message.parts: + return "" + if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): + return message.parts[0].text + + content: List[ChatCompletionContentPartTextParam] = [] + for part in message.parts: + if not isinstance(part, TextMessagePart): + raise ValueError(f"{message.role.value} 消息仅支持文本片段") + content.append(_build_text_content_part(part.text)) + return content + + +def _convert_user_message_content(message: Message) -> str | List[ChatCompletionContentPartParam]: + """将用户消息转换为 OpenAI 兼容内容。 + + Args: + message: 内部统一消息对象。 + + Returns: + str | List[ChatCompletionContentPartParam]: 用户消息内容结构。 """ + if len(message.parts) == 1 and isinstance(message.parts[0], TextMessagePart): + return message.parts[0].text - def _convert_message_item(message: Message) -> ChatCompletionMessageParam: - """ - 转换单个消息格式 - :param message: 消息对象 - :return: 转换后的消息字典 - """ + content: List[ChatCompletionContentPartParam] = [] + for part in message.parts: + if isinstance(part, TextMessagePart): + content.append(_build_text_content_part(part.text)) + continue + content.append(_build_image_content_part(part)) + return content - # 添加Content - content: str | list[dict[str, Any]] - if isinstance(message.content, str): - content = message.content - elif isinstance(message.content, list): - content = [] - for item in message.content: - if isinstance(item, tuple): - image_format = item[0].lower() - # 规范 JPEG MIME 类型后缀,统一使用 image/jpeg - if image_format in ("jpg", "jpeg"): - mime_suffix = "jpeg" - else: - mime_suffix = image_format - content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"}, - } - ) - elif isinstance(item, str): - content.append({"type": "text", "text": item}) - else: - raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - ret = { - "role": message.role.value, - "content": content, - } +def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatCompletionMessageFunctionToolCallParam]: + """将内部工具调用转换为 OpenAI assistant tool_calls 结构。 - if message.role == RoleType.Assistant and getattr(message, "tool_calls", None): - tool_calls_payload: list[dict[str, Any]] = [] - for call in message.tool_calls or []: - tool_calls_payload.append( - { - "id": call.call_id, - "type": "function", - "function": { - "name": call.func_name, - "arguments": json.dumps(call.args or {}, ensure_ascii=False), - }, - } - ) - ret["tool_calls"] = tool_calls_payload - if ret["content"] == []: - ret["content"] = "" + Args: + tool_calls: 内部工具调用列表。 + + Returns: + List[ChatCompletionMessageFunctionToolCallParam]: OpenAI 兼容工具调用结构。 + """ + converted_tool_calls: List[ChatCompletionMessageFunctionToolCallParam] = [] + for tool_call in tool_calls: + converted_tool_calls.append( + { + "id": tool_call.call_id, + "type": "function", + "function": { + "name": tool_call.func_name, + "arguments": json.dumps(tool_call.args or {}, ensure_ascii=False), + }, + } + ) + return converted_tool_calls + + +def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]: + """将内部消息列表转换为 OpenAI 兼容消息列表。 + + Args: + messages: 内部统一消息列表。 + + Returns: + List[ChatCompletionMessageParam]: OpenAI SDK 所需的消息结构列表。 + """ + converted_messages: List[ChatCompletionMessageParam] = [] + for message in messages: + if message.role == RoleType.System: + system_payload: ChatCompletionSystemMessageParam = { + "role": "system", + "content": _convert_text_only_message_content(message), + } + converted_messages.append(system_payload) + continue + + if message.role == RoleType.User: + user_payload: ChatCompletionUserMessageParam = { + "role": "user", + "content": _convert_user_message_content(message), + } + converted_messages.append(user_payload) + continue + + if message.role == RoleType.Assistant: + assistant_payload: ChatCompletionAssistantMessageParam = { + "role": "assistant", + "content": None if not message.parts and message.tool_calls else _convert_text_only_message_content(message), + } + if message.tool_calls: + assistant_payload["tool_calls"] = _convert_assistant_tool_calls(message.tool_calls) + converted_messages.append(assistant_payload) + continue - # 添加工具调用ID if message.role == RoleType.Tool: - if not message.tool_call_id: - raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") - ret["tool_call_id"] = message.tool_call_id + if message.tool_call_id is None: + raise ValueError("Tool 消息缺少 tool_call_id") + tool_payload: ChatCompletionToolMessageParam = { + "role": "tool", + "content": _convert_text_only_message_content(message), + "tool_call_id": message.tool_call_id, + } + converted_messages.append(tool_payload) + continue - return ret # type: ignore + raise ValueError(f"不支持的消息角色:{message.role}") - return [_convert_message_item(message) for message in messages] + return converted_messages -def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]: +def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletionToolParam]: + """将工具定义转换为 OpenAI 兼容的工具列表。 + + Args: + tool_options: 内部统一工具定义列表。 + + Returns: + List[ChatCompletionToolParam]: OpenAI SDK 所需的工具定义列表。 """ - 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式 - :param tool_options: 工具选项列表 - :return: 转换后的工具选项列表 - """ - - def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]: - """ - 转换单个工具参数格式 - :param tool_option_param: 工具参数对象 - :return: 转换后的工具参数字典 - """ - # JSON Schema 类型名称修正: - # - 布尔类型使用 "boolean" 而不是 "bool" - # - 浮点数使用 "number" 而不是 "float" - param_type_value = tool_option_param.param_type.value - if param_type_value == "bool": - param_type_value = "boolean" - elif param_type_value == "float": - param_type_value = "number" - - return_dict: dict[str, Any] = { - "type": param_type_value, - "description": tool_option_param.description, - } - if tool_option_param.enum_values: - return_dict["enum"] = tool_option_param.enum_values - return return_dict - - def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: - """ - 转换单个工具项格式 - :param tool_option: 工具选项对象 - :return: 转换后的工具选项字典 - """ - ret: dict[str, Any] = { + converted_tools: List[ChatCompletionToolParam] = [] + for tool_option in tool_options: + function_schema: FunctionDefinition = { "name": tool_option.name, "description": tool_option.description, } - if tool_option.params: - ret["parameters"] = { - "type": "object", - "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, - "required": [param.name for param in tool_option.params if param.required], + parameters_schema = tool_option.parameters_schema + if parameters_schema is not None: + function_schema["parameters"] = cast(Dict[str, object], parameters_schema) + converted_tools.append( + { + "type": "function", + "function": function_schema, } - return ret - - return [ - { - "type": "function", - "function": _convert_tool_option_item(tool_option), - } - for tool_option in tool_options - ] + ) + return converted_tools -def _process_delta( - delta: ChoiceDelta, - has_rc_attr_flag: bool, - in_rc_flag: bool, - rc_delta_buffer: io.StringIO, - fc_delta_buffer: io.StringIO, - tool_calls_buffer: list[tuple[str, str, io.StringIO]], -) -> bool: - # 接收content - if has_rc_attr_flag: - # 有独立的推理内容块,则无需考虑content内容的判读 - if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore - # 如果有推理内容,则将其写入推理内容缓冲区 - assert isinstance(delta.reasoning_content, str) # type: ignore - rc_delta_buffer.write(delta.reasoning_content) # type: ignore - elif delta.content: - # 如果有正式内容,则将其写入正式内容缓冲区 - fc_delta_buffer.write(delta.content) - elif hasattr(delta, "content") and delta.content is not None: - # 没有独立的推理内容块,但有正式内容 - if in_rc_flag: - # 当前在推理内容块中 - if delta.content == "": - # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块 - in_rc_flag = False - else: - # 其他情况视为推理内容,加入推理内容缓冲区 - rc_delta_buffer.write(delta.content) - elif delta.content == "" and not fc_delta_buffer.getvalue(): - # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token - # 则将其视为推理内容的开始标记,进入推理内容块 - in_rc_flag = True +def _extract_usage_record(usage: Any) -> UsageTuple | None: + """从响应对象中提取 usage 三元组。 + + Args: + usage: OpenAI SDK 返回的 usage 对象。 + + Returns: + UsageTuple | None: `(prompt_tokens, completion_tokens, total_tokens)`。 + """ + if usage is None: + return None + return ( + getattr(usage, "prompt_tokens", 0) or 0, + getattr(usage, "completion_tokens", 0) or 0, + getattr(usage, "total_tokens", 0) or 0, + ) + + +def _parse_tool_arguments( + raw_arguments: str, + parse_mode: ToolArgumentParseMode, + response: Any, +) -> Dict[str, Any]: + """解析工具调用参数字符串。 + + Args: + raw_arguments: 工具调用参数原始字符串。 + parse_mode: 参数解析模式。 + response: 原始响应对象,用于异常上下文。 + + Returns: + Dict[str, Any]: 解析后的参数字典。 + + Raises: + RespParseException: 当参数无法解析为字典时抛出。 + """ + try: + if parse_mode == ToolArgumentParseMode.STRICT: + arguments: Any = json.loads(raw_arguments) + elif parse_mode == ToolArgumentParseMode.REPAIR: + arguments = repair_json(raw_arguments, return_objects=True, logging=False) else: - # 其他情况视为正式内容,加入正式内容缓冲区 - fc_delta_buffer.write(delta.content) - # 接收tool_calls - if hasattr(delta, "tool_calls") and delta.tool_calls: - tool_call_delta = delta.tool_calls[0] + arguments = repair_json(raw_arguments, return_objects=True, logging=False) + if isinstance(arguments, str) and parse_mode in { + ToolArgumentParseMode.AUTO, + ToolArgumentParseMode.DOUBLE_DECODE, + }: + arguments = repair_json(arguments, return_objects=True, logging=False) + except json.JSONDecodeError as exc: + raise RespParseException(response, f"响应解析失败,无法解析工具调用参数。原始参数:{raw_arguments}") from exc - if tool_call_delta.index >= len(tool_calls_buffer): - # 调用索引号大于等于缓冲区长度,说明是新的工具调用 - if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name: - tool_calls_buffer.append( - ( - tool_call_delta.id, - tool_call_delta.function.name, - io.StringIO(), - ) + if not isinstance(arguments, dict): + raise RespParseException( + response, + f"响应解析失败,工具调用参数必须解析为字典,实际类型为 {type(arguments).__name__}。", + ) + return arguments + + +def _extract_reasoning_and_content( + content: str, + parse_mode: ReasoningParseMode, +) -> Tuple[str | None, str | None]: + """从文本内容中提取推理内容与正式输出。 + + Args: + content: 模型返回的文本内容。 + parse_mode: 推理解析模式。 + + Returns: + Tuple[str | None, str | None]: `(reasoning_content, content)`。 + """ + if parse_mode in {ReasoningParseMode.NATIVE, ReasoningParseMode.NONE}: + return None, content + + match = THINK_CONTENT_PATTERN.match(content) + if not match: + return None, content + if match.group("think") is not None: + reasoning_content = match.group("think").strip() or None + final_content = match.group("content").strip() or None + return reasoning_content, final_content + if match.group("think_unclosed") is not None: + return match.group("think_unclosed").strip() or None, None + return None, match.group("content_only").strip() or None + + +def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None: + """记录因长度截断导致的告警日志。 + + Args: + finish_reason: OpenAI 兼容接口返回的完成原因。 + model_name: 上游返回的模型标识。 + """ + if finish_reason == "length": + logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整") + + +def _coerce_openai_argument(value: Any) -> Any | Omit: + """将可选参数转换为 OpenAI SDK 期望的值。 + + Args: + value: 原始参数值。 + + Returns: + Any | Omit: `None` 会被转换为 `omit`,其余值原样返回。 + """ + if value is None: + return omit + return value + + +def _build_api_status_message(error: APIStatusError) -> str: + """构建更适合记录和展示的状态错误信息。 + + Args: + error: OpenAI SDK 抛出的状态错误。 + + Returns: + str: 拼装后的错误信息。 + """ + message_parts: List[str] = [] + if getattr(error, "message", None): + message_parts.append(str(error.message)) + response_text = getattr(getattr(error, "response", None), "text", None) + if response_text: + message_parts.append(str(response_text)[:300]) + if message_parts: + return " | ".join(message_parts) + return f"上游接口返回状态码 {error.status_code}" + + +@dataclass(slots=True) +class _StreamedToolCallState: + """流式工具调用累积状态。""" + + index: int + call_id: str = "" + function_name: str = "" + arguments_buffer: io.StringIO = field(default_factory=io.StringIO) + + def append_arguments(self, arguments_chunk: str) -> None: + """追加一段工具调用参数字符串。 + + Args: + arguments_chunk: 参数增量片段。 + """ + self.arguments_buffer.write(arguments_chunk) + + def close(self) -> None: + """关闭内部缓存。""" + if not self.arguments_buffer.closed: + self.arguments_buffer.close() + + +class _OpenAIStreamAccumulator: + """OpenAI 兼容流式响应累积器。""" + + def __init__( + self, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, + ) -> None: + """初始化累积器。 + + Args: + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + """ + self.reasoning_parse_mode = reasoning_parse_mode + self.tool_argument_parse_mode = tool_argument_parse_mode + self.reasoning_buffer = io.StringIO() + self.content_buffer = io.StringIO() + self.tool_call_states: Dict[int, _StreamedToolCallState] = {} + self.finish_reason: str | None = None + self.model_name: str | None = None + self._using_native_reasoning = False + + def capture_event_metadata(self, event: ChatCompletionChunk) -> None: + """捕获事件中的完成原因和模型名。 + + Args: + event: 当前流式事件。 + """ + if getattr(event, "model", None) and not self.model_name: + self.model_name = event.model + if getattr(event, "choices", None): + finish_reason = getattr(event.choices[0], "finish_reason", None) + if finish_reason: + self.finish_reason = finish_reason + + def process_delta(self, delta: ChoiceDelta) -> None: + """处理一个增量块。 + + Args: + delta: 当前增量对象。 + """ + self._process_reasoning_delta(delta) + self._process_tool_call_delta(delta) + + def _process_reasoning_delta(self, delta: ChoiceDelta) -> None: + """处理推理内容与正式内容。 + + Args: + delta: 当前增量对象。 + """ + native_reasoning = getattr(delta, "reasoning_content", None) + if isinstance(native_reasoning, str) and native_reasoning: + self._using_native_reasoning = True + if self.reasoning_parse_mode != ReasoningParseMode.NONE: + self.reasoning_buffer.write(native_reasoning) + return + + content_chunk = getattr(delta, "content", None) + if not isinstance(content_chunk, str) or content_chunk == "": + return + + if self.reasoning_parse_mode == ReasoningParseMode.NONE: + self.content_buffer.write(content_chunk) + return + + if self.reasoning_parse_mode == ReasoningParseMode.NATIVE: + self.content_buffer.write(content_chunk) + return + + self.content_buffer.write(content_chunk) + + def _process_tool_call_delta(self, delta: ChoiceDelta) -> None: + """处理工具调用增量。 + + Args: + delta: 当前增量对象。 + """ + tool_call_deltas = getattr(delta, "tool_calls", None) or [] + for tool_call_delta in tool_call_deltas: + state = self.tool_call_states.setdefault(tool_call_delta.index, _StreamedToolCallState(index=tool_call_delta.index)) + if tool_call_delta.id: + state.call_id = tool_call_delta.id + function = tool_call_delta.function + if function is not None and function.name: + state.function_name = function.name + if function is not None and function.arguments: + state.append_arguments(function.arguments) + + def build_response(self) -> APIResponse: + """构建最终 APIResponse 对象。 + + Returns: + APIResponse: 累积完成的响应对象。 + + Raises: + EmptyResponseException: 当响应中既无可见内容也无工具调用时抛出。 + RespParseException: 当工具调用结构不完整时抛出。 + """ + response = APIResponse() + + content = self.content_buffer.getvalue().strip() + reasoning_content = self.reasoning_buffer.getvalue().strip() + if not self._using_native_reasoning and self.reasoning_parse_mode != ReasoningParseMode.NONE and content: + parsed_reasoning_content, parsed_content = _extract_reasoning_and_content( + content=content, + parse_mode=self.reasoning_parse_mode, + ) + if parsed_reasoning_content: + reasoning_content = parsed_reasoning_content + content = parsed_content or "" + if reasoning_content: + response.reasoning_content = reasoning_content + if content: + response.content = content + + if self.tool_call_states: + response.tool_calls = [] + for index in sorted(self.tool_call_states): + state = self.tool_call_states[index] + if not state.function_name: + raise RespParseException(None, f"响应解析失败,工具调用 {index} 缺少函数名。") + raw_arguments = state.arguments_buffer.getvalue().strip() + arguments = ( + _parse_tool_arguments(raw_arguments, self.tool_argument_parse_mode, None) + if raw_arguments + else None ) - else: - logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。") + call_id = state.call_id or f"tool_call_{index}" + response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) - if tool_call_delta.function and tool_call_delta.function.arguments: - # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 - tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments) + response.raw_data = {"model": self.model_name} if self.model_name else None - return in_rc_flag + if not response.content and not response.tool_calls: + raise EmptyResponseException() + return response -def _build_stream_api_resp( - _fc_delta_buffer: io.StringIO, - _rc_delta_buffer: io.StringIO, - _tool_calls_buffer: list[tuple[str, str, io.StringIO]], - finish_reason: str | None = None, -) -> APIResponse: - resp = APIResponse() - - if _rc_delta_buffer.tell() > 0: - # 如果推理内容缓冲区不为空,则将其写入APIResponse对象 - resp.reasoning_content = _rc_delta_buffer.getvalue() - _rc_delta_buffer.close() - if _fc_delta_buffer.tell() > 0: - # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 - resp.content = _fc_delta_buffer.getvalue() - _fc_delta_buffer.close() - if _tool_calls_buffer: - # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 - resp.tool_calls = [] - for call_id, function_name, arguments_buffer in _tool_calls_buffer: - if arguments_buffer.tell() > 0: - # 如果参数串缓冲区不为空,则解析为JSON对象 - raw_arg_data = arguments_buffer.getvalue() - arguments_buffer.close() - try: - arguments = json.loads(repair_json(raw_arg_data)) - if not isinstance(arguments, dict): - raise RespParseException( - None, - f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}", - ) - except json.JSONDecodeError as e: - raise RespParseException( - None, - f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}", - ) from e - else: - arguments_buffer.close() - arguments = None - - resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) - - # 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出) - # 保留 finish_reason 仅用于上层判断 - - if not resp.content and not resp.tool_calls: - raise EmptyResponseException() - - return resp + def close(self) -> None: + """关闭内部缓冲区。""" + if not self.reasoning_buffer.closed: + self.reasoning_buffer.close() + if not self.content_buffer.closed: + self.content_buffer.close() + for state in self.tool_call_states.values(): + state.close() async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + *, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, +) -> Tuple[APIResponse, UsageTuple | None]: + """处理 OpenAI 兼容流式响应。 + + Args: + resp_stream: OpenAI SDK 返回的流式响应对象。 + interrupt_flag: 外部中断标记。 + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 """ - 流式响应处理函数 - 处理OpenAI API的流式响应 - :param resp_stream: 流式响应对象 - :return: APIResponse对象 - """ - - _has_rc_attr_flag = False # 标记是否有独立的推理内容块 - _in_rc_flag = False # 标记是否在推理内容块中 - _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 - _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 - _usage_record = None # 使用情况记录 - finish_reason: str | None = None # 记录最后的 finish_reason - _model_name: str | None = None # 记录模型名 - - def _insure_buffer_closed(): - # 确保缓冲区被关闭 - if _rc_delta_buffer and not _rc_delta_buffer.closed: - _rc_delta_buffer.close() - if _fc_delta_buffer and not _fc_delta_buffer.closed: - _fc_delta_buffer.close() - for _, _, buffer in _tool_calls_buffer: - if buffer and not buffer.closed: - buffer.close() - - async for event in resp_stream: - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量被设置,则抛出ReqAbortException - _insure_buffer_closed() - raise ReqAbortException("请求被外部信号中断") - # 空 choices / usage-only 帧的防御 - if not hasattr(event, "choices") or not event.choices: - if hasattr(event, "usage") and event.usage: - _usage_record = ( - event.usage.prompt_tokens or 0, - event.usage.completion_tokens or 0, - event.usage.total_tokens or 0, - ) - continue # 跳过本帧,避免访问 choices[0] - delta = event.choices[0].delta # 获取当前块的delta内容 - - if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason: - finish_reason = event.choices[0].finish_reason - - if hasattr(event, "model") and event.model and not _model_name: - _model_name = event.model # 记录模型名 - - if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore - # 标记:有独立的推理内容块 - _has_rc_attr_flag = True - - _in_rc_flag = _process_delta( - delta, - _has_rc_attr_flag, - _in_rc_flag, - _rc_delta_buffer, - _fc_delta_buffer, - _tool_calls_buffer, - ) - - if event.usage: - # 如果有使用情况,则将其存储在APIResponse对象中 - _usage_record = ( - event.usage.prompt_tokens or 0, - event.usage.completion_tokens or 0, - event.usage.total_tokens or 0, - ) + accumulator = _OpenAIStreamAccumulator( + reasoning_parse_mode=reasoning_parse_mode, + tool_argument_parse_mode=tool_argument_parse_mode, + ) + usage_record: UsageTuple | None = None try: - resp = _build_stream_api_resp( - _fc_delta_buffer, - _rc_delta_buffer, - _tool_calls_buffer, - finish_reason=finish_reason, - ) - # 统一在这里输出 max_tokens 截断的警告,并从 resp 中读取 - if finish_reason == "length": - # 把模型名塞到 resp.raw_data,后续严格“从 resp 提取” - try: - if _model_name: - resp.raw_data = {"model": _model_name} - except Exception: - pass - model_dbg = None - try: - if isinstance(resp.raw_data, dict): - model_dbg = resp.raw_data.get("model") - except Exception: - model_dbg = None + async for event in resp_stream: + if interrupt_flag and interrupt_flag.is_set(): + raise ReqAbortException("请求被外部信号中断") - # 统一日志格式 - logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (model_dbg or "")) + accumulator.capture_event_metadata(event) + event_usage = _extract_usage_record(getattr(event, "usage", None)) + if event_usage is not None: + usage_record = event_usage - return resp, _usage_record - except Exception: - # 确保缓冲区被关闭 - _insure_buffer_closed() - raise + if not getattr(event, "choices", None): + continue + accumulator.process_delta(event.choices[0].delta) -pattern = re.compile( - r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", - re.DOTALL, -) -"""用于解析推理内容的正则表达式""" + response = accumulator.build_response() + model_name = None + if isinstance(response.raw_data, dict): + model_name = response.raw_data.get("model") + _log_length_truncation(accumulator.finish_reason, model_name) + return response, usage_record + finally: + accumulator.close() def _default_normal_response_parser( resp: ChatCompletion, -) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: - """ - 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 - :param resp: 响应对象 - :return: APIResponse对象 - """ - api_response = APIResponse() + *, + reasoning_parse_mode: ReasoningParseMode, + tool_argument_parse_mode: ToolArgumentParseMode, +) -> Tuple[APIResponse, UsageTuple | None]: + """解析 OpenAI 兼容的非流式响应。 - # 兼容部分 OpenAI 兼容服务在空回复时返回 choices=None 的情况 + Args: + resp: OpenAI SDK 返回的聊天补全响应。 + reasoning_parse_mode: 推理内容解析模式。 + tool_argument_parse_mode: 工具参数解析模式。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 + + Raises: + EmptyResponseException: 当 choices 为空或响应内容为空时抛出。 + """ choices = getattr(resp, "choices", None) if not choices: - try: - model_dbg = getattr(resp, "model", None) - id_dbg = getattr(resp, "id", None) - usage_dbg = None - if hasattr(resp, "usage") and resp.usage: - usage_dbg = { - "prompt": getattr(resp.usage, "prompt_tokens", None), - "completion": getattr(resp.usage, "completion_tokens", None), - "total": getattr(resp.usage, "total_tokens", None), - } - try: - raw_snippet = str(resp)[:300] - except Exception: - raw_snippet = "" - logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}") - except Exception: - # 日志采集失败不应影响控制流 - pass - # 统一抛出可重试的 EmptyResponseException,触发上层重试逻辑 raise EmptyResponseException("响应解析失败,choices 为空或缺失") + + api_response = APIResponse() message_part = choices[0].message + native_reasoning = getattr(message_part, "reasoning_content", None) + message_content = message_part.content if isinstance(message_part.content, str) else None - if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore - # 有有效的推理字段 - api_response.content = message_part.content - api_response.reasoning_content = message_part.reasoning_content # type: ignore - elif message_part.content: - # 提取推理和内容 - match = pattern.match(message_part.content) - if not match: - raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容") - if match.group("think") is not None: - result = match.group("think").strip(), match.group("content").strip() - elif match.group("think_unclosed") is not None: - result = match.group("think_unclosed").strip(), None - else: - result = None, match.group("content_only").strip() - api_response.reasoning_content, api_response.content = result - - # 提取工具调用 - if message_part.tool_calls: - api_response.tool_calls = [] - for call in message_part.tool_calls: - try: - arguments = json.loads(repair_json(call.function.arguments)) - # 【新增修复逻辑】如果解析出来还是字符串,说明发生了双重编码,尝试二次解析 - if isinstance(arguments, str): - try: - # 尝试对字符串内容再次进行修复和解析 - arguments = json.loads(repair_json(arguments)) - except Exception: - # 如果二次解析失败,保留原值,让下方的 isinstance(dict) 抛出更具体的错误 - pass - if not isinstance(arguments, dict): - # 此时为了调试方便,建议打印出 arguments 的类型 - raise RespParseException( - resp, - f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}", - ) - api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) - except json.JSONDecodeError as e: - raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e - - # 提取Usage信息 - if resp.usage: - _usage_record = ( - resp.usage.prompt_tokens or 0, - resp.usage.completion_tokens or 0, - resp.usage.total_tokens or 0, + if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE: + api_response.reasoning_content = native_reasoning + api_response.content = message_content + elif isinstance(message_content, str) and message_content: + reasoning_content, final_content = _extract_reasoning_and_content( + content=message_content, + parse_mode=reasoning_parse_mode, ) - else: - _usage_record = None + api_response.reasoning_content = reasoning_content + api_response.content = final_content - # 将原始响应存储在原始数据中 + tool_calls = getattr(message_part, "tool_calls", None) or [] + if tool_calls: + api_response.tool_calls = [] + for tool_call in tool_calls: + if tool_call.type != "function": + raise RespParseException(resp, f"响应解析失败,暂不支持工具调用类型 {tool_call.type}。") + raw_arguments = tool_call.function.arguments or "" + arguments = _parse_tool_arguments(raw_arguments, tool_argument_parse_mode, resp) + api_response.tool_calls.append( + ToolCall( + call_id=tool_call.id, + func_name=tool_call.function.name, + args=arguments, + ) + ) + + usage_record = _extract_usage_record(getattr(resp, "usage", None)) api_response.raw_data = resp - # 检查 max_tokens 截断 - try: - choice0 = resp.choices[0] - reason = getattr(choice0, "finish_reason", None) - if reason and reason == "length": - # print(resp) - _model_name = resp.model - # 统一日志格式 - logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (_model_name or "")) - return api_response, _usage_record - except Exception as e: - logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") + finish_reason = getattr(resp.choices[0], "finish_reason", None) + _log_length_truncation(finish_reason, getattr(resp, "model", None)) if not api_response.content and not api_response.tool_calls: raise EmptyResponseException() - return api_response, _usage_record + return api_response, usage_record @client_registry.register_client_class("openai") -class OpenaiClient(BaseClient): - def __init__(self, api_provider: APIProvider): +class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletion]): + """OpenAI 兼容客户端。""" + + client: AsyncOpenAI + reasoning_parse_mode: ReasoningParseMode + tool_argument_parse_mode: ToolArgumentParseMode + + def __init__(self, api_provider: APIProvider) -> None: + """初始化 OpenAI 兼容客户端。 + + Args: + api_provider: API 提供商配置。 + """ super().__init__(api_provider) - self.client: AsyncOpenAI = AsyncOpenAI( - base_url=api_provider.base_url, - api_key=api_provider.api_key, - max_retries=0, + client_config = build_openai_compatible_client_config(api_provider) + self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode) + self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode) + self.client = AsyncOpenAI( + api_key=client_config.api_key, + organization=api_provider.organization, + project=api_provider.project, + base_url=client_config.base_url, timeout=api_provider.timeout, + max_retries=api_provider.max_retry, + default_headers=client_config.default_headers or None, + default_query=client_config.default_query or None, ) - async def get_response( + def _build_default_stream_response_handler( self, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: Optional[int] = 1024, - temperature: Optional[float] = 0.7, - response_format: RespFormat | None = None, - stream_response_handler: Optional[ - Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], - ] - ] = None, - async_response_parser: Optional[ - Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] - ] = None, - interrupt_flag: asyncio.Event | None = None, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取对话响应 + request: ResponseRequest, + ) -> ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: + """构建 OpenAI 默认流式响应处理器。 + Args: - model_info: 模型信息 - message_list: 对话体 - tool_options: 工具选项(可选,默认为None) - max_tokens: 最大token数(可选,默认为1024) - temperature: 温度(可选,默认为0.7) - response_format: 响应格式(可选,默认为 NotGiven ) - stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) - async_response_parser: 响应解析函数(可选,默认为default_response_parser) - interrupt_flag: 中断信号量(可选,默认为None) + request: 统一响应请求对象。 + Returns: - (响应文本, 推理文本, 工具调用, 其他数据) + ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]]: 默认流式处理器。 """ - if stream_response_handler is None: - stream_response_handler = _default_stream_response_handler + del request - if async_response_parser is None: - async_response_parser = _default_normal_response_parser + async def default_stream_handler( + resp_stream: AsyncStream[ChatCompletionChunk], + flag: asyncio.Event | None, + ) -> Tuple[APIResponse, UsageTuple | None]: + """包装默认流式解析器。""" + return await _default_stream_response_handler( + resp_stream, + flag, + reasoning_parse_mode=self.reasoning_parse_mode, + tool_argument_parse_mode=self.tool_argument_parse_mode, + ) - # 将messages构造为OpenAI API所需的格式 - messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) - # 将tool_options转换为OpenAI API所需的格式 - tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore - openai_response_format = _convert_response_format(response_format) + return default_stream_handler + + def _build_default_response_parser( + self, + request: ResponseRequest, + ) -> ProviderResponseParser[ChatCompletion]: + """构建 OpenAI 默认非流式响应解析器。 + + Args: + request: 统一响应请求对象。 + + Returns: + ProviderResponseParser[ChatCompletion]: 默认非流式解析器。 + """ + del request + + def default_response_parser( + response: ChatCompletion, + ) -> Tuple[APIResponse, UsageTuple | None]: + """包装默认非流式解析器。""" + return _default_normal_response_parser( + response, + reasoning_parse_mode=self.reasoning_parse_mode, + tool_argument_parse_mode=self.tool_argument_parse_mode, + ) + + return default_response_parser + + async def _execute_response_request( + self, + request: ResponseRequest, + stream_response_handler: ProviderStreamResponseHandler[AsyncStream[ChatCompletionChunk]], + response_parser: ProviderResponseParser[ChatCompletion], + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的文本/多模态响应请求。 + + Args: + request: 统一响应请求对象。 + stream_response_handler: 流式响应处理器。 + response_parser: 非流式响应解析器。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 + """ + model_info = request.model_info + messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list) + tools: Iterable[ChatCompletionToolParam] | Omit = ( + _convert_tool_options(request.tool_options) if request.tool_options else omit + ) + openai_response_format = _convert_response_format(request.response_format) + request_overrides = split_openai_request_overrides( + request.extra_params, + reserved_body_keys=CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS, + ) + + temperature_argument = ( + omit if "temperature" in request_overrides.extra_body else _coerce_openai_argument(request.temperature) + ) + max_tokens_argument = ( + omit + if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body + else _coerce_openai_argument(request.max_tokens) + ) try: if model_info.force_stream_mode: - req_task = asyncio.create_task( + stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task( self.client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, - temperature=temperature, - max_tokens=max_tokens, + temperature=temperature_argument, + max_tokens=max_tokens_argument, stream=True, response_format=openai_response_format, - extra_body=extra_params, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - - resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) - else: - # 发送请求并获取响应 - # start_time = time.time() - req_task = asyncio.create_task( - self.client.chat.completions.create( - model=model_info.model_identifier, - messages=messages, - tools=tools, - temperature=temperature, - max_tokens=max_tokens, - stream=False, - response_format=openai_response_format, - extra_body=extra_params, - ) + raw_response = cast( + AsyncStream[ChatCompletionChunk], + await await_task_with_interrupt(stream_task, request.interrupt_flag), ) - while not req_task.done(): - if interrupt_flag and interrupt_flag.is_set(): - # 如果中断量存在且被设置,则取消任务并抛出异常 - req_task.cancel() - raise ReqAbortException("请求被外部信号中断") - await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态 + return await stream_response_handler(raw_response, request.interrupt_flag) - # logger. - # logger.debug(f"OpenAI API响应(非流式): {req_task.result()}") - - # logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}") - - resp, usage_record = async_response_parser(req_task.result()) - except APIConnectionError as e: - # 重封装APIConnectionError为NetworkConnectionError - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code, e.message) from e - - if usage_record: - resp.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=usage_record[0], - completion_tokens=usage_record[1], - total_tokens=usage_record[2], + completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature_argument, + max_tokens=max_tokens_argument, + stream=False, + response_format=openai_response_format, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, + ) ) + raw_response = cast( + ChatCompletion, + await await_task_with_interrupt(completion_task, request.interrupt_flag), + ) + return response_parser(raw_response) + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc - # logger.debug(f"OpenAI API响应: {resp}") - - return resp - - async def get_embedding( + async def _execute_embedding_request( self, - model_info: ModelInfo, - embedding_input: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取文本嵌入 - :param model_info: 模型信息 - :param embedding_input: 嵌入输入文本 - :return: 嵌入响应 + request: EmbeddingRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的文本嵌入请求。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + embedding_input = request.embedding_input + extra_params = request.extra_params + request_overrides = split_openai_request_overrides(extra_params) + try: raw_response = await self.client.embeddings.create( model=model_info.model_identifier, input=embedding_input, - extra_body=extra_params, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) - except APIConnectionError as e: - # 添加详细的错误信息以便调试 - logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") - logger.error(f"错误类型: {type(e)}") - if hasattr(e, "__cause__") and e.__cause__: - logger.error(f"底层错误: {str(e.__cause__)}") - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code) from e + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc response = APIResponse() - - # 解析嵌入响应 - if len(raw_response.data) > 0: + if raw_response.data: response.embedding = raw_response.data[0].embedding else: - raise RespParseException( - raw_response, - "响应解析失败,缺失嵌入数据。", - ) + raise RespParseException(raw_response, "响应解析失败,缺失嵌入数据。") - # 解析使用情况 - if hasattr(raw_response, "usage"): - response.usage = UsageRecord( - model_name=model_info.name, - provider_name=model_info.api_provider, - prompt_tokens=raw_response.usage.prompt_tokens or 0, - completion_tokens=getattr(raw_response.usage, "completion_tokens", 0), - total_tokens=raw_response.usage.total_tokens or 0, - ) + usage_record = _extract_usage_record(getattr(raw_response, "usage", None)) + return response, usage_record - return response - - async def get_audio_transcriptions( + async def _execute_audio_transcription_request( self, - model_info: ModelInfo, - audio_base64: str, - extra_params: dict[str, Any] | None = None, - ) -> APIResponse: - """ - 获取音频转录 - :param model_info: 模型信息 - :param audio_base64: base64编码的音频数据 - :extra_params: 附加的请求参数 - :return: 音频转录响应 + request: AudioTranscriptionRequest, + ) -> Tuple[APIResponse, UsageTuple | None]: + """执行 OpenAI 兼容的音频转录请求。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 """ + model_info = request.model_info + audio_base64 = request.audio_base64 + extra_params = request.extra_params + request_overrides = split_openai_request_overrides(extra_params) + audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64))) + try: raw_response = await self.client.audio.transcriptions.create( model=model_info.model_identifier, - file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))), - extra_body=extra_params, + file=audio_file, + extra_headers=request_overrides.extra_headers or None, + extra_query=request_overrides.extra_query or None, + extra_body=request_overrides.extra_body or None, ) - except APIConnectionError as e: - raise NetworkConnectionError() from e - except APIStatusError as e: - # 重封装APIError为RespNotOkException - raise RespNotOkException(e.status_code) from e - response = APIResponse() - # 解析转录响应 - if hasattr(raw_response, "text"): - response.content = raw_response.text - else: - raise RespParseException( - raw_response, - "响应解析失败,缺失转录文本。", - ) - return response + except APIConnectionError as exc: + raise NetworkConnectionError(str(exc)) from exc + except APIStatusError as exc: + raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc - def get_support_image_formats(self) -> list[str]: - """ - 获取支持的图片格式 - :return: 支持的图片格式列表 + response = APIResponse() + transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None) + if isinstance(transcription_text, str): + response.content = transcription_text + return response, None + raise RespParseException(raw_response, "响应解析失败,缺失转录文本。") + + def get_support_image_formats(self) -> List[str]: + """获取支持的图片格式列表。 + + Returns: + List[str]: 当前客户端支持的图片格式列表。 """ return ["jpg", "jpeg", "png", "webp", "gif"] diff --git a/src/llm_models/openai_compat.py b/src/llm_models/openai_compat.py new file mode 100644 index 00000000..19190e0a --- /dev/null +++ b/src/llm_models/openai_compat.py @@ -0,0 +1,140 @@ +from dataclasses import dataclass, field +from typing import Any, Mapping + +from src.config.model_configs import APIProvider, OpenAICompatibleAuthType + + +@dataclass(slots=True) +class OpenAICompatibleClientConfig: + """OpenAI 兼容客户端的基础配置。""" + + api_key: str + base_url: str + default_headers: dict[str, str] = field(default_factory=dict) + default_query: dict[str, object] = field(default_factory=dict) + + +@dataclass(slots=True) +class OpenAICompatibleRequestOverrides: + """单次请求级别的附加配置。""" + + extra_headers: dict[str, str] = field(default_factory=dict) + extra_query: dict[str, object] = field(default_factory=dict) + extra_body: dict[str, Any] = field(default_factory=dict) + + +def normalize_openai_base_url(base_url: str) -> str: + """规范化 OpenAI 兼容接口的基础地址。 + + Args: + base_url: 原始基础地址。 + + Returns: + str: 去掉尾部斜杠后的地址。 + """ + return base_url.rstrip("/") + + +def _build_auth_header_value(prefix: str, api_key: str) -> str: + """构造鉴权请求头的值。 + + Args: + prefix: 请求头前缀。 + api_key: 实际密钥。 + + Returns: + str: 拼接完成的请求头值。 + """ + normalized_prefix = prefix.strip() + if not normalized_prefix: + return api_key + return f"{normalized_prefix} {api_key}" + + +def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig: + """构建 OpenAI 兼容客户端配置。 + + Args: + api_provider: API 提供商配置。 + + Returns: + OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。 + """ + default_headers = dict(api_provider.default_headers) + default_query: dict[str, object] = dict(api_provider.default_query) + client_api_key = api_provider.api_key + + if api_provider.auth_type == OpenAICompatibleAuthType.BEARER: + if ( + api_provider.auth_header_name != "Authorization" + or api_provider.auth_header_prefix.strip() != "Bearer" + ): + client_api_key = "" + default_headers[api_provider.auth_header_name] = _build_auth_header_value( + prefix=api_provider.auth_header_prefix, + api_key=api_provider.api_key, + ) + elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER: + client_api_key = "" + default_headers[api_provider.auth_header_name] = _build_auth_header_value( + prefix=api_provider.auth_header_prefix, + api_key=api_provider.api_key, + ) + elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY: + client_api_key = "" + default_query[api_provider.auth_query_name] = api_provider.api_key + elif api_provider.auth_type == OpenAICompatibleAuthType.NONE: + client_api_key = "" + + return OpenAICompatibleClientConfig( + api_key=client_api_key, + base_url=normalize_openai_base_url(api_provider.base_url), + default_headers=default_headers, + default_query=default_query, + ) + + +def _extract_mapping(value: Any) -> dict[str, Any]: + """将任意映射值规范化为普通字典。 + + Args: + value: 原始输入值。 + + Returns: + dict[str, Any]: 规范化后的字典。非映射值时返回空字典。 + """ + if isinstance(value, Mapping): + return {str(key): item for key, item in value.items()} + return {} + + +def split_openai_request_overrides( + extra_params: Mapping[str, Any] | None, + *, + reserved_body_keys: set[str] | None = None, +) -> OpenAICompatibleRequestOverrides: + """拆分单次请求中的头、查询参数和请求体扩展字段。 + + Args: + extra_params: 模型级别或请求级别的附加参数。 + reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。 + + Returns: + OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。 + """ + raw_params = dict(extra_params or {}) + extra_headers = _extract_mapping(raw_params.pop("headers", None)) + extra_query = _extract_mapping(raw_params.pop("query", None)) + extra_body = _extract_mapping(raw_params.pop("body", None)) + blocked_body_keys = reserved_body_keys or set() + + for key, value in raw_params.items(): + if key in blocked_body_keys: + continue + extra_body[key] = value + + return OpenAICompatibleRequestOverrides( + extra_headers={key: str(value) for key, value in extra_headers.items()}, + extra_query=extra_query, + extra_body=extra_body, + ) diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index 960de08b..8ed392ef 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -1,133 +1,280 @@ +from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Tuple from .tool_option import ToolCall -# 设计这系列类的目的是为未来可能的扩展做准备 +class RoleType(str, Enum): + """消息角色类型。""" - -class RoleType(Enum): System = "system" User = "user" Assistant = "assistant" Tool = "tool" -SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式 +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] +"""默认支持的图片格式列表。""" +@dataclass(slots=True) +class TextMessagePart: + """文本消息片段。""" + + text: str + + def __post_init__(self) -> None: + """执行文本片段的基础校验。 + + Raises: + ValueError: 当文本为空时抛出。 + """ + if self.text == "": + raise ValueError("文本消息片段不能为空字符串") + + +@dataclass(slots=True) +class ImageMessagePart: + """Base64 图片消息片段。""" + + image_format: str + image_base64: str + + def __post_init__(self) -> None: + """执行图片片段的基础校验。 + + Raises: + ValueError: 当图片格式或 Base64 数据无效时抛出。 + """ + if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS: + raise ValueError("不受支持的图片格式") + if not self.image_base64: + raise ValueError("图片的 base64 编码不能为空") + + @property + def normalized_image_format(self) -> str: + """获取规范化后的图片格式。 + + Returns: + str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。 + """ + image_format = self.image_format.lower() + if image_format in {"jpg", "jpeg"}: + return "jpeg" + return image_format + + +MessagePart = TextMessagePart | ImageMessagePart + + +@dataclass(slots=True) class Message: - def __init__( - self, - role: RoleType, - content: str | list[tuple[str, str] | str], - tool_call_id: str | None = None, - tool_calls: Optional[List[ToolCall]] = None, - ): + """统一消息模型。""" + + role: RoleType + parts: List[MessagePart] = field(default_factory=list) + tool_call_id: str | None = None + tool_calls: List[ToolCall] | None = None + + def __post_init__(self) -> None: + """执行消息对象的基础校验。 + + Raises: + ValueError: 当消息内容或工具调用信息不完整时抛出。 """ - 初始化消息对象 - (不应直接修改Message类,而应使用MessageBuilder类来构建对象) + if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls): + raise ValueError("消息内容不能为空") + if self.role == RoleType.Tool and not self.tool_call_id: + raise ValueError("Tool 角色的工具调用 ID 不能为空") + + @property + def content(self) -> str | List[Tuple[str, str] | str]: + """获取兼容旧逻辑的内容视图。 + + Returns: + str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串, + 否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。 """ - self.role: RoleType = role - self.content: str | list[tuple[str, str] | str] = content - self.tool_call_id: str | None = tool_call_id - self.tool_calls: Optional[List[ToolCall]] = tool_calls + if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart): + return self.parts[0].text + content_items: List[Tuple[str, str] | str] = [] + for part in self.parts: + if isinstance(part, TextMessagePart): + content_items.append(part.text) + else: + content_items.append((part.image_format, part.image_base64)) + return content_items + + def get_text_content(self) -> str: + """提取消息中的所有文本片段。 + + Returns: + str: 以原始顺序拼接后的文本内容。 + """ + return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart)) def __str__(self) -> str: + """生成便于调试的字符串表示。 + + Returns: + str: 当前消息对象的可读摘要。 + """ return ( - f"Role: {self.role}, Content: {self.content}, " + f"Role: {self.role}, Parts: {self.parts}, " f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}" ) class MessageBuilder: - def __init__(self): + """消息构建器。""" + + def __init__(self) -> None: + """初始化构建器。""" self.__role: RoleType = RoleType.User - self.__content: list[tuple[str, str] | str] = [] + self.__parts: List[MessagePart] = [] self.__tool_call_id: str | None = None - self.__tool_calls: Optional[List[ToolCall]] = None + self.__tool_calls: List[ToolCall] | None = None def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": - """ - 设置角色(默认为User) - :param role: 角色 - :return: MessageBuilder对象 + """设置消息角色。 + + Args: + role: 目标角色,默认为 `user`。 + + Returns: + MessageBuilder: 当前构建器实例。 """ self.__role = role return self + def add_text_part(self, text: str) -> "MessageBuilder": + """追加文本片段。 + + Args: + text: 文本内容。 + + Returns: + MessageBuilder: 当前构建器实例。 + """ + self.__parts.append(TextMessagePart(text=text)) + return self + def add_text_content(self, text: str) -> "MessageBuilder": + """追加文本片段。 + + Args: + text: 文本内容。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 添加文本内容 - :param text: 文本内容 - :return: MessageBuilder对象 + return self.add_text_part(text) + + def add_image_base64_part( + self, + image_format: str, + image_base64: str, + support_formats: List[str] = SUPPORTED_IMAGE_FORMATS, + ) -> "MessageBuilder": + """追加 Base64 图片片段。 + + Args: + image_format: 图片格式。 + image_base64: 图片的 Base64 编码。 + support_formats: 允许的图片格式列表。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当图片格式不被支持时抛出。 """ - self.__content.append(text) + if image_format.lower() not in support_formats: + raise ValueError("不受支持的图片格式") + self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64)) return self def add_image_content( self, image_format: str, image_base64: str, - support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式 + support_formats: List[str] = SUPPORTED_IMAGE_FORMATS, ) -> "MessageBuilder": - """ - 添加图片内容 - :param image_format: 图片格式 - :param image_base64: 图片的base64编码 - :return: MessageBuilder对象 - """ - if image_format.lower() not in support_formats: - raise ValueError("不受支持的图片格式") - if not image_base64: - raise ValueError("图片的base64编码不能为空") - self.__content.append((image_format, image_base64)) - return self + """追加 Base64 图片片段。 - def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + Args: + image_format: 图片格式。 + image_base64: 图片的 Base64 编码。 + support_formats: 允许的图片格式列表。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 添加工具调用指令(调用时请确保已设置为Tool角色) - :param tool_call_id: 工具调用指令的id - :return: MessageBuilder对象 + return self.add_image_base64_part( + image_format=image_format, + image_base64=image_base64, + support_formats=support_formats, + ) + + def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder": + """设置工具结果消息引用的工具调用 ID。 + + Args: + tool_call_id: 工具调用 ID。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。 """ if self.__role != RoleType.Tool: - raise ValueError("仅当角色为Tool时才能添加工具调用ID") + raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID") if not tool_call_id: - raise ValueError("工具调用ID不能为空") + raise ValueError("工具调用 ID 不能为空") self.__tool_call_id = tool_call_id return self - def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder": + def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + """设置工具结果消息引用的工具调用 ID。 + + Args: + tool_call_id: 工具调用 ID。 + + Returns: + MessageBuilder: 当前构建器实例。 """ - 设置助手消息的工具调用列表 - :param tool_calls: 工具调用列表 - :return: MessageBuilder对象 + return self.set_tool_call_id(tool_call_id) + + def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder": + """设置助手消息中的工具调用列表。 + + Args: + tool_calls: 工具调用列表。 + + Returns: + MessageBuilder: 当前构建器实例。 + + Raises: + ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。 """ if self.__role != RoleType.Assistant: - raise ValueError("仅当角色为Assistant时才能设置工具调用列表") + raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表") if not tool_calls: raise ValueError("工具调用列表不能为空") - self.__tool_calls = tool_calls + self.__tool_calls = list(tool_calls) return self def build(self) -> Message: - """ - 构建消息对象 - :return: Message对象 - """ - if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls): - raise ValueError("内容不能为空") - if self.__role == RoleType.Tool and self.__tool_call_id is None: - raise ValueError("Tool角色的工具调用ID不能为空") + """构建消息对象。 + Returns: + Message: 构建完成的消息对象。 + """ return Message( role=self.__role, - content=( - self.__content[0] - if (len(self.__content) == 1 and isinstance(self.__content[0], str)) - else self.__content - ), + parts=list(self.__parts), tool_call_id=self.__tool_call_id, - tool_calls=self.__tool_calls, + tool_calls=list(self.__tool_calls) if self.__tool_calls else None, ) diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py index e1baa374..4319b03d 100644 --- a/src/llm_models/payload_content/resp_format.py +++ b/src/llm_models/payload_content/resp_format.py @@ -1,51 +1,40 @@ +from copy import deepcopy from enum import Enum -from typing import Optional, Any +from typing import Any, Dict, List, Mapping, Optional, Type, cast from pydantic import BaseModel -from typing_extensions import TypedDict, Required +from typing_extensions import Required, TypedDict class RespFormatType(Enum): - TEXT = "text" # 文本 - JSON_OBJ = "json_object" # JSON - JSON_SCHEMA = "json_schema" # JSON Schema + """响应格式类型。""" + + TEXT = "text" + JSON_OBJ = "json_object" + JSON_SCHEMA = "json_schema" class JsonSchema(TypedDict, total=False): + """内部使用的 JSON Schema 包装结构。""" + name: Required[str] - """ - The name of the response format. - - Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length - of 64. - """ - description: Optional[str] - """ - A description of what the response format is for, used by the model to determine - how to respond in the format. - """ - - schema: dict[str, object] - """ - The schema for the response format, described as a JSON Schema object. Learn how - to build JSON schemas [here](https://json-schema.org/). - """ - + schema: Dict[str, Any] strict: Optional[bool] - """ - Whether to enable strict schema adherence when generating the output. If set to - true, the model will always follow the exact schema defined in the `schema` - field. Only a subset of JSON Schema is supported when `strict` is `true`. To - learn more, read the - [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). - """ -def _json_schema_type_check(instance) -> str | None: +def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None: + """检查 JSON Schema 包装结构是否合法。 + + Args: + instance: 待检查的 JSON Schema 包装字典。 + + Returns: + str | None: 不合法时返回错误信息,合法时返回 `None`。 + """ if "name" not in instance: return "schema必须包含'name'字段" - elif not isinstance(instance["name"], str) or instance["name"].strip() == "": + if not isinstance(instance["name"], str) or instance["name"].strip() == "": return "schema的'name'字段必须是非空字符串" if "description" in instance and ( not isinstance(instance["description"], str) or instance["description"].strip() == "" @@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None: return "schema的'description'字段只能填入非空字符串" if "schema" not in instance: return "schema必须包含'schema'字段" - elif not isinstance(instance["schema"], dict): + if not isinstance(instance["schema"], dict): return "schema的'schema'字段必须是字典,详见https://json-schema.org/" if "strict" in instance and not isinstance(instance["strict"], bool): return "schema的'strict'字段只能填入布尔值" - return None -def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: - """ - 递归移除JSON Schema中的title字段 +def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]: + """递归移除 JSON Schema 中的 `title` 字段。 + + Args: + schema: 待处理的 Schema 树。 + + Returns: + Dict[str, Any] | List[Any]: 处理后的 Schema 树。 """ if isinstance(schema, list): - # 如果当前Schema是列表,则对所有dict/list子元素递归调用 - for idx, item in enumerate(schema): + for index, item in enumerate(schema): if isinstance(item, (dict, list)): - schema[idx] = _remove_title(item) - elif isinstance(schema, dict): - # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "title" in schema: - del schema["title"] - for key, value in schema.items(): - if isinstance(value, (dict, list)): - schema[key] = _remove_title(value) + schema[index] = _remove_title(item) + return schema + if "title" in schema: + del schema["title"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) return schema -def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: - """ - 链接JSON Schema中的definitions字段 +def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]: + """展开 Schema 中的本地 `$defs`/`$ref` 引用。 + + Args: + schema: 待处理的根 Schema。 + + Returns: + Dict[str, Any]: 展开后的 Schema。 """ def link_definitions_recursive( - path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any] - ) -> dict[str, Any]: - """ - 递归链接JSON Schema中的definitions字段 - :param path: 当前路径 - :param sub_schema: 子Schema - :param defs: Schema定义集 - :return: + path: str, + sub_schema: Dict[str, Any] | List[Any], + definitions: Dict[str, Any], + ) -> Dict[str, Any] | List[Any]: + """递归展开局部定义。 + + Args: + path: 当前递归路径。 + sub_schema: 当前子 Schema。 + definitions: 已收集的定义字典。 + + Returns: + Dict[str, Any] | List[Any]: 展开后的子 Schema。 """ if isinstance(sub_schema, list): - # 如果当前Schema是列表,则遍历每个元素 - for i in range(len(sub_schema)): - if isinstance(sub_schema[i], dict): - sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs) - else: - # 否则为字典 - if "$defs" in sub_schema: - # 如果当前Schema有$def字段,则将其添加到defs中 - key_prefix = f"{path}/$defs/" - for key, value in sub_schema["$defs"].items(): - def_key = key_prefix + key - if def_key not in defs: - defs[def_key] = value - del sub_schema["$defs"] - if "$ref" in sub_schema: - # 如果当前Schema有$ref字段,则将其替换为defs中的定义 - def_key = sub_schema["$ref"] - if def_key in defs: - sub_schema = defs[def_key] - else: - raise ValueError(f"Schema中引用的定义'{def_key}'不存在") - # 遍历键值对 - for key, value in sub_schema.items(): - if isinstance(value, (dict, list)): - # 如果当前值是字典或列表,则递归调用 - sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs) + for index, item in enumerate(sub_schema): + if isinstance(item, (dict, list)): + sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions) + return sub_schema + if "$defs" in sub_schema: + key_prefix = f"{path}/$defs/" + defs_payload = cast(Dict[str, Any], sub_schema["$defs"]) + for key, value in defs_payload.items(): + definition_key = key_prefix + key + if definition_key not in definitions: + definitions[definition_key] = value + del sub_schema["$defs"] + + if "$ref" in sub_schema: + definition_key = cast(str, sub_schema["$ref"]) + if definition_key in definitions: + return definitions[definition_key] + raise ValueError(f"Schema中引用的定义'{definition_key}'不存在") + + for key, value in sub_schema.items(): + if isinstance(value, (dict, list)): + sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions) return sub_schema - return link_definitions_recursive("#", schema, {}) + return cast(Dict[str, Any], link_definitions_recursive("#", schema, {})) -def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: - """ - 递归移除JSON Schema中的$defs字段 +def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]: + """递归移除 JSON Schema 中的 `$defs` 字段。 + + Args: + schema: 待处理的 Schema 树。 + + Returns: + Dict[str, Any] | List[Any]: 处理后的 Schema 树。 """ if isinstance(schema, list): - # 如果当前Schema是列表,则对所有dict/list子元素递归调用 - for idx, item in enumerate(schema): + for index, item in enumerate(schema): if isinstance(item, (dict, list)): - schema[idx] = _remove_title(item) - elif isinstance(schema, dict): - # 是字典,移除title字段,并对所有dict/list子元素递归调用 - if "$defs" in schema: - del schema["$defs"] - for key, value in schema.items(): - if isinstance(value, (dict, list)): - schema[key] = _remove_title(value) + schema[index] = _remove_defs(item) + return schema + if "$defs" in schema: + del schema["$defs"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_defs(value) return schema class RespFormat: - """ - 响应格式 - """ + """统一响应格式定义。""" @staticmethod - def _generate_schema_from_model(schema): - json_schema = { - "name": schema.__name__, - "schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))), + def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema: + """从 Pydantic 模型生成内部 JSON Schema 包装结构。 + + Args: + schema_model: Pydantic 模型类。 + + Returns: + JsonSchema: 内部统一 JSON Schema 包装结构。 + """ + schema_tree = deepcopy(schema_model.model_json_schema()) + json_schema: JsonSchema = { + "name": schema_model.__name__, + "schema": cast( + Dict[str, Any], + _remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))), + ), "strict": False, } - if schema.__doc__: - json_schema["description"] = schema.__doc__ + if schema_model.__doc__: + json_schema["description"] = schema_model.__doc__ return json_schema def __init__( self, format_type: RespFormatType = RespFormatType.TEXT, - schema: type | JsonSchema | None = None, - ): - """ - 响应格式 - :param format_type: 响应格式类型(默认为文本) - :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效) + schema: Type[BaseModel] | JsonSchema | None = None, + ) -> None: + """初始化响应格式对象。 + + Args: + format_type: 响应格式类型。 + schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。 """ self.format_type: RespFormatType = format_type + self.schema_source: Type[BaseModel] | JsonSchema | None = schema + self.schema: JsonSchema | None = None - if format_type == RespFormatType.JSON_SCHEMA: - if schema is None: - raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") - if isinstance(schema, dict): - if check_msg := _json_schema_type_check(schema): - raise ValueError(f"schema格式不正确,{check_msg}") + if format_type != RespFormatType.JSON_SCHEMA: + return + if schema is None: + raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") + if isinstance(schema, dict): + if check_msg := _json_schema_type_check(schema): + raise ValueError(f"schema格式不正确,{check_msg}") + self.schema = cast(JsonSchema, deepcopy(schema)) + return + if isinstance(schema, type) and issubclass(schema, BaseModel): + try: + self.schema = self._generate_schema_from_model(schema) + except Exception as exc: + raise ValueError( + f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" + f"{schema.__name__}:\n" + ) from exc + return + raise ValueError("schema必须是BaseModel的子类或JsonSchema") - self.schema = schema - elif issubclass(schema, BaseModel): - try: - json_schema = self._generate_schema_from_model(schema) + def get_schema_object(self) -> Dict[str, Any] | None: + """获取内部包装中的对象级 JSON Schema。 - self.schema = json_schema - except Exception as e: - raise ValueError( - f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" - f"{schema.__name__}:\n" - ) from e - else: - raise ValueError("schema必须是BaseModel的子类或JsonSchema") - else: - self.schema = None - - def to_dict(self): + Returns: + Dict[str, Any] | None: 对象级 JSON Schema;不存在时返回 `None`。 """ - 将响应格式转换为字典 - :return: 字典 + if self.schema is None: + return None + schema_payload = self.schema.get("schema") + if isinstance(schema_payload, dict): + return cast(Dict[str, Any], deepcopy(schema_payload)) + return None + + def to_dict(self) -> Dict[str, Any]: + """将响应格式转换为字典。 + + Returns: + Dict[str, Any]: 序列化后的响应格式字典。 """ if self.schema: return { "format_type": self.format_type.value, "schema": self.schema, } - else: - return { - "format_type": self.format_type.value, - } + return { + "format_type": self.format_type.value, + } diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py index 9fedbc86..ac5224cc 100644 --- a/src/llm_models/payload_content/tool_option.py +++ b/src/llm_models/payload_content/tool_option.py @@ -1,83 +1,368 @@ +from copy import deepcopy +from dataclasses import dataclass, field from enum import Enum +from typing import Any, Dict, List, Tuple, TypeAlias, cast -class ToolParamType(Enum): +class ToolParamType(str, Enum): + """工具参数类型。""" + + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + FLOAT = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None] +"""旧版工具参数元组格式。""" + + +def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType: + """将任意输入值规范化为内部工具参数类型。 + + Args: + raw_value: 原始参数类型值。 + + Returns: + ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。 """ - 工具调用参数类型 + if isinstance(raw_value, ToolParamType): + return raw_value + + normalized_value = str(raw_value or "").strip().lower() + if normalized_value in {"integer", "int"}: + return ToolParamType.INTEGER + if normalized_value in {"number", "float"}: + return ToolParamType.NUMBER + if normalized_value in {"boolean", "bool"}: + return ToolParamType.BOOLEAN + if normalized_value == "array": + return ToolParamType.ARRAY + if normalized_value == "object": + return ToolParamType.OBJECT + return ToolParamType.STRING + + +def _is_object_schema(schema: Dict[str, Any]) -> bool: + """判断输入字典是否已经是对象级 JSON Schema。 + + Args: + schema: 待判断的字典。 + + Returns: + bool: 为对象级 JSON Schema 时返回 `True`。 """ - - STRING = "string" # 字符串 - INTEGER = "integer" # 整型 - FLOAT = "float" # 浮点型 - BOOLEAN = "bool" # 布尔型 + return schema.get("type") == "object" or "properties" in schema or "required" in schema +def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]: + """将属性映射转换为对象级 JSON Schema。 + + Args: + property_map: 仅包含属性定义的映射。 + + Returns: + Dict[str, Any]: 对象级 JSON Schema。 + """ + required_names: List[str] = [] + normalized_properties: Dict[str, Any] = {} + for property_name, property_schema in property_map.items(): + if not isinstance(property_schema, dict): + continue + + property_schema_copy = deepcopy(property_schema) + is_required = bool(property_schema_copy.pop("required", False)) + if is_required: + required_names.append(str(property_name)) + normalized_properties[str(property_name)] = property_schema_copy + + parameters_schema: Dict[str, Any] = { + "type": "object", + "properties": normalized_properties, + } + if required_names: + parameters_schema["required"] = required_names + return parameters_schema + + +@dataclass(slots=True) class ToolParam: - """ - 工具调用参数 - """ + """工具参数定义。""" - def __init__( - self, + name: str + param_type: ToolParamType + description: str + required: bool + enum_values: List[Any] | None = None + items_schema: Dict[str, Any] | None = None + properties: Dict[str, Dict[str, Any]] | None = None + required_properties: List[str] = field(default_factory=list) + additional_properties: bool | Dict[str, Any] | None = None + default: Any = None + + def __post_init__(self) -> None: + """执行参数定义的基础校验。 + + Raises: + ValueError: 当参数名称或复杂类型定义不合法时抛出。 + """ + if not self.name: + raise ValueError("参数名称不能为空") + if self.param_type == ToolParamType.ARRAY and self.items_schema is None: + raise ValueError("数组参数必须提供 items_schema") + if self.param_type == ToolParamType.OBJECT and self.properties is None: + self.properties = {} + + @classmethod + def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam": + """从旧版五元组参数定义构建工具参数。 + + Args: + parameter: 旧版参数元组。 + + Returns: + ToolParam: 规范化后的工具参数对象。 + """ + return cls( + name=parameter[0], + param_type=parameter[1], + description=parameter[2], + required=parameter[3], + enum_values=parameter[4], + ) + + @classmethod + def from_dict( + cls, name: str, - param_type: ToolParamType, - description: str, - required: bool, - enum_values: list[str] | None = None, - ): + parameter_schema: Dict[str, Any], + *, + required: bool = False, + ) -> "ToolParam": + """从属性级 JSON Schema 或结构化参数字典构建工具参数。 + + Args: + name: 参数名称。 + parameter_schema: 参数对应的 Schema 或结构化定义。 + required: 参数是否必填。 + + Returns: + ToolParam: 规范化后的工具参数对象。 """ - 初始化工具调用参数 - (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象) - :param name: 参数名称 - :param param_type: 参数类型 - :param description: 参数描述 - :param required: 是否必填 + raw_required_properties = parameter_schema.get("required_properties") + if raw_required_properties is None and isinstance(parameter_schema.get("required"), list): + raw_required_properties = parameter_schema.get("required") + return cls( + name=name, + param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")), + description=str(parameter_schema.get("description", "") or ""), + required=required, + enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")), + items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")), + properties=deepcopy(parameter_schema.get("properties")), + required_properties=list(raw_required_properties or []), + additional_properties=deepcopy( + parameter_schema["additional_properties"] + if "additional_properties" in parameter_schema + else parameter_schema.get("additionalProperties") + ), + default=deepcopy(parameter_schema.get("default")), + ) + + def to_json_schema(self) -> Dict[str, Any]: + """将参数定义转换为 JSON Schema。 + + Returns: + Dict[str, Any]: 参数对应的 JSON Schema 片段。 """ - self.name: str = name - self.param_type: ToolParamType = param_type - self.description: str = description - self.required: bool = required - self.enum_values: list[str] | None = enum_values + schema: Dict[str, Any] = { + "type": self.param_type.value, + "description": self.description, + } + if self.enum_values: + schema["enum"] = list(self.enum_values) + if self.default is not None: + schema["default"] = deepcopy(self.default) + if self.param_type == ToolParamType.ARRAY and self.items_schema is not None: + schema["items"] = deepcopy(self.items_schema) + if self.param_type == ToolParamType.OBJECT: + schema["properties"] = deepcopy(self.properties or {}) + if self.required_properties: + schema["required"] = list(self.required_properties) + if self.additional_properties is not None: + schema["additionalProperties"] = deepcopy(self.additional_properties) + return schema +@dataclass(slots=True) class ToolOption: - """ - 工具调用项 - """ + """工具定义。""" - def __init__( - self, - name: str, - description: str, - params: list[ToolParam] | None = None, - ): + name: str + description: str + params: List[ToolParam] | None = None + parameters_schema_override: Dict[str, Any] | None = None + + def __post_init__(self) -> None: + """执行工具定义的基础校验。 + + Raises: + ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。 """ - 初始化工具调用项 - (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象) - :param name: 工具名称 - :param description: 工具描述 - :param params: 工具参数列表 + if not self.name: + raise ValueError("工具名称不能为空") + if not self.description: + raise ValueError("工具描述不能为空") + if self.parameters_schema_override is not None: + schema_type = self.parameters_schema_override.get("type") + if schema_type != "object": + raise ValueError("工具参数 Schema 必须是 object 类型") + + @classmethod + def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption": + """从任意支持的工具定义字典构建内部工具对象。 + + 支持以下输入形状: + - `{"name", "description", "parameters_schema"}` + - `{"name", "description", "parameters"}` + - OpenAI function tool:`{"type": "function", "function": {...}}` + - 仅属性映射的对象参数定义:`{"query": {"type": "string"}}` + + Args: + definition: 原始工具定义字典。 + + Returns: + ToolOption: 规范化后的工具定义对象。 + + Raises: + ValueError: 当工具定义缺少必要字段时抛出。 """ - self.name: str = name - self.description: str = description - self.params: list[ToolParam] | None = params + if definition.get("type") == "function" and isinstance(definition.get("function"), dict): + function_definition = cast(Dict[str, Any], definition["function"]) + return cls.from_definition( + { + "name": function_definition.get("name", ""), + "description": function_definition.get("description", ""), + "parameters_schema": function_definition.get("parameters"), + } + ) + + name = str(definition.get("name", "") or "").strip() + description = str(definition.get("description", "") or "").strip() + if not name: + raise ValueError("工具定义缺少 name") + if not description: + description = f"工具 {name}" + + parameters_schema = definition.get("parameters_schema") + if isinstance(parameters_schema, dict): + normalized_schema = deepcopy(parameters_schema) + if not _is_object_schema(normalized_schema): + normalized_schema = _build_parameters_schema_from_property_map(normalized_schema) + return cls( + name=name, + description=description, + params=None, + parameters_schema_override=normalized_schema, + ) + + raw_parameters = definition.get("parameters") + if isinstance(raw_parameters, dict): + normalized_schema = deepcopy(raw_parameters) + if not _is_object_schema(normalized_schema): + normalized_schema = _build_parameters_schema_from_property_map(normalized_schema) + return cls( + name=name, + description=description, + params=None, + parameters_schema_override=normalized_schema, + ) + + if isinstance(raw_parameters, list): + params: List[ToolParam] = [] + for raw_parameter in raw_parameters: + if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5: + params.append(ToolParam.from_legacy_tuple(raw_parameter)) + continue + if isinstance(raw_parameter, dict): + parameter_name = str(raw_parameter.get("name", "") or "").strip() + if not parameter_name: + continue + params.append( + ToolParam.from_dict( + parameter_name, + raw_parameter, + required=bool(raw_parameter.get("required", False)), + ) + ) + return cls( + name=name, + description=description, + params=params or None, + parameters_schema_override=None, + ) + + return cls(name=name, description=description, params=None, parameters_schema_override=None) + + @property + def parameters_schema(self) -> Dict[str, Any] | None: + """获取工具参数的对象级 JSON Schema。 + + Returns: + Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。 + """ + if self.parameters_schema_override is not None: + return deepcopy(self.parameters_schema_override) + if not self.params: + return None + return { + "type": "object", + "properties": {param.name: param.to_json_schema() for param in self.params}, + "required": [param.name for param in self.params if param.required], + } + + def to_openai_function_schema(self) -> Dict[str, Any]: + """转换为 OpenAI function calling 结构。 + + Returns: + Dict[str, Any]: OpenAI 兼容的工具定义。 + """ + function_schema: Dict[str, Any] = { + "name": self.name, + "description": self.description, + } + if self.parameters_schema is not None: + function_schema["parameters"] = self.parameters_schema + return { + "type": "function", + "function": function_schema, + } class ToolOptionBuilder: - """ - 工具调用项构建器 - """ + """工具定义构建器。""" - def __init__(self): + def __init__(self) -> None: + """初始化构建器。""" self.__name: str = "" self.__description: str = "" - self.__params: list[ToolParam] = [] + self.__params: List[ToolParam] = [] + self.__parameters_schema_override: Dict[str, Any] | None = None def set_name(self, name: str) -> "ToolOptionBuilder": - """ - 设置工具名称 - :param name: 工具名称 - :return: ToolBuilder实例 + """设置工具名称。 + + Args: + name: 工具名称。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当名称为空时抛出。 """ if not name: raise ValueError("工具名称不能为空") @@ -85,35 +370,76 @@ class ToolOptionBuilder: return self def set_description(self, description: str) -> "ToolOptionBuilder": - """ - 设置工具描述 - :param description: 工具描述 - :return: ToolBuilder实例 + """设置工具描述。 + + Args: + description: 工具描述。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当描述为空时抛出。 """ if not description: raise ValueError("工具描述不能为空") self.__description = description return self + def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder": + """直接设置完整的参数对象 Schema。 + + Args: + schema: 完整的对象级 JSON Schema。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当 schema 不是 object 类型时抛出。 + """ + if schema.get("type") != "object": + raise ValueError("工具参数 Schema 必须是 object 类型") + self.__parameters_schema_override = deepcopy(schema) + self.__params.clear() + return self + def add_param( self, name: str, param_type: ToolParamType, description: str, required: bool = False, - enum_values: list[str] | None = None, + enum_values: List[Any] | None = None, + *, + items_schema: Dict[str, Any] | None = None, + properties: Dict[str, Dict[str, Any]] | None = None, + required_properties: List[str] | None = None, + additional_properties: bool | Dict[str, Any] | None = None, + default: Any = None, ) -> "ToolOptionBuilder": - """ - 添加工具参数 - :param name: 参数名称 - :param param_type: 参数类型 - :param description: 参数描述 - :param required: 是否必填(默认为False) - :return: ToolBuilder实例 - """ - if not name or not description: - raise ValueError("参数名称/描述不能为空") + """添加一个参数定义。 + Args: + name: 参数名称。 + param_type: 参数类型。 + description: 参数描述。 + required: 参数是否必填。 + enum_values: 可选的枚举值列表。 + items_schema: 数组参数的元素 Schema。 + properties: 对象参数的属性定义。 + required_properties: 对象参数内部的必填字段。 + additional_properties: 对象参数是否允许额外字段。 + default: 参数默认值。 + + Returns: + ToolOptionBuilder: 当前构建器实例。 + + Raises: + ValueError: 当构建器已经设置完整 Schema 时抛出。 + """ + if self.__parameters_schema_override is not None: + raise ValueError("已设置完整参数 Schema,不能再逐项添加参数") self.__params.append( ToolParam( name=name, @@ -121,43 +447,83 @@ class ToolOptionBuilder: description=description, required=required, enum_values=enum_values, + items_schema=deepcopy(items_schema), + properties=deepcopy(properties), + required_properties=list(required_properties or []), + additional_properties=deepcopy(additional_properties), + default=deepcopy(default), ) ) - return self - def build(self): - """ - 构建工具调用项 - :return: 工具调用项 - """ - if self.__name == "" or self.__description == "": - raise ValueError("工具名称/描述不能为空") + def build(self) -> ToolOption: + """构建工具定义。 + Returns: + ToolOption: 构建完成的工具定义。 + + Raises: + ValueError: 当工具名称或描述缺失时抛出。 + """ + if not self.__name or not self.__description: + raise ValueError("工具名称和描述不能为空") return ToolOption( name=self.__name, description=self.__description, - params=None if len(self.__params) == 0 else self.__params, + params=None if not self.__params else list(self.__params), + parameters_schema_override=deepcopy(self.__parameters_schema_override), ) -class ToolCall: - """ - 来自模型反馈的工具调用 - """ +ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any] +"""统一的工具定义输入类型。""" - def __init__( - self, - call_id: str, - func_name: str, - args: dict | None = None, - ): + +def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption: + """将任意支持的工具输入规范化为内部 `ToolOption`。 + + Args: + tool_definition: 原始工具定义输入。 + + Returns: + ToolOption: 规范化后的工具定义对象。 + """ + if isinstance(tool_definition, ToolOption): + return tool_definition + return ToolOption.from_definition(tool_definition) + + +def normalize_tool_options( + tool_definitions: List[ToolDefinitionInput] | None, +) -> List[ToolOption] | None: + """批量规范化工具定义列表。 + + Args: + tool_definitions: 原始工具定义列表。 + + Returns: + List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。 + """ + if not tool_definitions: + return None + return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions] + + +@dataclass(slots=True) +class ToolCall: + """来自模型输出的工具调用。""" + + call_id: str + func_name: str + args: Dict[str, Any] | None = None + + def __post_init__(self) -> None: + """执行工具调用的基础校验。 + + Raises: + ValueError: 当工具调用标识或函数名缺失时抛出。 """ - 初始化工具调用 - :param call_id: 工具调用ID - :param func_name: 要调用的函数名称 - :param args: 工具调用参数 - """ - self.call_id: str = call_id - self.func_name: str = func_name - self.args: dict | None = args + if not self.call_id: + raise ValueError("工具调用 ID 不能为空") + if not self.func_name: + raise ValueError("工具函数名称不能为空") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index a3bfb74f..775fa663 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,29 +1,50 @@ -import re -import asyncio -import time -import random -import json - +from dataclasses import dataclass from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Set + +import asyncio +import random +import re +import time import traceback from src.common.logger import get_logger +from src.common.data_models.llm_service_data_models import ( + LLMAudioTranscriptionResult, + LLMEmbeddingResult, + LLMResponseResult, +) from src.config.config import config_manager from src.config.model_configs import APIProvider, ModelInfo, TaskConfig -from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat, RespFormatType -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType -from .model_client.base_client import BaseClient, APIResponse, client_registry -from .model_client import ensure_configured_clients_loaded -from .utils import compress_messages, llm_usage_recorder -from .exceptions import ( - NetworkConnectionError, - RespNotOkException, +from src.llm_models.exceptions import ( EmptyResponseException, ModelAttemptFailed, + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, ) +from src.llm_models.model_client import ensure_configured_clients_loaded +from src.llm_models.model_client.base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + ClientRequest, + EmbeddingRequest, + ResponseRequest, + client_registry, +) +from src.llm_models.payload_content.message import Message, MessageBuilder +from src.llm_models.payload_content.resp_format import RespFormat +from src.llm_models.payload_content.tool_option import ( + ToolCall, + ToolDefinitionInput, + ToolOption, + normalize_tool_options, +) +from src.llm_models.utils import compress_messages, llm_usage_recorder install(extra_lines=3) @@ -38,106 +59,69 @@ class RequestType(Enum): AUDIO = "audio" -class LLMRequest: - """LLM请求类""" +@dataclass(slots=True) +class LLMExecutionResult: + """单次模型执行结果。""" - def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: - self.task_name = request_type - self.model_for_task = model_set + api_response: APIResponse + model_info: ModelInfo + + +class LLMOrchestrator: + """LLM 编排调度器。""" + + def __init__(self, task_name: str, request_type: str = "") -> None: + """初始化 LLM 请求调度器。 + + Args: + task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 + request_type: 当前请求的业务类型标识。 + """ + self.task_name = task_name.strip() self.request_type = request_type - self._task_config_signature = self._build_task_config_signature(model_set) - self._task_config_name = self._resolve_task_config_name(model_set) + self.model_for_task = self._get_task_config_or_raise() self.model_usage: Dict[str, Tuple[int, int, int]] = { model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" - @staticmethod - def _build_task_config_signature(model_set: TaskConfig) -> tuple: - return ( - tuple(model_set.model_list), - model_set.selection_strategy, - model_set.temperature, - model_set.max_tokens, - model_set.slow_threshold, - ) + def _get_task_config_or_raise(self) -> TaskConfig: + """获取当前任务名对应的最新任务配置。 - @staticmethod - def _iter_task_config_items(model_task_config: Any) -> list[tuple[str, TaskConfig]]: - cls = type(model_task_config) - if hasattr(cls, "model_fields"): - attrs = [name for name in cls.model_fields.keys() if not name.startswith("__")] - else: - attrs = [name for name in dir(model_task_config) if not name.startswith("__")] + Returns: + TaskConfig: 当前任务对应的最新任务配置对象。 - items: list[tuple[str, TaskConfig]] = [] - for attr in attrs: - value = getattr(model_task_config, attr, None) - if isinstance(value, TaskConfig): - items.append((attr, value)) - return items + Raises: + ValueError: 当任务名为空或对应配置不存在时抛出。 + """ + if not self.task_name: + raise ValueError("任务配置名称不能为空") - def _resolve_task_config_by_signature(self, model_set: TaskConfig) -> Optional[str]: - target_signature = self._build_task_config_signature(model_set) model_task_config = config_manager.get_model_config().model_task_config - return next( - ( - attr - for attr, value in self._iter_task_config_items(model_task_config) - if self._build_task_config_signature(value) == target_signature - ), - None, - ) - - def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]: - try: - model_task_config = config_manager.get_model_config().model_task_config - except Exception: - return None - for attr, value in self._iter_task_config_items(model_task_config): - if value is model_set: - return attr - try: - return self._resolve_task_config_by_signature(model_set) - except Exception: - return None - return None - - def _get_latest_task_config(self) -> TaskConfig: - if self._task_config_name: - try: - model_task_config = config_manager.get_model_config().model_task_config - value = getattr(model_task_config, self._task_config_name, None) - if isinstance(value, TaskConfig): - return value - except Exception: - return self.model_for_task - try: - if resolved_name := self._resolve_task_config_by_signature(self.model_for_task): - self._task_config_name = resolved_name - model_task_config = config_manager.get_model_config().model_task_config - value = getattr(model_task_config, resolved_name, None) - if isinstance(value, TaskConfig): - return value - except Exception: - return self.model_for_task - return self.model_for_task + task_config = getattr(model_task_config, self.task_name, None) + if not isinstance(task_config, TaskConfig): + raise ValueError(f"未找到名为 '{self.task_name}' 的任务配置") + return task_config def _refresh_task_config(self) -> TaskConfig: - latest = self._get_latest_task_config() + """刷新并同步任务配置缓存。 + + Returns: + TaskConfig: 刷新后的任务配置对象。 + """ + latest = self._get_task_config_or_raise() if latest is not self.model_for_task: self.model_for_task = latest - self._task_config_signature = self._build_task_config_signature(latest) if list(self.model_usage.keys()) != latest.model_list: self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list} return self.model_for_task def _check_slow_request(self, time_cost: float, model_name: str) -> None: - """检查请求是否过慢并输出警告日志 + """检查请求是否过慢并输出警告日志。 Args: - time_cost: 请求耗时(秒) - model_name: 使用的模型名称 + time_cost: 请求耗时(秒)。 + model_name: 使用的模型名称。 """ threshold = self.model_for_task.slow_threshold if time_cost > threshold: @@ -147,6 +131,31 @@ class LLMRequest: f" 如果你认为该警告出现得过于频繁,请调整model_config.toml中对应任务的slow_threshold至符合你实际情况的合理值" ) + @staticmethod + def _build_generation_result( + content: str, + reasoning_content: str, + model_name: str, + tool_calls: List[ToolCall] | None, + ) -> LLMResponseResult: + """构建统一的文本响应结果。 + + Args: + content: 模型返回的正文内容。 + reasoning_content: 模型返回的推理内容。 + model_name: 实际使用的模型名称。 + tool_calls: 模型返回的工具调用列表。 + + Returns: + LLMResponseResult: 统一文本响应结果对象。 + """ + return LLMResponseResult( + response=content, + reasoning=reasoning_content, + model_name=model_name, + tool_calls=tool_calls, + ) + async def generate_response_for_image( self, prompt: str, @@ -154,15 +163,20 @@ class LLMRequest: image_format: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 为图像生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """为图像生成响应。 + Args: - prompt (str): 提示词 - image_base64 (str): 图像的Base64编码字符串 - image_format (str): 图像格式(如 'png', 'jpeg' 等) + prompt: 文本提示词。 + image_base64: 图像的 Base64 编码字符串。 + image_format: 图像格式,例如 `png`、`jpeg`。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ self._refresh_task_config() start_time = time.time() @@ -175,12 +189,15 @@ class LLMRequest: ) return [message_builder.build()] - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info content = response.content or "" reasoning_content = response.reasoning_content or "" tool_calls = response.tool_calls @@ -198,44 +215,49 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time_cost, ) - return content, (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls) + + async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult: + """为语音生成转录响应。 - async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: - """ - 为语音生成响应 Args: - voice_base64 (str): 语音的Base64编码字符串 + voice_base64: 语音的 Base64 编码字符串。 + Returns: - (Optional[str]): 生成的文本描述或None + LLMAudioTranscriptionResult: 语音转写结果对象。 """ self._refresh_task_config() - response, _ = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.AUDIO, audio_base64=voice_base64, ) - return response.content or None + return LLMAudioTranscriptionResult(text=execution_result.api_response.content or None) async def generate_response_async( self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: List[ToolDefinitionInput] | None = None, response_format: RespFormat | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 异步生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """异步生成文本响应。 + Args: - prompt (str): 提示词 - temperature (float, optional): 温度参数 - max_tokens (int, optional): 最大token数 - tools (Optional[List[Dict[str, Any]]]): 工具列表 - response_format (RespFormat | None): 响应格式 - raise_when_empty (bool): 当响应为空时是否抛出异常 + prompt: 提示词。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + tools: 原始工具定义列表。 + response_format: 响应格式约束。 + raise_when_empty: 保留字段,当前版本暂未单独使用。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ + del raise_when_empty self._refresh_task_config() start_time = time.time() @@ -246,14 +268,17 @@ class LLMRequest: tool_built = self._build_tool_options(tools) - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, response_format=response_format, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info logger.debug(f"LLM请求总耗时: {time.time() - start_time}") logger.debug(f"LLM生成内容: {response}") @@ -273,42 +298,56 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time.time() - start_time, ) - return content or "", (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) async def generate_response_with_message_async( self, message_factory: Callable[[BaseClient], List[Message]], temperature: Optional[float] = None, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: List[ToolDefinitionInput] | None = None, response_format: RespFormat | None = None, raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 异步生成响应 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMResponseResult: + """基于外部消息工厂异步生成响应。 + Args: - message_factory (Callable[[BaseClient], List[Message]]): 已构建好的消息工厂 - temperature (float, optional): 温度参数 - max_tokens (int, optional): 最大token数 - tools (Optional[List[Dict[str, Any]]]): 工具列表 - response_format (RespFormat | None): 响应格式 - raise_when_empty (bool): 当响应为空时是否抛出异常 + message_factory: 消息工厂,会根据客户端能力构建消息列表。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + tools: 原始工具定义列表。 + response_format: 响应格式约束。 + raise_when_empty: 保留字段,当前版本暂未单独使用。 + interrupt_flag: 外部中断标记;被设置时会尽快终止请求。 + Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + LLMResponseResult: 统一文本响应结果对象。 """ + del raise_when_empty self._refresh_task_config() start_time = time.time() tool_built = self._build_tool_options(tools) + if self.request_type.startswith("maisaka_"): + logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项") - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.RESPONSE, message_factory=message_factory, temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, response_format=response_format, + interrupt_flag=interrupt_flag, ) + response = execution_result.api_response + model_info = execution_result.model_info + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] generate_response_with_message_async 执行完成 " + f"(model={model_info.name}, time_cost={time.time() - start_time:.2f}s)" + ) time_cost = time.time() - start_time logger.debug(f"LLM请求总耗时: {time_cost}") @@ -330,116 +369,25 @@ class LLMRequest: endpoint="/chat/completions", time_cost=time_cost, ) - return content or "", (reasoning_content, model_info.name, tool_calls) + return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) - async def generate_structured_response_async( - self, - prompt: str, - schema: type | dict[str, Any], - fallback_result: dict[str, Any] | None = None, - temperature: Optional[float] = 0.0, - max_tokens: Optional[int] = None, - ) -> Tuple[dict[str, Any], Tuple[str, str, Optional[List[ToolCall]]], bool]: - """ - 结构化输出快速接口: - - 默认启用 JSON_SCHEMA 严格模式 - - 单模型单次尝试(不重试、不切换模型) - - 失败时立即返回 fallback_result + async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult: + """获取嵌入向量。 - Returns: - (结构化结果, (推理内容, 模型名, 工具调用), 是否成功) - """ - self._refresh_task_config() - start_time = time.time() - - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_list = [message_builder.build()] - - response_format = RespFormat(schema=schema, format_type=RespFormatType.JSON_SCHEMA) - if response_format.schema: - response_format.schema["strict"] = True - - model_info, api_provider, client = self._select_model() - fallback_data = fallback_result or {} - - try: - response = await self._attempt_request_on_model( - model_info=model_info, - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - message_list=message_list, - tool_options=None, - response_format=response_format, - stream_response_handler=None, - async_response_parser=None, - temperature=temperature, - max_tokens=max_tokens, - embedding_input=None, - audio_base64=None, - retry_limit=1, - ) - - time_cost = time.time() - start_time - self._check_slow_request(time_cost, model_info.name) - - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - parsed_result: dict[str, Any] | None = None - if response.content: - try: - parsed = json.loads(response.content) - if isinstance(parsed, dict): - parsed_result = parsed - except json.JSONDecodeError: - parsed_result = None - - if parsed_result is None: - logger.warning(f"结构化输出解析失败,使用降级结果。模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) - return fallback_data, (reasoning_content, model_info.name, tool_calls), False - - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - if response_usage := response.usage: - total_tokens += response_usage.total_tokens - llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=response_usage, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", - time_cost=time_cost, - ) - self.model_usage[model_info.name] = (total_tokens, penalty, max(usage_penalty - 1, 0)) - return parsed_result, (reasoning_content, model_info.name, tool_calls), True - - except Exception as e: - time_cost = time.time() - start_time - self._check_slow_request(time_cost, model_info.name) - logger.warning(f"结构化输出请求失败,直接降级。模型: {model_info.name}, 错误: {e}") - - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) - - return fallback_data, ("", model_info.name, None), False - - async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """ - 获取嵌入向量 Args: - embedding_input (str): 获取嵌入的目标 + embedding_input: 待编码的文本。 + Returns: - (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + LLMEmbeddingResult: 向量生成结果对象。 """ self._refresh_task_config() start_time = time.time() - response, model_info = await self._execute_request( + execution_result = await self._execute_request( request_type=RequestType.EMBEDDING, embedding_input=embedding_input, ) + response = execution_result.api_response + model_info = execution_result.model_info embedding = response.embedding if usage := response.usage: llm_usage_recorder.record_usage_to_database( @@ -452,11 +400,207 @@ class LLMRequest: ) if not embedding: raise RuntimeError("获取embedding失败") - return embedding, model_info.name + return LLMEmbeddingResult(embedding=embedding, model_name=model_info.name) + + def _resolve_effective_temperature( + self, + model_info: ModelInfo, + temperature: Optional[float], + ) -> Optional[float]: + """解析响应请求最终使用的温度参数。 + + Args: + model_info: 当前模型信息。 + temperature: 调用方显式传入的温度。 + + Returns: + Optional[float]: 最终生效的温度参数。 + """ + if temperature is not None: + return temperature + if model_info.temperature is not None: + return model_info.temperature + if "temperature" in model_info.extra_params: + return model_info.extra_params["temperature"] + return self.model_for_task.temperature + + def _resolve_effective_max_tokens( + self, + model_info: ModelInfo, + max_tokens: Optional[int], + ) -> Optional[int]: + """解析响应请求最终使用的最大输出 token 数。 + + Args: + model_info: 当前模型信息。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + Optional[int]: 最终生效的最大 token 数。 + """ + if max_tokens is not None: + return max_tokens + if model_info.max_tokens is not None: + return model_info.max_tokens + if "max_tokens" in model_info.extra_params: + return model_info.extra_params["max_tokens"] + return self.model_for_task.max_tokens + + def _build_response_request( + self, + model_info: ModelInfo, + message_list: List[Message], + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + stream_response_handler: Optional[Callable[..., Any]], + async_response_parser: Optional[Callable[..., Any]], + interrupt_flag: asyncio.Event | None, + temperature: Optional[float], + max_tokens: Optional[int], + ) -> ResponseRequest: + """构建统一响应请求对象。 + + Args: + model_info: 当前模型信息。 + message_list: 请求消息列表。 + tool_options: 工具定义列表。 + response_format: 输出格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + interrupt_flag: 外部中断标记。 + temperature: 调用方显式传入的温度。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + ResponseRequest: 统一响应请求对象。 + """ + return ResponseRequest( + model_info=model_info, + message_list=list(message_list), + tool_options=None if tool_options is None else list(tool_options), + max_tokens=self._resolve_effective_max_tokens(model_info, max_tokens), + temperature=self._resolve_effective_temperature(model_info, temperature), + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + extra_params=dict(model_info.extra_params), + ) + + @staticmethod + def _build_embedding_request( + model_info: ModelInfo, + embedding_input: str, + ) -> EmbeddingRequest: + """构建统一嵌入请求对象。 + + Args: + model_info: 当前模型信息。 + embedding_input: 嵌入输入文本。 + + Returns: + EmbeddingRequest: 统一嵌入请求对象。 + """ + return EmbeddingRequest( + model_info=model_info, + embedding_input=embedding_input, + extra_params=dict(model_info.extra_params), + ) + + @staticmethod + def _build_audio_transcription_request( + model_info: ModelInfo, + audio_base64: str, + max_tokens: Optional[int] = None, + ) -> AudioTranscriptionRequest: + """构建统一音频转录请求对象。 + + Args: + model_info: 当前模型信息。 + audio_base64: Base64 编码的音频数据。 + max_tokens: 调用方显式传入的最大 token 数。 + + Returns: + AudioTranscriptionRequest: 统一音频转录请求对象。 + """ + return AudioTranscriptionRequest( + model_info=model_info, + audio_base64=audio_base64, + max_tokens=max_tokens, + extra_params=dict(model_info.extra_params), + ) + + def _build_client_request( + self, + request_type: RequestType, + model_info: ModelInfo, + message_list: List[Message], + tool_options: List[ToolOption] | None, + response_format: RespFormat | None, + stream_response_handler: Optional[Callable[..., Any]], + async_response_parser: Optional[Callable[..., Any]], + interrupt_flag: asyncio.Event | None, + temperature: Optional[float], + max_tokens: Optional[int], + embedding_input: str | None, + audio_base64: str | None, + ) -> ClientRequest: + """按请求类型构建统一客户端请求对象。 + + Args: + request_type: 请求类型。 + model_info: 当前模型信息。 + message_list: 请求消息列表。 + tool_options: 工具定义列表。 + response_format: 响应格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + interrupt_flag: 外部中断标记。 + temperature: 调用方显式传入的温度。 + max_tokens: 调用方显式传入的最大 token 数。 + embedding_input: 嵌入输入文本。 + audio_base64: Base64 编码的音频数据。 + + Returns: + ClientRequest: 对应请求类型的统一请求对象。 + + Raises: + ValueError: 请求类型未知或缺少必需字段时抛出。 + """ + if request_type == RequestType.RESPONSE: + return self._build_response_request( + model_info=model_info, + message_list=message_list, + tool_options=tool_options, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + temperature=temperature, + max_tokens=max_tokens, + ) + if request_type == RequestType.EMBEDDING: + if embedding_input is None: + raise ValueError("嵌入输入不能为空") + return self._build_embedding_request(model_info=model_info, embedding_input=embedding_input) + if request_type == RequestType.AUDIO: + if audio_base64 is None: + raise ValueError("音频 Base64 不能为空") + return self._build_audio_transcription_request( + model_info=model_info, + audio_base64=audio_base64, + max_tokens=max_tokens, + ) + raise ValueError(f"不支持的请求类型: {request_type}") def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据配置的策略选择模型:balance(负载均衡)或 random(随机选择) + """根据策略选择一个可用模型。 + + Args: + exclude_models: 本次请求中需要排除的模型名称集合。 + + Returns: + Tuple[ModelInfo, APIProvider, BaseClient]: 选中的模型、提供商与客户端实例。 """ self._refresh_task_config() available_models = { @@ -499,75 +643,38 @@ class LLMRequest: async def _attempt_request_on_model( self, - model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, - request_type: RequestType, - message_list: List[Message], - tool_options: list[ToolOption] | None, - response_format: RespFormat | None, - stream_response_handler: Optional[Callable[..., Any]], - async_response_parser: Optional[Callable[..., Any]], - temperature: Optional[float], - max_tokens: Optional[int], - embedding_input: str | None, - audio_base64: str | None, + request: ClientRequest, retry_limit: Optional[int] = None, ) -> APIResponse: - """ - 在单个模型上执行请求,包含针对临时错误的重试逻辑。 - 如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。 + """在单个模型上执行请求,并处理重试逻辑。 + + Args: + api_provider: 当前请求对应的 API 提供商配置。 + client: 已初始化的客户端实例。 + request: 统一客户端请求对象。 + retry_limit: 显式指定的重试次数;未指定时使用 Provider 配置。 + + Returns: + APIResponse: 统一响应对象。 + + Raises: + ModelAttemptFailed: 当当前模型重试耗尽或遇到硬错误时抛出。 """ retry_remain = retry_limit if retry_limit is not None else api_provider.max_retry retry_remain = max(1, retry_remain) - compressed_messages: Optional[List[Message]] = None + model_info = request.model_info + original_response_request = request if isinstance(request, ResponseRequest) else None + active_request: ClientRequest = request while retry_remain > 0: try: - if request_type == RequestType.RESPONSE: - # 温度优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 - effective_temperature = temperature - if effective_temperature is None: - effective_temperature = model_info.temperature - if effective_temperature is None: - effective_temperature = (model_info.extra_params or {}).get("temperature") - if effective_temperature is None: - effective_temperature = self.model_for_task.temperature - - # max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 - effective_max_tokens = max_tokens - if effective_max_tokens is None: - effective_max_tokens = model_info.max_tokens - if effective_max_tokens is None: - effective_max_tokens = (model_info.extra_params or {}).get("max_tokens") - if effective_max_tokens is None: - effective_max_tokens = self.model_for_task.max_tokens - - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=effective_max_tokens, - temperature=effective_temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input is not None, "嵌入输入不能为空" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "音频Base64不能为空" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) + if isinstance(active_request, ResponseRequest): + return await client.get_response(active_request) + if isinstance(active_request, EmbeddingRequest): + return await client.get_embedding(active_request) + return await client.get_audio_transcriptions(active_request) except EmptyResponseException as e: # 空回复:通常为临时问题,单独记录并重试 original_error_info = self._get_original_error_info(e) @@ -625,12 +732,19 @@ class LLMRequest: continue # 特殊处理413,尝试压缩 - if e.status_code == 413 and message_list and not compressed_messages: + if ( + e.status_code == 413 + and isinstance(active_request, ResponseRequest) + and active_request.message_list + and original_response_request is not None + and active_request.message_list == original_response_request.message_list + ): logger.warning( f"任务 '{task_display}' 的模型 '{model_info.name}' 返回413请求体过大,尝试压缩后重试..." ) # 压缩消息本身不消耗重试次数 - compressed_messages = compress_messages(message_list) + compressed_messages = compress_messages(active_request.message_list) + active_request = active_request.copy_with(message_list=compressed_messages) continue # 不可重试的HTTP错误 @@ -639,6 +753,25 @@ class LLMRequest: ) raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e + except RespParseException as e: + original_error_info = self._get_original_error_info(e) + retry_remain -= 1 + task_display = self.request_type or "未知任务" + if retry_remain <= 0: + logger.error( + f"任务 '{task_display}' 的模型 '{model_info.name}' 在响应解析多次失败后仍然失败。{original_error_info}" + ) + raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e + + logger.warning( + f"任务 '{task_display}' 的模型 '{model_info.name}' 返回内容解析失败(可重试): {str(e)}{original_error_info}。" + f"剩余重试次数: {retry_remain}" + ) + await asyncio.sleep(api_provider.retry_interval) + + except ReqAbortException: + raise + except Exception as e: logger.error(traceback.format_exc()) @@ -658,7 +791,7 @@ class LLMRequest: self, request_type: RequestType, message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, - tool_options: list[ToolOption] | None = None, + tool_options: List[ToolOption] | None = None, response_format: RespFormat | None = None, stream_response_handler: Optional[Callable[..., Any]] = None, async_response_parser: Optional[Callable[..., Any]] = None, @@ -666,9 +799,25 @@ class LLMRequest: max_tokens: Optional[int] = None, embedding_input: str | None = None, audio_base64: str | None = None, - ) -> Tuple[APIResponse, ModelInfo]: - """ - 调度器函数,负责模型选择、故障切换。 + interrupt_flag: asyncio.Event | None = None, + ) -> LLMExecutionResult: + """执行一次完整的模型调度请求。 + + Args: + request_type: 请求类型。 + message_factory: 消息工厂,仅在响应请求中使用。 + tool_options: 工具定义列表。 + response_format: 响应格式定义。 + stream_response_handler: 流式响应处理函数。 + async_response_parser: 非流式响应解析函数。 + temperature: 显式指定的温度参数。 + max_tokens: 显式指定的最大输出 token 数。 + embedding_input: 嵌入输入文本。 + audio_base64: Base64 编码的音频数据。 + interrupt_flag: 外部中断标记。 + + Returns: + LLMExecutionResult: 单次模型执行结果对象。 """ failed_models_this_request: Set[str] = set() max_attempts = len(self.model_for_task.model_list) @@ -676,32 +825,65 @@ class LLMRequest: for _ in range(max_attempts): model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] 已选择模型 model={model_info.name} " + f"provider={api_provider.name} request_type={request_type.value}" + ) message_list = [] if message_factory: + if self.request_type.startswith("maisaka_"): + logger.info(f"LLMOrchestrator[{self.request_type}] 正在通过 message_factory 构建消息列表") message_list = message_factory(client) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] message_factory 返回了 {len(message_list)} 条消息" + ) try: - response = await self._attempt_request_on_model( - model_info, - api_provider, - client, - request_type, + request = self._build_client_request( + request_type=request_type, + model_info=model_info, message_list=message_list, tool_options=tool_options, response_format=response_format, stream_response_handler=stream_response_handler, async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, temperature=temperature, max_tokens=max_tokens, embedding_input=embedding_input, audio_base64=audio_base64, ) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] 正在向模型 model={model_info.name} 发送请求 " + f"(tool_options={len(tool_options or [])})" + ) + response = await self._attempt_request_on_model( + api_provider, + client, + request=request, + ) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] 模型 model={model_info.name} 已返回 API 响应" + ) total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] if response_usage := response.usage: total_tokens += response_usage.total_tokens self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) - return response, model_info + return LLMExecutionResult(api_response=response, model_info=model_info) + + except ReqAbortException as e: + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] 模型 model={model_info.name} 的请求已被外部信号中断" + ) + raise e except ModelAttemptFailed as e: last_exception = e.original_exception or e @@ -719,46 +901,27 @@ class LLMRequest: raise last_exception raise RuntimeError("请求失败,所有可用模型均已尝试失败。") - def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - # sourcery skip: extract-method - """构建工具选项列表""" - if not tools: - return None - tool_options: List[ToolOption] = [] - for tool in tools: - tool_legal = True - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - try: - assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - tool_options_builder.add_param( - name=param[0], - param_type=param[1], - description=param[2], - required=param[3], - enum_values=param[4], - ) - except AssertionError as ae: - tool_legal = False - logger.error(f"{param[0]} 参数定义错误: {str(ae)}") - except Exception as e: - tool_legal = False - logger.error(f"构建工具参数失败: {str(e)}") - if tool_legal: - tool_options.append(tool_options_builder.build()) - return tool_options or None + def _build_tool_options(self, tools: List[ToolDefinitionInput] | None) -> List[ToolOption] | None: + """将任意输入工具定义列表规范化为内部工具选项。 + + Args: + tools: 原始工具定义列表。 + + Returns: + List[ToolOption] | None: 规范化后的工具选项列表。 + """ + return normalize_tool_options(tools) @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" + """提取 `` 思维链内容。 + + Args: + content: 原始模型输出文本。 + + Returns: + Tuple[str, str]: `(正文内容, 推理内容)`。 + """ match = re.search(r"(?:)?(.*?)", content, re.DOTALL) content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() reasoning = match[1].strip() if match else "" @@ -766,7 +929,14 @@ class LLMRequest: @staticmethod def _get_original_error_info(e: Exception) -> str: - """获取原始错误信息""" + """提取底层异常信息。 + + Args: + e: 当前捕获的异常对象。 + + Returns: + str: 可直接拼接到日志中的底层异常描述。 + """ if e.__cause__: original_error_type = type(e.__cause__).__name__ original_error_msg = str(e.__cause__) @@ -777,17 +947,16 @@ class LLMRequest: class TempMethodsLLMUtils: @staticmethod def get_model_info_by_name(model_name: str) -> ModelInfo: - """根据模型名称获取模型信息 + """根据模型名称获取模型信息。 Args: - model_config: ModelConfig实例 model_name: 模型名称 Returns: - ModelInfo: 模型信息 + ModelInfo: 模型信息。 Raises: - ValueError: 未找到指定模型 + ValueError: 未找到指定模型。 """ for model in config_manager.get_model_config().models: if model.name == model_name: @@ -796,17 +965,16 @@ class TempMethodsLLMUtils: @staticmethod def get_provider_by_name(provider_name: str) -> APIProvider: - """根据提供商名称获取提供商信息 + """根据提供商名称获取提供商信息。 Args: - model_config: ModelConfig实例 provider_name: 提供商名称 Returns: - APIProvider: API提供商信息 + APIProvider: API 提供商信息。 Raises: - ValueError: 未找到指定提供商 + ValueError: 未找到指定提供商。 """ for provider in config_manager.get_model_config().api_providers: if provider.name == provider_name: diff --git a/src/maisaka/LICENSE b/src/maisaka/LICENSE deleted file mode 100644 index cb1ae897..00000000 --- a/src/maisaka/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2026 SengokuCola - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/src/maisaka/builtin_tools.py b/src/maisaka/builtin_tools.py index 0017f1fb..6afeb68d 100644 --- a/src/maisaka/builtin_tools.py +++ b/src/maisaka/builtin_tools.py @@ -1,48 +1,163 @@ -""" -MaiSaka built-in tool definitions. -""" +"""Maisaka 内置工具声明。""" -from typing import List +from copy import deepcopy +from typing import Any, Dict, List -from src.llm_models.payload_content.tool_option import ToolOption, ToolParamType +from src.core.tooling import ToolSpec, build_tool_detailed_description +from src.llm_models.payload_content.tool_option import ToolDefinitionInput -def create_builtin_tools() -> List[ToolOption]: - """Create built-in tools exposed to the main chat-loop model.""" - from src.llm_models.payload_content.tool_option import ToolOptionBuilder +def _build_tool_spec( + name: str, + brief_description: str, + parameters_schema: Dict[str, Any] | None = None, + detailed_description: str = "", +) -> ToolSpec: + """构建单个内置工具声明。 - tools: List[ToolOption] = [] + Args: + name: 工具名称。 + brief_description: 简要描述。 + parameters_schema: 参数 Schema。 + detailed_description: 详细描述;为空时自动根据参数生成。 - wait_builder = ToolOptionBuilder() - wait_builder.set_name("wait") - wait_builder.set_description("Pause speaking and wait for the user to provide more input.") - wait_builder.add_param( - name="seconds", - param_type=ToolParamType.INTEGER, - description="How many seconds to wait before timing out.", - required=True, - enum_values=None, + Returns: + ToolSpec: 构建完成的工具声明。 + """ + + normalized_schema = deepcopy(parameters_schema) if parameters_schema is not None else None + return ToolSpec( + name=name, + brief_description=brief_description, + detailed_description=( + detailed_description.strip() + or build_tool_detailed_description(normalized_schema) + ), + parameters_schema=normalized_schema, + provider_name="maisaka_builtin", + provider_type="builtin", ) - tools.append(wait_builder.build()) - - reply_builder = ToolOptionBuilder() - reply_builder.set_name("reply") - reply_builder.set_description("Generate and emit a visible reply based on the current thought.") - tools.append(reply_builder.build()) - - no_reply_builder = ToolOptionBuilder() - no_reply_builder.set_name("no_reply") - no_reply_builder.set_description("Do not emit a visible reply this round and continue thinking.") - tools.append(no_reply_builder.build()) - - stop_builder = ToolOptionBuilder() - stop_builder.set_name("stop") - stop_builder.set_description("Stop the current inner loop and return control to the outer chat flow.") - tools.append(stop_builder.build()) - - return tools -def get_builtin_tools() -> List[ToolOption]: - """Return built-in tools.""" - return create_builtin_tools() +def create_builtin_tool_specs() -> List[ToolSpec]: + """创建 Maisaka 内置工具声明列表。 + + Returns: + List[ToolSpec]: 内置工具声明列表。 + """ + + return [ + _build_tool_spec( + name="wait", + brief_description="暂停当前对话并等待用户新的输入。", + parameters_schema={ + "type": "object", + "properties": { + "seconds": { + "type": "integer", + "description": "等待的秒数。", + }, + }, + "required": ["seconds"], + }, + ), + _build_tool_spec( + name="reply", + brief_description="根据当前思考生成并发送一条可见回复。", + parameters_schema={ + "type": "object", + "properties": { + "msg_id": { + "type": "string", + "description": "要回复的目标用户消息编号。", + }, + "quote": { + "type": "boolean", + "description": "是否以引用回复的方式发送。", + "default": True, + }, + "unknown_words": { + "type": "array", + "description": "回复前可能需要查询的黑话或词条列表。", + "items": {"type": "string"}, + }, + }, + "required": ["msg_id"], + }, + ), + _build_tool_spec( + name="query_jargon", + brief_description="查询当前聊天上下文中的黑话或词条含义。", + parameters_schema={ + "type": "object", + "properties": { + "words": { + "type": "array", + "description": "要查询的词条列表。", + "items": {"type": "string"}, + }, + }, + "required": ["words"], + }, + ), + _build_tool_spec( + name="query_person_info", + brief_description="查询某个人的档案和相关记忆信息。", + parameters_schema={ + "type": "object", + "properties": { + "person_name": { + "type": "string", + "description": "人物名称、昵称或用户 ID。", + }, + "limit": { + "type": "integer", + "description": "最多返回多少条匹配记录。", + "default": 3, + }, + }, + "required": ["person_name"], + }, + ), + _build_tool_spec( + name="no_reply", + brief_description="本轮不发送可见回复,继续下一步思考。", + ), + _build_tool_spec( + name="stop", + brief_description="暂停当前内部循环,等待新的外部消息。", + ), + _build_tool_spec( + name="send_emoji", + brief_description="发送一个合适的表情包来辅助表达情绪。", + parameters_schema={ + "type": "object", + "properties": { + "emotion": { + "type": "string", + "description": "希望表达的情绪,例如 happy、sad、angry 等。", + }, + }, + }, + ), + ] + + +def get_builtin_tool_specs() -> List[ToolSpec]: + """获取 Maisaka 内置工具声明。 + + Returns: + List[ToolSpec]: 内置工具声明列表。 + """ + + return create_builtin_tool_specs() + + +def get_builtin_tools() -> List[ToolDefinitionInput]: + """获取兼容旧模型层的内置工具定义。 + + Returns: + List[ToolDefinitionInput]: 可直接传给模型层的工具定义。 + """ + + return [tool_spec.to_llm_definition() for tool_spec in create_builtin_tool_specs()] diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py new file mode 100644 index 00000000..9525f299 --- /dev/null +++ b/src/maisaka/chat_loop_service.py @@ -0,0 +1,866 @@ +"""Maisaka 对话循环服务。""" + +from base64 import b64decode +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from time import perf_counter +from typing import Any, Dict, List, Optional, Sequence + +import asyncio +import json +import random + +from PIL import Image as PILImage +from pydantic import BaseModel, Field as PydanticField +from rich.console import Group, RenderableType +from rich.panel import Panel +from rich.pretty import Pretty +from rich.text import Text + +from src.cli.console import console +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.common.logger import get_logger +from src.common.prompt_i18n import load_prompt +from src.config.config import global_config +from src.core.tooling import ToolRegistry, ToolSpec +from src.know_u.knowledge import extract_category_ids_from_result +from src.llm_models.model_client.base_client import BaseClient +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType +from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options +from src.services.llm_service import LLMServiceClient + +from .builtin_tools import get_builtin_tools +from .context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage +from .message_adapter import format_speaker_content + + +@dataclass(slots=True) +class ChatResponse: + """LLM 对话循环单步响应。""" + + content: Optional[str] + tool_calls: List[ToolCall] + raw_message: AssistantMessage + + +class ToolFilterSelection(BaseModel): + """工具筛选响应。""" + + selected_tool_names: list[str] = PydanticField(default_factory=list) + """经过预筛后保留的候选工具名称列表。""" + + +logger = get_logger("maisaka_chat_loop") + + +class MaisakaChatLoopService: + """负责 Maisaka 主对话循环、系统提示词和终端渲染。""" + + def __init__( + self, + chat_system_prompt: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 2048, + ) -> None: + """初始化 Maisaka 对话循环服务。 + + Args: + chat_system_prompt: 可选的系统提示词。 + temperature: 规划器温度参数。 + max_tokens: 规划器最大输出长度。 + """ + + self._temperature = temperature + self._max_tokens = max_tokens + self._extra_tools: List[ToolOption] = [] + self._interrupt_flag: asyncio.Event | None = None + self._tool_registry: ToolRegistry | None = None + self._prompts_loaded = False + self._prompt_load_lock = asyncio.Lock() + self._personality_prompt = self._build_personality_prompt() + if chat_system_prompt is None: + self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant." + else: + self._chat_system_prompt = chat_system_prompt + self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner") + self._tool_filter_llm = LLMServiceClient( + task_name=global_config.maisaka.tool_filter_task_name, + request_type="maisaka_tool_filter", + ) + + @property + def personality_prompt(self) -> str: + """返回当前人格提示词。""" + + return self._personality_prompt + + def _build_personality_prompt(self) -> str: + """构造人格提示词。""" + + try: + bot_name = global_config.bot.nickname + if global_config.bot.alias_names: + bot_nickname = f", also known as {','.join(global_config.bot.alias_names)}" + else: + bot_nickname = "" + + prompt_personality = global_config.personality.personality + if ( + hasattr(global_config.personality, "states") + and global_config.personality.states + and hasattr(global_config.personality, "state_probability") + and global_config.personality.state_probability > 0 + and random.random() < global_config.personality.state_probability + ): + prompt_personality = random.choice(global_config.personality.states) + + return f"Your name is {bot_name}{bot_nickname}; persona: {prompt_personality};" + except Exception: + return "Your name is MaiMai; persona: lively and cute AI assistant." + + async def ensure_chat_prompt_loaded(self, tools_section: str = "") -> None: + """确保主聊天提示词已经加载完成。 + + Args: + tools_section: 额外注入到提示词中的工具说明片段。 + """ + + if self._prompts_loaded: + return + + async with self._prompt_load_lock: + if self._prompts_loaded: + return + + try: + self._chat_system_prompt = load_prompt( + "maidairy_chat", + file_tools_section=tools_section, + bot_name=global_config.bot.nickname, + identity=self._personality_prompt, + ) + except Exception: + self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant." + + self._prompts_loaded = True + + def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None: + """设置额外工具定义。 + + Args: + tools: 兼容旧接口的额外工具定义列表。 + """ + + self._extra_tools = normalize_tool_options(list(tools)) or [] + + def set_tool_registry(self, tool_registry: ToolRegistry | None) -> None: + """设置统一工具注册表。 + + Args: + tool_registry: 统一工具注册表;传入 ``None`` 时退回旧工具列表模式。 + """ + + self._tool_registry = tool_registry + + def set_interrupt_flag(self, interrupt_flag: asyncio.Event | None) -> None: + """设置当前 planner 请求使用的中断标记。""" + self._interrupt_flag = interrupt_flag + + def _build_request_messages(self, selected_history: List[LLMContextMessage]) -> List[Message]: + """构造发给大模型的消息列表。 + + Args: + selected_history: 已选中的上下文消息列表。 + + Returns: + List[Message]: 发送给大模型的消息列表。 + """ + + messages: List[Message] = [] + system_msg = MessageBuilder().set_role(RoleType.System) + system_msg.add_text_content(self._chat_system_prompt) + messages.append(system_msg.build()) + + for msg in selected_history: + llm_message = msg.to_llm_message() + if llm_message is not None: + messages.append(llm_message) + + return messages + + @staticmethod + def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool: + """判断一个工具是否属于默认内置工具。 + + Args: + tool_spec: 待判断的工具声明。 + + Returns: + bool: 是否为默认内置工具。 + """ + + return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin" + + @classmethod + def _split_builtin_and_candidate_tools( + cls, + tool_specs: List[ToolSpec], + ) -> tuple[List[ToolSpec], List[ToolSpec]]: + """拆分内置工具与可筛选工具列表。 + + Args: + tool_specs: 当前全部工具声明。 + + Returns: + tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。 + """ + + builtin_tool_specs: List[ToolSpec] = [] + candidate_tool_specs: List[ToolSpec] = [] + for tool_spec in tool_specs: + if cls._is_builtin_tool_spec(tool_spec): + builtin_tool_specs.append(tool_spec) + else: + candidate_tool_specs.append(tool_spec) + return builtin_tool_specs, candidate_tool_specs + + @staticmethod + def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str: + """截断工具筛选阶段展示的文本。 + + Args: + text: 原始文本。 + max_length: 最长保留字符数。 + + Returns: + str: 截断后的文本。 + """ + + normalized_text = text.strip() + if len(normalized_text) <= max_length: + return normalized_text + return f"{normalized_text[: max_length - 1]}…" + + def _build_tool_filter_prompt( + self, + selected_history: List[LLMContextMessage], + candidate_tool_specs: List[ToolSpec], + max_keep: int, + ) -> str: + """构造小模型工具预筛选提示词。 + + Args: + selected_history: 已选中的对话上下文。 + candidate_tool_specs: 非内置候选工具列表。 + max_keep: 最多保留的候选工具数量。 + + Returns: + str: 用于工具预筛的小模型提示词。 + """ + + history_lines: List[str] = [] + for message in selected_history[-10:]: + plain_text = message.processed_plain_text.strip() + if not plain_text: + continue + history_lines.append( + f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}" + ) + + if history_lines: + history_section = "\n".join(history_lines) + else: + history_section = "- 当前没有可用的对话上下文。" + + tool_lines = [ + f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}" + for tool_spec in candidate_tool_specs + ] + tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。" + + return ( + "你是 Maisaka 的工具预筛选器。\n" + "你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n" + "默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n" + "你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n" + f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n" + "请严格返回 JSON 对象,格式为:" + '{"selected_tool_names":["工具名1","工具名2"]}\n\n' + f"【最近对话】\n{history_section}\n\n" + f"【候选工具(仅简要描述)】\n{tool_section}" + ) + + @staticmethod + def _parse_tool_filter_response( + response_text: str, + candidate_tool_specs: List[ToolSpec], + max_keep: int, + ) -> List[ToolSpec] | None: + """解析工具预筛选响应。 + + Args: + response_text: 小模型返回的原始文本。 + candidate_tool_specs: 非内置候选工具列表。 + max_keep: 最多保留的候选工具数量。 + + Returns: + List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。 + """ + + normalized_response = response_text.strip() + if not normalized_response: + return None + + selected_tool_names: List[str] + try: + selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names + except Exception: + try: + parsed_payload = json.loads(normalized_response) + except json.JSONDecodeError: + return None + + if isinstance(parsed_payload, dict): + raw_tool_names = parsed_payload.get("selected_tool_names", []) + elif isinstance(parsed_payload, list): + raw_tool_names = parsed_payload + else: + return None + + if not isinstance(raw_tool_names, list): + return None + + selected_tool_names = [] + for item in raw_tool_names: + normalized_name = str(item).strip() + if normalized_name: + selected_tool_names.append(normalized_name) + + candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs} + filtered_tool_specs: List[ToolSpec] = [] + seen_names: set[str] = set() + for tool_name in selected_tool_names: + normalized_name = tool_name.strip() + if not normalized_name or normalized_name in seen_names: + continue + tool_spec = candidate_map.get(normalized_name) + if tool_spec is None: + continue + + seen_names.add(normalized_name) + filtered_tool_specs.append(tool_spec) + if len(filtered_tool_specs) >= max_keep: + break + + return filtered_tool_specs + + async def _filter_tool_specs_for_planner( + self, + selected_history: List[LLMContextMessage], + tool_specs: List[ToolSpec], + ) -> List[ToolSpec]: + """在将工具交给 planner 前进行快速预筛选。 + + Args: + selected_history: 已选中的对话上下文。 + tool_specs: 当前全部可用工具声明。 + + Returns: + List[ToolSpec]: 最终交给 planner 的工具声明列表。 + """ + + threshold = max(1, int(global_config.maisaka.tool_filter_threshold)) + max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep)) + if len(tool_specs) <= threshold: + return tool_specs + + builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs) + if not candidate_tool_specs: + return tool_specs + if len(candidate_tool_specs) <= max_keep: + return [*builtin_tool_specs, *candidate_tool_specs] + + filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep) + logger.info( + "工具预筛选开始: " + f"总工具数={len(tool_specs)} " + f"内置工具数={len(builtin_tool_specs)} " + f"候选工具数={len(candidate_tool_specs)} " + f"最多保留候选数={max_keep}" + ) + + try: + generation_result = await self._tool_filter_llm.generate_response( + prompt=filter_prompt, + options=LLMGenerationOptions( + temperature=0.0, + max_tokens=256, + response_format=RespFormat( + format_type=RespFormatType.JSON_SCHEMA, + schema=ToolFilterSelection, + ), + ), + ) + except Exception as exc: + logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}") + return tool_specs + + filtered_candidate_tool_specs = self._parse_tool_filter_response( + generation_result.response or "", + candidate_tool_specs, + max_keep, + ) + if filtered_candidate_tool_specs is None: + logger.warning( + "工具预筛选返回结果无法解析,保留全部工具。" + f" 原始返回={generation_result.response or ''!r}" + ) + return tool_specs + + filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs] + if not filtered_tool_specs: + logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。") + return tool_specs + + logger.info( + "工具预筛选完成: " + f"筛选前总数={len(tool_specs)} " + f"筛选后总数={len(filtered_tool_specs)} " + f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}" + ) + return filtered_tool_specs + + async def analyze_knowledge_need( + self, + chat_history: List[LLMContextMessage], + categories_summary: str, + ) -> List[str]: + """分析当前对话是否需要检索知识库分类。""" + visible_history: List[str] = [] + for message in chat_history[-8:]: + if not message.processed_plain_text: + continue + visible_history.append(f"{message.role}: {message.processed_plain_text}") + + if not visible_history or not categories_summary.strip(): + return [] + + prompt = ( + "你需要判断当前对话是否需要查询知识库。\n" + "请只返回最相关的分类编号,多个编号用空格分隔;如果完全不需要,返回 none。\n\n" + f"【可用分类】\n{categories_summary}\n\n" + f"【最近对话】\n{chr(10).join(visible_history)}" + ) + + try: + generation_result = await self._llm_chat.generate_response( + prompt=prompt, + options=LLMGenerationOptions( + temperature=0.1, + max_tokens=64, + ), + ) + except Exception: + return [] + + return extract_category_ids_from_result(generation_result.response or "") + + @staticmethod + def _get_role_badge_style(role: str) -> str: + """返回终端中角色标签的样式。 + + Args: + role: 消息角色名称。 + + Returns: + str: Rich 可识别的样式字符串。 + """ + + if role == "system": + return "bold white on blue" + if role == "user": + return "bold black on green" + if role == "assistant": + return "bold black on yellow" + if role == "tool": + return "bold white on magenta" + return "bold white on bright_black" + + @staticmethod + def _get_role_badge_label(role: str) -> str: + """返回终端中角色标签的中文名称。 + + Args: + role: 消息角色名称。 + + Returns: + str: 用于展示的中文角色名称。 + """ + + if role == "system": + return "系统" + if role == "user": + return "用户" + if role == "assistant": + return "助手" + if role == "tool": + return "工具" + return "未知" + + @staticmethod + def _build_terminal_image_preview(image_base64: str) -> Optional[str]: + """构造终端图片预览字符画。 + + Args: + image_base64: 图片的 Base64 编码。 + + Returns: + Optional[str]: 生成成功时返回字符画文本,否则返回 ``None``。 + """ + + ascii_chars = " .:-=+*#%@" + + try: + image_bytes = b64decode(image_base64) + with PILImage.open(BytesIO(image_bytes)) as image: + grayscale = image.convert("L") + width, height = grayscale.size + if width <= 0 or height <= 0: + return None + + preview_width = max(8, int(global_config.maisaka.terminal_image_preview_width)) + preview_height = max(1, int(height * (preview_width / width) * 0.5)) + resized = grayscale.resize((preview_width, preview_height)) + pixels = list(resized.tobytes()) + except Exception: + return None + + rows: List[str] = [] + for row_index in range(preview_height): + row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width] + row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels) + rows.append(row) + + return "\n".join(rows) + + @classmethod + def _render_message_content(cls, content: Any) -> RenderableType: + """将消息内容渲染为终端可展示对象。 + + Args: + content: 原始消息内容。 + + Returns: + RenderableType: Rich 可渲染对象。 + """ + + if isinstance(content, str): + return Text(content) + + if isinstance(content, list): + parts: List[RenderableType] = [] + for item in content: + if isinstance(item, str): + parts.append(Text(item)) + continue + if isinstance(item, tuple) and len(item) == 2: + image_format, image_base64 = item + if isinstance(image_format, str) and isinstance(image_base64, str): + approx_size = max(0, len(image_base64) * 3 // 4) + size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B" + preview_parts: List[RenderableType] = [ + Text(f"图片格式 image/{image_format} {size_text}\nbase64 内容已省略", style="magenta") + ] + if global_config.maisaka.terminal_image_preview: + preview_text = cls._build_terminal_image_preview(image_base64) + if preview_text: + preview_parts.append(Text(preview_text, style="white")) + parts.append( + Panel( + Group(*preview_parts), + border_style="magenta", + padding=(0, 1), + ) + ) + continue + if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str): + parts.append(Text(item["text"])) + else: + parts.append(Pretty(item, expand_all=True)) + return Group(*parts) if parts else Text("") + + if content is None: + return Text("") + + return Pretty(content, expand_all=True) + + @staticmethod + def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]: + """将工具调用对象格式化为易读字典。 + + Args: + tool_call: 原始工具调用对象或字典。 + + Returns: + Dict[str, Any]: 适合终端展示的工具调用字典。 + """ + + if isinstance(tool_call, dict): + function_info = tool_call.get("function", {}) + return { + "id": tool_call.get("id"), + "name": function_info.get("name", tool_call.get("name")), + "arguments": function_info.get("arguments", tool_call.get("arguments")), + } + + return { + "id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)), + "name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)), + "arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)), + } + + def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel: + """渲染单个工具调用面板。 + + Args: + tool_call: 原始工具调用对象。 + index: 工具调用在当前消息中的序号。 + parent_index: 所属消息的序号。 + + Returns: + Panel: 工具调用展示面板。 + """ + + title = Text.assemble( + Text(" 工具调用 ", style="bold white on magenta"), + Text(f" #{parent_index}.{index}", style="muted"), + ) + return Panel( + Pretty(self._format_tool_call_for_display(tool_call), expand_all=True), + title=title, + border_style="magenta", + padding=(0, 1), + ) + + def _render_message_panel(self, message: Any, index: int) -> Panel: + """渲染单条消息面板。 + + Args: + message: 原始消息对象或字典。 + index: 消息序号。 + + Returns: + Panel: 终端展示面板。 + """ + + if isinstance(message, dict): + raw_role = message.get("role", "unknown") + content = message.get("content") + tool_call_id = message.get("tool_call_id") + else: + raw_role = getattr(message, "role", "unknown") + content = getattr(message, "content", None) + tool_call_id = getattr(message, "tool_call_id", None) + + role = raw_role.value if isinstance(raw_role, RoleType) else str(raw_role) + title = Text.assemble( + Text(f" {self._get_role_badge_label(role)} ", style=self._get_role_badge_style(role)), + Text(f" #{index}", style="muted"), + ) + + parts: List[RenderableType] = [] + if content not in (None, "", []): + parts.append(Text(" 消息 ", style="bold cyan")) + parts.append(self._render_message_content(content)) + + if tool_call_id: + parts.append( + Text.assemble( + Text(" 工具调用编号 ", style="bold magenta"), + Text(" "), + Text(str(tool_call_id), style="magenta"), + ) + ) + + if not parts: + parts.append(Text("[空消息]", style="muted")) + + return Panel( + Group(*parts), + title=title, + border_style="dim", + padding=(0, 1), + ) + + async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse: + """执行一轮 Maisaka 规划器请求。 + + Args: + chat_history: 当前对话历史。 + + Returns: + ChatResponse: 本轮规划器返回结果。 + """ + + await self.ensure_chat_prompt_loaded() + selected_history, selection_reason = self._select_llm_context_messages(chat_history) + built_messages = self._build_request_messages(selected_history) + + def message_factory(_client: BaseClient) -> List[Message]: + """返回当前轮次已经构建好的请求消息。 + + Args: + _client: 当前模型客户端;此处不依赖客户端能力。 + + Returns: + List[Message]: 已经构建好的消息列表。 + """ + + del _client + return built_messages + + all_tools: List[ToolDefinitionInput] + if self._tool_registry is not None: + tool_specs = await self._tool_registry.list_tools() + filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs) + all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs] + else: + all_tools = [*get_builtin_tools(), *self._extra_tools] + + ordered_panels: List[Panel] = [] + for index, msg in enumerate(built_messages, start=1): + ordered_panels.append(self._render_message_panel(msg, index)) + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + for tool_call_index, tool_call in enumerate(tool_calls, start=1): + ordered_panels.append(self._render_tool_call_panel(tool_call, tool_call_index, index)) + + if global_config.maisaka.show_thinking and ordered_panels: + console.print( + Panel( + Group(*ordered_panels), + title="MaiSaka 大模型请求 - 对话单步", + subtitle=selection_reason, + border_style="cyan", + padding=(0, 1), + ) + ) + + request_started_at = perf_counter() + logger.info( + "规划器请求开始: " + f"已选上下文消息数={len(selected_history)} " + f"大模型消息数={len(built_messages)} " + f"工具数={len(all_tools)} " + f"启用打断={self._interrupt_flag is not None}" + ) + generation_result = await self._llm_chat.generate_response_with_messages( + message_factory=message_factory, + options=LLMGenerationOptions( + tool_options=all_tools if all_tools else None, + temperature=self._temperature, + max_tokens=self._max_tokens, + interrupt_flag=self._interrupt_flag, + ), + ) + request_elapsed = perf_counter() - request_started_at + logger.info(f"规划器请求完成,耗时={request_elapsed:.3f} 秒") + + tool_call_summaries = [ + { + "调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)), + "工具名": getattr(tool_call, "func_name", getattr(tool_call, "name", None)), + "参数": getattr(tool_call, "args", getattr(tool_call, "arguments", None)), + } + for tool_call in (generation_result.tool_calls or []) + ] + logger.info( + f"Maisaka 规划器返回结果: 内容={generation_result.response or ''!r} " + f"工具调用={tool_call_summaries}" + ) + + raw_message = AssistantMessage( + content=generation_result.response or "", + timestamp=datetime.now(), + tool_calls=generation_result.tool_calls or [], + ) + return ChatResponse( + content=generation_result.response, + tool_calls=generation_result.tool_calls or [], + raw_message=raw_message, + ) + + @staticmethod + def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]: + """选择真正发送给 LLM 的上下文消息。 + + Args: + chat_history: 当前全部对话历史。 + + Returns: + tuple[List[LLMContextMessage], str]: `(已选上下文, 选择说明)`。 + """ + + max_context_size = max(1, int(global_config.chat.max_context_size)) + selected_indices: List[int] = [] + counted_message_count = 0 + + for index in range(len(chat_history) - 1, -1, -1): + message = chat_history[index] + if message.to_llm_message() is None: + continue + + selected_indices.append(index) + if message.count_in_context: + counted_message_count += 1 + if counted_message_count >= max_context_size: + break + + if not selected_indices: + return [], f"上下文判定:最近 {max_context_size} 条 user/assistant(当前 0 条)" + + selected_indices.reverse() + selected_history = [chat_history[index] for index in selected_indices] + return ( + selected_history, + ( + f"上下文判定:最近 {max_context_size} 条 user/assistant;" + f"展示并发送窗口内消息 {len(selected_history)} 条" + ), + ) + + @staticmethod + def build_chat_context(user_text: str) -> List[LLMContextMessage]: + """根据用户输入构造最小对话上下文。 + + Args: + user_text: 用户输入文本。 + + Returns: + List[LLMContextMessage]: 构造好的上下文消息列表。 + """ + + timestamp = datetime.now() + visible_text = format_speaker_content( + global_config.maisaka.user_name.strip() or "用户", + user_text, + timestamp, + ) + planner_prefix = ( + f"[时间]{timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{global_config.maisaka.user_name.strip() or '用户'}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" + ) + return [ + SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]), + visible_text=visible_text, + timestamp=timestamp, + source_kind="user", + ) + ] diff --git a/src/maisaka/cli.py b/src/maisaka/cli.py deleted file mode 100644 index ba4c85b1..00000000 --- a/src/maisaka/cli.py +++ /dev/null @@ -1,448 +0,0 @@ -""" -MaiSaka CLI and conversation loop. -""" - -from datetime import datetime -from typing import Optional - -import asyncio -import os - -from rich import box -from rich.markdown import Markdown -from rich.panel import Panel -from rich.text import Text - -from src.common.data_models.mai_message_data_model import MaiMessage -from src.config.config import global_config - -from .config import ( - ENABLE_COGNITION_MODULE, - ENABLE_EMOTION_MODULE, - ENABLE_KNOWLEDGE_MODULE, - ENABLE_MCP, - ENABLE_TIMING_MODULE, - SHOW_THINKING, - USER_NAME, - console, -) -from .input_reader import InputReader -from .knowledge import retrieve_relevant_knowledge -from .knowledge_store import get_knowledge_store -from .llm_service import MaiSakaLLMService, build_message, remove_last_perception -from .message_adapter import format_speaker_content -from .mcp_client import MCPManager -from .timing import build_timing_info -from .tool_handlers import ( - ToolHandlerContext, - handle_list_files, - handle_mcp_tool, - handle_read_file, - handle_stop, - handle_unknown_tool, - handle_wait, - handle_write_file, -) - - -class BufferCLI: - """Command line interface for Maisaka.""" - - def __init__(self): - self.llm_service: Optional[MaiSakaLLMService] = None - self._reader = InputReader() - self._chat_history: Optional[list[MaiMessage]] = None - self._knowledge_store = get_knowledge_store() - - knowledge_stats = self._knowledge_store.get_stats() - if knowledge_stats["total_items"] > 0: - console.print(f"[success][OK] Knowledge store: {knowledge_stats['total_items']} item(s)[/success]") - else: - console.print("[muted][OK] Knowledge store: initialized with no data[/muted]") - - self._chat_start_time: Optional[datetime] = None - self._last_user_input_time: Optional[datetime] = None - self._last_assistant_response_time: Optional[datetime] = None - self._user_input_times: list[datetime] = [] - self._mcp_manager: Optional[MCPManager] = None - self._init_llm() - - def _init_llm(self): - """Initialize the LLM service from the main project config.""" - thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower() - enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None - - self.llm_service = MaiSakaLLMService( - api_key="", - base_url=None, - model="", - enable_thinking=enable_thinking, - ) - - model_name = self.llm_service._model_name - console.print(f"[success][OK] LLM service initialized[/success] [muted](model: {model_name})[/muted]") - - def _build_tool_context(self) -> ToolHandlerContext: - """Build the shared tool handler context.""" - ctx = ToolHandlerContext( - llm_service=self.llm_service, - reader=self._reader, - user_input_times=self._user_input_times, - ) - ctx.last_user_input_time = self._last_user_input_time - return ctx - - def _show_banner(self): - """Render the startup banner.""" - banner = Text() - banner.append("MaiSaka", style="bold cyan") - banner.append(" v2.0\n", style="muted") - banner.append("Type to chat | Ctrl+C to exit", style="muted") - - console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2))) - console.print() - - async def _start_chat(self, user_text: str): - """Append user input and continue the inner loop.""" - if not self.llm_service: - console.print("[warning]LLM service is not initialized; skipping chat.[/warning]") - return - - now = datetime.now() - self._last_user_input_time = now - self._user_input_times.append(now) - - if self._chat_history is None: - self._chat_start_time = now - self._last_assistant_response_time = None - self._chat_history = self.llm_service.build_chat_context(user_text) - else: - self._chat_history.append(build_message(role="user", content=format_speaker_content(USER_NAME, user_text))) - - await self._run_llm_loop(self._chat_history) - - async def _run_llm_loop(self, chat_history: list[MaiMessage]): - """ - Main inner loop for the Maisaka planner. - - Each round may produce internal thoughts and optionally call tools: - - reply(): generate a visible reply for the current round - - no_reply(): skip visible output and continue the loop - - wait(seconds): wait for new user input - - stop(): stop the current inner loop and return to idle - - Per round: - 1. Run enabled analysis modules in parallel when the previous round used tools. - 2. Call the planner model with the current history. - 3. Append the assistant thought and execute any requested tools. - """ - consecutive_errors = 0 - last_had_tool_calls = True - - while True: - if last_had_tool_calls: - timing_info = build_timing_info( - self._chat_start_time, - self._last_user_input_time, - self._last_assistant_response_time, - self._user_input_times, - ) - - tasks = [] - status_text_parts = [] - - if ENABLE_EMOTION_MODULE: - tasks.append(("eq", self.llm_service.analyze_emotion(chat_history))) - status_text_parts.append("emotion") - if ENABLE_COGNITION_MODULE: - tasks.append(("cognition", self.llm_service.analyze_cognition(chat_history))) - status_text_parts.append("cognition") - if ENABLE_TIMING_MODULE: - tasks.append(("timing", self.llm_service.analyze_timing(chat_history, timing_info))) - status_text_parts.append("timing") - if ENABLE_KNOWLEDGE_MODULE: - tasks.append(("knowledge", retrieve_relevant_knowledge(self.llm_service, chat_history))) - status_text_parts.append("knowledge") - - with console.status( - f"[info]{' + '.join(status_text_parts)} analyzing...[/info]", - spinner="dots", - ): - results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True) - - eq_result, cognition_result, timing_result, knowledge_result = None, None, None, None - result_idx = 0 - if ENABLE_EMOTION_MODULE: - eq_result = results[result_idx] - result_idx += 1 - if ENABLE_COGNITION_MODULE: - cognition_result = results[result_idx] - result_idx += 1 - if ENABLE_TIMING_MODULE: - timing_result = results[result_idx] - result_idx += 1 - if ENABLE_KNOWLEDGE_MODULE: - knowledge_result = results[result_idx] - result_idx += 1 - - eq_analysis = "" - if ENABLE_EMOTION_MODULE: - if isinstance(eq_result, Exception): - console.print(f"[warning]Emotion analysis failed: {eq_result}[/warning]") - elif eq_result: - eq_analysis = eq_result - if SHOW_THINKING: - console.print( - Panel( - Markdown(eq_analysis), - title="Emotion", - border_style="bright_yellow", - padding=(0, 1), - style="dim", - ) - ) - - cognition_analysis = "" - if ENABLE_COGNITION_MODULE: - if isinstance(cognition_result, Exception): - console.print(f"[warning]Cognition analysis failed: {cognition_result}[/warning]") - elif cognition_result: - cognition_analysis = cognition_result - if SHOW_THINKING: - console.print( - Panel( - Markdown(cognition_analysis), - title="Cognition", - border_style="bright_cyan", - padding=(0, 1), - style="dim", - ) - ) - - timing_analysis = "" - if ENABLE_TIMING_MODULE: - if isinstance(timing_result, Exception): - console.print(f"[warning]Timing analysis failed: {timing_result}[/warning]") - elif timing_result: - timing_analysis = timing_result - if SHOW_THINKING: - console.print( - Panel( - Markdown(timing_analysis), - title="Timing", - border_style="bright_blue", - padding=(0, 1), - style="dim", - ) - ) - - knowledge_analysis = "" - if ENABLE_KNOWLEDGE_MODULE: - if isinstance(knowledge_result, Exception): - console.print(f"[warning]Knowledge analysis failed: {knowledge_result}[/warning]") - elif knowledge_result: - knowledge_analysis = knowledge_result - if SHOW_THINKING: - console.print( - Panel( - Markdown(knowledge_analysis), - title="Knowledge", - border_style="bright_magenta", - padding=(0, 1), - style="dim", - ) - ) - - remove_last_perception(chat_history) - - perception_parts = [] - if eq_analysis: - perception_parts.append(f"Emotion\n{eq_analysis}") - if cognition_analysis: - perception_parts.append(f"Cognition\n{cognition_analysis}") - if timing_analysis: - perception_parts.append(f"Timing\n{timing_analysis}") - if knowledge_analysis: - perception_parts.append(f"Knowledge\n{knowledge_analysis}") - - if perception_parts: - chat_history.append( - build_message( - role="assistant", - content="\n\n".join(perception_parts), - message_kind="perception", - source="assistant", - ) - ) - else: - if SHOW_THINKING: - console.print("[muted]Skipping module analysis because the last round used no tools.[/muted]") - - with console.status("[info]AI is thinking...[/info]", spinner="dots"): - try: - response = await self.llm_service.chat_loop_step(chat_history) - consecutive_errors = 0 - except Exception as exc: - consecutive_errors += 1 - console.print(f"[error]LLM call failed: {exc}[/error]") - if consecutive_errors >= 3: - console.print("[error]Too many consecutive errors. Exiting chat.[/error]\n") - break - continue - - chat_history.append(response.raw_message) - self._last_assistant_response_time = datetime.now() - - if SHOW_THINKING and response.content: - console.print( - Panel( - Markdown(response.content), - title="Thought", - border_style="dim", - padding=(1, 2), - style="dim", - ) - ) - - if response.content and not response.tool_calls: - last_had_tool_calls = False - continue - - if response.tool_calls: - should_stop = False - ctx = self._build_tool_context() - - for tc in response.tool_calls: - if tc.func_name == "stop": - await handle_stop(tc, chat_history) - should_stop = True - - elif tc.func_name == "reply": - reply = await self._generate_visible_reply(chat_history, response.content) - chat_history.append( - build_message( - role="tool", - content="Visible reply generated and recorded.", - source="tool", - tool_call_id=tc.call_id, - ) - ) - chat_history.append( - build_message( - role="user", - content=format_speaker_content(global_config.bot.nickname.strip() or "MaiSaka", reply), - source="guided_reply", - ) - ) - - elif tc.func_name == "no_reply": - if SHOW_THINKING: - console.print("[muted]No visible reply this round.[/muted]") - chat_history.append( - build_message( - role="tool", - content="No visible reply was sent for this round.", - source="tool", - tool_call_id=tc.call_id, - ) - ) - - elif tc.func_name == "wait": - tool_result = await handle_wait(tc, chat_history, ctx) - if ctx.last_user_input_time != self._last_user_input_time: - self._last_user_input_time = ctx.last_user_input_time - if tool_result.startswith("[[QUIT]]"): - should_stop = True - - elif tc.func_name == "write_file": - await handle_write_file(tc, chat_history) - - elif tc.func_name == "read_file": - await handle_read_file(tc, chat_history) - - elif tc.func_name == "list_files": - await handle_list_files(tc, chat_history) - - elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tc.func_name): - await handle_mcp_tool(tc, chat_history, self._mcp_manager) - - else: - await handle_unknown_tool(tc, chat_history) - - if should_stop: - console.print("[muted]Conversation paused. Waiting for new input...[/muted]\n") - break - - last_had_tool_calls = True - else: - last_had_tool_calls = False - continue - - async def _init_mcp(self): - """Initialize MCP servers and register exposed tools.""" - config_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "mcp_config.json", - ) - self._mcp_manager = await MCPManager.from_config(config_path) - - if self._mcp_manager and self.llm_service: - mcp_tools = self._mcp_manager.get_openai_tools() - if mcp_tools: - self.llm_service.set_extra_tools(mcp_tools) - summary = self._mcp_manager.get_tool_summary() - console.print( - Panel( - f"Loaded {len(mcp_tools)} MCP tool(s):\n{summary}", - title="MCP Tools", - border_style="green", - padding=(0, 1), - ) - ) - - async def _generate_visible_reply(self, chat_history: list[MaiMessage], latest_thought: str) -> str: - """Generate and emit a visible reply based on the latest thought.""" - if not self.llm_service or not latest_thought: - return "" - - with console.status("[info]Generating visible reply...[/info]", spinner="dots"): - reply = await self.llm_service.generate_reply(latest_thought, chat_history) - - console.print( - Panel( - Markdown(reply), - title="MaiSaka", - border_style="magenta", - padding=(1, 2), - ) - ) - - return reply - - async def run(self): - """Main interactive loop.""" - if ENABLE_MCP: - await self._init_mcp() - else: - console.print("[muted]MCP is disabled (ENABLE_MCP=false)[/muted]") - - self._reader.start(asyncio.get_event_loop()) - self._show_banner() - - try: - while True: - console.print("[bold cyan]> [/bold cyan]", end="") - raw_input = await self._reader.get_line() - - if raw_input is None: - console.print("\n[muted]Goodbye![/muted]") - break - - raw_input = raw_input.strip() - if not raw_input: - continue - - await self._start_chat(raw_input) - finally: - if self._mcp_manager: - await self._mcp_manager.close() diff --git a/src/maisaka/config.py b/src/maisaka/config.py deleted file mode 100644 index b454ddfd..00000000 --- a/src/maisaka/config.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -MaiSaka - 全局配置 -从主项目配置系统读取配置、Rich Console 实例、主题定义。 -""" - -from pathlib import Path -import sys - -from rich.console import Console -from rich.theme import Theme - -from src.config.config import global_config - -# 添加项目根目录到路径以导入主配置 -_root = Path(__file__).parent.parent.parent.absolute() -if str(_root) not in sys.path: - sys.path.insert(0, str(_root)) - -# ──────────────────── 模块开关配置 ──────────────────── -ENABLE_EMOTION_MODULE = global_config.maisaka.enable_emotion_module -ENABLE_COGNITION_MODULE = global_config.maisaka.enable_cognition_module -ENABLE_TIMING_MODULE = global_config.maisaka.enable_timing_module -ENABLE_KNOWLEDGE_MODULE = global_config.maisaka.enable_knowledge_module -ENABLE_MCP = global_config.maisaka.enable_mcp -ENABLE_WRITE_FILE = global_config.maisaka.enable_write_file -ENABLE_READ_FILE = global_config.maisaka.enable_read_file -ENABLE_LIST_FILES = global_config.maisaka.enable_list_files -SHOW_ANALYZE_COGNITION_PROMPT = global_config.maisaka.show_analyze_cognition_prompt -SHOW_ANALYZE_TIMING_PROMPT = global_config.maisaka.show_analyze_timing_prompt -SHOW_THINKING = global_config.maisaka.show_thinking -USER_NAME = global_config.maisaka.user_name.strip() or "用户" - - -# ──────────────────── Rich 主题 & Console ──────────────────── - -custom_theme = Theme( - { - "info": "cyan", - "success": "green", - "warning": "yellow", - "error": "bold red", - "muted": "dim", - "accent": "bold magenta", - } -) - -console = Console(theme=custom_theme) diff --git a/src/maisaka/context_messages.py b/src/maisaka/context_messages.py new file mode 100644 index 00000000..8da06a23 --- /dev/null +++ b/src/maisaka/context_messages.py @@ -0,0 +1,275 @@ +"""Maisaka 内部上下文消息抽象。""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from io import BytesIO +from typing import Optional +import base64 + +from PIL import Image as PILImage + +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.llm_models.payload_content.tool_option import ToolCall + + +def _guess_image_format(image_bytes: bytes) -> Optional[str]: + if not image_bytes: + return None + + try: + with PILImage.open(BytesIO(image_bytes)) as image: + return image.format.lower() if image.format else None + except Exception: + return None + + +def _build_message_from_sequence( + role: RoleType, + message_sequence: MessageSequence, + fallback_text: str, + *, + tool_call_id: Optional[str] = None, + tool_calls: Optional[list[ToolCall]] = None, +) -> Optional[Message]: + """根据消息片段构造统一 LLM 消息。""" + builder = MessageBuilder().set_role(role) + if role == RoleType.Assistant and tool_calls: + builder.set_tool_calls(tool_calls) + if role == RoleType.Tool and tool_call_id: + builder.add_tool_call(tool_call_id) + + has_content = False + for component in message_sequence.components: + if isinstance(component, TextComponent): + if component.text: + builder.add_text_content(component.text) + has_content = True + continue + + if isinstance(component, (EmojiComponent, ImageComponent)): + image_format = _guess_image_format(component.binary_data) + if image_format and component.binary_data: + builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8")) + has_content = True + continue + + if component.content: + builder.add_text_content(component.content) + has_content = True + + if not has_content and fallback_text: + builder.add_text_content(fallback_text) + has_content = True + + if not has_content and not (role == RoleType.Assistant and tool_calls): + return None + return builder.build() + + +class ReferenceMessageType(str, Enum): + """参考消息类型。""" + + CUSTOM = "custom" + JARGON = "jargon" + KNOWLEDGE = "knowledge" + MEMORY = "memory" + TOOL_HINT = "tool_hint" + + +class LLMContextMessage(ABC): + """Maisaka 内部用于组织 LLM 上下文的统一消息抽象。""" + + timestamp: datetime + + @property + @abstractmethod + def role(self) -> str: + """返回 LLM 消息角色。""" + + @property + @abstractmethod + def processed_plain_text(self) -> str: + """返回可读的纯文本内容。""" + + @property + def count_in_context(self) -> bool: + """是否占用普通 user/assistant 上下文窗口。""" + return True + + @property + def source(self) -> str: + """返回消息来源。""" + return self.__class__.__name__ + + @abstractmethod + def to_llm_message(self) -> Optional[Message]: + """转换为统一 LLM 消息。""" + + def consume_once(self) -> bool: + """消费一次生命周期,返回是否继续保留。""" + return True + + +@dataclass(slots=True) +class SessionBackedMessage(LLMContextMessage): + """真实会话上下文消息。""" + + raw_message: MessageSequence + visible_text: str + timestamp: datetime + message_id: Optional[str] = None + original_message: Optional[SessionMessage] = None + source_kind: str = "user" + + @property + def role(self) -> str: + return RoleType.User.value + + @property + def processed_plain_text(self) -> str: + return self.visible_text + + @property + def source(self) -> str: + return self.source_kind + + def to_llm_message(self) -> Optional[Message]: + return _build_message_from_sequence( + RoleType.User, + self.raw_message, + self.processed_plain_text, + ) + + @classmethod + def from_session_message( + cls, + session_message: SessionMessage, + *, + raw_message: MessageSequence, + visible_text: str, + source_kind: str = "user", + ) -> "SessionBackedMessage": + """从真实 SessionMessage 构造上下文消息。""" + return cls( + raw_message=raw_message, + visible_text=visible_text, + timestamp=session_message.timestamp, + message_id=session_message.message_id, + original_message=session_message, + source_kind=source_kind, + ) + + +@dataclass(slots=True) +class ReferenceMessage(LLMContextMessage): + """参考消息。""" + + content: str + timestamp: datetime + reference_type: ReferenceMessageType = ReferenceMessageType.CUSTOM + remaining_uses_value: Optional[int] = 1 + display_prefix: str = "[参考消息]" + + @property + def role(self) -> str: + return RoleType.User.value + + @property + def processed_plain_text(self) -> str: + return f"{self.display_prefix}\n{self.content}".strip() + + @property + def count_in_context(self) -> bool: + return False + + @property + def source(self) -> str: + return self.reference_type.value + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([TextComponent(self.processed_plain_text)]) + return _build_message_from_sequence(RoleType.User, message_sequence, self.processed_plain_text) + + def consume_once(self) -> bool: + if self.remaining_uses_value is None: + return True + + self.remaining_uses_value -= 1 + return self.remaining_uses_value > 0 + + +@dataclass(slots=True) +class AssistantMessage(LLMContextMessage): + """内部 assistant 消息。""" + + content: str + timestamp: datetime + tool_calls: list[ToolCall] = field(default_factory=list) + source_kind: str = "assistant" + + @property + def role(self) -> str: + return RoleType.Assistant.value + + @property + def processed_plain_text(self) -> str: + return self.content + + @property + def count_in_context(self) -> bool: + return self.source_kind != "perception" + + @property + def source(self) -> str: + return self.source_kind + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([]) + if self.content: + message_sequence.text(self.content) + return _build_message_from_sequence( + RoleType.Assistant, + message_sequence, + self.content, + tool_calls=self.tool_calls or None, + ) + + +@dataclass(slots=True) +class ToolResultMessage(LLMContextMessage): + """工具返回结果消息。""" + + content: str + timestamp: datetime + tool_call_id: str + tool_name: str = "" + success: bool = True + + @property + def role(self) -> str: + return RoleType.Tool.value + + @property + def processed_plain_text(self) -> str: + return self.content + + @property + def count_in_context(self) -> bool: + return False + + @property + def source(self) -> str: + return self.tool_name or "tool" + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([TextComponent(self.content)]) + return _build_message_from_sequence( + RoleType.Tool, + message_sequence, + self.content, + tool_call_id=self.tool_call_id, + ) diff --git a/src/maisaka/emotion.py b/src/maisaka/emotion.py deleted file mode 100644 index d2a4f657..00000000 --- a/src/maisaka/emotion.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -MaiSaka - Emotion 模块 -情绪感知分析,分析用户的情绪状态和言语态度。 - -注意:emotion.prompt 已迁移至主项目 prompts/ 目录 -使用 prompt_manager.get_prompt("maidairy_emotion") 加载。 -""" - -from typing import List, Optional - -from src.common.data_models.mai_message_data_model import MaiMessage - -from .config import USER_NAME -from .message_adapter import get_message_role, get_message_text - - -def extract_user_messages(chat_history: List[MaiMessage], limit: Optional[int] = None) -> List[MaiMessage]: - """ - 从对话历史中提取用户消息。 - - Args: - chat_history: 完整的对话历史 - limit: 最多提取多少条用户消息,None 表示不限制 - - Returns: - 只包含用户消息的列表 - """ - user_messages = [msg for msg in chat_history if get_message_role(msg) == "user"] - if limit and len(user_messages) > limit: - return user_messages[-limit:] - return user_messages - - -def build_emotion_context(chat_history: List[MaiMessage]) -> str: - """ - 构建用于情绪分析的对话上下文文本。 - - Args: - chat_history: 完整的对话历史 - - Returns: - 格式化后的对话上下文文本 - """ - # 获取最近的对话(约 8-10 条消息) - recent_messages = chat_history[-10:] if len(chat_history) > 10 else chat_history - - context_parts = [] - for msg in recent_messages: - role = get_message_role(msg) - content = get_message_text(msg) - - if role == "user": - context_parts.append(f"{USER_NAME}: {content}") - elif role == "assistant": - # 只显示 assistant 的实际发言,跳过感知信息 - if "【AI 感知】" not in content: - context_parts.append(f"助手: {content}") - - return "\n".join(context_parts) diff --git a/src/maisaka/knowledge.py b/src/maisaka/knowledge.py deleted file mode 100644 index f56fbdc5..00000000 --- a/src/maisaka/knowledge.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -MaiSaka knowledge retrieval helpers. -""" - -from typing import List - -from src.common.data_models.mai_message_data_model import MaiMessage - -from .knowledge_store import KNOWLEDGE_CATEGORIES, get_knowledge_store - -NO_RESULT_KEYWORDS = [ - "\u65e0", - "\u6ca1\u6709", - "\u4e0d\u9002\u7528", - "\u65e0\u9700", - "\u65e0\u76f8\u5173", -] - - -def extract_category_ids_from_result(result: str) -> List[str]: - """Extract valid category ids from an LLM result string.""" - if not result: - return [] - - normalized = result.strip() - if not normalized: - return [] - - lowered = normalized.lower() - if any(keyword in lowered for keyword in ["none", "no relevant", "no_need", "no need"]): - return [] - if any(keyword in normalized for keyword in NO_RESULT_KEYWORDS): - return [] - - category_ids: List[str] = [] - for part in normalized.replace(",", " ").replace("\uff0c", " ").replace("\n", " ").split(): - candidate = part.strip() - if candidate in KNOWLEDGE_CATEGORIES and candidate not in category_ids: - category_ids.append(candidate) - - return category_ids - - -async def retrieve_relevant_knowledge( - llm_service, - chat_history: List[MaiMessage], -) -> str: - """Retrieve formatted knowledge snippets relevant to the current chat history.""" - store = get_knowledge_store() - categories_summary = store.get_categories_summary() - - try: - category_ids = await llm_service.analyze_knowledge_need(chat_history, categories_summary) - if not category_ids: - return "" - return store.get_formatted_knowledge(category_ids) - except Exception: - return "" diff --git a/src/maisaka/knowledge_store.py b/src/maisaka/knowledge_store.py deleted file mode 100644 index f91573d2..00000000 --- a/src/maisaka/knowledge_store.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -MaiSaka - 了解列表持久化存储 -存储用户个人特征信息,支持层级结构和本地持久化。 -""" - -import json -import os -from pathlib import Path -from typing import Dict, List, Optional, Any -from datetime import datetime - -# 数据目录 - 项目根目录下的 mai_knowledge -PROJECT_ROOT = Path(os.path.dirname(os.path.abspath(__file__))) -KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge" -KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json" - - -# 个人特征分类列表(预定义) -KNOWLEDGE_CATEGORIES = { - "1": "性别", - "2": "性格", - "3": "饮食口味", - "4": "交友喜好", - "5": "情绪/理性倾向", - "6": "兴趣爱好", - "7": "职业/专业", - "8": "生活习惯", - "9": "价值观", - "10": "沟通风格", - "11": "学习方式", - "12": "压力应对方式", -} - - -class KnowledgeStore: - """ - 了解列表存储。 - - 特性: - - 持久化到 JSON 文件 - - 层级结构存储(按分类) - - 支持增量更新 - - 启动时自动加载 - """ - - def __init__(self): - """初始化了解存储""" - self._knowledge: Dict[str, List[Dict[str, Any]]] = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} - self._ensure_data_dir() - self._load() - - def _ensure_data_dir(self): - """确保数据目录存在""" - KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True) - - def _load(self): - """从文件加载了解数据""" - if not KNOWLEDGE_FILE.exists(): - self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} - return - - try: - with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as f: - loaded = json.load(f) - # 确保所有分类都存在 - for category_id in KNOWLEDGE_CATEGORIES: - if category_id not in loaded: - loaded[category_id] = [] - self._knowledge = loaded - except Exception as e: - print(f"[warning]加载了解数据失败: {e}[/warning]") - self._knowledge = {category_id: [] for category_id in KNOWLEDGE_CATEGORIES} - - def _save(self): - """保存了解数据到文件""" - try: - with open(KNOWLEDGE_FILE, "w", encoding="utf-8") as f: - json.dump(self._knowledge, f, ensure_ascii=False, indent=2) - except Exception as e: - print(f"[warning]保存了解数据失败: {e}[/warning]") - - def add_knowledge( - self, - category_id: str, - content: str, - metadata: Optional[Dict[str, Any]] = None, - ) -> bool: - """ - 添加一条了解信息。 - - Args: - category_id: 分类编号 - content: 了解内容 - metadata: 元数据 - - Returns: - 是否添加成功 - """ - if category_id not in KNOWLEDGE_CATEGORIES: - return False - - try: - knowledge_item = { - "id": f"know_{category_id}_{datetime.now().timestamp()}", - "content": content, - "metadata": metadata or {}, - "created_at": datetime.now().isoformat(), - } - self._knowledge[category_id].append(knowledge_item) - self._save() - return True - except Exception: - return False - - def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]: - """ - 获取某个分类的所有了解信息。 - - Args: - category_id: 分类编号 - - Returns: - 该分类的所有了解信息 - """ - return self._knowledge.get(category_id, []) - - def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]: - """获取所有了解信息""" - return self._knowledge - - def get_category_name(self, category_id: str) -> str: - """获取分类名称""" - return KNOWLEDGE_CATEGORIES.get(category_id, "未知分类") - - def get_categories_summary(self) -> str: - """获取所有分类的摘要(用于 LLM 展示)""" - lines = [] - for category_id, category_name in KNOWLEDGE_CATEGORIES.items(): - count = len(self._knowledge.get(category_id, [])) - if count > 0: - lines.append(f"{category_id}. {category_name} ({count}条)") - else: - lines.append(f"{category_id}. {category_name} (无数据)") - return "\n".join(lines) - - def get_formatted_knowledge(self, category_ids: List[str]) -> str: - """ - 获取指定分类的了解内容,格式化为文本。 - - Args: - category_ids: 分类编号列表 - - Returns: - 格式化后的了解内容文本 - """ - parts = [] - for category_id in category_ids: - category_name = self.get_category_name(category_id) - items = self.get_category_knowledge(category_id) - - if items: - parts.append(f"【{category_name}】") - for item in items: - content = item.get("content", "") - parts.append(f" - {content}") - - return "\n".join(parts) if parts else "暂无相关了解信息" - - def get_stats(self) -> Dict[str, Any]: - """获取了解数据统计信息""" - total_items = sum(len(items) for items in self._knowledge.values()) - return { - "total_categories": len(KNOWLEDGE_CATEGORIES), - "total_items": total_items, - "data_file": str(KNOWLEDGE_FILE), - "data_exists": KNOWLEDGE_FILE.exists(), - "data_size_kb": KNOWLEDGE_FILE.stat().st_size / 1024 if KNOWLEDGE_FILE.exists() else 0, - } - - -# 全局单例 -_knowledge_store_instance: Optional[KnowledgeStore] = None - - -def get_knowledge_store() -> KnowledgeStore: - """获取了解存储实例(单例模式)""" - global _knowledge_store_instance - if _knowledge_store_instance is None: - _knowledge_store_instance = KnowledgeStore() - return _knowledge_store_instance diff --git a/src/maisaka/llm_service.py b/src/maisaka/llm_service.py deleted file mode 100644 index 63b6d505..00000000 --- a/src/maisaka/llm_service.py +++ /dev/null @@ -1,600 +0,0 @@ -""" -MaiSaka LLM 服务 - 使用主项目 LLM 系统 -将主项目的 LLMRequest 适配为 MaiSaka 需要的接口 -""" - -from datetime import datetime - -import random -from dataclasses import dataclass -from typing import Any, List, Optional - -from rich.console import Group -from rich.panel import Panel -from rich.pretty import Pretty -from rich.text import Text - -from src.common.data_models.mai_message_data_model import MaiMessage -from src.common.logger import get_logger -from src.config.config import config_manager, global_config -from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType -from src.llm_models.payload_content.tool_option import ToolCall, ToolOption -from src.llm_models.utils_model import LLMRequest -from src.prompt.prompt_manager import prompt_manager - -from . import config -from .config import console -from .builtin_tools import get_builtin_tools -from .message_adapter import ( - build_message, - format_speaker_content, - get_message_kind, - get_message_role, - get_message_text, - get_tool_call_id, - get_tool_calls, - remove_last_perception, - to_llm_message, -) - -logger = get_logger("maisaka_llm") - -@dataclass -class ChatResponse: - """LLM 对话循环单步响应""" - - content: Optional[str] - tool_calls: List[ToolCall] - raw_message: MaiMessage - - -class MaiSakaLLMService: - """MaiSaka LLM 服务 - 适配主项目 LLM 系统""" - - def __init__( - self, - api_key: Optional[str] = None, - base_url: Optional[str] = None, - model: Optional[str] = None, - chat_system_prompt: Optional[str] = None, - temperature: float = 0.5, - max_tokens: int = 2048, - enable_thinking: Optional[bool] = None, - ): - """ - 初始化 LLM 服务 - - 参数仅为兼容性保留,实际使用主项目配置 - """ - self._temperature = temperature - self._max_tokens = max_tokens - self._enable_thinking = enable_thinking - self._extra_tools: List[dict] = [] - - # 获取主项目模型配置 - try: - model_config = config_manager.get_model_config() - self._model_configs = model_config.model_task_config - except Exception: - # 如果配置加载失败,使用默认配置 - from src.config.model_configs import ModelTaskConfig - - self._model_configs = ModelTaskConfig() - logger.warning("无法加载主项目模型配置,使用默认配置") - - # 初始化 LLMRequest 实例(只使用 tool_use 和 replyer) - self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use") - # 主对话也使用 tool_use 模型(因为需要工具调用支持) - self._llm_planner = LLMRequest(model_set=self._model_configs.planner, request_type="maisaka_planner") - self._llm_chat = self._llm_planner - self._llm_utils = self._llm_tool_use - # 回复生成使用 replyer 模型 - self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer") - - # 尝试修复数据库 schema(忽略错误) - self._try_fix_database_schema() - - # 构建人设信息 - personality_prompt = self._build_personality_prompt() - - # 加载系统提示词 - if chat_system_prompt is None: - try: - chat_prompt = prompt_manager.get_prompt("maidairy_chat") - tools_section = "" - if config.ENABLE_WRITE_FILE: - tools_section += "\n• write_file(filename, content) — 在 mai_files 目录下写入文件。" - if config.ENABLE_READ_FILE: - tools_section += "\n• read_file(filename) — 读取 mai_files 目录下的文件内容。" - if config.ENABLE_LIST_FILES: - tools_section += "\n• list_files() — 获取 mai_files 目录下所有文件的元信息列表。" - - chat_prompt.add_context("file_tools_section", tools_section if tools_section else "") - chat_prompt.add_context("bot_name", global_config.bot.nickname) - chat_prompt.add_context("identity", personality_prompt) - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - self._chat_system_prompt = loop.run_until_complete(prompt_manager.render_prompt(chat_prompt)) - logger.info(f"系统提示词已渲染,长度: {len(self._chat_system_prompt)}") - finally: - loop.close() - except Exception as e: - logger.error(f"加载系统提示词失败: {e}") - self._chat_system_prompt = f"{personality_prompt}\n\n你是一个友好的 AI 助手。" - else: - self._chat_system_prompt = chat_system_prompt - - self._model_name = ( - self._model_configs.planner.model_list[0] if self._model_configs.planner.model_list else "未配置" - ) - - - # 加载子模块提示词 - self._emotion_prompt: Optional[str] = None - self._cognition_prompt: Optional[str] = None - self._timing_prompt: Optional[str] = None - try: - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - self._emotion_prompt = loop.run_until_complete( - prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_emotion")) - ) - self._cognition_prompt = loop.run_until_complete( - prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_cognition")) - ) - self._timing_prompt = loop.run_until_complete( - prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_timing")) - ) - logger.info("成功加载 MaiSaka 子模块提示词") - finally: - loop.close() - except Exception as e: - logger.warning(f"加载子模块提示词失败,将使用默认提示词: {e}") - - def _try_fix_database_schema(self) -> None: - """尝试修复数据库 schema,添加缺失的列""" - try: - from src.common.database.database_client import get_db_session - from sqlalchemy import text - - with get_db_session() as session: - # 检查 model_api_provider_name 列是否存在 - result = session.execute(text("PRAGMA table_info(llm_usage)")) - columns = [row[1] for row in result.fetchall()] - - if "model_api_provider_name" not in columns: - # 添加缺失的列 - session.execute(text("ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)")) - session.commit() - logger.info("数据库 schema 已修复:添加 model_api_provider_name 列") - except Exception: - # 静默忽略任何错误,不影响正常流程 - pass - - def _build_personality_prompt(self) -> str: - """构建人设信息,参考 replyer 的做法""" - try: - bot_name = global_config.bot.nickname - if global_config.bot.alias_names: - bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" - else: - bot_nickname = "" - - # 获取基础personality - prompt_personality = global_config.personality.personality - - # 检查是否需要随机替换为状态(personality 本体) - if ( - hasattr(global_config.personality, "states") - and global_config.personality.states - and hasattr(global_config.personality, "state_probability") - and global_config.personality.state_probability > 0 - and random.random() < global_config.personality.state_probability - ): - # 随机选择一个状态替换personality - selected_state = random.choice(global_config.personality.states) - prompt_personality = selected_state - - prompt_personality = f"{prompt_personality};" - return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" - except Exception as e: - logger.warning(f"构建人设信息失败: {e}") - # 返回默认人设 - return "你的名字是麦麦,你是一个活泼可爱的AI助手。" - - def set_extra_tools(self, tools: List[dict]) -> None: - """设置额外的工具定义(如 MCP 工具)""" - self._extra_tools = list(tools) - - @staticmethod - def _get_role_badge_style(role: str) -> str: - """为不同 role 返回不同的标签样式。""" - if role == "system": - return "bold white on blue" - if role == "user": - return "bold black on green" - if role == "assistant": - return "bold black on yellow" - if role == "tool": - return "bold white on magenta" - return "bold white on bright_black" - - @staticmethod - def _render_message_content(content: Any) -> object: - """把消息内容转成适合 Rich 输出的 renderable。""" - if isinstance(content, str): - return Text(content) - - if isinstance(content, list): - parts: list[object] = [] - for item in content: - if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str): - parts.append(Text(item["text"])) - else: - parts.append(Pretty(item, expand_all=True)) - return Group(*parts) if parts else Text("") - - if content is None: - return Text("") - - return Pretty(content, expand_all=True) - - @staticmethod - def _format_tool_call_for_display(tool_call: Any) -> dict[str, Any]: - """将 tool call 转成适合 CLI 展示的结构。""" - if isinstance(tool_call, dict): - function_info = tool_call.get("function", {}) - return { - "id": tool_call.get("id"), - "name": function_info.get("name", tool_call.get("name")), - "arguments": function_info.get("arguments", tool_call.get("arguments")), - } - - return { - "id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)), - "name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)), - "arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)), - } - - def _render_message_panel(self, message: Any, index: int) -> Panel: - """渲染主循环 prompt 中的一条消息。""" - if isinstance(message, dict): - raw_role = message.get("role", "unknown") - content = message.get("content") - tool_calls = message.get("tool_calls") - tool_call_id = message.get("tool_call_id") - else: - raw_role = getattr(message, "role", "unknown") - content = getattr(message, "content", None) - tool_calls = getattr(message, "tool_calls", None) - tool_call_id = getattr(message, "tool_call_id", None) - - role = raw_role.value if hasattr(raw_role, "value") else str(raw_role) - title = Text.assemble( - Text(f" {role.upper()} ", style=self._get_role_badge_style(role)), - Text(f" #{index}", style="muted"), - ) - - parts: list[object] = [] - if content not in (None, "", []): - parts.append(Text(" message ", style="bold cyan")) - parts.append(self._render_message_content(content)) - - if tool_calls: - parts.append(Text(" tool_calls ", style="bold magenta")) - parts.append( - Pretty( - [self._format_tool_call_for_display(tool_call) for tool_call in tool_calls], - expand_all=True, - ) - ) - - if tool_call_id: - parts.append( - Text.assemble( - Text(" tool_call_id ", style="bold magenta"), - Text(" "), - Text(str(tool_call_id), style="magenta"), - ) - ) - - if not parts: - parts.append(Text("[empty message]", style="muted")) - - return Panel( - Group(*parts), - title=title, - border_style="dim", - padding=(0, 1), - ) - - @staticmethod - def _tool_option_to_dict(tool: "ToolOption") -> dict: - """将 ToolOption 对象转换为主项目期望的 dict 格式 - - 主项目的 _build_tool_options() 期望的格式: - { - "name": str, - "description": str, - "parameters": List[Tuple[name, ToolParamType, description, required, enum_values]] - } - """ - params = [] - if tool.params: - for param in tool.params: - params.append((param.name, param.param_type, param.description, param.required, param.enum_values)) - return {"name": tool.name, "description": tool.description, "parameters": params} - - async def chat_loop_step(self, chat_history: list[MaiMessage]) -> ChatResponse: - """执行对话循环的一步 - 使用 tool_use 模型""" - - def message_factory(client) -> list[Message]: - """将 MaiSaka 的 chat_history 转换为主项目的 Message 格式""" - messages: list[Message] = [] - - # 首先添加系统提示词 - system_msg = MessageBuilder().set_role(RoleType.System) - system_msg.add_text_content(self._chat_system_prompt) - messages.append(system_msg.build()) - - # 然后添加对话历史 - for msg in chat_history: - llm_message = to_llm_message(msg) - if llm_message is not None: - messages.append(llm_message) - - return messages - - # 调用 LLM(使用带消息的接口) - # 合并内置工具和额外工具(将 ToolOption 对象转换为 dict) - all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + ( - self._extra_tools if self._extra_tools else [] - ) - - # 打印消息列表 - built_messages = message_factory(None) - - ordered_panels = [self._render_message_panel(msg, index + 1) for index, msg in enumerate(built_messages)] - - if config.SHOW_THINKING and ordered_panels: - console.print( - Panel( - Group(*ordered_panels), - title="MaiSaka LLM Request - chat_loop_step", - border_style="cyan", - padding=(0, 1), - ) - ) - - - response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async( - message_factory=message_factory, - tools=all_tools if all_tools else None, - temperature=self._temperature, - max_tokens=self._max_tokens, - ) - raw_message = build_message( - role=RoleType.Assistant.value, - content=response or "", - source="assistant", - tool_calls=tool_calls or None, - ) - - return ChatResponse( - content=response, - tool_calls=tool_calls or [], - raw_message=raw_message, - ) - - def _filter_for_api(self, chat_history: list[MaiMessage]) -> str: - """过滤对话历史为 API 格式""" - parts = [] - for msg in chat_history: - role = get_message_role(msg) - content = get_message_text(msg) - - # 跳过内部字段 - if get_message_kind(msg) == "perception" or role == RoleType.Tool.value: - continue - - if role == RoleType.System.value: - parts.append(f"System: {content}") - elif role == RoleType.User.value: - parts.append(f"User: {content}") - elif role == RoleType.Assistant.value: - # 处理工具调用 - tool_calls = get_tool_calls(msg) - if tool_calls: - tool_desc = ", ".join([tc.func_name for tc in tool_calls if tc.func_name]) - parts.append(f"Assistant (called tools: {tool_desc})") - else: - parts.append(f"Assistant: {content}") - - return "\n\n".join(parts) - - def build_chat_context(self, user_text: str) -> list[MaiMessage]: - """构建对话上下文""" - return [ - build_message( - role=RoleType.User.value, - content=format_speaker_content(config.USER_NAME, user_text), - source="user", - ) - ] - - # ──────── 分析模块(使用 utils 模型) ──────── - - async def analyze_emotion(self, chat_history: list[MaiMessage]) -> str: - """情绪分析 - 使用 utils 模型""" - filtered = [m for m in chat_history if get_message_kind(m) != "perception"] - recent = filtered[-10:] if len(filtered) > 10 else filtered - - # 使用加载的系统提示词 - system_prompt = self._emotion_prompt or "请分析以下对话中用户的情绪状态和言语态度:" - - prompt_parts = [f"{system_prompt}\n\n【对话内容】\n"] - for msg in recent: - role = get_message_role(msg) - content = get_message_text(msg) - if role == RoleType.User.value: - prompt_parts.append(f"{config.USER_NAME}: {content}") - elif role == RoleType.Assistant.value: - prompt_parts.append(f"助手: {content}") - - prompt = "\n".join(prompt_parts) - - if config.SHOW_THINKING: - print("\n" + "=" * 60) - print("MaiSaka LLM Request - analyze_emotion:") - print(f" {prompt}") - print("=" * 60 + "\n") - - try: - response, _ = await self._llm_utils.generate_response_async( - prompt=prompt, - temperature=0.3, - max_tokens=512, - ) - - return response - except Exception as e: - logger.error(f"情绪分析 LLM 调用出错: {e}") - return "" - - async def analyze_cognition(self, chat_history: list[MaiMessage]) -> str: - """认知分析 - 使用 utils 模型""" - filtered = [m for m in chat_history if get_message_kind(m) != "perception"] - recent = filtered[-10:] if len(filtered) > 10 else filtered - - # 使用加载的系统提示词 - system_prompt = self._cognition_prompt or "请分析以下对话中用户的意图、认知状态和目的:" - - prompt_parts = [f"{system_prompt}\n\n【对话内容】\n"] - for msg in recent: - role = get_message_role(msg) - content = get_message_text(msg) - if role == RoleType.User.value: - prompt_parts.append(f"{config.USER_NAME}: {content}") - elif role == RoleType.Assistant.value: - prompt_parts.append(f"助手: {content}") - - prompt = "\n".join(prompt_parts) - - if config.SHOW_THINKING and config.SHOW_ANALYZE_COGNITION_PROMPT: - print("\n" + "=" * 60) - print("MaiSaka LLM Request - analyze_cognition:") - print(f" {prompt}") - print("=" * 60 + "\n") - - try: - response, _ = await self._llm_utils.generate_response_async( - prompt=prompt, - temperature=0.3, - max_tokens=512, - ) - - return response - except Exception as e: - logger.error(f"认知分析 LLM 调用出错: {e}") - return "" - - async def analyze_timing(self, chat_history: list[MaiMessage], timing_info: str) -> str: - """时间分析 - 使用 utils 模型""" - filtered = [ - m - for m in chat_history - if get_message_kind(m) != "perception" and get_message_role(m) != RoleType.System.value - ] - - # 使用加载的系统提示词 - system_prompt = self._timing_prompt or "请分析以下对话的时间节奏和用户状态:" - - prompt_parts = [f"{system_prompt}\n\n【系统时间戳信息】\n{timing_info}\n\n【当前对话记录】\n"] - for msg in filtered: - role = get_message_role(msg) - content = get_message_text(msg) - if role == RoleType.User.value: - prompt_parts.append(f"{config.USER_NAME}: {content}") - elif role == RoleType.Assistant.value: - prompt_parts.append(f"助手: {content}") - - prompt = "\n".join(prompt_parts) - - if config.SHOW_THINKING and config.SHOW_ANALYZE_TIMING_PROMPT: - print("\n" + "=" * 60) - print("MaiSaka LLM Request - analyze_timing:") - print(f" {prompt}") - print("=" * 60 + "\n") - - try: - response, _ = await self._llm_utils.generate_response_async( - prompt=prompt, - temperature=0.3, - max_tokens=512, - ) - - return response - except Exception as e: - logger.error(f"时间分析 LLM 调用出错: {e}") - return "" - - # ──────── 回复生成(使用 replyer 模型) ──────── - - async def generate_reply(self, reason: str, chat_history: list[MaiMessage]) -> str: - """ - 生成回复 - 使用 replyer 模型 - 可供 Replyer 类直接调用 - """ - from datetime import datetime - from .replyer import format_chat_history - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # 格式化对话历史 - filtered_history = [ - msg - for msg in chat_history - if get_message_role(msg) != RoleType.System.value and get_message_kind(msg) != "perception" - ] - formatted_history = format_chat_history(filtered_history) - - # 获取回复提示词 - try: - replyer_prompt = prompt_manager.get_prompt("maidairy_replyer") - system_prompt = await prompt_manager.render_prompt(replyer_prompt) - except Exception: - system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。" - - user_prompt = ( - f"当前时间:{current_time}\n\n【聊天记录】\n{formatted_history}\n\n【你的想法】\n{reason}\n\n现在,你说:" - ) - - messages = f"System: {system_prompt}\n\nUser: {user_prompt}" - - if config.SHOW_THINKING: - print("\n" + "=" * 60) - print("MaiSaka LLM Request - generate_reply:") - print(f" {messages}") - print("=" * 60 + "\n") - - try: - response, _ = await self._llm_replyer.generate_response_async( - prompt=messages, - temperature=0.8, - max_tokens=512, - ) - return response.strip() if response else "..." - except Exception as e: - logger.error(f"回复生成 LLM 调用出错: {e}") - return "..." - - - - - diff --git a/src/maisaka/mcp_client/config.py b/src/maisaka/mcp_client/config.py deleted file mode 100644 index 742d3218..00000000 --- a/src/maisaka/mcp_client/config.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -MaiSaka - MCP 配置加载与验证 -从 mcp_config.json 读取 MCP 服务器定义,解析为结构化配置对象。 - -配置格式示例: -{ - "mcpServers": { - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "C:/Users"], - "env": {} - }, - "remote-api": { - "url": "http://localhost:8080/sse", - "headers": {"Authorization": "Bearer xxx"} - } - } -} - -- command + args: Stdio 传输(启动子进程) -- url: SSE 传输(连接远程服务器) -""" - -import json -import os -from dataclasses import dataclass, field -from typing import Optional - -from ..config import console - - -@dataclass -class MCPServerConfig: - """单个 MCP 服务器配置。""" - - name: str - - # ── Stdio 传输 ── - command: Optional[str] = None - args: list[str] = field(default_factory=list) - env: Optional[dict[str, str]] = None - - # ── SSE 传输 ── - url: Optional[str] = None - headers: dict[str, str] = field(default_factory=dict) - - @property - def transport_type(self) -> str: - """返回传输类型: 'stdio' / 'sse' / 'unknown'。""" - if self.command: - return "stdio" - if self.url: - return "sse" - return "unknown" - - -def load_mcp_config(config_path: str = "mcp_config.json") -> list[MCPServerConfig]: - """ - 从配置文件加载 MCP 服务器列表。 - - Args: - config_path: 配置文件路径 - - Returns: - 解析后的 MCPServerConfig 列表;文件不存在或为空时返回空列表。 - """ - if not os.path.isfile(config_path): - return [] - - try: - with open(config_path, "r", encoding="utf-8") as f: - data = json.load(f) - except (json.JSONDecodeError, OSError) as e: - console.print(f"[warning]⚠️ 读取 MCP 配置失败: {e}[/warning]") - return [] - - mcp_servers = data.get("mcpServers", {}) - if not isinstance(mcp_servers, dict): - console.print("[warning]⚠️ mcp_config.json 中 mcpServers 格式无效[/warning]") - return [] - - configs: list[MCPServerConfig] = [] - for name, cfg in mcp_servers.items(): - if not isinstance(cfg, dict): - console.print(f"[warning]⚠️ MCP 服务器 '{name}' 配置格式无效,已跳过[/warning]") - continue - - server = MCPServerConfig( - name=name, - command=cfg.get("command"), - args=cfg.get("args", []), - env=cfg.get("env"), - url=cfg.get("url"), - headers=cfg.get("headers", {}), - ) - - if server.transport_type == "unknown": - console.print(f"[warning]⚠️ MCP 服务器 '{name}' 缺少 command 或 url,已跳过[/warning]") - continue - - configs.append(server) - - return configs diff --git a/src/maisaka/mcp_client/connection.py b/src/maisaka/mcp_client/connection.py deleted file mode 100644 index 9f489402..00000000 --- a/src/maisaka/mcp_client/connection.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -MaiSaka - 单个 MCP 服务器连接管理 -封装单个 MCP 服务器的连接生命周期:连接 → 发现工具 → 调用工具 → 断开。 -""" - -from contextlib import AsyncExitStack -from typing import Any, Optional - -from ..config import console -from .config import MCPServerConfig - -# ──────────────────── MCP SDK 可选导入 ──────────────────── -# -# mcp 是可选依赖。如果未安装,MCP_AVAILABLE = False, -# MCPManager.from_config() 会检测到并返回 None,不影响主程序运行。 - -try: - from mcp import ClientSession - - try: - from mcp.client.stdio import StdioServerParameters - except ImportError: - from mcp import StdioServerParameters # type: ignore[attr-defined] - - from mcp.client.stdio import stdio_client - - MCP_AVAILABLE = True -except ImportError: - MCP_AVAILABLE = False - ClientSession = None # type: ignore[assignment,misc] - StdioServerParameters = None # type: ignore[assignment,misc] - stdio_client = None # type: ignore[assignment] - -try: - from mcp.client.sse import sse_client - - SSE_AVAILABLE = True -except ImportError: - SSE_AVAILABLE = False - sse_client = None # type: ignore[assignment] - - -class MCPConnection: - """ - 管理单个 MCP 服务器的连接生命周期。 - - 支持两种传输方式: - - Stdio: 启动子进程,通过 stdin/stdout 通信 - - SSE: 连接远程 HTTP SSE 端点 - """ - - def __init__(self, config: MCPServerConfig): - self.config = config - self.session: Optional[Any] = None # mcp.ClientSession - self.tools: list = [] # mcp Tool objects - self._exit_stack = AsyncExitStack() - - async def connect(self) -> bool: - """ - 连接到 MCP 服务器并发现可用工具。 - - Returns: - True 表示连接成功,False 表示失败。 - """ - if not MCP_AVAILABLE: - console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]") - return False - - try: - await self._exit_stack.__aenter__() - - if self.config.transport_type == "stdio": - read_stream, write_stream = await self._connect_stdio() - elif self.config.transport_type == "sse": - read_stream, write_stream = await self._connect_sse() - else: - console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]") - return False - - # 创建并初始化 MCP 会话 - self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream)) - await self.session.initialize() - - # 发现工具 - result = await self.session.list_tools() - self.tools = result.tools if hasattr(result, "tools") else [] - - return True - - except Exception as e: - console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]") - await self.close() - return False - - async def _connect_stdio(self): - """建立 Stdio 传输连接。""" - params = StdioServerParameters( - command=self.config.command, - args=self.config.args, - env=self.config.env, - ) - return await self._exit_stack.enter_async_context(stdio_client(params)) - - async def _connect_sse(self): - """建立 SSE 传输连接。""" - if not SSE_AVAILABLE: - raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]") - return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers)) - - async def call_tool(self, tool_name: str, arguments: dict) -> str: - """ - 调用 MCP 工具并返回结果文本。 - - Args: - tool_name: 工具名称 - arguments: 工具参数字典 - - Returns: - 工具执行结果的文本表示。 - """ - if not self.session: - return f"MCP 服务器 '{self.config.name}' 未连接" - - result = await self.session.call_tool(tool_name, arguments=arguments) - - # 将结果内容转换为文本 - parts: list[str] = [] - for content in result.content: - if hasattr(content, "text"): - parts.append(content.text) - elif hasattr(content, "data"): - # 二进制/图片内容,展示类型信息 - content_type = getattr(content, "mimeType", "unknown") - parts.append(f"[{content_type} 二进制内容]") - elif hasattr(content, "type"): - parts.append(f"[{content.type} 内容]") - - return "\n".join(parts) if parts else "工具执行成功(无输出)" - - async def close(self): - """关闭连接并释放资源。""" - try: - await self._exit_stack.aclose() - except Exception: - pass - self.session = None - self.tools = [] diff --git a/src/maisaka/mcp_client/manager.py b/src/maisaka/mcp_client/manager.py deleted file mode 100644 index 5409a39d..00000000 --- a/src/maisaka/mcp_client/manager.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -MaiSaka - MCP 管理器 -管理所有 MCP 服务器连接,提供统一的工具发现与调用接口。 -""" - -from typing import Optional - -from ..config import console -from .config import MCPServerConfig, load_mcp_config -from .connection import MCPConnection, MCP_AVAILABLE - -# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突 -BUILTIN_TOOL_NAMES = frozenset( - { - "wait", - "stop", - "create_table", - "list_tables", - "view_table", - } -) - - -class MCPManager: - """ - MCP 服务器连接管理器。 - - 职责: - - 根据配置文件连接所有 MCP 服务器 - - 将 MCP 工具转换为 OpenAI function calling 格式 - - 路由工具调用到正确的 MCP 服务器 - - 统一管理连接生命周期 - """ - - def __init__(self): - self._connections: dict[str, MCPConnection] = {} # server_name → connection - self._tool_to_server: dict[str, str] = {} # tool_name → server_name - - # ──────── 工厂方法 ──────── - - @classmethod - async def from_config( - cls, - config_path: str = "mcp_config.json", - ) -> Optional["MCPManager"]: - """ - 从配置文件创建并初始化 MCPManager。 - - Args: - config_path: mcp_config.json 文件路径 - - Returns: - 初始化完成的 MCPManager;无配置或全部连接失败时返回 None。 - """ - configs = load_mcp_config(config_path) - if not configs: - return None - - if not MCP_AVAILABLE: - console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,请运行: pip install mcp[/warning]") - return None - - manager = cls() - await manager._connect_all(configs) - - if not manager._connections: - console.print("[warning]⚠️ 所有 MCP 服务器连接失败[/warning]") - return None - - return manager - - # ──────── 连接管理 ──────── - - async def _connect_all(self, configs: list[MCPServerConfig]): - """连接所有配置的 MCP 服务器,跳过失败的连接。""" - for cfg in configs: - conn = MCPConnection(cfg) - success = await conn.connect() - if not success: - continue - - self._connections[cfg.name] = conn - - # 注册工具,检查冲突 - registered = 0 - for tool in conn.tools: - tool_name = tool.name - - if tool_name in BUILTIN_TOOL_NAMES: - console.print( - f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]" - ) - continue - - if tool_name in self._tool_to_server: - existing_server = self._tool_to_server[tool_name] - console.print( - f"[warning]⚠️ MCP 工具 '{tool_name}' " - f"(来自 {cfg.name}) 与 {existing_server} 冲突,已跳过[/warning]" - ) - continue - - self._tool_to_server[tool_name] = cfg.name - registered += 1 - - console.print( - f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]" - ) - - # ──────── 工具发现 ──────── - - def get_openai_tools(self) -> list[dict]: - """ - 将所有已注册的 MCP 工具转换为 OpenAI function calling 格式。 - - Returns: - OpenAI tools 格式的工具定义列表。 - """ - tools: list[dict] = [] - - for server_name, conn in self._connections.items(): - for tool in conn.tools: - # 只包含成功注册的工具 - if tool.name not in self._tool_to_server: - continue - if self._tool_to_server[tool.name] != server_name: - continue - - # MCP inputSchema → OpenAI parameters - parameters = ( - dict(tool.inputSchema) - if hasattr(tool, "inputSchema") and tool.inputSchema - else {"type": "object", "properties": {}} - ) - # 移除 $schema 字段(部分 MCP 服务器会带上,OpenAI 不接受) - parameters.pop("$schema", None) - - tools.append( - { - "type": "function", - "function": { - "name": tool.name, - "description": (tool.description or f"MCP tool from {server_name}"), - "parameters": parameters, - }, - } - ) - - return tools - - # ──────── 工具调用 ──────── - - def is_mcp_tool(self, tool_name: str) -> bool: - """判断工具名是否为已注册的 MCP 工具。""" - return tool_name in self._tool_to_server - - async def call_tool(self, tool_name: str, arguments: dict) -> str: - """ - 调用指定的 MCP 工具。 - - 自动路由到正确的 MCP 服务器。 - - Args: - tool_name: 工具名称 - arguments: 工具参数 - - Returns: - 工具执行结果文本。 - """ - server_name = self._tool_to_server.get(tool_name) - if not server_name or server_name not in self._connections: - return f"MCP 工具 '{tool_name}' 未找到" - - conn = self._connections[server_name] - try: - return await conn.call_tool(tool_name, arguments) - except Exception as e: - return f"MCP 工具 '{tool_name}' 执行失败: {e}" - - # ──────── 信息展示 ──────── - - def get_tool_summary(self) -> str: - """获取所有已注册 MCP 工具的摘要信息。""" - parts: list[str] = [] - for server_name, conn in self._connections.items(): - tool_names = [ - t.name - for t in conn.tools - if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name - ] - if tool_names: - parts.append(f" • {server_name}: {', '.join(tool_names)}") - return "\n".join(parts) - - @property - def server_count(self) -> int: - """已连接的 MCP 服务器数量。""" - return len(self._connections) - - @property - def tool_count(self) -> int: - """已注册的 MCP 工具总数。""" - return len(self._tool_to_server) - - # ──────── 生命周期 ──────── - - async def close(self): - """关闭所有 MCP 服务器连接。""" - for conn in self._connections.values(): - await conn.close() - self._connections.clear() - self._tool_to_server.clear() diff --git a/src/maisaka/mcp_config.json b/src/maisaka/mcp_config.json deleted file mode 100644 index 959b4eed..00000000 --- a/src/maisaka/mcp_config.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "mcpServers": { - "tavily": { - "command": "npx", - "args": [ - "-y", - "mcp-remote", - "https://mcp.tavily.com/mcp/?tavilyApiKey=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" - ], - "env": {} - } - } -} \ No newline at end of file diff --git a/src/maisaka/mcp_config.json.template b/src/maisaka/mcp_config.json.template deleted file mode 100644 index 89207601..00000000 --- a/src/maisaka/mcp_config.json.template +++ /dev/null @@ -1,13 +0,0 @@ -{ - "mcpServers": { - "tavily": { - "command": "npx", - "args": [ - "-y", - "mcp-remote", - "https://mcp.tavily.com/mcp/?tavilyApiKey=YOUR_API_KEY_HERE" - ], - "env": {} - } - } -} diff --git a/src/maisaka/message_adapter.py b/src/maisaka/message_adapter.py index caa9d6dd..b52d1baa 100644 --- a/src/maisaka/message_adapter.py +++ b/src/maisaka/message_adapter.py @@ -1,181 +1,69 @@ -""" -MaiSaka message adapters built on top of the main project's MaiMessage model. -""" +"""Maisaka 文本与消息片段适配工具。""" +from copy import deepcopy from datetime import datetime -import re from typing import Optional -from uuid import uuid4 +import re -from src.common.data_models.mai_message_data_model import MaiMessage, MessageInfo, UserInfo -from src.common.data_models.message_component_data_model import MessageSequence -from src.config.config import global_config -from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType -from src.llm_models.payload_content.tool_option import ToolCall +from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent -from .config import USER_NAME - -MAISAKA_PLATFORM = "maisaka" -MAISAKA_SESSION_ID = "maisaka_cli" -MESSAGE_KIND_KEY = "maisaka_message_kind" -SOURCE_KEY = "maisaka_source" -LLM_ROLE_KEY = "maisaka_llm_role" -TOOL_CALL_ID_KEY = "maisaka_tool_call_id" -TOOL_CALLS_KEY = "maisaka_tool_calls" -SPEAKER_PREFIX_PATTERN = re.compile(r"^\[(?P[^\]]+)\](?P.*)$", re.DOTALL) +SPEAKER_PREFIX_PATTERN = re.compile( + r"^(?:(?P\d{2}:\d{2}:\d{2}))?(?:\[msg_id:(?P[^\]]+)\])?\[(?P[^\]]+)\](?P.*)$", + re.DOTALL, +) -def _build_user_info_for_role(role: str) -> UserInfo: - if role == RoleType.User.value: - return UserInfo(user_id="maisaka_user", user_nickname=USER_NAME, user_cardname=None) - if role == RoleType.Tool.value: - return UserInfo(user_id="maisaka_tool", user_nickname="tool", user_cardname=None) - return UserInfo( - user_id="maisaka_assistant", - user_nickname=global_config.bot.nickname.strip() or "MaiSaka", - user_cardname=None, - ) - - -def _serialize_tool_call(tool_call: ToolCall) -> dict: - return { - "call_id": tool_call.call_id, - "func_name": tool_call.func_name, - "args": tool_call.args or {}, - } - - -def _deserialize_tool_call(data: dict) -> ToolCall: - return ToolCall( - call_id=str(data.get("call_id", "")), - func_name=str(data.get("func_name", "")), - args=data.get("args", {}) or {}, - ) - - -def build_message( - role: str, +def format_speaker_content( + speaker_name: str, content: str, - *, - message_kind: str = "normal", - source: Optional[str] = None, - tool_call_id: Optional[str] = None, - tool_calls: Optional[list[ToolCall]] = None, timestamp: Optional[datetime] = None, message_id: Optional[str] = None, -) -> MaiMessage: - """Build a MaiMessage for the Maisaka session history.""" - resolved_timestamp = timestamp or datetime.now() - resolved_role = role.value if isinstance(role, RoleType) else role - message = MaiMessage( - message_id=message_id or f"maisaka_{uuid4().hex}", - timestamp=resolved_timestamp, - platform=MAISAKA_PLATFORM, - ) - message.message_info = MessageInfo( - user_info=_build_user_info_for_role(resolved_role), - group_info=None, - additional_config={ - LLM_ROLE_KEY: resolved_role, - MESSAGE_KIND_KEY: message_kind, - SOURCE_KEY: source or resolved_role, - TOOL_CALL_ID_KEY: tool_call_id, - TOOL_CALLS_KEY: [_serialize_tool_call(tool_call) for tool_call in (tool_calls or [])], - }, - ) - message.session_id = MAISAKA_SESSION_ID - message.raw_message = MessageSequence([]) - if content: - message.raw_message.text(content) - message.processed_plain_text = content - message.display_message = content - return message - - -def format_speaker_content(speaker_name: str, content: str) -> str: - """Format visible conversation content with an explicit speaker label.""" - return f"[{speaker_name}]{content}" +) -> str: + """将可见文本格式化为带说话人前缀的样式。""" + time_prefix = timestamp.strftime("%H:%M:%S") if timestamp is not None else "" + message_id_prefix = f"[msg_id:{message_id}]" if message_id else "" + return f"{time_prefix}{message_id_prefix}[{speaker_name}]{content}" def parse_speaker_content(content: str) -> tuple[Optional[str], str]: - """Parse content formatted as [speaker]message.""" + """解析形如 [speaker]message 的可见文本。""" match = SPEAKER_PREFIX_PATTERN.match(content or "") if not match: return None, content or "" return match.group("speaker"), match.group("content") -def get_message_text(message: MaiMessage) -> str: - if message.processed_plain_text is not None: - return message.processed_plain_text - if message.display_message is not None: - return message.display_message +def clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence: + """复制消息片段序列。""" + return MessageSequence([deepcopy(component) for component in message_sequence.components]) + +def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str: + """从消息片段序列提取可见文本。""" parts: list[str] = [] - for component in message.raw_message.components: - text = getattr(component, "text", None) - if isinstance(text, str): - parts.append(text) + for component in message_sequence.components: + if isinstance(component, TextComponent): + match = SPEAKER_PREFIX_PATTERN.match(component.text or "") + if not match: + parts.append(component.text) + continue + + normalized_parts: list[str] = [] + if match.group("timestamp"): + normalized_parts.append(match.group("timestamp")) + message_id = match.group("message_id") + if message_id: + normalized_parts.append(f"[msg_id:{message_id}]") + normalized_parts.append(f"[{match.group('speaker')}]") + normalized_parts.append(match.group("content")) + parts.append("".join(normalized_parts)) + continue + + if isinstance(component, EmojiComponent): + parts.append("[表情包]") + continue + + if isinstance(component, ImageComponent): + parts.append("[图片]") + return "".join(parts) - - -def get_message_role(message: MaiMessage) -> str: - return str(message.message_info.additional_config.get(LLM_ROLE_KEY, RoleType.User.value)) - - -def get_message_kind(message: MaiMessage) -> str: - return str(message.message_info.additional_config.get(MESSAGE_KIND_KEY, "normal")) - - -def get_message_source(message: MaiMessage) -> str: - return str(message.message_info.additional_config.get(SOURCE_KEY, get_message_role(message))) - - -def is_perception_message(message: MaiMessage) -> bool: - return get_message_kind(message) == "perception" - - -def get_tool_call_id(message: MaiMessage) -> Optional[str]: - value = message.message_info.additional_config.get(TOOL_CALL_ID_KEY) - return str(value) if value else None - - -def get_tool_calls(message: MaiMessage) -> list[ToolCall]: - raw_tool_calls = message.message_info.additional_config.get(TOOL_CALLS_KEY, []) - if not isinstance(raw_tool_calls, list): - return [] - return [_deserialize_tool_call(item) for item in raw_tool_calls if isinstance(item, dict)] - - -def remove_last_perception(messages: list[MaiMessage]) -> None: - for index in range(len(messages) - 1, -1, -1): - if is_perception_message(messages[index]): - messages.pop(index) - break - - -def to_llm_message(message: MaiMessage) -> Optional[Message]: - role = get_message_role(message) - content = get_message_text(message) - tool_call_id = get_tool_call_id(message) - tool_calls = get_tool_calls(message) - - if role == RoleType.System.value: - role_type = RoleType.System - elif role == RoleType.User.value: - role_type = RoleType.User - elif role == RoleType.Assistant.value: - role_type = RoleType.Assistant - elif role == RoleType.Tool.value: - role_type = RoleType.Tool - else: - return None - - builder = MessageBuilder().set_role(role_type) - if role_type == RoleType.Assistant and tool_calls: - builder.set_tool_calls(tool_calls) - if role_type == RoleType.Tool and tool_call_id: - builder.add_tool_call(tool_call_id) - if content: - builder.add_text_content(content) - return builder.build() diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py new file mode 100644 index 00000000..dd806e4b --- /dev/null +++ b/src/maisaka/reasoning_engine.py @@ -0,0 +1,1387 @@ +"""Maisaka 推理引擎。""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional, cast + +import asyncio +import difflib +import json +import time +import traceback + +from sqlmodel import col, select + +from src.chat.heart_flow.heartFC_utils import CycleDetail +from src.chat.message_receive.message import SessionMessage +from src.chat.replyer.replyer_manager import replyer_manager +from src.chat.utils.utils import process_llm_response +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.common.database.database import get_db_session +from src.common.database.database_model import PersonInfo +from src.common.logger import get_logger +from src.config.config import global_config +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec +from src.know_u.knowledge_store import get_knowledge_store +from src.learners.jargon_explainer import search_jargon +from src.llm_models.exceptions import ReqAbortException +from src.llm_models.payload_content.tool_option import ToolCall +from src.services import database_service as database_api, send_service + +from .context_messages import ( + AssistantMessage, + LLMContextMessage, + SessionBackedMessage, + ToolResultMessage, +) +from .message_adapter import ( + build_visible_text_from_sequence, + clone_message_sequence, + format_speaker_content, +) + +if TYPE_CHECKING: + from .runtime import MaisakaHeartFlowChatting + from .tool_provider import BuiltinToolHandler + +logger = get_logger("maisaka_reasoning_engine") + + +class MaisakaReasoningEngine: + """负责内部思考、推理与工具执行。""" + + def __init__(self, runtime: "MaisakaHeartFlowChatting") -> None: + self._runtime = runtime + self._last_reasoning_content: str = "" + + def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]: + """构造 Maisaka 内置工具处理器映射。 + + Returns: + dict[str, BuiltinToolHandler]: 工具名到处理器的映射。 + """ + + return { + "reply": self._invoke_reply_tool, + "no_reply": self._invoke_no_reply_tool, + "query_jargon": self._invoke_query_jargon_tool, + "query_person_info": self._invoke_query_person_info_tool, + "wait": self._invoke_wait_tool, + "stop": self._invoke_stop_tool, + "send_emoji": self._invoke_send_emoji_tool, + } + + async def run_loop(self) -> None: + """独立消费消息批次,并执行对应的内部思考轮次。""" + try: + while self._runtime._running: + cached_messages = await self._runtime._internal_turn_queue.get() + timeout_triggered = cached_messages is None + if not timeout_triggered and not cached_messages: + self._runtime._internal_turn_queue.task_done() + continue + + self._runtime._agent_state = self._runtime._STATE_RUNNING + if cached_messages: + self._append_wait_interrupted_message_if_needed() + await self._ingest_messages(cached_messages) + anchor_message = cached_messages[-1] + else: + anchor_message = self._get_timeout_anchor_message() + if anchor_message is None: + logger.warning( + f"{self._runtime.log_prefix} 等待超时后缺少可复用的锚点消息,跳过本轮继续思考" + ) + self._runtime._internal_turn_queue.task_done() + continue + logger.info(f"{self._runtime.log_prefix} 等待超时后开始新一轮思考") + self._runtime._chat_history.append(self._build_wait_timeout_message()) + self._trim_chat_history() + try: + for round_index in range(self._runtime._max_internal_rounds): + cycle_detail = self._start_cycle() + self._runtime._log_cycle_started(cycle_detail, round_index) + planner_started_at = time.time() + try: + logger.info( + f"{self._runtime.log_prefix} 规划器开始执行: " + f"回合={round_index + 1} " + f"历史消息数={len(self._runtime._chat_history)} " + f"开始时间={planner_started_at:.3f}" + ) + interrupt_flag = asyncio.Event() + self._runtime._planner_interrupt_flag = interrupt_flag + self._runtime._chat_loop_service.set_interrupt_flag(interrupt_flag) + try: + response = await self._runtime._chat_loop_service.chat_loop_step(self._runtime._chat_history) + finally: + if self._runtime._planner_interrupt_flag is interrupt_flag: + self._runtime._planner_interrupt_flag = None + self._runtime._chat_loop_service.set_interrupt_flag(None) + cycle_detail.time_records["planner"] = time.time() - planner_started_at + logger.info( + f"{self._runtime.log_prefix} 规划器执行完成: " + f"回合={round_index + 1} " + f"耗时={cycle_detail.time_records['planner']:.3f} 秒" + ) + + reasoning_content = response.content or "" + if self._should_replace_reasoning(reasoning_content): + response.content = "让我根据新情况重新思考:" + response.raw_message.content = "让我根据新情况重新思考:" + logger.info(f"{self._runtime.log_prefix} 当前思考与上一轮过于相似,已替换为重新思考提示") + + self._last_reasoning_content = reasoning_content + self._runtime._chat_history.append(response.raw_message) + + if response.tool_calls: + tool_started_at = time.time() + should_pause = await self._handle_tool_calls( + response.tool_calls, + response.content or "", + anchor_message, + ) + cycle_detail.time_records["tool_calls"] = time.time() - tool_started_at + if should_pause: + break + continue + + if response.content: + continue + + break + except ReqAbortException: + interrupted_at = time.time() + logger.info( + f"{self._runtime.log_prefix} 规划器打断成功: " + f"回合={round_index + 1} " + f"开始时间={planner_started_at:.3f} " + f"打断时间={interrupted_at:.3f} " + f"耗时={interrupted_at - planner_started_at:.3f} 秒" + ) + break + finally: + self._end_cycle(cycle_detail) + finally: + if self._runtime._agent_state == self._runtime._STATE_RUNNING: + self._runtime._agent_state = self._runtime._STATE_STOP + self._runtime._internal_turn_queue.task_done() + except asyncio.CancelledError: + self._runtime._log_internal_loop_cancelled() + raise + except Exception: + logger.exception(f"{self._runtime.log_prefix} Maisaka 内部循环发生异常") + logger.error(traceback.format_exc()) + raise + + def _get_timeout_anchor_message(self) -> Optional[SessionMessage]: + """在 wait 超时后复用最近一条真实用户消息作为锚点。""" + if self._runtime.message_cache: + return self._runtime.message_cache[-1] + return None + + def _build_wait_timeout_message(self) -> ToolResultMessage: + """构造 wait 超时后的工具结果消息。""" + tool_call_id = self._runtime._pending_wait_tool_call_id or "wait_timeout" + self._runtime._pending_wait_tool_call_id = None + return ToolResultMessage( + content="等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。", + timestamp=datetime.now(), + tool_call_id=tool_call_id, + tool_name="wait", + ) + + def _append_wait_interrupted_message_if_needed(self) -> None: + """如果 wait 被新消息打断,则补一条对应的工具结果消息。""" + tool_call_id = self._runtime._pending_wait_tool_call_id + if not tool_call_id: + return + + self._runtime._pending_wait_tool_call_id = None + self._runtime._chat_history.append( + ToolResultMessage( + content="等待过程被新的用户输入打断,已继续处理最新消息。", + timestamp=datetime.now(), + tool_call_id=tool_call_id, + tool_name="wait", + ) + ) + + async def _ingest_messages(self, messages: list[SessionMessage]) -> None: + """处理传入消息列表,将其转换为历史消息并加入聊天历史缓存。""" + for message in messages: + # 构建用户消息序列 + user_sequence, visible_text = await self._build_message_sequence(message) + if not user_sequence.components: + continue + + history_message = SessionBackedMessage.from_session_message( + message, + raw_message=user_sequence, + visible_text=visible_text, + source_kind="user", + ) + self._insert_chat_history_message(history_message) + self._trim_chat_history() + + async def _build_message_sequence(self, message: SessionMessage) -> tuple[MessageSequence, str]: + message_sequence = MessageSequence([]) + planner_prefix = self._build_planner_user_prefix(message) + + appended_component = False + if global_config.maisaka.direct_image_input: + source_sequence = getattr(message, "maisaka_original_raw_message", message.raw_message) + else: + source_sequence = message.raw_message + + planner_components = clone_message_sequence(source_sequence).components + if planner_components and isinstance(planner_components[0], TextComponent): + planner_components[0].text = planner_prefix + planner_components[0].text + else: + planner_components.insert(0, TextComponent(planner_prefix)) + + for component in planner_components: + message_sequence.components.append(component) + appended_component = True + + legacy_visible_text = self._build_legacy_visible_text(message, source_sequence) + if not appended_component: + if not message.processed_plain_text: + await message.process() + content = (message.processed_plain_text or "").strip() + if content: + message_sequence.text(planner_prefix + content) + legacy_visible_text = self._build_legacy_visible_text_from_text(message, content) + + return message_sequence, legacy_visible_text + + @staticmethod + def _build_planner_user_prefix(message: SessionMessage) -> str: + user_info = message.message_info.user_info + timestamp_text = message.timestamp.strftime("%H:%M:%S") + user_name = user_info.user_nickname or user_info.user_id + group_card = user_info.user_cardname or "" + message_id = message.message_id or "" + return ( + f"[时间]{timestamp_text}\n" + f"[用户]{user_name}\n" + f"[用户群昵称]{group_card}\n" + f"[msg_id]{message_id}\n" + "[发言内容]" + ) + + def _build_legacy_visible_text(self, message: SessionMessage, source_sequence: MessageSequence) -> str: + user_info = message.message_info.user_info + speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id + legacy_sequence = MessageSequence([]) + legacy_sequence.text(format_speaker_content(speaker_name, "", message.timestamp, message.message_id)) + for component in clone_message_sequence(source_sequence).components: + legacy_sequence.components.append(component) + return build_visible_text_from_sequence(legacy_sequence).strip() + + def _build_legacy_visible_text_from_text(self, message: SessionMessage, content: str) -> str: + user_info = message.message_info.user_info + speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id + return format_speaker_content(speaker_name, content, message.timestamp, message.message_id).strip() + + def _insert_chat_history_message(self, message: LLMContextMessage) -> int: + """将消息按处理顺序追加到聊天历史末尾。""" + self._runtime._chat_history.append(message) + return len(self._runtime._chat_history) - 1 + + def _start_cycle(self) -> CycleDetail: + """开始一轮 Maisaka 思考循环。""" + self._runtime._cycle_counter += 1 + self._runtime._current_cycle_detail = CycleDetail(cycle_id=self._runtime._cycle_counter) + self._runtime._current_cycle_detail.thinking_id = f"maisaka_tid{round(time.time(), 2)}" + return self._runtime._current_cycle_detail + + def _end_cycle(self, cycle_detail: CycleDetail, only_long_execution: bool = True) -> CycleDetail: + """结束并记录一轮 Maisaka 思考循环。""" + cycle_detail.end_time = time.time() + self._runtime.history_loop.append(cycle_detail) + + timer_strings = [ + f"{name}: {duration:.2f}s" + for name, duration in cycle_detail.time_records.items() + if not only_long_execution or duration >= 0.1 + ] + self._runtime._log_cycle_completed(cycle_detail, timer_strings) + return cycle_detail + + def _trim_chat_history(self) -> None: + """裁剪聊天历史,保证用户消息数量不超过配置限制。""" + conversation_message_count = sum(1 for message in self._runtime._chat_history if message.count_in_context) + if conversation_message_count <= self._runtime._max_context_size: + return + + trimmed_history = list(self._runtime._chat_history) + removed_count = 0 + + while conversation_message_count >= self._runtime._max_context_size and trimmed_history: + removed_message = trimmed_history.pop(0) + removed_count += 1 + if removed_message.count_in_context: + conversation_message_count -= 1 + + self._runtime._chat_history = trimmed_history + self._runtime._log_history_trimmed(removed_count, conversation_message_count) + + @staticmethod + def _calculate_similarity(text1: str, text2: str) -> float: + """计算两个文本之间的相似度。 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + float: 相似度值,范围 0-1,1 表示完全相同 + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + def _should_replace_reasoning(self, current_content: str) -> bool: + """判断是否需要替换推理内容。 + + 当当前推理内容与上一次相似度大于90%时,返回True。 + + Args: + current_content: 当前的推理内容 + + Returns: + bool: 是否需要替换 + """ + if not self._last_reasoning_content or not current_content: + logger.info( + f"{self._runtime.log_prefix} 跳过思考相似度判定: " + f"上一轮为空={not bool(self._last_reasoning_content)} " + f"当前为空={not bool(current_content)} 相似度=0.00" + ) + return False + + similarity = self._calculate_similarity(current_content, self._last_reasoning_content) + logger.info(f"{self._runtime.log_prefix} 思考内容相似度: {similarity:.2f}") + return similarity > 0.9 + + @staticmethod + def _post_process_reply_text(reply_text: str) -> list[str]: + """沿用旧回复链的文本后处理,执行分段与错别字注入。""" + processed_segments: list[str] = [] + for segment in process_llm_response(reply_text): + normalized_segment = segment.strip() + if normalized_segment: + processed_segments.append(normalized_segment) + + if processed_segments: + return processed_segments + return [reply_text.strip()] + + def _build_tool_invocation(self, tool_call: ToolCall, latest_thought: str) -> ToolInvocation: + """将模型输出的工具调用转换为统一调用对象。 + + Args: + tool_call: 模型返回的工具调用。 + latest_thought: 当前轮的最新思考文本。 + + Returns: + ToolInvocation: 统一工具调用对象。 + """ + + return ToolInvocation( + tool_name=tool_call.func_name, + arguments=dict(tool_call.args or {}), + call_id=tool_call.call_id, + session_id=self._runtime.session_id, + stream_id=self._runtime.session_id, + reasoning=latest_thought, + ) + + def _build_tool_execution_context( + self, + latest_thought: str, + anchor_message: SessionMessage, + ) -> ToolExecutionContext: + """构造统一工具执行上下文。 + + Args: + latest_thought: 当前轮的最新思考文本。 + anchor_message: 当前轮的锚点消息。 + + Returns: + ToolExecutionContext: 统一工具执行上下文。 + """ + + return ToolExecutionContext( + session_id=self._runtime.session_id, + stream_id=self._runtime.session_id, + reasoning=latest_thought, + metadata={"anchor_message": anchor_message}, + ) + + @staticmethod + def _normalize_tool_record_value(value: Any) -> Any: + """将工具记录中的任意值规范化为可序列化结构。 + + Args: + value: 原始值。 + + Returns: + Any: 适合写入 JSON 的规范化结果。 + """ + + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item) + return normalized_dict + if isinstance(value, (list, tuple, set)): + return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value] + if isinstance(value, bytes): + return f"" + if hasattr(value, "model_dump"): + try: + return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump()) + except Exception: + return str(value) + if hasattr(value, "__dict__"): + try: + return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__)) + except Exception: + return str(value) + return str(value) + + @staticmethod + def _truncate_tool_record_text(text: str, max_length: int = 180) -> str: + """截断工具记录中的展示文本。 + + Args: + text: 原始文本。 + max_length: 最长保留字符数。 + + Returns: + str: 截断后的文本。 + """ + + normalized_text = text.strip() + if len(normalized_text) <= max_length: + return normalized_text + return f"{normalized_text[: max_length - 1]}…" + + def _build_tool_record_payload( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> dict[str, Any]: + """构造统一工具落库数据。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + + Returns: + dict[str, Any]: 可直接写入数据库的工具记录数据。 + """ + + payload: dict[str, Any] = { + "call_id": invocation.call_id, + "session_id": invocation.session_id, + "stream_id": invocation.stream_id, + "arguments": self._normalize_tool_record_value(invocation.arguments), + "success": result.success, + "content": result.content, + "error_message": result.error_message, + "history_content": result.get_history_content(), + "structured_content": self._normalize_tool_record_value(result.structured_content), + "metadata": self._normalize_tool_record_value(result.metadata), + } + if tool_spec is not None: + payload["provider_name"] = tool_spec.provider_name + payload["provider_type"] = tool_spec.provider_type + payload["brief_description"] = tool_spec.brief_description + payload["detailed_description"] = tool_spec.detailed_description + payload["title"] = tool_spec.title + return payload + + def _build_tool_display_prompt( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> str: + """构造展示给历史回放与 UI 的工具摘要。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + + Returns: + str: 用于展示的工具摘要文本。 + """ + + custom_display_prompt = result.metadata.get("record_display_prompt") + if isinstance(custom_display_prompt, str) and custom_display_prompt.strip(): + return custom_display_prompt.strip() + + structured_content = ( + result.structured_content + if isinstance(result.structured_content, dict) + else {} + ) + history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200) + normalized_args = self._normalize_tool_record_value(invocation.arguments) + + if invocation.tool_name == "reply": + target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方" + reply_text = str(structured_content.get("reply_text") or "").strip() + if result.success and reply_text: + return f"你对{target_user_name}进行了回复:{reply_text}" + target_message_id = str(invocation.arguments.get("msg_id") or "").strip() + error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120) + return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}" + + if invocation.tool_name == "send_emoji": + description = str(structured_content.get("description") or "").strip() + emotion_list = structured_content.get("emotion") + if isinstance(emotion_list, list): + emotion_text = "、".join(str(item).strip() for item in emotion_list if str(item).strip()) + else: + emotion_text = "" + if result.success and description: + if emotion_text: + return f"你发送了表情包:{description}(情绪:{emotion_text})" + return f"你发送了表情包:{description}" + return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}" + + if invocation.tool_name == "wait": + wait_seconds = invocation.arguments.get("seconds", 30) + return f"你让当前对话先等待 {wait_seconds} 秒。" + + if invocation.tool_name == "stop": + return "你暂停了当前对话循环,等待新的外部消息。" + + if invocation.tool_name == "query_jargon": + words = invocation.arguments.get("words", []) + if isinstance(words, list): + words_text = "、".join(str(item).strip() for item in words if str(item).strip()) + else: + words_text = "" + if words_text: + return f"你查询了这些黑话或词条:{words_text}" + return "你查询了一次黑话或词条信息。" + + if invocation.tool_name == "query_person_info": + person_name = str(invocation.arguments.get("person_name") or "").strip() + if person_name: + return f"你查询了人物信息:{person_name}" + return "你查询了一次人物信息。" + + brief_description = "" + if tool_spec is not None: + brief_description = tool_spec.brief_description.strip() + + if normalized_args: + arguments_text = self._truncate_tool_record_text( + json.dumps(normalized_args, ensure_ascii=False), + max_length=160, + ) + else: + arguments_text = "{}" + + if result.success: + if brief_description: + return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}" + return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}" + + error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160) + return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}" + + async def _store_tool_execution_record( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> None: + """将工具执行结果落库到统一工具记录表。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + """ + + if self._runtime.chat_stream is None: + logger.debug( + f"{self._runtime.log_prefix} 当前没有 chat_stream,跳过工具记录存储: " + f"工具={invocation.tool_name}" + ) + return + + builtin_prompt = "" + if tool_spec is not None: + builtin_prompt = tool_spec.build_llm_description() + + try: + await database_api.store_tool_info( + chat_stream=self._runtime.chat_stream, + builtin_prompt=builtin_prompt, + display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec), + tool_id=invocation.call_id, + tool_data=self._build_tool_record_payload(invocation, result, tool_spec), + tool_name=invocation.tool_name, + tool_reasoning=invocation.reasoning, + ) + except Exception: + logger.exception( + f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}" + ) + + def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None: + """将统一工具执行结果写回 Maisaka 历史。 + + Args: + tool_call: 原始工具调用对象。 + result: 统一工具执行结果。 + """ + + history_content = result.get_history_content() + if not history_content: + history_content = "工具执行成功。" if result.success else f"工具 {tool_call.func_name} 执行失败。" + + self._runtime._chat_history.append( + ToolResultMessage( + content=history_content, + timestamp=datetime.now(), + tool_call_id=tool_call.call_id, + tool_name=tool_call.func_name, + success=result.success, + ) + ) + + @staticmethod + def _build_tool_call_from_invocation(invocation: ToolInvocation) -> ToolCall: + """将统一工具调用对象恢复为 `ToolCall` 兼容对象。 + + Args: + invocation: 统一工具调用对象。 + + Returns: + ToolCall: 兼容旧内部逻辑的工具调用对象。 + """ + + return ToolCall( + call_id=invocation.call_id or f"{invocation.tool_name}_call", + func_name=invocation.tool_name, + args=dict(invocation.arguments), + ) + + @staticmethod + def _build_tool_success_result( + tool_name: str, + content: str = "", + structured_content: Any = None, + metadata: Optional[dict[str, Any]] = None, + ) -> ToolExecutionResult: + """构造统一工具成功结果。 + + Args: + tool_name: 工具名称。 + content: 结果文本。 + structured_content: 结构化结果。 + metadata: 附加元数据。 + + Returns: + ToolExecutionResult: 统一工具成功结果。 + """ + + return ToolExecutionResult( + tool_name=tool_name, + success=True, + content=content, + structured_content=structured_content, + metadata=dict(metadata or {}), + ) + + @staticmethod + def _build_tool_failure_result( + tool_name: str, + error_message: str, + structured_content: Any = None, + metadata: Optional[dict[str, Any]] = None, + ) -> ToolExecutionResult: + """构造统一工具失败结果。 + + Args: + tool_name: 工具名称。 + error_message: 错误信息。 + structured_content: 结构化结果。 + metadata: 附加元数据。 + + Returns: + ToolExecutionResult: 统一工具失败结果。 + """ + + return ToolExecutionResult( + tool_name=tool_name, + success=False, + error_message=error_message, + structured_content=structured_content, + metadata=dict(metadata or {}), + ) + + async def _invoke_reply_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 reply 内置工具。""" + + latest_thought = context.reasoning if context is not None else invocation.reasoning + return await self._handle_reply(self._build_tool_call_from_invocation(invocation), latest_thought) + + async def _invoke_no_reply_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 no_reply 内置工具。""" + + del context + return self._build_tool_success_result(invocation.tool_name, "本轮未发送可见回复。") + + async def _invoke_query_jargon_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 query_jargon 内置工具。""" + + del context + return await self._handle_query_jargon(self._build_tool_call_from_invocation(invocation)) + + async def _invoke_query_person_info_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 query_person_info 内置工具。""" + + del context + return await self._handle_query_person_info(self._build_tool_call_from_invocation(invocation)) + + async def _invoke_wait_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 wait 内置工具。""" + + del context + seconds = invocation.arguments.get("seconds", 30) + try: + wait_seconds = int(seconds) + except (TypeError, ValueError): + wait_seconds = 30 + wait_seconds = max(0, wait_seconds) + self._runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id) + return self._build_tool_success_result( + invocation.tool_name, + f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。", + metadata={"pause_execution": True}, + ) + + async def _invoke_stop_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 stop 内置工具。""" + + del context + self._runtime._enter_stop_state() + return self._build_tool_success_result( + invocation.tool_name, + "当前对话循环已暂停,等待新消息到来。", + metadata={"pause_execution": True}, + ) + + async def _invoke_send_emoji_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行 send_emoji 内置工具。""" + + del context + return await self._handle_send_emoji(self._build_tool_call_from_invocation(invocation)) + + async def _handle_tool_calls( + self, + tool_calls: list[ToolCall], + latest_thought: str, + anchor_message: SessionMessage, + ) -> bool: + """执行一批统一工具调用。 + + Args: + tool_calls: 模型返回的工具调用列表。 + latest_thought: 当前轮的最新思考文本。 + anchor_message: 当前轮的锚点消息。 + + Returns: + bool: 是否需要暂停当前思考循环。 + """ + + if self._runtime._tool_registry is None: + for tool_call in tool_calls: + invocation = self._build_tool_invocation(tool_call, latest_thought) + result = ToolExecutionResult( + tool_name=tool_call.func_name, + success=False, + error_message="统一工具注册表尚未初始化。", + ) + await self._store_tool_execution_record(invocation, result, None) + self._append_tool_execution_result(tool_call, result) + return False + + execution_context = self._build_tool_execution_context(latest_thought, anchor_message) + tool_spec_map = { + tool_spec.name: tool_spec + for tool_spec in await self._runtime._tool_registry.list_tools() + } + for tool_call in tool_calls: + invocation = self._build_tool_invocation(tool_call, latest_thought) + result = await self._runtime._tool_registry.invoke(invocation, execution_context) + await self._store_tool_execution_record( + invocation, + result, + tool_spec_map.get(invocation.tool_name), + ) + self._append_tool_execution_result(tool_call, result) + + if not result.success and tool_call.func_name == "reply": + logger.warning(f"{self._runtime.log_prefix} 回复工具未生成可见消息,将继续下一轮循环") + + if bool(result.metadata.get("pause_execution", False)): + return True + + return False + + async def _handle_query_jargon(self, tool_call: ToolCall) -> ToolExecutionResult: + """查询黑话解释并返回统一工具结果。 + + Args: + tool_call: 当前工具调用。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + + tool_args = tool_call.args or {} + raw_words = tool_args.get("words") + + if not isinstance(raw_words, list): + return self._build_tool_failure_result( + tool_call.func_name, + "查询黑话工具需要提供 `words` 数组参数。", + ) + + words: list[str] = [] + seen_words: set[str] = set() + for item in raw_words: + if not isinstance(item, str): + continue + word = item.strip() + if not word or word in seen_words: + continue + seen_words.add(word) + words.append(word) + + if not words: + return self._build_tool_failure_result( + tool_call.func_name, + "查询黑话工具至少需要一个非空词条。", + ) + + logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}") + + results: list[dict[str, object]] = [] + for word in words: + exact_matches = search_jargon( + keyword=word, + chat_id=self._runtime.session_id, + limit=5, + case_sensitive=False, + fuzzy=False, + ) + matched_entries = exact_matches or search_jargon( + keyword=word, + chat_id=self._runtime.session_id, + limit=5, + case_sensitive=False, + fuzzy=True, + ) + + results.append( + { + "word": word, + "found": bool(matched_entries), + "matches": matched_entries, + } + ) + + logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}") + return self._build_tool_success_result( + tool_call.func_name, + json.dumps({"results": results}, ensure_ascii=False), + structured_content={"results": results}, + ) + + async def _handle_query_person_info(self, tool_call: ToolCall) -> ToolExecutionResult: + """查询指定人物的档案和相关知识。 + + Args: + tool_call: 当前工具调用。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + + tool_args = tool_call.args or {} + raw_person_name = tool_args.get("person_name") + raw_limit = tool_args.get("limit", 3) + + if not isinstance(raw_person_name, str): + return self._build_tool_failure_result( + tool_call.func_name, + "查询人物信息工具需要提供字符串类型的 `person_name` 参数。", + ) + + person_name = raw_person_name.strip() + if not person_name: + return self._build_tool_failure_result( + tool_call.func_name, + "查询人物信息工具需要提供非空的 `person_name` 参数。", + ) + + try: + limit = max(1, min(int(raw_limit), 10)) + except (TypeError, ValueError): + limit = 3 + + logger.info( + f"{self._runtime.log_prefix} 已触发人物信息查询: " + f"人物名={person_name!r} 限制条数={limit}" + ) + + persons = self._query_person_records(person_name, limit) + result = { + "query": person_name, + "persons": persons, + "related_knowledge": self._query_related_knowledge(person_name, persons, limit), + } + + logger.info( + f"{self._runtime.log_prefix} 人物信息查询完成: " + f"人物记录数={len(result['persons'])} 相关知识数={len(result['related_knowledge'])}" + ) + return self._build_tool_success_result( + tool_call.func_name, + json.dumps(result, ensure_ascii=False), + structured_content=result, + ) + + def _query_person_records(self, person_name: str, limit: int) -> list[dict[str, Any]]: + """按名称、昵称或用户 ID 查询人物档案。""" + with get_db_session() as session: + records = session.exec( + select(PersonInfo) + .where( + col(PersonInfo.person_name).contains(person_name) + | col(PersonInfo.user_nickname).contains(person_name) + | col(PersonInfo.user_id).contains(person_name) + ) + .order_by(col(PersonInfo.last_known_time).desc(), col(PersonInfo.id).desc()) + .limit(limit) + ).all() + + persons: list[dict[str, Any]] = [] + for record in records: + memory_points: list[str] = [] + if record.memory_points: + try: + parsed_points = json.loads(record.memory_points) + if isinstance(parsed_points, list): + memory_points = [str(point).strip() for point in parsed_points if str(point).strip()] + except (json.JSONDecodeError, TypeError, ValueError): + memory_points = [] + + persons.append( + { + "person_id": record.person_id, + "person_name": record.person_name or "", + "user_nickname": record.user_nickname, + "user_id": record.user_id, + "platform": record.platform, + "name_reason": record.name_reason or "", + "is_known": record.is_known, + "know_counts": record.know_counts, + "memory_points": memory_points[:20], + "last_known_time": ( + record.last_known_time.isoformat() if record.last_known_time is not None else None + ), + } + ) + + return persons + + def _query_related_knowledge( + self, + person_name: str, + persons: list[dict[str, Any]], + limit: int, + ) -> list[dict[str, Any]]: + """从 Maisaka knowledge 中补充检索与该人物相关的条目。""" + store = get_knowledge_store() + knowledge_items: list[dict[str, Any]] = [] + seen_ids: set[str] = set() + + for person in persons: + matched_items = store.get_knowledge_by_user( + platform=str(person.get("platform", "")).strip(), + user_id=str(person.get("user_id", "")).strip(), + user_nickname=str(person.get("user_nickname", "")).strip(), + person_name=str(person.get("person_name", "")).strip(), + limit=max(limit, 5), + ) + for item in matched_items: + item_id = str(item.get("id", "")).strip() + if item_id and item_id in seen_ids: + continue + if item_id: + seen_ids.add(item_id) + knowledge_items.append(item) + + if not knowledge_items: + fallback_items = store.search_knowledge(person_name, limit=max(limit, 5)) + for item in fallback_items: + item_id = str(item.get("id", "")).strip() + if item_id and item_id in seen_ids: + continue + if item_id: + seen_ids.add(item_id) + knowledge_items.append(item) + + results: list[dict[str, Any]] = [] + for item in knowledge_items: + results.append( + { + "id": str(item.get("id", "")).strip(), + "category_id": str(item.get("category_id", "")).strip(), + "category_name": str(item.get("category_name", "")).strip(), + "content": str(item.get("content", "")).strip(), + "metadata": item.get("metadata", {}), + "created_at": item.get("created_at"), + } + ) + return results + + async def _handle_reply( + self, + tool_call: ToolCall, + latest_thought: str, + ) -> ToolExecutionResult: + """执行 reply 工具并生成可见回复。 + + Args: + tool_call: 当前工具调用。 + latest_thought: 当前轮的最新思考文本。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + + tool_args = tool_call.args or {} + target_message_id = str(tool_args.get("msg_id") or "").strip() + quote_reply = bool(tool_args.get("quote", True)) + raw_unknown_words = tool_args.get("unknown_words") + unknown_words = raw_unknown_words if isinstance(raw_unknown_words, list) else None + if not target_message_id: + return self._build_tool_failure_result( + tool_call.func_name, + "回复工具需要提供有效的 `msg_id` 参数。", + ) + + target_message = self._runtime._source_messages_by_id.get(target_message_id) + if target_message is None: + return self._build_tool_failure_result( + tool_call.func_name, + f"未找到要回复的目标消息,msg_id={target_message_id}", + ) + + logger.info( + f"{self._runtime.log_prefix} 已触发回复工具: " + f"目标消息编号={target_message_id} 引用回复={quote_reply} 最新思考={latest_thought!r}" + ) + logger.info(f"{self._runtime.log_prefix} 正在获取 Maisaka 回复生成器") + try: + replyer = replyer_manager.get_replyer( + chat_stream=self._runtime.chat_stream, + request_type="maisaka_replyer", + replyer_type="maisaka", + ) + except Exception: + logger.exception( + f"{self._runtime.log_prefix} 获取回复生成器时发生异常: " + f"目标消息编号={target_message_id}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + "获取 Maisaka 回复生成器时发生异常。", + ) + + if replyer is None: + logger.error(f"{self._runtime.log_prefix} 获取 Maisaka 回复生成器失败") + return self._build_tool_failure_result( + tool_call.func_name, + "Maisaka 回复生成器当前不可用。", + ) + + from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator + + replyer = cast(MaisakaReplyGenerator, replyer) + logger.info(f"{self._runtime.log_prefix} 已成功获取 Maisaka 回复生成器") + + logger.info(f"{self._runtime.log_prefix} 正在调用回复生成接口: 目标消息编号={target_message_id}") + try: + success, reply_result = await replyer.generate_reply_with_context( + reply_reason=latest_thought, + stream_id=self._runtime.session_id, + reply_message=target_message, + chat_history=self._runtime._chat_history, + unknown_words=unknown_words, + log_reply=False, + ) + except Exception as exc: + import traceback + logger.error( + f"{self._runtime.log_prefix} 回复生成器执行异常: 目标消息编号={target_message_id} " + f"异常类型={type(exc).__name__} 异常信息={str(exc)}\n{traceback.format_exc()}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + "生成可见回复时发生异常。", + ) + + logger.info( + f"{self._runtime.log_prefix} 回复生成完成: " + f"成功={success} 回复文本={reply_result.completion.response_text!r} " + f"错误信息={reply_result.error_message!r}" + ) + reply_text = reply_result.completion.response_text.strip() if success else "" + if not reply_text: + logger.warning( + f"{self._runtime.log_prefix} 回复生成器返回空文本: " + f"目标消息编号={target_message_id} 错误信息={reply_result.error_message!r}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + "生成可见回复失败。", + ) + + reply_segments = self._post_process_reply_text(reply_text) + combined_reply_text = "".join(reply_segments) + logger.info( + f"{self._runtime.log_prefix} 回复后处理完成: " + f"目标消息编号={target_message_id} 分段数={len(reply_segments)} " + f"分段内容={reply_segments!r}" + ) + + logger.info( + f"{self._runtime.log_prefix} 正在发送引导回复: " + f"目标消息编号={target_message_id} 引用回复={quote_reply} 回复分段={reply_segments!r}" + ) + try: + sent = False + for index, segment in enumerate(reply_segments): + sent = await send_service.text_to_stream( + text=segment, + stream_id=self._runtime.session_id, + set_reply=quote_reply if index == 0 else False, + reply_message=target_message if quote_reply and index == 0 else None, + selected_expressions=reply_result.selected_expression_ids or None, + typing=index > 0, + ) + if not sent: + break + except Exception: + logger.exception( + f"{self._runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + "发送可见回复时发生异常。", + ) + + logger.info( + f"{self._runtime.log_prefix} 引导回复发送结果: " + f"目标消息编号={target_message_id} 发送成功={sent}" + ) + if not sent: + return self._build_tool_failure_result( + tool_call.func_name, + "可见回复生成成功,但发送失败。", + structured_content={ + "msg_id": target_message_id, + "quote": quote_reply, + "reply_segments": reply_segments, + }, + ) + + target_user_info = target_message.message_info.user_info + target_user_name = ( + target_user_info.user_cardname + or target_user_info.user_nickname + or target_user_info.user_id + ) + + bot_name = global_config.bot.nickname.strip() or "MaiSaka" + reply_timestamp = datetime.now() + planner_prefix = ( + f"[时间]{reply_timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{bot_name}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" + ) + history_message = SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{combined_reply_text}")]), + visible_text="", + timestamp=reply_timestamp, + source_kind="guided_reply", + ) + visible_reply_text = format_speaker_content( + bot_name, + combined_reply_text, + reply_timestamp, + ) + history_message.visible_text = visible_reply_text + self._runtime._chat_history.append(history_message) + return self._build_tool_success_result( + tool_call.func_name, + "可见回复已生成并发送。", + structured_content={ + "msg_id": target_message_id, + "quote": quote_reply, + "reply_text": combined_reply_text, + "reply_segments": reply_segments, + "target_user_name": target_user_name, + }, + ) + + async def _handle_send_emoji(self, tool_call: ToolCall) -> ToolExecutionResult: + """处理发送表情包的工具调用。 + + Args: + tool_call: 工具调用对象。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + from src.chat.emoji_system.emoji_manager import emoji_manager + from src.common.utils.utils_image import ImageUtils + import random + + tool_args = tool_call.args or {} + emotion = str(tool_args.get("emotion") or "").strip() + + logger.info(f"{self._runtime.log_prefix} 已触发表情包发送工具: 情绪={emotion!r}") + + # 获取表情包列表 + if not emoji_manager.emojis: + return self._build_tool_failure_result( + tool_call.func_name, + "当前表情包库中没有可用表情。", + ) + + # 根据情感选择表情包 + selected_emoji = None + if emotion: + # 尝试找到匹配情感的表情包 + matching_emojis = [ + emoji for emoji in emoji_manager.emojis + if emotion.lower() in (e.lower() for e in emoji.emotion) + ] + if matching_emojis: + selected_emoji = random.choice(matching_emojis) + logger.info( + f"{self._runtime.log_prefix} 找到 {len(matching_emojis)} 个匹配情绪 {emotion!r} 的表情包," + f"已选择:{selected_emoji.description}" + ) + + # 如果没有找到匹配的情感表情包,随机选择一个 + if selected_emoji is None: + selected_emoji = random.choice(emoji_manager.emojis) + logger.info( + f"{self._runtime.log_prefix} 没有表情包匹配情绪 {emotion!r}," + f"已随机选择:{selected_emoji.description}" + ) + + # 更新表情包使用次数 + emoji_manager.update_emoji_usage(selected_emoji) + + # 获取表情包的 base64 数据 + try: + emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path)) + if not emoji_base64: + raise ValueError("表情图片转换为 base64 失败") + except Exception as exc: + logger.error( + f"{self._runtime.log_prefix} 表情图片转换为 base64 失败: {exc}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + f"发送表情包失败:{exc}", + ) + + # 发送表情包 + try: + sent = await send_service.emoji_to_stream( + emoji_base64=emoji_base64, + stream_id=self._runtime.session_id, + storage_message=True, + set_reply=False, + reply_message=None, + ) + except Exception as exc: + logger.exception( + f"{self._runtime.log_prefix} 发送表情包时发生异常: {exc}" + ) + return self._build_tool_failure_result( + tool_call.func_name, + f"发送表情包时发生异常:{exc}", + ) + + if sent: + logger.info( + f"{self._runtime.log_prefix} 表情包发送成功: " + f"描述={selected_emoji.description!r} 情绪标签={selected_emoji.emotion}" + ) + return self._build_tool_success_result( + tool_call.func_name, + f"已发送表情包:{selected_emoji.description}(情绪:{', '.join(selected_emoji.emotion)})", + structured_content={ + "description": selected_emoji.description, + "emotion": list(selected_emoji.emotion), + }, + ) + logger.warning(f"{self._runtime.log_prefix} 表情包发送失败") + return self._build_tool_failure_result( + tool_call.func_name, + "发送表情包失败。", + ) diff --git a/src/maisaka/replyer.py b/src/maisaka/replyer.py deleted file mode 100644 index 9483f2ab..00000000 --- a/src/maisaka/replyer.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -MaiSaka reply helper. -""" - -from typing import Optional - -from src.common.data_models.mai_message_data_model import MaiMessage -from src.config.config import global_config - -from .config import USER_NAME -from .llm_service import MaiSakaLLMService -from .message_adapter import get_message_role, get_message_text, is_perception_message, parse_speaker_content - - -def _normalize_content(content: str, limit: int = 500) -> str: - normalized = " ".join((content or "").split()) - if len(normalized) > limit: - return normalized[:limit] + "..." - return normalized - - -def _format_message_time(message: MaiMessage) -> str: - return message.timestamp.strftime("%H:%M:%S") - - -def _extract_visible_assistant_reply(message: MaiMessage) -> str: - if is_perception_message(message): - return "" - return "" - - -def _extract_guided_bot_reply(message: MaiMessage) -> str: - speaker_name, body = parse_speaker_content(get_message_text(message).strip()) - bot_nickname = global_config.bot.nickname.strip() or "Bot" - if speaker_name == bot_nickname: - return _normalize_content(body.strip()) - return "" - - -def format_chat_history(messages: list[MaiMessage]) -> str: - """Format visible chat history for reply generation.""" - bot_nickname = global_config.bot.nickname.strip() or "Bot" - parts: list[str] = [] - - for message in messages: - role = get_message_role(message) - timestamp = _format_message_time(message) - - if role == "user": - guided_reply = _extract_guided_bot_reply(message) - if guided_reply: - parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}") - continue - - _, content_body = parse_speaker_content(get_message_text(message)) - content = _normalize_content(content_body) - if content: - parts.append(f"{timestamp} {USER_NAME}: {content}") - continue - - if role == "assistant": - visible_reply = _extract_visible_assistant_reply(message) - if visible_reply: - parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}") - - return "\n".join(parts) - - -class Replyer: - """Generate visible replies from thoughts and context.""" - - def __init__(self, llm_service: Optional[MaiSakaLLMService] = None): - self._llm_service = llm_service - self._enabled = True - - def set_llm_service(self, llm_service: MaiSakaLLMService) -> None: - self._llm_service = llm_service - - def set_enabled(self, enabled: bool) -> None: - self._enabled = enabled - - async def reply(self, reason: str, chat_history: list[MaiMessage]) -> str: - if not self._enabled or not reason or self._llm_service is None: - return "..." - - return await self._llm_service.generate_reply(reason, chat_history) diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py new file mode 100644 index 00000000..21c03a06 --- /dev/null +++ b/src/maisaka/runtime.py @@ -0,0 +1,456 @@ +"""Maisaka 非 CLI 运行时。""" + +from typing import Literal, Optional + +import asyncio +import time + +from src.chat.heart_flow.heartFC_utils import CycleDetail +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo +from src.common.logger import get_logger +from src.common.utils.utils_config import ExpressionConfigUtils +from src.config.config import global_config +from src.core.tooling import ToolRegistry +from src.know_u.knowledge import KnowledgeLearner +from src.learners.expression_learner import ExpressionLearner +from src.learners.jargon_miner import JargonMiner +from src.mcp_module import MCPManager +from src.mcp_module.host_llm_bridge import MCPHostLLMBridge +from src.mcp_module.provider import MCPToolProvider +from src.plugin_runtime.tool_provider import PluginToolProvider + +from .chat_loop_service import MaisakaChatLoopService +from .context_messages import LLMContextMessage +from .reasoning_engine import MaisakaReasoningEngine +from .tool_provider import MaisakaBuiltinToolProvider + +logger = get_logger("maisaka_runtime") + + +class MaisakaHeartFlowChatting: + """会话级别的 Maisaka 运行时。""" + + _STATE_RUNNING: Literal["running"] = "running" + _STATE_WAIT: Literal["wait"] = "wait" + _STATE_STOP: Literal["stop"] = "stop" + + def __init__(self, session_id: str): + self.session_id = session_id + chat_stream = chat_manager.get_session_by_session_id(session_id) + if chat_stream is None: + raise ValueError(f"未找到会话 {session_id} 对应的 Maisaka 运行时") + self.chat_stream: BotChatSession = chat_stream + + session_name = chat_manager.get_session_name(session_id) or session_id + self.log_prefix = f"[{session_name}]" + self._chat_loop_service = MaisakaChatLoopService() + self._chat_history: list[LLMContextMessage] = [] + self.history_loop: list[CycleDetail] = [] + + # Keep all original messages for batching and later learning. + self.message_cache: list[SessionMessage] = [] + self._last_processed_index = 0 + self._internal_turn_queue: asyncio.Queue[Optional[list[SessionMessage]]] = asyncio.Queue() + + self._mcp_manager: Optional[MCPManager] = None + self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None + self._current_cycle_detail: Optional[CycleDetail] = None + self._source_messages_by_id: dict[str, SessionMessage] = {} + self._running = False + self._cycle_counter = 0 + self._internal_loop_task: Optional[asyncio.Task] = None + self._loop_task: Optional[asyncio.Task] = None + self._new_message_event = asyncio.Event() + self._max_internal_rounds = global_config.maisaka.max_internal_rounds + self._max_context_size = max(1, int(global_config.chat.max_context_size)) + self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP + self._wait_until: Optional[float] = None + self._pending_wait_tool_call_id: Optional[str] = None + self._planner_interrupt_flag: Optional[asyncio.Event] = None + + expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id) + self._enable_expression_use = expr_use + self._enable_expression_learning = expr_learn + self._enable_jargon_learning = jargon_learn + self._min_messages_for_extraction = 10 + self._min_extraction_interval = 30 + self._last_expression_extraction_time = 0.0 + self._last_knowledge_extraction_time = 0.0 + self._expression_learner = ExpressionLearner(session_id) + self._jargon_miner = JargonMiner(session_id, session_name=session_name) + self._knowledge_learner = KnowledgeLearner(session_id) + + self._reasoning_engine = MaisakaReasoningEngine(self) + self._tool_registry = ToolRegistry() + self._register_tool_providers() + + async def start(self) -> None: + """启动运行时主循环。""" + if self._running: + self._ensure_background_tasks_running() + return + + if global_config.mcp.enable: + await self._init_mcp() + + self._running = True + self._ensure_background_tasks_running() + logger.info(f"{self.log_prefix} Maisaka 运行时已启动") + + async def stop(self) -> None: + """停止运行时主循环。""" + if not self._running: + return + + self._running = False + self._new_message_event.set() + while not self._internal_turn_queue.empty(): + _ = self._internal_turn_queue.get_nowait() + + if self._loop_task is not None: + self._loop_task.cancel() + try: + await self._loop_task + except asyncio.CancelledError: + pass + finally: + self._loop_task = None + + if self._internal_loop_task is not None: + self._internal_loop_task.cancel() + try: + await self._internal_loop_task + except asyncio.CancelledError: + pass + finally: + self._internal_loop_task = None + + await self._tool_registry.close() + self._mcp_manager = None + self._mcp_host_bridge = None + + logger.info(f"{self.log_prefix} Maisaka 运行时已停止") + + def adjust_talk_frequency(self, frequency: float) -> None: + """兼容现有管理器接口的占位方法。""" + _ = frequency + + async def register_message(self, message: SessionMessage) -> None: + """缓存一条新消息并唤醒主循环。""" + if self._running: + self._ensure_background_tasks_running() + self.message_cache.append(message) + self._source_messages_by_id[message.message_id] = message + if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None: + logger.info( + f"{self.log_prefix} 收到新消息,发起规划器打断; " + f"消息编号={message.message_id} 缓存条数={len(self.message_cache)} " + f"时间戳={time.time():.3f}" + ) + self._planner_interrupt_flag.set() + if self._agent_state in (self._STATE_WAIT, self._STATE_STOP): + self._agent_state = self._STATE_RUNNING + self._new_message_event.set() + + def _ensure_background_tasks_running(self) -> None: + """确保后台任务仍在运行,若崩溃则自动拉起。""" + if not self._running: + return + + if self._internal_loop_task is None or self._internal_loop_task.done(): + if self._internal_loop_task is not None and not self._internal_loop_task.cancelled(): + try: + exc = self._internal_loop_task.exception() + except Exception: + exc = None + if exc is not None: + logger.error(f"{self.log_prefix} 内部循环任务异常退出: {exc}") + self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop()) + logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 内部循环任务") + + if self._loop_task is None or self._loop_task.done(): + if self._loop_task is not None and not self._loop_task.cancelled(): + try: + exc = self._loop_task.exception() + except Exception: + exc = None + if exc is not None: + logger.error(f"{self.log_prefix} 主循环任务异常退出: {exc}") + self._loop_task = asyncio.create_task(self._main_loop()) + logger.warning(f"{self.log_prefix} 已重新拉起 Maisaka 主循环任务") + + def _register_tool_providers(self) -> None: + """注册 Maisaka 运行时默认启用的工具 Provider。""" + + self._tool_registry.register_provider( + MaisakaBuiltinToolProvider(self._reasoning_engine.build_builtin_tool_handlers()) + ) + self._tool_registry.register_provider(PluginToolProvider()) + self._chat_loop_service.set_tool_registry(self._tool_registry) + + async def _main_loop(self) -> None: + try: + while self._running: + if not self._has_pending_messages(): + if self._agent_state == self._STATE_WAIT: + trigger_reason = await self._wait_for_trigger() + else: + self._new_message_event.clear() + await self._new_message_event.wait() + trigger_reason: Literal["message", "timeout", "stop"] = "message" if self._running else "stop" + else: + trigger_reason = "message" + + if not self._running: + return + if trigger_reason == "stop": + self._agent_state = self._STATE_STOP + continue + + self._new_message_event.clear() + + if trigger_reason == "timeout": + # 等待超时后继续下一轮内部思考,但不要重复注入旧消息。 + logger.info(f"{self.log_prefix} 等待超时后已投递继续思考触发信号") + await self._internal_turn_queue.put(None) + continue + + while self._has_pending_messages(): + cached_messages = self._collect_pending_messages() + if not cached_messages: + break + await self._internal_turn_queue.put(cached_messages) + asyncio.create_task(self._trigger_batch_learning(cached_messages)) + except asyncio.CancelledError: + logger.info(f"{self.log_prefix} Maisaka 运行时主循环已取消") + + def _has_pending_messages(self) -> bool: + return self._last_processed_index < len(self.message_cache) + + def _collect_pending_messages(self) -> list[SessionMessage]: + """从消息缓存中收集一批尚未处理的消息。""" + start_index = self._last_processed_index + pending_messages = self.message_cache[start_index:] + if not pending_messages: + return [] + + unique_messages: list[SessionMessage] = [] + seen_message_ids: set[str] = set() + for message in pending_messages: + message_id = message.message_id + if message_id in seen_message_ids: + continue + seen_message_ids.add(message_id) + unique_messages.append(message) + + self._last_processed_index = len(self.message_cache) + logger.info( + f"{self.log_prefix} 已从消息缓存区[{start_index}:{self._last_processed_index}] " + f"收集 {len(unique_messages)} 条新消息" + ) + return unique_messages + + async def _wait_for_trigger(self) -> Literal["message", "timeout", "stop"]: + """等待 wait 状态的触发结果。""" + if self._agent_state != self._STATE_WAIT: + await self._new_message_event.wait() + return "message" + + if self._wait_until is None: + await self._new_message_event.wait() + return "message" + + timeout = self._wait_until - time.time() + if timeout <= 0: + logger.info(f"{self.log_prefix} Maisaka 等待已超时") + self._agent_state = self._STATE_RUNNING + self._wait_until = None + return "timeout" + + try: + await asyncio.wait_for(self._new_message_event.wait(), timeout=timeout) + return "message" + except asyncio.TimeoutError: + logger.info(f"{self.log_prefix} Maisaka 等待已超时") + self._agent_state = self._STATE_RUNNING + self._wait_until = None + return "timeout" + + def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None: + """切换到等待状态。""" + self._agent_state = self._STATE_WAIT + self._wait_until = None if seconds is None else time.time() + seconds + self._pending_wait_tool_call_id = tool_call_id + + def _enter_stop_state(self) -> None: + """切换到停止状态。""" + self._agent_state = self._STATE_STOP + self._wait_until = None + self._pending_wait_tool_call_id = None + + async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None: + """按同一批消息触发表达方式、黑话和 knowledge 学习。""" + expression_result, knowledge_result = await asyncio.gather( + self._trigger_expression_learning(messages), + self._trigger_knowledge_learning(messages), + return_exceptions=True, + ) + if isinstance(expression_result, Exception): + logger.error(f"{self.log_prefix} 表达学习任务异常退出: {expression_result}") + if isinstance(knowledge_result, Exception): + logger.error(f"{self.log_prefix} 知识学习任务异常退出: {knowledge_result}") + + async def _trigger_expression_learning(self, messages: list[SessionMessage]) -> None: + """基于新收集的一批消息触发表达学习。""" + self._expression_learner.add_messages(messages) + + if not self._enable_expression_learning: + logger.debug(f"{self.log_prefix} 表达学习未启用,跳过当前批次") + return + + elapsed = time.time() - self._last_expression_extraction_time + if elapsed < self._min_extraction_interval: + logger.debug( + f"{self.log_prefix} 表达学习尚未达到触发间隔: " + f"已过={elapsed:.2f} 秒 阈值={self._min_extraction_interval} 秒" + ) + return + + cache_size = self._expression_learner.get_cache_size() + if cache_size < self._min_messages_for_extraction: + logger.debug( + f"{self.log_prefix} 表达学习因缓存数量不足而跳过: " + f"学习器缓存={cache_size} 阈值={self._min_messages_for_extraction} " + f"消息总缓存={len(self.message_cache)}" + ) + return + + self._last_expression_extraction_time = time.time() + logger.info( + f"{self.log_prefix} 开始表达学习: " + f"新批次消息数={len(messages)} 学习器缓存={cache_size} " + f"消息总缓存={len(self.message_cache)} " + f"启用黑话学习={self._enable_jargon_learning}" + ) + + try: + jargon_miner = self._jargon_miner if self._enable_jargon_learning else None + learnt_style = await self._expression_learner.learn(jargon_miner) + if learnt_style: + logger.info(f"{self.log_prefix} 表达学习已完成") + else: + logger.debug(f"{self.log_prefix} 表达学习已完成,但没有可用结果") + except Exception: + logger.exception(f"{self.log_prefix} 表达学习失败") + + async def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None: + """基于新收集的一批消息触发知识学习。""" + self._knowledge_learner.add_messages(messages) + + if not global_config.maisaka.enable_knowledge_module: + logger.debug(f"{self.log_prefix} 知识学习未启用,跳过当前批次") + return + + elapsed = time.time() - self._last_knowledge_extraction_time + if elapsed < self._min_extraction_interval: + logger.debug( + f"{self.log_prefix} 知识学习尚未达到触发间隔: " + f"已过={elapsed:.2f} 秒 阈值={self._min_extraction_interval} 秒" + ) + return + + cache_size = self._knowledge_learner.get_cache_size() + if cache_size < self._min_messages_for_extraction: + logger.debug( + f"{self.log_prefix} 知识学习因缓存数量不足而跳过: " + f"学习器缓存={cache_size} 阈值={self._min_messages_for_extraction} " + f"消息总缓存={len(self.message_cache)}" + ) + return + + self._last_knowledge_extraction_time = time.time() + logger.info( + f"{self.log_prefix} 开始知识学习: " + f"新批次消息数={len(messages)} 学习器缓存={cache_size} " + f"消息总缓存={len(self.message_cache)}" + ) + + try: + added_count = await self._knowledge_learner.learn() + if added_count > 0: + logger.info(f"{self.log_prefix} 知识学习已完成: 新增条目数={added_count}") + else: + logger.debug(f"{self.log_prefix} 知识学习已完成,但没有可用结果") + except Exception: + logger.exception(f"{self.log_prefix} 知识学习失败") + + async def _init_mcp(self) -> None: + """初始化 MCP 工具并注册到统一工具层。""" + self._mcp_host_bridge = MCPHostLLMBridge( + sampling_task_name=global_config.mcp.client.sampling.task_name, + ) + self._mcp_manager = await MCPManager.from_app_config( + global_config.mcp, + host_callbacks=self._mcp_host_bridge.build_callbacks(), + ) + if self._mcp_manager is None: + logger.info(f"{self.log_prefix} MCP 管理器不可用") + return + + mcp_tool_specs = self._mcp_manager.get_tool_specs() + if not mcp_tool_specs: + logger.info(f"{self.log_prefix} 没有可供 Maisaka 使用的 MCP 工具") + return + + self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager)) + logger.info( + f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具。\n" + f"{self._mcp_manager.get_feature_summary()}" + ) + + def _build_runtime_user_info(self) -> UserInfo: + if self.chat_stream.user_id: + return UserInfo( + user_id=self.chat_stream.user_id, + user_nickname=global_config.maisaka.user_name.strip() or "用户", + user_cardname=None, + ) + return UserInfo(user_id="maisaka_user", user_nickname="用户", user_cardname=None) + + def _build_group_info(self, message: Optional[SessionMessage] = None) -> Optional[GroupInfo]: + group_info = None + if message is not None: + group_info = message.message_info.group_info + elif self.chat_stream.context and self.chat_stream.context.message: + group_info = self.chat_stream.context.message.message_info.group_info + + if group_info is None: + return None + + return GroupInfo(group_id=group_info.group_id, group_name=group_info.group_name) + + def _log_cycle_started(self, cycle_detail: CycleDetail, round_index: int) -> None: + logger.info( + f"{self.log_prefix} MaiSaka 轮次开始: 循环编号={cycle_detail.cycle_id} " + f"回合={round_index + 1}/{self._max_internal_rounds} " + f"上下文消息数={len(self._chat_history)}" + ) + + def _log_cycle_completed(self, cycle_detail: CycleDetail, timer_strings: list[str]) -> None: + end_time = cycle_detail.end_time if cycle_detail.end_time is not None else cycle_detail.start_time + logger.info( + f"{self.log_prefix} MaiSaka 轮次结束: 循环编号={cycle_detail.cycle_id} " + f"总耗时={end_time - cycle_detail.start_time:.2f} 秒; " + f"阶段耗时={', '.join(timer_strings) if timer_strings else '无'}" + ) + + def _log_history_trimmed(self, removed_count: int, user_message_count: int) -> None: + logger.info( + f"{self.log_prefix} 已裁剪 {removed_count} 条历史消息; " + f"剩余计入上下文的消息数={user_message_count}" + ) + + def _log_internal_loop_cancelled(self) -> None: + logger.info(f"{self.log_prefix} Maisaka 内部循环已取消") diff --git a/src/maisaka/timing.py b/src/maisaka/timing.py deleted file mode 100644 index 1709506f..00000000 --- a/src/maisaka/timing.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -MaiSaka timing helpers. -""" - -from datetime import datetime -from typing import Optional - - -def _format_duration(total_seconds: int) -> str: - hours, remainder = divmod(total_seconds, 3600) - minutes, seconds = divmod(remainder, 60) - if hours > 0: - return f"{hours}h {minutes}m {seconds}s" - if minutes > 0: - return f"{minutes}m {seconds}s" - return f"{seconds}s" - - -def _get_time_period_label(hour: int) -> str: - if 0 <= hour < 6: - return "late_night" - if 6 <= hour < 9: - return "morning" - if 9 <= hour < 12: - return "late_morning" - if 12 <= hour < 14: - return "noon" - if 14 <= hour < 18: - return "afternoon" - if 18 <= hour < 22: - return "evening" - return "night" - - -def build_timing_info( - chat_start_time: Optional[datetime], - last_user_input_time: Optional[datetime], - last_assistant_response_time: Optional[datetime], - user_input_times: list[datetime], -) -> str: - """Build readable timing context for the timing analysis prompt.""" - now = datetime.now() - parts: list[str] = [f"Current time: {now.strftime('%Y-%m-%d %H:%M:%S')}"] - - if chat_start_time: - elapsed_seconds = int((now - chat_start_time).total_seconds()) - parts.append(f"Conversation duration: {_format_duration(elapsed_seconds)}") - - if last_user_input_time: - since_user_seconds = int((now - last_user_input_time).total_seconds()) - parts.append(f"Seconds since last user input: {since_user_seconds}") - - if last_assistant_response_time: - since_assistant_seconds = int((now - last_assistant_response_time).total_seconds()) - parts.append(f"Seconds since last Maisaka reply: {since_assistant_seconds}") - - if len(user_input_times) >= 2: - intervals = [ - int((user_input_times[index] - user_input_times[index - 1]).total_seconds()) - for index in range(1, len(user_input_times)) - ] - average_interval = sum(intervals) / len(intervals) - parts.append(f"Average user input interval: {int(average_interval)}s") - parts.append(f"Total user input count: {len(user_input_times)}") - - parts.append(f"Current time period: {_get_time_period_label(now.hour)}") - return "\n".join(parts) diff --git a/src/maisaka/tool_handlers.py b/src/maisaka/tool_handlers.py index 68d00f22..57d98c9d 100644 --- a/src/maisaka/tool_handlers.py +++ b/src/maisaka/tool_handlers.py @@ -1,239 +1,130 @@ """ -MaiSaka tool handlers. +MaiSaka 工具处理器。 """ from datetime import datetime -from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import json as _json -import os from rich.panel import Panel -from src.common.data_models.mai_message_data_model import MaiMessage +from src.cli.console import console +from src.cli.input_reader import InputReader from src.llm_models.payload_content.tool_option import ToolCall -from .config import console -from .input_reader import InputReader -from .llm_service import MaiSakaLLMService -from .message_adapter import build_message +from .context_messages import LLMContextMessage, ToolResultMessage if TYPE_CHECKING: - from .mcp_client import MCPManager - - -MAI_FILES_DIR = Path(os.path.join(os.path.dirname(os.path.abspath(__file__)), "mai_files")) + from src.mcp_module import MCPManager class ToolHandlerContext: - """Shared context for tool handlers.""" + """工具处理器共享上下文。""" def __init__( self, - llm_service: MaiSakaLLMService, reader: InputReader, user_input_times: list[datetime], ) -> None: - self.llm_service = llm_service self.reader = reader self.user_input_times = user_input_times self.last_user_input_time: Optional[datetime] = None -async def handle_stop(tc: ToolCall, chat_history: list[MaiMessage]) -> None: - """Handle the stop tool.""" - console.print("[accent]Calling tool: stop()[/accent]") +async def handle_stop(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None: + """处理 stop 工具。""" + console.print("[accent]调用工具: stop()[/accent]") chat_history.append( - build_message(role="tool", content="Conversation loop will stop after this round.", tool_call_id=tc.call_id) + ToolResultMessage( + content="当前轮次结束后将停止对话循环。", + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) ) -async def handle_wait(tc: ToolCall, chat_history: list[MaiMessage], ctx: ToolHandlerContext) -> str: - """Handle the wait tool.""" +async def handle_wait(tc: ToolCall, chat_history: list[LLMContextMessage], ctx: ToolHandlerContext) -> str: + """处理 wait 工具。""" seconds = (tc.args or {}).get("seconds", 30) seconds = max(5, min(seconds, 300)) - console.print(f"[accent]Calling tool: wait({seconds})[/accent]") + console.print(f"[accent]调用工具: wait({seconds})[/accent]") tool_result = await _do_wait(seconds, ctx) - chat_history.append(build_message(role="tool", content=tool_result, tool_call_id=tc.call_id)) + chat_history.append( + ToolResultMessage( + content=tool_result, + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) + ) return tool_result async def _do_wait(seconds: int, ctx: ToolHandlerContext) -> str: - """Wait for user input with a timeout.""" - console.print(f"[muted]Waiting for user input (timeout: {seconds}s)...[/muted]") + """等待用户输入,支持超时。""" + console.print(f"[muted]等待用户输入中(超时: {seconds} 秒)...[/muted]") console.print("[bold magenta]> [/bold magenta]", end="") user_input = await ctx.reader.get_line(timeout=seconds) if user_input is None: console.print() - console.print("[muted]Wait timeout[/muted]") - return "Wait timed out; no user input received." + console.print("[muted]等待超时[/muted]") + return "等待超时,未收到用户输入。" user_input = user_input.strip() if not user_input: - return "User submitted an empty input." + return "用户提交了空输入。" now = datetime.now() ctx.last_user_input_time = now ctx.user_input_times.append(now) if user_input.lower() in ("/quit", "/exit", "/q"): - return "[[QUIT]] User requested to exit." + return "[[QUIT]] 用户请求退出。" - return f"User input received: {user_input}" + return f"已收到用户输入: {user_input}" -async def handle_mcp_tool(tc: ToolCall, chat_history: list[MaiMessage], mcp_manager: "MCPManager") -> None: - """Handle an MCP tool call.""" +async def handle_mcp_tool(tc: ToolCall, chat_history: list[LLMContextMessage], mcp_manager: "MCPManager") -> None: + """处理 MCP 工具调用。""" args_str = _json.dumps(tc.args or {}, ensure_ascii=False) args_preview = args_str if len(args_str) <= 120 else args_str[:120] + "..." - console.print(f"[accent]Calling MCP tool: {tc.func_name}({args_preview})[/accent]") + console.print(f"[accent]调用 MCP 工具: {tc.func_name}({args_preview})[/accent]") - with console.status(f"[info]Running MCP tool {tc.func_name}...[/info]", spinner="dots"): + with console.status(f"[info]正在执行 MCP 工具 {tc.func_name}...[/info]", spinner="dots"): result = await mcp_manager.call_tool(tc.func_name, tc.args or {}) - display_text = result if len(result) <= 800 else result[:800] + "\n... (truncated)" + display_text = result if len(result) <= 800 else result[:800] + "\n...(已截断)" console.print( Panel( display_text, - title=f"MCP: {tc.func_name}", + title=f"MCP 工具:{tc.func_name}", border_style="bright_green", padding=(0, 1), ) ) - chat_history.append(build_message(role="tool", content=result, tool_call_id=tc.call_id)) - - -async def handle_unknown_tool(tc: ToolCall, chat_history: list[MaiMessage]) -> None: - """Handle an unknown tool call.""" - console.print(f"[accent]Calling unknown tool: {tc.func_name}({tc.args})[/accent]") - chat_history.append(build_message(role="tool", content=f"Unknown tool: {tc.func_name}", tool_call_id=tc.call_id)) - - -async def handle_write_file(tc: ToolCall, chat_history: list[MaiMessage]) -> None: - """Write a file under the local mai_files workspace.""" - filename = (tc.args or {}).get("filename", "") - content = (tc.args or {}).get("content", "") - console.print(f'[accent]Calling tool: write_file("{filename}")[/accent]') - - MAI_FILES_DIR.mkdir(parents=True, exist_ok=True) - file_path = MAI_FILES_DIR / filename - - try: - file_path.parent.mkdir(parents=True, exist_ok=True) - with open(file_path, "w", encoding="utf-8") as file: - file.write(content) - - file_size = file_path.stat().st_size - console.print( - Panel( - f"Path: {filename}\nSize: {file_size} bytes", - title="File Written", - border_style="green", - padding=(0, 1), - ) + chat_history.append( + ToolResultMessage( + content=result, + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, ) - chat_history.append( - build_message( - role="tool", - content=f"File written successfully: {filename} ({file_size} bytes)", - tool_call_id=tc.call_id, - ) + ) + + +async def handle_unknown_tool(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None: + """处理未知工具调用。""" + console.print(f"[accent]调用未知工具: {tc.func_name}({tc.args})[/accent]") + chat_history.append( + ToolResultMessage( + content=f"未知工具: {tc.func_name}", + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, ) - except Exception as exc: - error_msg = f"Failed to write file: {exc}" - console.print(f"[error]{error_msg}[/error]") - chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id)) - - -async def handle_read_file(tc: ToolCall, chat_history: list[MaiMessage]) -> None: - """Read a file from the local mai_files workspace.""" - filename = (tc.args or {}).get("filename", "") - console.print(f'[accent]Calling tool: read_file("{filename}")[/accent]') - - file_path = MAI_FILES_DIR / filename - - try: - if not file_path.exists(): - error_msg = f"File does not exist: {filename}" - console.print(f"[warning]{error_msg}[/warning]") - chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id)) - return - - if not file_path.is_file(): - error_msg = f"Path is not a file: {filename}" - console.print(f"[warning]{error_msg}[/warning]") - chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id)) - return - - with open(file_path, "r", encoding="utf-8") as file: - file_content = file.read() - - display_content = file_content if len(file_content) <= 1000 else file_content[:1000] + "\n... (truncated)" - console.print( - Panel( - display_content, - title=f"Read File: {filename}", - border_style="blue", - padding=(0, 1), - ) - ) - chat_history.append( - build_message(role="tool", content=f"File content of {filename}:\n{file_content}", tool_call_id=tc.call_id) - ) - except Exception as exc: - error_msg = f"Failed to read file: {exc}" - console.print(f"[error]{error_msg}[/error]") - chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id)) - - -async def handle_list_files(tc: ToolCall, chat_history: list[MaiMessage]) -> None: - """List files under the local mai_files workspace.""" - console.print("[accent]Calling tool: list_files()[/accent]") - - try: - MAI_FILES_DIR.mkdir(parents=True, exist_ok=True) - - files_info: list[dict[str, Any]] = [] - for item in MAI_FILES_DIR.rglob("*"): - if item.is_file(): - stat = item.stat() - files_info.append( - { - "name": str(item.relative_to(MAI_FILES_DIR)), - "size": stat.st_size, - "modified": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"), - } - ) - - if not files_info: - result_text = "No files found under mai_files." - else: - files_info.sort(key=lambda item: item["name"]) - lines = [f"Found {len(files_info)} file(s):\n"] - for item in files_info: - lines.append(f"- {item['name']} ({item['size']} bytes, modified {item['modified']})") - result_text = "\n".join(lines) - - console.print( - Panel( - result_text, - title="File List", - border_style="cyan", - padding=(0, 1), - ) - ) - chat_history.append(build_message(role="tool", content=result_text, tool_call_id=tc.call_id)) - except Exception as exc: - error_msg = f"Failed to list files: {exc}" - console.print(f"[error]{error_msg}[/error]") - chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id)) - - -try: - MAI_FILES_DIR.mkdir(parents=True, exist_ok=True) -except Exception as exc: - console.print(f"[warning]Failed to initialize mai_files directory: {exc}[/warning]") + ) diff --git a/src/maisaka/tool_provider.py b/src/maisaka/tool_provider.py new file mode 100644 index 00000000..273fd4bd --- /dev/null +++ b/src/maisaka/tool_provider.py @@ -0,0 +1,64 @@ +"""Maisaka 内置工具 Provider。""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Dict, Optional + +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec + +from .builtin_tools import get_builtin_tool_specs + +BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]] + + +class MaisakaBuiltinToolProvider(ToolProvider): + """Maisaka 内置工具提供者。""" + + provider_name = "maisaka_builtin" + provider_type = "builtin" + + def __init__(self, handlers: Optional[Dict[str, BuiltinToolHandler]] = None) -> None: + """初始化内置工具 Provider。 + + Args: + handlers: 工具名到异步处理器的映射。 + """ + + self._handlers = dict(handlers or {}) + + async def list_tools(self) -> list[ToolSpec]: + """列出全部内置工具。""" + + return list(get_builtin_tool_specs()) + + async def invoke( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行指定内置工具。 + + Args: + invocation: 工具调用请求。 + context: 执行上下文。 + + Returns: + ToolExecutionResult: 工具执行结果。 + """ + + handler = self._handlers.get(invocation.tool_name) + if handler is None: + return ToolExecutionResult( + tool_name=invocation.tool_name, + success=False, + error_message=f"未找到内置工具处理器:{invocation.tool_name}", + ) + return await handler(invocation, context) + + async def close(self) -> None: + """关闭 Provider。 + + 内置 Provider 无需释放额外资源。 + """ + diff --git a/src/maisaka/mcp_client/__init__.py b/src/mcp_module/__init__.py similarity index 73% rename from src/maisaka/mcp_client/__init__.py rename to src/mcp_module/__init__.py index bd996975..0fd5bee7 100644 --- a/src/maisaka/mcp_client/__init__.py +++ b/src/mcp_module/__init__.py @@ -1,12 +1,13 @@ """ -MaiSaka - MCP (Model Context Protocol) 客户端包 +MCP (Model Context Protocol) 客户端包。 提供 MCPManager 用于管理 MCP 服务器连接、发现工具、调用工具。 用法: + from src.config.config import global_config from .manager import MCPManager - manager = await MCPManager.from_config("mcp_config.json") + manager = await MCPManager.from_app_config(global_config.mcp) if manager: tools = manager.get_openai_tools() # 获取 OpenAI 格式工具列表 result = await manager.call_tool(name, args) # 调用工具 diff --git a/src/mcp_module/config.py b/src/mcp_module/config.py new file mode 100644 index 00000000..4d4d73af --- /dev/null +++ b/src/mcp_module/config.py @@ -0,0 +1,160 @@ +"""MCP 运行时配置转换。 + +负责将主程序官方配置中的 MCP 配置转换为运行时使用的结构化对象。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from src.config.official_configs import MCPConfig + + +@dataclass(slots=True) +class MCPAuthorizationRuntimeConfig: + """MCP HTTP 认证运行时配置。""" + + mode: Literal["none", "bearer"] = "none" + bearer_token: str = "" + + +@dataclass(slots=True) +class MCPRootRuntimeConfig: + """MCP Root 运行时配置。""" + + uri: str + name: str = "" + + +@dataclass(slots=True) +class MCPClientRuntimeConfig: + """MCP 客户端宿主能力运行时配置。""" + + client_name: str = "MaiBot" + client_version: str = "1.0.0" + enable_roots: bool = False + roots: list[MCPRootRuntimeConfig] = field(default_factory=list) + enable_sampling: bool = False + sampling_task_name: str = "planner" + sampling_include_context_support: bool = False + sampling_tool_support: bool = False + enable_elicitation: bool = False + elicitation_allow_form: bool = True + elicitation_allow_url: bool = False + + +@dataclass(slots=True) +class MCPServerRuntimeConfig: + """单个 MCP 服务器的运行时配置。""" + + name: str + transport: Literal["stdio", "streamable_http"] = "stdio" + command: str = "" + args: list[str] = field(default_factory=list) + env: dict[str, str] = field(default_factory=dict) + url: str = "" + headers: dict[str, str] = field(default_factory=dict) + http_timeout_seconds: float = 30.0 + read_timeout_seconds: float = 300.0 + authorization: MCPAuthorizationRuntimeConfig = field(default_factory=MCPAuthorizationRuntimeConfig) + + @property + def transport_type(self) -> str: + """返回当前服务器的传输类型。 + + Returns: + str: ``stdio``、``streamable_http`` 或 ``unknown``。 + """ + + if self.transport == "stdio" and self.command: + return "stdio" + if self.transport == "streamable_http" and self.url: + return "streamable_http" + return "unknown" + + def build_http_headers(self) -> dict[str, str]: + """构建远程 HTTP 连接需要附加的请求头。 + + Returns: + dict[str, str]: 归一化后的请求头集合。 + """ + + headers = {str(key): str(value) for key, value in self.headers.items()} + if self.authorization.mode == "bearer" and self.authorization.bearer_token.strip(): + headers["Authorization"] = f"Bearer {self.authorization.bearer_token.strip()}" + return headers + + +def build_mcp_client_runtime_config(mcp_config: "MCPConfig") -> MCPClientRuntimeConfig: + """将官方 MCP 客户端配置转换为运行时结构。 + + Args: + mcp_config: 主程序中的 MCP 官方配置对象。 + + Returns: + MCPClientRuntimeConfig: MCP 客户端宿主能力运行时配置。 + """ + + roots = [ + MCPRootRuntimeConfig( + uri=root.uri.strip(), + name=root.name.strip(), + ) + for root in mcp_config.client.roots.items + if root.enabled and root.uri.strip() + ] + + return MCPClientRuntimeConfig( + client_name=mcp_config.client.client_name.strip() or "MaiBot", + client_version=mcp_config.client.client_version.strip() or "1.0.0", + enable_roots=mcp_config.client.roots.enable and bool(roots), + roots=roots, + enable_sampling=mcp_config.client.sampling.enable, + sampling_task_name=mcp_config.client.sampling.task_name.strip() or "planner", + sampling_include_context_support=mcp_config.client.sampling.include_context_support, + sampling_tool_support=mcp_config.client.sampling.tool_support, + enable_elicitation=mcp_config.client.elicitation.enable, + elicitation_allow_form=mcp_config.client.elicitation.allow_form, + elicitation_allow_url=mcp_config.client.elicitation.allow_url, + ) + + +def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerRuntimeConfig]: + """将官方 MCP 配置转换为运行时配置列表。 + + Args: + mcp_config: 主程序中的 MCP 官方配置对象。 + + Returns: + list[MCPServerRuntimeConfig]: 启用且配置完整的 MCP 服务器列表。 + """ + + if not mcp_config.enable: + return [] + + runtime_configs: list[MCPServerRuntimeConfig] = [] + for server in mcp_config.servers: + if not server.enabled: + continue + + runtime_configs.append( + MCPServerRuntimeConfig( + name=server.name.strip(), + transport=server.transport, + command=server.command.strip(), + args=[str(argument) for argument in server.args], + env={str(key): str(value) for key, value in server.env.items()}, + url=server.url.strip(), + headers={str(key): str(value) for key, value in server.headers.items()}, + http_timeout_seconds=float(server.http_timeout_seconds), + read_timeout_seconds=float(server.read_timeout_seconds), + authorization=MCPAuthorizationRuntimeConfig( + mode=server.authorization.mode, + bearer_token=server.authorization.bearer_token.strip(), + ), + ) + ) + + return runtime_configs diff --git a/src/mcp_module/connection.py b/src/mcp_module/connection.py new file mode 100644 index 00000000..c598e8bc --- /dev/null +++ b/src/mcp_module/connection.py @@ -0,0 +1,559 @@ +""" +MaiSaka - 单个 MCP 服务器连接管理 +封装单个 MCP 服务器的连接生命周期:连接 → 发现能力 → 调用工具/读取资源 → 断开。 +""" + +from __future__ import annotations + +from contextlib import AsyncExitStack +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Callable, Optional, cast + +import httpx + +from src.cli.console import console +from src.core.tooling import ToolExecutionResult + +from .config import MCPClientRuntimeConfig, MCPRootRuntimeConfig, MCPServerRuntimeConfig +from .hooks import MCPHostCallbacks +from .models import ( + MCPPromptResult, + MCPResourceReadResult, + build_prompt_result, + build_resource_read_result, + build_tool_content_items, +) + +if TYPE_CHECKING: + from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT + +# ──────────────────── MCP SDK 可选导入 ──────────────────── +# +# mcp 是可选依赖。如果未安装,MCP_AVAILABLE = False, +# MCPManager.from_app_config() 会检测到并返回 None,不影响主程序运行。 + +try: + from mcp import ClientSession, types as mcp_types + + try: + from mcp.client.stdio import StdioServerParameters + except ImportError: + from mcp import StdioServerParameters # type: ignore[attr-defined] + + from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamable_http_client + + MCP_AVAILABLE = True + STREAMABLE_HTTP_AVAILABLE = True +except ImportError: + MCP_AVAILABLE = False + STREAMABLE_HTTP_AVAILABLE = False + ClientSession = None # type: ignore[assignment,misc] + StdioServerParameters = None # type: ignore[assignment,misc] + mcp_types = None # type: ignore[assignment] + stdio_client = None # type: ignore[assignment] + streamable_http_client = None # type: ignore[assignment] + + +class MCPConnection: + """管理单个 MCP 服务器的连接生命周期。""" + + def __init__( + self, + config: MCPServerRuntimeConfig, + client_config: MCPClientRuntimeConfig, + host_callbacks: Optional[MCPHostCallbacks] = None, + ) -> None: + """初始化单个 MCP 连接。 + + Args: + config: 当前服务器的运行时配置。 + client_config: MCP 客户端宿主能力运行时配置。 + host_callbacks: 宿主侧能力回调集合。 + """ + + self.config = config + self.client_config = client_config + self.host_callbacks = host_callbacks or MCPHostCallbacks() + + self.session: Optional[Any] = None + self.server_capabilities: Optional[Any] = None + self.tools: list[Any] = [] + self.prompts: list[Any] = [] + self.resources: list[Any] = [] + self.resource_templates: list[Any] = [] + self.protocol_version: str = "" + + self._http_client: Optional[httpx.AsyncClient] = None + self._session_id_getter: Optional[Callable[[], str | None]] = None + self._exit_stack = AsyncExitStack() + + @property + def session_id(self) -> str: + """返回当前连接协商得到的 MCP 会话标识。 + + Returns: + str: 当前会话 ID;无会话时返回空字符串。 + """ + + if self._session_id_getter is None: + return "" + return self._session_id_getter() or "" + + async def connect(self) -> bool: + """连接到 MCP 服务器并发现可用能力。 + + Returns: + bool: `True` 表示连接成功,`False` 表示失败。 + """ + + if not MCP_AVAILABLE: + console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]") + return False + + try: + await self._exit_stack.__aenter__() + read_stream, write_stream = await self._connect_transport() + session = await self._create_client_session(read_stream, write_stream) + self.session = session + initialize_result = await session.initialize() + self.server_capabilities = getattr(initialize_result, "capabilities", None) + self.protocol_version = str(getattr(initialize_result, "protocolVersion", "") or "") + + await self._load_server_features() + return True + + except Exception as exc: + console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {exc}[/warning]") + await self.close() + return False + + async def _connect_transport(self) -> tuple[Any, Any]: + """根据配置建立底层传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + + if self.config.transport_type == "stdio": + return await self._connect_stdio() + if self.config.transport_type == "streamable_http": + return await self._connect_streamable_http() + + raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}") + + async def _connect_stdio(self) -> tuple[Any, Any]: + """建立 stdio 传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + + if StdioServerParameters is None or stdio_client is None: + raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端") + if not self.config.command: + raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 stdio command 配置") + + params = StdioServerParameters( + command=self.config.command, + args=self.config.args, + env=self.config.env, + ) + return await self._exit_stack.enter_async_context(stdio_client(params)) + + async def _connect_streamable_http(self) -> tuple[Any, Any]: + """建立 Streamable HTTP 传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + + if not STREAMABLE_HTTP_AVAILABLE or streamable_http_client is None: + raise ImportError("当前环境未安装可用的 MCP Streamable HTTP 客户端") + if not self.config.url: + raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 Streamable HTTP url 配置") + + self._http_client = await self._exit_stack.enter_async_context(self._build_http_client()) + read_stream, write_stream, session_id_getter = await self._exit_stack.enter_async_context( + streamable_http_client( + url=self.config.url, + http_client=self._http_client, + terminate_on_close=True, + ) + ) + self._session_id_getter = session_id_getter + return read_stream, write_stream + + def _build_http_client(self) -> httpx.AsyncClient: + """构建 Streamable HTTP 使用的 `httpx` 客户端。 + + Returns: + httpx.AsyncClient: 预配置的异步 HTTP 客户端。 + """ + + return httpx.AsyncClient( + headers=self.config.build_http_headers(), + timeout=httpx.Timeout(self.config.http_timeout_seconds), + ) + + async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any: + """创建并返回 MCP `ClientSession`。 + + Args: + read_stream: 底层读取流。 + write_stream: 底层写入流。 + + Returns: + Any: 已初始化的 MCP `ClientSession` 实例。 + """ + + if ClientSession is None: + raise RuntimeError("当前环境未安装可用的 MCP ClientSession") + + list_roots_callback = self._build_list_roots_callback() + sampling_callback = ( + self.host_callbacks.sampling_callback + if self.client_config.enable_sampling and self.host_callbacks.sampling_callback is not None + else None + ) + elicitation_callback = ( + self.host_callbacks.elicitation_callback + if self.client_config.enable_elicitation and self.host_callbacks.elicitation_callback is not None + else None + ) + logging_callback = cast(Optional["LoggingFnT"], self.host_callbacks.logging_callback) + message_handler = cast(Optional["MessageHandlerFnT"], self.host_callbacks.message_handler) + + if self.client_config.enable_sampling and sampling_callback is None: + console.print( + f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 sampling 配置,但宿主未提供 sampling 回调,当前不会声明该能力[/warning]" + ) + if self.client_config.enable_elicitation and elicitation_callback is None: + console.print( + f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 elicitation 配置,但宿主未提供 elicitation 回调,当前不会声明该能力[/warning]" + ) + + session = await self._exit_stack.enter_async_context( + ClientSession( + read_stream, + write_stream, + read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds), + sampling_callback=cast(Optional["SamplingFnT"], sampling_callback), + elicitation_callback=cast(Optional["ElicitationFnT"], elicitation_callback), + list_roots_callback=cast(Optional["ListRootsFnT"], list_roots_callback), + logging_callback=logging_callback, + message_handler=message_handler, + client_info=self._build_client_info(), + sampling_capabilities=self._build_sampling_capabilities(sampling_callback), + ) + ) + return session + + def _build_client_info(self) -> Any: + """构建 MCP 客户端实现信息。 + + Returns: + Any: MCP SDK 的 `Implementation` 对象。 + """ + + if mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + return mcp_types.Implementation( + name=self.client_config.client_name, + version=self.client_config.client_version, + ) + + def _build_sampling_capabilities(self, sampling_callback: Any) -> Any | None: + """构建 Sampling 能力声明。 + + Args: + sampling_callback: 当前宿主侧的 Sampling 回调。 + + Returns: + Any | None: Sampling 能力对象;未启用时返回 ``None``。 + """ + + if mcp_types is None: + return None + if sampling_callback is None: + return None + + context_capability = ( + mcp_types.SamplingContextCapability() + if self.client_config.sampling_include_context_support + else None + ) + tools_capability = ( + mcp_types.SamplingToolsCapability() + if self.client_config.sampling_tool_support + else None + ) + return mcp_types.SamplingCapability( + context=context_capability, + tools=tools_capability, + ) + + def _build_list_roots_callback(self) -> Any | None: + """构建 Roots 列表回调。 + + Returns: + Any | None: 符合 MCP SDK 要求的回调;未启用时返回 ``None``。 + """ + + if mcp_types is None: + return None + if not self.client_config.enable_roots or not self.client_config.roots: + return None + + async def _list_roots(context: Any) -> Any: + """返回当前客户端声明的 Roots 列表。 + + Args: + context: MCP 请求上下文。 + + Returns: + Any: MCP `ListRootsResult` 对象。 + """ + + del context + types_module = mcp_types + if types_module is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + roots = [ + types_module.Root(uri=cast(Any, root.uri), name=root.name or None) + for root in self.client_config.roots + ] + return types_module.ListRootsResult(roots=roots) + + return _list_roots + + async def _load_server_features(self) -> None: + """根据服务端能力声明加载工具、Prompt 与 Resource。""" + + self.tools = await self._list_tools() if self.supports_tools() else [] + self.prompts = await self._list_prompts() if self.supports_prompts() else [] + self.resources = await self._list_resources() if self.supports_resources() else [] + self.resource_templates = ( + await self._list_resource_templates() if self.supports_resources() else [] + ) + + def supports_tools(self) -> bool: + """判断服务端是否声明支持 Tools。 + + Returns: + bool: 是否支持 Tools。 + """ + + return bool(self.server_capabilities is not None and getattr(self.server_capabilities, "tools", None) is not None) + + def supports_prompts(self) -> bool: + """判断服务端是否声明支持 Prompts。 + + Returns: + bool: 是否支持 Prompts。 + """ + + return bool( + self.server_capabilities is not None and getattr(self.server_capabilities, "prompts", None) is not None + ) + + def supports_resources(self) -> bool: + """判断服务端是否声明支持 Resources。 + + Returns: + bool: 是否支持 Resources。 + """ + + return bool( + self.server_capabilities is not None and getattr(self.server_capabilities, "resources", None) is not None + ) + + async def _list_tools(self) -> list[Any]: + """分页加载服务端暴露的全部工具。 + + Returns: + list[Any]: MCP SDK 的原始工具对象列表。 + """ + + if self.session is None: + return [] + + tools: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_tools(cursor=cursor) + tools.extend(list(getattr(result, "tools", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return tools + + async def _list_prompts(self) -> list[Any]: + """分页加载服务端暴露的全部 Prompt。 + + Returns: + list[Any]: MCP SDK 的原始 Prompt 对象列表。 + """ + + if self.session is None: + return [] + + prompts: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_prompts(cursor=cursor) + prompts.extend(list(getattr(result, "prompts", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return prompts + + async def _list_resources(self) -> list[Any]: + """分页加载服务端暴露的全部 Resource。 + + Returns: + list[Any]: MCP SDK 的原始 Resource 对象列表。 + """ + + if self.session is None: + return [] + + resources: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_resources(cursor=cursor) + resources.extend(list(getattr(result, "resources", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return resources + + async def _list_resource_templates(self) -> list[Any]: + """分页加载服务端暴露的全部 Resource Template。 + + Returns: + list[Any]: MCP SDK 的原始 Resource Template 对象列表。 + """ + + if self.session is None: + return [] + + resource_templates: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_resource_templates(cursor=cursor) + resource_templates.extend(list(getattr(result, "resourceTemplates", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return resource_templates + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult: + """调用 MCP 工具并返回统一执行结果。 + + Args: + tool_name: 工具名称。 + arguments: 工具参数字典。 + + Returns: + ToolExecutionResult: 统一执行结果。 + """ + + if self.session is None: + return ToolExecutionResult( + tool_name=tool_name, + success=False, + error_message=f"MCP 服务器 '{self.config.name}' 未连接", + metadata={"server_name": self.config.name}, + ) + + try: + result = await self.session.call_tool( + tool_name, + arguments=arguments, + read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds), + ) + except Exception as exc: + return ToolExecutionResult( + tool_name=tool_name, + success=False, + error_message=f"MCP 工具 '{tool_name}' 执行失败: {exc}", + metadata={"server_name": self.config.name}, + ) + + content_items = build_tool_content_items(list(getattr(result, "content", []) or [])) + text_parts = [item.text.strip() for item in content_items if item.content_type == "text" and item.text.strip()] + structured_content = getattr(result, "structuredContent", None) + is_error = bool(getattr(result, "isError", False)) + history_content = "\n".join(text_parts).strip() + error_message = history_content if is_error else "" + + return ToolExecutionResult( + tool_name=tool_name, + success=not is_error, + content=history_content if not is_error else "", + error_message=error_message, + structured_content=structured_content, + content_items=content_items, + metadata={ + "server_name": self.config.name, + "protocol_version": self.protocol_version, + "session_id": self.session_id, + }, + ) + + async def get_prompt( + self, + prompt_name: str, + arguments: Optional[dict[str, str]] = None, + ) -> MCPPromptResult: + """读取指定 MCP Prompt 的内容。 + + Args: + prompt_name: Prompt 名称。 + arguments: Prompt 参数字典。 + + Returns: + MCPPromptResult: 统一 Prompt 结果。 + """ + + if self.session is None: + raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接") + + result = await self.session.get_prompt(prompt_name, arguments=arguments) + return build_prompt_result(result, prompt_name=prompt_name, server_name=self.config.name) + + async def read_resource(self, uri: str) -> MCPResourceReadResult: + """读取指定 MCP Resource 的内容。 + + Args: + uri: 资源 URI。 + + Returns: + MCPResourceReadResult: 统一资源读取结果。 + """ + + if self.session is None: + raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接") + + result = await self.session.read_resource(uri) + return build_resource_read_result(result, uri=uri, server_name=self.config.name) + + async def close(self) -> None: + """关闭连接并释放资源。""" + + try: + await self._exit_stack.aclose() + except Exception: + pass + + self.session = None + self.server_capabilities = None + self.tools = [] + self.prompts = [] + self.resources = [] + self.resource_templates = [] + self.protocol_version = "" + self._http_client = None + self._session_id_getter = None diff --git a/src/mcp_module/hooks.py b/src/mcp_module/hooks.py new file mode 100644 index 00000000..c1890390 --- /dev/null +++ b/src/mcp_module/hooks.py @@ -0,0 +1,20 @@ +"""MCP 宿主回调声明。""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + + +@dataclass(slots=True) +class MCPHostCallbacks: + """MCP 宿主回调集合。 + + 该对象用于向 `MCPConnection` 注入宿主侧可选能力, + 例如 Sampling、Elicitation、日志消费和自定义消息处理。 + """ + + sampling_callback: Callable[..., Awaitable[Any]] | None = None + elicitation_callback: Callable[..., Awaitable[Any]] | None = None + logging_callback: Callable[..., Awaitable[None]] | None = None + message_handler: Callable[..., Awaitable[None]] | None = None diff --git a/src/mcp_module/host_llm_bridge.py b/src/mcp_module/host_llm_bridge.py new file mode 100644 index 00000000..1b8bc10d --- /dev/null +++ b/src/mcp_module/host_llm_bridge.py @@ -0,0 +1,597 @@ +"""MCP 宿主侧大模型桥接服务。""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +import json + +from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMResponseResult +from src.common.logger import get_logger +from src.core.tooling import build_tool_detailed_description +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput +from src.services.llm_service import LLMServiceClient + +from .hooks import MCPHostCallbacks +from .models import build_tool_content_items + +if TYPE_CHECKING: + from src.llm_models.model_client.base_client import BaseClient + +try: + from mcp import types as mcp_types + + MCP_TYPES_AVAILABLE = True +except ImportError: + mcp_types = None # type: ignore[assignment] + MCP_TYPES_AVAILABLE = False + +logger = get_logger("mcp_host_llm_bridge") + + +class MCPHostLLMBridge: + """将 MCP Sampling 请求桥接到主程序大模型调用链。""" + + def __init__(self, sampling_task_name: str = "planner") -> None: + """初始化 MCP 宿主侧大模型桥接服务。 + + Args: + sampling_task_name: 执行 Sampling 请求时使用的模型任务名。 + """ + + self._sampling_task_name = sampling_task_name.strip() or "planner" + self._sampling_client = LLMServiceClient( + task_name=self._sampling_task_name, + request_type="mcp_sampling", + ) + + def build_callbacks(self) -> MCPHostCallbacks: + """构建可注入给 MCP 连接层的宿主回调集合。 + + Returns: + MCPHostCallbacks: 包含 Sampling 回调的宿主回调集合。 + """ + + return MCPHostCallbacks( + sampling_callback=self.handle_sampling_request, + ) + + async def handle_sampling_request(self, context: Any, params: Any) -> Any: + """处理服务端发起的 MCP Sampling 请求。 + + Args: + context: MCP SDK 传入的请求上下文。 + params: `sampling/createMessage` 请求参数。 + + Returns: + Any: MCP `CreateMessageResult`、`CreateMessageResultWithTools` 或 `ErrorData`。 + """ + + del context + if not MCP_TYPES_AVAILABLE or mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + try: + tool_choice_mode = self._get_tool_choice_mode(params) + tool_definitions = self._build_tool_definitions( + raw_tools=getattr(params, "tools", None), + tool_choice_mode=tool_choice_mode, + ) + message_factory = self._build_message_factory( + raw_messages=list(getattr(params, "messages", []) or []), + system_prompt=self._build_system_prompt( + raw_system_prompt=str(getattr(params, "systemPrompt", "") or ""), + tool_choice_mode=tool_choice_mode, + tool_definitions=tool_definitions, + ), + ) + + generation_result = await self._sampling_client.generate_response_with_messages( + message_factory=message_factory, + options=LLMGenerationOptions( + temperature=self._coerce_float(getattr(params, "temperature", None)), + max_tokens=int(getattr(params, "maxTokens", 1024) or 1024), + tool_options=tool_definitions, + ), + ) + + if tool_choice_mode == "required" and tool_definitions and not generation_result.tool_calls: + return mcp_types.ErrorData( + code=mcp_types.INTERNAL_ERROR, + message="Sampling 要求必须调用工具,但模型未返回任何工具调用", + ) + + return self._build_sampling_result( + generation_result=generation_result, + tools_enabled=bool(tool_definitions), + ) + + except Exception as exc: + logger.exception(f"MCP Sampling 调用失败: {exc}") + return mcp_types.ErrorData( + code=mcp_types.INTERNAL_ERROR, + message=f"MCP Sampling 调用失败: {exc}", + ) + + @staticmethod + def _coerce_float(raw_value: Any) -> float | None: + """将任意原始值转换为浮点数。 + + Args: + raw_value: 原始输入值。 + + Returns: + float | None: 转换后的浮点数;无法转换时返回 ``None``。 + """ + + if raw_value is None: + return None + if isinstance(raw_value, int | float): + return float(raw_value) + return None + + @staticmethod + def _get_tool_choice_mode(params: Any) -> str: + """读取 Sampling 请求中的工具选择模式。 + + Args: + params: Sampling 请求参数对象。 + + Returns: + str: `auto`、`required` 或 `none`;缺省时返回 `auto`。 + """ + + tool_choice = getattr(params, "toolChoice", None) + mode = str(getattr(tool_choice, "mode", "") or "").strip().lower() + if mode in {"required", "none"}: + return mode + return "auto" + + def _build_system_prompt( + self, + raw_system_prompt: str, + tool_choice_mode: str, + tool_definitions: list[ToolDefinitionInput] | None, + ) -> str: + """构建发送给主程序大模型的系统提示词。 + + Args: + raw_system_prompt: 服务端请求中的系统提示词。 + tool_choice_mode: 当前工具选择模式。 + tool_definitions: 参与本次 Sampling 的工具定义。 + + Returns: + str: 最终系统提示词。 + """ + + prompt_parts: list[str] = [] + if raw_system_prompt.strip(): + prompt_parts.append(raw_system_prompt.strip()) + if tool_choice_mode == "required" and tool_definitions: + prompt_parts.append("本轮回答必须至少调用一个工具;不要直接结束回答。") + return "\n\n".join(part for part in prompt_parts if part).strip() + + def _build_message_factory( + self, + raw_messages: list[Any], + system_prompt: str, + ) -> Any: + """构建 MCP Sampling 使用的消息工厂。 + + Args: + raw_messages: MCP Sampling 原始消息列表。 + system_prompt: 规范化后的系统提示词。 + + Returns: + Any: 供 `LLMServiceClient` 使用的消息工厂。 + """ + + def _message_factory(client: "BaseClient") -> list[Message]: + """延迟构建内部消息列表。 + + Args: + client: 当前被选中的底层模型客户端。 + + Returns: + list[Message]: 内部统一消息列表。 + """ + + messages: list[Message] = [] + if system_prompt.strip(): + messages.append( + MessageBuilder() + .set_role(RoleType.System) + .add_text_content(system_prompt.strip()) + .build() + ) + + for raw_message in raw_messages: + messages.extend(self._convert_sampling_message(raw_message, client)) + return messages + + return _message_factory + + def _convert_sampling_message(self, raw_message: Any, client: "BaseClient") -> list[Message]: + """将单条 MCP Sampling 消息转换为内部消息列表。 + + Args: + raw_message: MCP Sampling 原始消息对象。 + client: 当前底层模型客户端。 + + Returns: + list[Message]: 转换后的内部消息列表。 + """ + + role = str(getattr(raw_message, "role", "") or "").strip().lower() + content_blocks = self._get_content_blocks(getattr(raw_message, "content", None)) + + if role == "assistant": + assistant_message = self._build_assistant_message(content_blocks, client) + return [assistant_message] if assistant_message is not None else [] + + if role == "user": + return self._build_user_messages(content_blocks, client) + + raise ValueError(f"不支持的 MCP Sampling 消息角色: {role}") + + @staticmethod + def _get_content_blocks(raw_content: Any) -> list[Any]: + """将 MCP Sampling 消息内容统一为列表。 + + Args: + raw_content: 原始内容字段。 + + Returns: + list[Any]: 统一后的内容块列表。 + """ + + if raw_content is None: + return [] + if isinstance(raw_content, list): + return list(raw_content) + return [raw_content] + + def _build_assistant_message(self, content_blocks: list[Any], client: "BaseClient") -> Optional[Message]: + """构建内部 assistant 消息。 + + Args: + content_blocks: MCP assistant 内容块列表。 + client: 当前底层模型客户端。 + + Returns: + Optional[Message]: 转换后的内部 assistant 消息;无有效内容时返回 ``None``。 + """ + + message_builder = MessageBuilder().set_role(RoleType.Assistant) + tool_calls: list[ToolCall] = [] + has_visible_content = False + + for content_block in content_blocks: + content_type = self._get_content_type(content_block) + if content_type == "tool_use": + tool_calls.append( + ToolCall( + call_id=str(getattr(content_block, "id", "") or ""), + func_name=str(getattr(content_block, "name", "") or ""), + args=self._normalize_tool_call_arguments(getattr(content_block, "input", None)), + ) + ) + continue + + has_visible_content = self._append_sampling_content_to_builder( + message_builder=message_builder, + content_block=content_block, + client=client, + ) or has_visible_content + + if tool_calls: + message_builder.set_tool_calls(tool_calls) + + if not has_visible_content and not tool_calls: + return None + return message_builder.build() + + def _build_user_messages(self, content_blocks: list[Any], client: "BaseClient") -> list[Message]: + """构建内部 user/tool 消息序列。 + + Args: + content_blocks: MCP user 内容块列表。 + client: 当前底层模型客户端。 + + Returns: + list[Message]: 转换后的内部消息序列。 + """ + + messages: list[Message] = [] + message_builder = MessageBuilder().set_role(RoleType.User) + has_user_content = False + + def flush_user_message() -> None: + """在当前存在用户可见内容时落盘一条 user 消息。""" + + nonlocal message_builder, has_user_content + if not has_user_content: + return + messages.append(message_builder.build()) + message_builder = MessageBuilder().set_role(RoleType.User) + has_user_content = False + + for content_block in content_blocks: + content_type = self._get_content_type(content_block) + if content_type == "tool_result": + flush_user_message() + messages.append(self._build_tool_result_message(content_block)) + continue + + has_user_content = self._append_sampling_content_to_builder( + message_builder=message_builder, + content_block=content_block, + client=client, + ) or has_user_content + + flush_user_message() + return messages + + @staticmethod + def _get_content_type(content_block: Any) -> str: + """读取 MCP 内容块类型。 + + Args: + content_block: MCP 内容块对象。 + + Returns: + str: 规范化后的内容块类型。 + """ + + return str(getattr(content_block, "type", "text") or "text").strip().lower() + + def _append_sampling_content_to_builder( + self, + message_builder: MessageBuilder, + content_block: Any, + client: "BaseClient", + ) -> bool: + """将 MCP 普通内容块追加到内部消息构建器。 + + Args: + message_builder: 内部消息构建器。 + content_block: MCP 内容块对象。 + client: 当前底层模型客户端。 + + Returns: + bool: 是否成功追加了可见内容。 + """ + + content_type = self._get_content_type(content_block) + if content_type == "text": + text_content = str(getattr(content_block, "text", "") or "") + if text_content.strip(): + message_builder.add_text_content(text_content) + return True + return False + + if content_type == "image": + image_data = str(getattr(content_block, "data", "") or "") + image_mime_type = str(getattr(content_block, "mimeType", "") or "") + image_format = self._normalize_image_format(image_mime_type) + if image_data and image_format: + message_builder.add_image_content( + image_format=image_format, + image_base64=image_data, + support_formats=client.get_support_image_formats(), + ) + return True + + message_builder.add_text_content( + f"[图片内容:mime_type={image_mime_type or 'unknown'},当前客户端无法直接透传]" + ) + return True + + if content_type == "audio": + audio_mime_type = str(getattr(content_block, "mimeType", "") or "") + message_builder.add_text_content(f"[音频内容:mime_type={audio_mime_type or 'unknown'}]") + return True + + return False + + @staticmethod + def _normalize_image_format(mime_type: str) -> str: + """将图片 MIME 类型转换为内部图片格式名称。 + + Args: + mime_type: MCP 图片 MIME 类型。 + + Returns: + str: 内部支持的图片格式名;不支持时返回空字符串。 + """ + + normalized_mime_type = mime_type.strip().lower() + if normalized_mime_type == "image/png": + return "png" + if normalized_mime_type in {"image/jpeg", "image/jpg"}: + return "jpeg" + if normalized_mime_type == "image/webp": + return "webp" + if normalized_mime_type == "image/gif": + return "gif" + return "" + + def _build_tool_result_message(self, content_block: Any) -> Message: + """将 MCP `tool_result` 内容块转换为内部 Tool 消息。 + + Args: + content_block: MCP `tool_result` 内容块对象。 + + Returns: + Message: 转换后的内部 Tool 消息。 + """ + + message_builder = MessageBuilder().set_role(RoleType.Tool) + message_builder.set_tool_call_id(str(getattr(content_block, "toolUseId", "") or "tool_result")) + summary_text = self._summarize_tool_result_content(content_block) + message_builder.add_text_content(summary_text or "工具执行完成。") + return message_builder.build() + + def _summarize_tool_result_content(self, content_block: Any) -> str: + """汇总 MCP `tool_result` 内容块中的结果文本。 + + Args: + content_block: MCP `tool_result` 内容块对象。 + + Returns: + str: 适合发送给主程序模型的工具结果摘要文本。 + """ + + raw_contents = list(getattr(content_block, "content", []) or []) + content_items = build_tool_content_items(raw_contents) + parts = [item.build_history_text().strip() for item in content_items if item.build_history_text().strip()] + + structured_content = getattr(content_block, "structuredContent", None) + if structured_content is not None: + try: + parts.append(json.dumps(structured_content, ensure_ascii=False)) + except (TypeError, ValueError): + parts.append(str(structured_content)) + + summary_text = "\n".join(part for part in parts if part).strip() + if bool(getattr(content_block, "isError", False)) and summary_text: + return f"工具执行失败:\n{summary_text}" + if bool(getattr(content_block, "isError", False)): + return "工具执行失败。" + return summary_text + + @staticmethod + def _normalize_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]: + """将原始工具调用参数规范化为字典。 + + Args: + raw_arguments: 原始工具参数。 + + Returns: + dict[str, Any]: 规范化后的参数字典。 + """ + + if isinstance(raw_arguments, dict): + return dict(raw_arguments) + if raw_arguments is None: + return {} + return {"value": raw_arguments} + + def _build_tool_definitions( + self, + raw_tools: Any, + tool_choice_mode: str, + ) -> list[ToolDefinitionInput] | None: + """将 MCP Sampling 工具定义转换为主程序内部工具定义。 + + Args: + raw_tools: MCP Sampling 请求中的工具列表。 + tool_choice_mode: 当前工具选择模式。 + + Returns: + list[ToolDefinitionInput] | None: 可传给主程序模型层的工具定义列表。 + """ + + if tool_choice_mode == "none": + return None + if not isinstance(raw_tools, list) or not raw_tools: + return None + + tool_definitions: list[ToolDefinitionInput] = [] + for raw_tool in raw_tools: + tool_name = str(getattr(raw_tool, "name", "") or "").strip() + if not tool_name: + continue + + parameters_schema = ( + dict(getattr(raw_tool, "inputSchema", {}) or {}) if getattr(raw_tool, "inputSchema", None) else {} + ) + if "$schema" in parameters_schema: + parameters_schema.pop("$schema") + + title = str(getattr(raw_tool, "title", "") or "").strip() + description = str(getattr(raw_tool, "description", "") or "").strip() + brief_description = description or title or f"工具 {tool_name}" + detailed_description = build_tool_detailed_description( + parameters_schema, + fallback_description=f"工具名称:{tool_name}", + ) + + tool_definitions.append( + { + "name": tool_name, + "description": "\n\n".join( + part for part in [brief_description, detailed_description] if part.strip() + ).strip(), + "parameters_schema": parameters_schema or {"type": "object", "properties": {}}, + } + ) + + return tool_definitions or None + + def _build_sampling_result( + self, + generation_result: LLMResponseResult, + tools_enabled: bool, + ) -> Any: + """将主程序模型响应转换为 MCP Sampling 结果。 + + Args: + generation_result: 主程序统一大模型响应结果。 + tools_enabled: 当前是否允许模型使用工具。 + + Returns: + Any: MCP `CreateMessageResult` 或 `CreateMessageResultWithTools`。 + """ + + if not MCP_TYPES_AVAILABLE or mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + text_content = str(generation_result.response or "") + tool_calls = list(generation_result.tool_calls or []) + model_name = generation_result.model_name or self._sampling_task_name + + if tools_enabled: + content_blocks: list[Any] = [] + if text_content.strip(): + content_blocks.append( + mcp_types.TextContent( + type="text", + text=text_content, + ) + ) + for tool_call in tool_calls: + content_blocks.append( + mcp_types.ToolUseContent( + type="tool_use", + name=tool_call.func_name, + id=tool_call.call_id, + input=dict(tool_call.args or {}), + ) + ) + + if not content_blocks: + content_blocks.append( + mcp_types.TextContent( + type="text", + text="", + ) + ) + + return mcp_types.CreateMessageResultWithTools( + role="assistant", + content=content_blocks[0] if len(content_blocks) == 1 else content_blocks, + model=model_name, + stopReason="toolUse" if tool_calls else "endTurn", + ) + + return mcp_types.CreateMessageResult( + role="assistant", + content=mcp_types.TextContent( + type="text", + text=text_content, + ), + model=model_name, + stopReason="endTurn", + ) diff --git a/src/mcp_module/manager.py b/src/mcp_module/manager.py new file mode 100644 index 00000000..53c4dbc4 --- /dev/null +++ b/src/mcp_module/manager.py @@ -0,0 +1,591 @@ +""" +MaiSaka - MCP 管理器 +管理所有 MCP 服务器连接,提供统一的工具、Prompt 与 Resource 访问入口。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from src.cli.console import console +from src.core.tooling import ( + ToolExecutionResult, + ToolInvocation, + ToolSpec, + build_tool_detailed_description, +) + +from .config import ( + MCPClientRuntimeConfig, + MCPServerRuntimeConfig, + build_mcp_client_runtime_config, + build_mcp_server_runtime_configs, +) +from .connection import MCPConnection, MCP_AVAILABLE +from .hooks import MCPHostCallbacks +from .models import ( + MCPPromptResult, + MCPPromptSpec, + MCPResourceReadResult, + MCPResourceSpec, + MCPResourceTemplateSpec, + build_prompt_spec, + build_resource_spec, + build_resource_template_spec, + build_tool_annotation, + build_tool_icon, +) + +if TYPE_CHECKING: + from src.config.official_configs import MCPConfig + +# 内置工具名称集合 —— MCP 工具不允许与这些名称冲突 +BUILTIN_TOOL_NAMES = frozenset( + { + "reply", + "no_reply", + "wait", + "stop", + "create_table", + "list_tables", + "view_table", + } +) + + +class MCPManager: + """MCP 服务器连接管理器。""" + + def __init__( + self, + client_config: MCPClientRuntimeConfig, + host_callbacks: Optional[MCPHostCallbacks] = None, + ) -> None: + """初始化 MCP 管理器。 + + Args: + client_config: MCP 客户端宿主能力运行时配置。 + host_callbacks: 宿主侧能力回调集合。 + """ + + self._client_config = client_config + self._host_callbacks = host_callbacks or MCPHostCallbacks() + self._connections: dict[str, MCPConnection] = {} + self._tool_to_server: dict[str, str] = {} + self._prompt_to_server: dict[str, str] = {} + self._resource_to_server: dict[str, str] = {} + self._resource_template_to_server: dict[str, str] = {} + + @classmethod + async def from_app_config( + cls, + mcp_config: "MCPConfig", + host_callbacks: Optional[MCPHostCallbacks] = None, + ) -> Optional["MCPManager"]: + """从官方配置创建并初始化 `MCPManager`。 + + Args: + mcp_config: 主程序中的 MCP 配置对象。 + host_callbacks: 宿主侧能力回调集合。 + + Returns: + Optional[MCPManager]: 初始化完成的管理器;无可用配置或全部连接失败时返回 ``None``。 + """ + + configs = build_mcp_server_runtime_configs(mcp_config) + if not configs: + return None + + if not MCP_AVAILABLE: + console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,请运行: pip install mcp[/warning]") + return None + + manager = cls( + client_config=build_mcp_client_runtime_config(mcp_config), + host_callbacks=host_callbacks, + ) + await manager._connect_all(configs) + + if not manager._connections: + console.print("[warning]⚠️ 所有 MCP 服务器连接失败[/warning]") + return None + + return manager + + async def _connect_all(self, configs: list[MCPServerRuntimeConfig]) -> None: + """连接全部已配置的 MCP 服务器。 + + Args: + configs: 服务器运行时配置列表。 + + Returns: + None + """ + + for config in configs: + connection = MCPConnection(config, self._client_config, self._host_callbacks) + success = await connection.connect() + if not success: + continue + + self._connections[config.name] = connection + registered_tool_count = self._register_tools(config.name, connection) + registered_prompt_count = self._register_prompts(config.name, connection) + registered_resource_count = self._register_resources(config.name, connection) + registered_template_count = self._register_resource_templates(config.name, connection) + console.print( + "[success]✓ MCP 服务器 " + f"'{config.name}' 已连接[/success] " + f"[muted](工具 {registered_tool_count} / Prompt {registered_prompt_count} / " + f"资源 {registered_resource_count} / 模板 {registered_template_count})[/muted]" + ) + + def _register_tools(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP 工具。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的工具数量。 + """ + + registered_count = 0 + for tool in connection.tools: + tool_name = str(tool.name) + + if tool_name in BUILTIN_TOOL_NAMES: + console.print( + f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与内置工具冲突,已跳过[/warning]" + ) + continue + + if tool_name in self._tool_to_server: + existing_server = self._tool_to_server[tool_name] + console.print( + f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + + self._tool_to_server[tool_name] = server_name + registered_count += 1 + return registered_count + + def _register_prompts(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Prompt。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的 Prompt 数量。 + """ + + registered_count = 0 + for prompt in connection.prompts: + prompt_name = str(prompt.name) + if prompt_name in self._prompt_to_server: + existing_server = self._prompt_to_server[prompt_name] + console.print( + f"[warning]⚠️ MCP Prompt '{prompt_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._prompt_to_server[prompt_name] = server_name + registered_count += 1 + return registered_count + + def _register_resources(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Resource。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的 Resource 数量。 + """ + + registered_count = 0 + for resource in connection.resources: + resource_uri = str(resource.uri) + if resource_uri in self._resource_to_server: + existing_server = self._resource_to_server[resource_uri] + console.print( + f"[warning]⚠️ MCP Resource '{resource_uri}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._resource_to_server[resource_uri] = server_name + registered_count += 1 + return registered_count + + def _register_resource_templates(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Resource Template。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的模板数量。 + """ + + registered_count = 0 + for resource_template in connection.resource_templates: + uri_template = str(resource_template.uriTemplate) + if uri_template in self._resource_template_to_server: + existing_server = self._resource_template_to_server[uri_template] + console.print( + "[warning]⚠️ MCP Resource Template " + f"'{uri_template}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._resource_template_to_server[uri_template] = server_name + registered_count += 1 + return registered_count + + def _build_tool_parameters_schema(self, tool: Any) -> dict[str, Any] | None: + """构造单个 MCP 工具的参数 Schema。 + + Args: + tool: MCP SDK 返回的原始工具对象。 + + Returns: + dict[str, Any] | None: 参数 Schema。 + """ + + parameters_schema = ( + dict(tool.inputSchema) + if hasattr(tool, "inputSchema") and tool.inputSchema + else {"type": "object", "properties": {}} + ) + parameters_schema.pop("$schema", None) + return parameters_schema + + def _build_tool_output_schema(self, tool: Any) -> dict[str, Any] | None: + """构造单个 MCP 工具的输出 Schema。 + + Args: + tool: MCP SDK 返回的原始工具对象。 + + Returns: + dict[str, Any] | None: 输出 Schema。 + """ + + output_schema = dict(tool.outputSchema) if hasattr(tool, "outputSchema") and tool.outputSchema else None + if isinstance(output_schema, dict): + output_schema.pop("$schema", None) + return output_schema + + def get_tool_specs(self) -> list[ToolSpec]: + """获取全部已注册 MCP 工具的统一声明。 + + Returns: + list[ToolSpec]: MCP 工具声明列表。 + """ + + tool_specs: list[ToolSpec] = [] + for server_name, connection in self._connections.items(): + for tool in connection.tools: + if self._tool_to_server.get(tool.name) != server_name: + continue + + parameters_schema = self._build_tool_parameters_schema(tool) + output_schema = self._build_tool_output_schema(tool) + brief_description = str(tool.description or f"来自 {server_name} 的 MCP 工具").strip() + tool_specs.append( + ToolSpec( + name=str(tool.name), + title=str(getattr(tool, "title", "") or ""), + brief_description=brief_description, + detailed_description=build_tool_detailed_description( + parameters_schema, + fallback_description=f"工具来源:MCP 服务 {server_name}。", + ), + parameters_schema=parameters_schema, + output_schema=output_schema, + provider_name="mcp", + provider_type="mcp", + icons=[build_tool_icon(item) for item in getattr(tool, "icons", []) or []], + annotation=build_tool_annotation(getattr(tool, "annotations", None)), + metadata={"server_name": server_name} | getattr(tool, "meta", {}), + ) + ) + return tool_specs + + def get_prompt_specs(self) -> list[MCPPromptSpec]: + """获取全部已注册 MCP Prompt 声明。 + + Returns: + list[MCPPromptSpec]: Prompt 声明列表。 + """ + + prompt_specs: list[MCPPromptSpec] = [] + for server_name, connection in self._connections.items(): + for prompt in connection.prompts: + if self._prompt_to_server.get(prompt.name) != server_name: + continue + prompt_specs.append(build_prompt_spec(prompt, server_name)) + return prompt_specs + + def get_resource_specs(self) -> list[MCPResourceSpec]: + """获取全部已注册 MCP Resource 声明。 + + Returns: + list[MCPResourceSpec]: Resource 声明列表。 + """ + + resource_specs: list[MCPResourceSpec] = [] + for server_name, connection in self._connections.items(): + for resource in connection.resources: + if self._resource_to_server.get(resource.uri) != server_name: + continue + resource_specs.append(build_resource_spec(resource, server_name)) + return resource_specs + + def get_resource_template_specs(self) -> list[MCPResourceTemplateSpec]: + """获取全部已注册 MCP Resource Template 声明。 + + Returns: + list[MCPResourceTemplateSpec]: Resource Template 声明列表。 + """ + + resource_template_specs: list[MCPResourceTemplateSpec] = [] + for server_name, connection in self._connections.items(): + for resource_template in connection.resource_templates: + if self._resource_template_to_server.get(resource_template.uriTemplate) != server_name: + continue + resource_template_specs.append(build_resource_template_spec(resource_template, server_name)) + return resource_template_specs + + def get_openai_tools(self) -> list[dict[str, Any]]: + """获取兼容旧模型层的 MCP 工具定义。 + + Returns: + list[dict[str, Any]]: OpenAI function tool 格式列表。 + """ + + return [ + { + "type": "function", + "function": { + "name": tool_spec.name, + "description": tool_spec.build_llm_description(), + "parameters": tool_spec.parameters_schema or {"type": "object", "properties": {}}, + }, + } + for tool_spec in self.get_tool_specs() + ] + + def is_mcp_tool(self, tool_name: str) -> bool: + """判断给定名称是否为已注册 MCP 工具。 + + Args: + tool_name: 工具名称。 + + Returns: + bool: 是否存在。 + """ + + return tool_name in self._tool_to_server + + def is_mcp_prompt(self, prompt_name: str) -> bool: + """判断给定名称是否为已注册 MCP Prompt。 + + Args: + prompt_name: Prompt 名称。 + + Returns: + bool: 是否存在。 + """ + + return prompt_name in self._prompt_to_server + + def is_mcp_resource(self, uri: str) -> bool: + """判断给定 URI 是否为已注册 MCP Resource。 + + Args: + uri: 资源 URI。 + + Returns: + bool: 是否存在。 + """ + + return uri in self._resource_to_server + + async def call_tool_invocation(self, invocation: ToolInvocation) -> ToolExecutionResult: + """执行统一的 MCP 工具调用。 + + Args: + invocation: 统一工具调用请求。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + + tool_name = invocation.tool_name + server_name = self._tool_to_server.get(tool_name) + if not server_name or server_name not in self._connections: + return ToolExecutionResult( + tool_name=tool_name, + success=False, + error_message=f"MCP 工具 '{tool_name}' 未找到", + ) + + connection = self._connections[server_name] + return await connection.call_tool(tool_name, invocation.arguments) + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> str: + """兼容旧接口,返回 MCP 工具的文本结果。 + + Args: + tool_name: 工具名称。 + arguments: 工具参数。 + + Returns: + str: 工具结果文本。 + """ + + result = await self.call_tool_invocation( + ToolInvocation( + tool_name=tool_name, + arguments=arguments, + ) + ) + return result.get_history_content() + + async def get_prompt( + self, + prompt_name: str, + arguments: Optional[dict[str, str]] = None, + ) -> MCPPromptResult: + """读取指定 Prompt 的内容。 + + Args: + prompt_name: Prompt 名称。 + arguments: Prompt 参数字典。 + + Returns: + MCPPromptResult: Prompt 获取结果。 + """ + + server_name = self._prompt_to_server.get(prompt_name) + if not server_name or server_name not in self._connections: + raise KeyError(f"MCP Prompt '{prompt_name}' 未找到") + + connection = self._connections[server_name] + return await connection.get_prompt(prompt_name, arguments=arguments) + + async def read_resource(self, uri: str) -> MCPResourceReadResult: + """读取指定 Resource 的内容。 + + Args: + uri: 资源 URI。 + + Returns: + MCPResourceReadResult: 资源读取结果。 + """ + + server_name = self._resource_to_server.get(uri) + if not server_name or server_name not in self._connections: + raise KeyError(f"MCP Resource '{uri}' 未找到") + + connection = self._connections[server_name] + return await connection.read_resource(uri) + + def get_tool_summary(self) -> str: + """获取所有已注册 MCP 工具的摘要信息。 + + Returns: + str: 工具摘要文本。 + """ + + parts: list[str] = [] + for server_name, connection in self._connections.items(): + tool_names = [ + str(tool.name) + for tool in connection.tools + if self._tool_to_server.get(tool.name) == server_name + ] + if tool_names: + parts.append(f" • {server_name}: {', '.join(tool_names)}") + return "\n".join(parts) + + def get_feature_summary(self) -> str: + """获取所有服务器能力的总体摘要。 + + Returns: + str: 多行摘要文本。 + """ + + parts: list[str] = [] + for server_name, connection in self._connections.items(): + tool_count = sum(1 for tool in connection.tools if self._tool_to_server.get(tool.name) == server_name) + prompt_count = sum( + 1 for prompt in connection.prompts if self._prompt_to_server.get(prompt.name) == server_name + ) + resource_count = sum( + 1 for resource in connection.resources if self._resource_to_server.get(resource.uri) == server_name + ) + template_count = sum( + 1 + for resource_template in connection.resource_templates + if self._resource_template_to_server.get(resource_template.uriTemplate) == server_name + ) + parts.append( + f" • {server_name}: 工具 {tool_count} / Prompt {prompt_count} / " + f"资源 {resource_count} / 模板 {template_count}" + ) + return "\n".join(parts) + + @property + def server_count(self) -> int: + """返回已连接 MCP 服务器数量。 + + Returns: + int: 服务器数量。 + """ + + return len(self._connections) + + @property + def tool_count(self) -> int: + """返回已注册 MCP 工具总数。 + + Returns: + int: 工具数量。 + """ + + return len(self._tool_to_server) + + @property + def prompt_count(self) -> int: + """返回已注册 MCP Prompt 总数。 + + Returns: + int: Prompt 数量。 + """ + + return len(self._prompt_to_server) + + @property + def resource_count(self) -> int: + """返回已注册 MCP Resource 总数。 + + Returns: + int: Resource 数量。 + """ + + return len(self._resource_to_server) + + async def close(self) -> None: + """关闭所有 MCP 服务器连接。""" + + for connection in self._connections.values(): + await connection.close() + self._connections.clear() + self._tool_to_server.clear() + self._prompt_to_server.clear() + self._resource_to_server.clear() + self._resource_template_to_server.clear() diff --git a/src/mcp_module/models.py b/src/mcp_module/models.py new file mode 100644 index 00000000..5550b8df --- /dev/null +++ b/src/mcp_module/models.py @@ -0,0 +1,418 @@ +"""MCP 结构化模型与转换工具。 + +负责在 MCP SDK 原始对象与主程序内部数据模型之间进行转换, +避免连接层和管理器层直接操作大量弱类型字段。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from src.core.tooling import ToolAnnotation, ToolContentItem, ToolIcon + + +def _dump_model_metadata(raw_value: Any) -> dict[str, Any]: + """提取任意 MCP 模型对象中的元数据字典。 + + Args: + raw_value: MCP SDK 返回的原始对象。 + + Returns: + dict[str, Any]: 归一化后的元数据字典。 + """ + + metadata = getattr(raw_value, "meta", None) + if isinstance(metadata, dict): + return dict(metadata) + return {} + + +def build_tool_icon(raw_icon: Any) -> ToolIcon: + """将 MCP 图标对象转换为统一图标模型。 + + Args: + raw_icon: MCP SDK 返回的图标对象。 + + Returns: + ToolIcon: 统一图标模型。 + """ + + sizes_value = getattr(raw_icon, "sizes", None) + sizes = [str(item) for item in sizes_value] if isinstance(sizes_value, list) else [] + return ToolIcon( + src=str(getattr(raw_icon, "src", "") or ""), + mime_type=str(getattr(raw_icon, "mimeType", "") or ""), + sizes=sizes, + ) + + +def build_tool_annotation(raw_annotation: Any) -> Optional[ToolAnnotation]: + """将 MCP 注解对象转换为统一注解模型。 + + Args: + raw_annotation: MCP SDK 返回的注解对象。 + + Returns: + Optional[ToolAnnotation]: 统一注解模型;无有效内容时返回 ``None``。 + """ + + if raw_annotation is None: + return None + + audience_value = getattr(raw_annotation, "audience", None) + audience = [str(item) for item in audience_value] if isinstance(audience_value, list) else [] + priority_value = getattr(raw_annotation, "priority", None) + priority = float(priority_value) if isinstance(priority_value, int | float) else None + metadata = _dump_model_metadata(raw_annotation) + + if not audience and priority is None and not metadata: + return None + + return ToolAnnotation( + audience=audience, + priority=priority, + metadata=metadata, + ) + + +def build_tool_content_item(raw_content: Any) -> ToolContentItem: + """将 MCP 内容块转换为统一工具内容项。 + + Args: + raw_content: MCP SDK 返回的内容块对象。 + + Returns: + ToolContentItem: 统一工具内容项。 + """ + + content_type = str(getattr(raw_content, "type", "") or "").strip().lower() + annotation = build_tool_annotation(getattr(raw_content, "annotations", None)) + metadata = _dump_model_metadata(raw_content) + + if content_type == "text" or hasattr(raw_content, "text"): + return ToolContentItem( + content_type="text", + text=str(getattr(raw_content, "text", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "image": + return ToolContentItem( + content_type="image", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "audio": + return ToolContentItem( + content_type="audio", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "resource_link": + return ToolContentItem( + content_type="resource_link", + uri=str(getattr(raw_content, "uri", "") or ""), + name=str(getattr(raw_content, "name", "") or ""), + description=str(getattr(raw_content, "description", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "resource" or hasattr(raw_content, "resource"): + resource = getattr(raw_content, "resource", None) + resource_metadata = metadata | _dump_model_metadata(resource) + return ToolContentItem( + content_type="resource", + text=str(getattr(resource, "text", "") or ""), + data=str(getattr(resource, "blob", "") or ""), + mime_type=str(getattr(resource, "mimeType", "") or ""), + uri=str(getattr(resource, "uri", "") or ""), + name=str(getattr(resource, "name", "") or ""), + annotation=annotation, + metadata=resource_metadata, + ) + + if hasattr(raw_content, "data"): + return ToolContentItem( + content_type="binary", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + return ToolContentItem( + content_type="unknown", + text=str(raw_content), + annotation=annotation, + metadata=metadata, + ) + + +def build_tool_content_items(raw_contents: list[Any] | None) -> list[ToolContentItem]: + """批量转换 MCP 内容块列表。 + + Args: + raw_contents: MCP SDK 返回的内容块列表。 + + Returns: + list[ToolContentItem]: 转换后的统一内容项列表。 + """ + + if not raw_contents: + return [] + return [build_tool_content_item(item) for item in raw_contents] + + +@dataclass(slots=True) +class MCPPromptArgumentSpec: + """MCP Prompt 参数声明。""" + + name: str + description: str = "" + required: bool = False + + +@dataclass(slots=True) +class MCPPromptSpec: + """MCP Prompt 声明。""" + + name: str + server_name: str + title: str = "" + description: str = "" + arguments: list[MCPPromptArgumentSpec] = field(default_factory=list) + icons: list[ToolIcon] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPPromptMessage: + """MCP Prompt 消息。""" + + role: str + content: ToolContentItem + + +@dataclass(slots=True) +class MCPPromptResult: + """MCP Prompt 获取结果。""" + + prompt_name: str + server_name: str + description: str = "" + messages: list[MCPPromptMessage] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceSpec: + """MCP Resource 声明。""" + + uri: str + server_name: str + name: str + title: str = "" + description: str = "" + mime_type: str = "" + size: int | None = None + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceTemplateSpec: + """MCP Resource Template 声明。""" + + uri_template: str + server_name: str + name: str + title: str = "" + description: str = "" + mime_type: str = "" + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceReadResult: + """MCP Resource 读取结果。""" + + uri: str + server_name: str + contents: list[ToolContentItem] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +def build_prompt_argument_spec(raw_argument: Any) -> MCPPromptArgumentSpec: + """将 MCP Prompt 参数对象转换为统一结构。 + + Args: + raw_argument: MCP SDK 返回的 Prompt 参数对象。 + + Returns: + MCPPromptArgumentSpec: 统一 Prompt 参数结构。 + """ + + return MCPPromptArgumentSpec( + name=str(getattr(raw_argument, "name", "") or ""), + description=str(getattr(raw_argument, "description", "") or ""), + required=bool(getattr(raw_argument, "required", False)), + ) + + +def build_prompt_spec(raw_prompt: Any, server_name: str) -> MCPPromptSpec: + """将 MCP Prompt 定义转换为统一结构。 + + Args: + raw_prompt: MCP SDK 返回的 Prompt 对象。 + server_name: Prompt 所属的服务器名称。 + + Returns: + MCPPromptSpec: 统一 Prompt 定义。 + """ + + raw_arguments = getattr(raw_prompt, "arguments", None) + raw_icons = getattr(raw_prompt, "icons", None) + return MCPPromptSpec( + name=str(getattr(raw_prompt, "name", "") or ""), + server_name=server_name, + title=str(getattr(raw_prompt, "title", "") or ""), + description=str(getattr(raw_prompt, "description", "") or ""), + arguments=[build_prompt_argument_spec(item) for item in raw_arguments] if isinstance(raw_arguments, list) else [], + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + metadata=_dump_model_metadata(raw_prompt), + ) + + +def build_prompt_result(raw_result: Any, prompt_name: str, server_name: str) -> MCPPromptResult: + """将 MCP Prompt 获取结果转换为统一结构。 + + Args: + raw_result: MCP SDK 返回的 Prompt 结果对象。 + prompt_name: Prompt 名称。 + server_name: Prompt 所属服务器名称。 + + Returns: + MCPPromptResult: 统一 Prompt 获取结果。 + """ + + messages: list[MCPPromptMessage] = [] + raw_messages = getattr(raw_result, "messages", None) + if isinstance(raw_messages, list): + for raw_message in raw_messages: + messages.append( + MCPPromptMessage( + role=str(getattr(raw_message, "role", "") or ""), + content=build_tool_content_item(getattr(raw_message, "content", None)), + ) + ) + + return MCPPromptResult( + prompt_name=prompt_name, + server_name=server_name, + description=str(getattr(raw_result, "description", "") or ""), + messages=messages, + metadata=_dump_model_metadata(raw_result), + ) + + +def build_resource_spec(raw_resource: Any, server_name: str) -> MCPResourceSpec: + """将 MCP Resource 定义转换为统一结构。 + + Args: + raw_resource: MCP SDK 返回的 Resource 对象。 + server_name: Resource 所属服务器名称。 + + Returns: + MCPResourceSpec: 统一 Resource 定义。 + """ + + raw_icons = getattr(raw_resource, "icons", None) + size_value = getattr(raw_resource, "size", None) + size = int(size_value) if isinstance(size_value, int | float) else None + return MCPResourceSpec( + uri=str(getattr(raw_resource, "uri", "") or ""), + server_name=server_name, + name=str(getattr(raw_resource, "name", "") or ""), + title=str(getattr(raw_resource, "title", "") or ""), + description=str(getattr(raw_resource, "description", "") or ""), + mime_type=str(getattr(raw_resource, "mimeType", "") or ""), + size=size, + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + annotation=build_tool_annotation(getattr(raw_resource, "annotations", None)), + metadata=_dump_model_metadata(raw_resource), + ) + + +def build_resource_template_spec(raw_template: Any, server_name: str) -> MCPResourceTemplateSpec: + """将 MCP Resource Template 定义转换为统一结构。 + + Args: + raw_template: MCP SDK 返回的 ResourceTemplate 对象。 + server_name: 模板所属服务器名称。 + + Returns: + MCPResourceTemplateSpec: 统一模板定义。 + """ + + raw_icons = getattr(raw_template, "icons", None) + return MCPResourceTemplateSpec( + uri_template=str(getattr(raw_template, "uriTemplate", "") or ""), + server_name=server_name, + name=str(getattr(raw_template, "name", "") or ""), + title=str(getattr(raw_template, "title", "") or ""), + description=str(getattr(raw_template, "description", "") or ""), + mime_type=str(getattr(raw_template, "mimeType", "") or ""), + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + annotation=build_tool_annotation(getattr(raw_template, "annotations", None)), + metadata=_dump_model_metadata(raw_template), + ) + + +def build_resource_read_result(raw_result: Any, uri: str, server_name: str) -> MCPResourceReadResult: + """将 MCP Resource 读取结果转换为统一结构。 + + Args: + raw_result: MCP SDK 返回的读取结果对象。 + uri: 被读取的资源 URI。 + server_name: 资源所属服务器名称。 + + Returns: + MCPResourceReadResult: 统一资源读取结果。 + """ + + contents: list[ToolContentItem] = [] + raw_contents = getattr(raw_result, "contents", None) + if isinstance(raw_contents, list): + for raw_content in raw_contents: + metadata = _dump_model_metadata(raw_content) + contents.append( + ToolContentItem( + content_type="resource", + text=str(getattr(raw_content, "text", "") or ""), + data=str(getattr(raw_content, "blob", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + uri=str(getattr(raw_content, "uri", "") or uri), + annotation=None, + metadata=metadata, + ) + ) + + return MCPResourceReadResult( + uri=uri, + server_name=server_name, + contents=contents, + metadata=_dump_model_metadata(raw_result), + ) diff --git a/src/mcp_module/provider.py b/src/mcp_module/provider.py new file mode 100644 index 00000000..84065eb8 --- /dev/null +++ b/src/mcp_module/provider.py @@ -0,0 +1,54 @@ +"""MCP 工具 Provider。""" + +from __future__ import annotations + +from typing import Optional + +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec + +from .manager import MCPManager + + +class MCPToolProvider(ToolProvider): + """基于 MCPManager 的工具 Provider。""" + + provider_name = "mcp" + provider_type = "mcp" + + def __init__(self, manager: MCPManager) -> None: + """初始化 MCP 工具 Provider。 + + Args: + manager: MCP 管理器实例。 + """ + + self._manager = manager + + async def list_tools(self) -> list[ToolSpec]: + """列出全部 MCP 工具。""" + + return self._manager.get_tool_specs() + + async def invoke( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行指定 MCP 工具。 + + Args: + invocation: 工具调用请求。 + context: 执行上下文。 + + Returns: + ToolExecutionResult: 工具执行结果。 + """ + + del context + return await self._manager.call_tool_invocation(invocation) + + async def close(self) -> None: + """关闭 Provider 并释放 MCP 连接。""" + + await self._manager.close() + diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index cedf971f..94f4390f 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -16,8 +16,9 @@ from json_repair import repair_json from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger -from src.config.config import model_config, global_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient from src.services import message_service as message_api from src.chat.utils.utils import is_bot_self from src.person_info.person_info import Person @@ -88,8 +89,8 @@ class ChatHistorySummarizer: # 注意:批次加载需要异步查询消息,所以在 start() 中调用 # LLM请求器,用于压缩聊天内容 - self.summarizer_llm = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer" + self.summarizer_llm = LLMServiceClient( + task_name="utils", request_type="chat_history_summarizer" ) # 后台循环相关 @@ -656,10 +657,11 @@ class ChatHistorySummarizer: prompt = await prompt_manager.render_prompt(prompt_template) try: - response, _ = await self.summarizer_llm.generate_response_async( + generation_result = await self.summarizer_llm.generate_response( prompt=prompt, - temperature=0.3, + options=LLMGenerationOptions(temperature=0.3), ) + response = generation_result.response logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") @@ -812,7 +814,8 @@ class ChatHistorySummarizer: prompt = await prompt_manager.render_prompt(prompt_template) try: - response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt) + generation_result = await self.summarizer_llm.generate_response(prompt=prompt) + response = generation_result.response # 解析JSON响应 json_str = response.strip() diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 2eadd05a..2554ebd1 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -5,7 +5,7 @@ import asyncio from datetime import datetime from typing import List, Dict, Any, Optional, Tuple, Callable from src.common.logger import get_logger -from src.config.config import global_config, model_config +from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.services import llm_service as llm_api from sqlmodel import select, col @@ -269,18 +269,18 @@ async def _react_agent_solve_question( return messages message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues] - ( - success, - response, - reasoning_content, - model_name, - tool_calls, - ) = await llm_api.generate_with_model_with_tools_by_message_factory( - message_factory_fn, # type: ignore[arg-type] - model_config=model_config.model_task_config.tool_use, - tool_options=tool_definitions, - request_type="memory.react", + generation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="utils", + request_type="memory.react", + message_factory=message_factory_fn, # type: ignore[arg-type] + tool_options=tool_definitions, + ) ) + success = generation_result.success + response = generation_result.completion.response + reasoning_content = generation_result.completion.reasoning + tool_calls = generation_result.completion.tool_calls # logger.info( # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" @@ -679,18 +679,16 @@ async def _react_agent_solve_question( evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template) - ( - eval_success, - eval_response, - eval_reasoning_content, - eval_model_name, - eval_tool_calls, - ) = await llm_api.generate_with_model_with_tools( - evaluation_prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[], # 最终评估阶段不提供工具 - request_type="memory.react.final", + evaluation_result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name="utils", + request_type="memory.react.final", + prompt=evaluation_prompt, + tool_options=[], + ) ) + eval_success = evaluation_result.success + eval_response = evaluation_result.completion.response if not eval_success: logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}") diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py index 1e1fa62b..f2dd1f0d 100644 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -1,11 +1,12 @@ -""" -工具注册系统 -提供统一的工具注册和管理接口 +"""工具注册系统。 + +提供统一的工具注册和管理接口。 """ -from typing import List, Dict, Any, Optional, Callable, Awaitable +from typing import Any, Awaitable, Callable, Dict, List, Optional + from src.common.logger import get_logger -from src.llm_models.payload_content.tool_option import ToolParamType +from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option logger = get_logger("memory_retrieval_tools") @@ -14,16 +15,19 @@ class MemoryRetrievalTool: """记忆检索工具基类""" def __init__( - self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] - ): - """ - 初始化工具 + self, + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]], + ) -> None: + """初始化工具。 Args: - name: 工具名称 - description: 工具描述 - parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}] - execute_func: 执行函数,必须是异步函数 + name: 工具名称。 + description: 工具描述。 + parameters: 参数定义列表。 + execute_func: 执行函数,必须是异步函数。 """ self.name = name self.description = description @@ -44,20 +48,17 @@ class MemoryRetrievalTool: params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" - async def execute(self, **kwargs) -> str: - """执行工具""" + async def execute(self, **kwargs: Any) -> str: + """执行工具。""" return await self.execute_func(**kwargs) def get_tool_definition(self) -> Dict[str, Any]: - """获取工具定义,用于LLM function calling + """获取规范化的工具定义。 Returns: - Dict[str, Any]: 工具定义字典,格式与BaseTool一致 - 格式: {"name": str, "description": str, "parameters": List[Tuple]} + Dict[str, Any]: 统一工具定义字典。 """ - # 转换参数格式为元组列表,格式与BaseTool一致 - # 格式: [("param_name", ToolParamType, "description", required, enum_values)] - param_tuples = [] + legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] for param in self.parameters: param_name = param.get("name", "") @@ -77,20 +78,27 @@ class MemoryRetrievalTool: } param_type = type_mapping.get(param_type_str, ToolParamType.STRING) - # 构建参数元组 - param_tuple = (param_name, param_type, param_desc, is_required, enum_values) - param_tuples.append(param_tuple) + legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values)) - # 构建工具定义,格式与BaseTool.get_tool_definition()一致 - tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples} - - return tool_def + normalized_option = normalize_tool_option( + { + "name": self.name, + "description": self.description, + "parameters": legacy_parameters, + } + ) + return { + "name": normalized_option.name, + "description": normalized_option.description, + "parameters_schema": normalized_option.parameters_schema, + } class MemoryRetrievalToolRegistry: """工具注册器""" - def __init__(self): + def __init__(self) -> None: + """初始化工具注册器。""" self.tools: Dict[str, MemoryRetrievalTool] = {} def register_tool(self, tool: MemoryRetrievalTool) -> None: @@ -137,15 +145,18 @@ _tool_registry = MemoryRetrievalToolRegistry() def register_memory_retrieval_tool( - name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]] + name: str, + description: str, + parameters: List[Dict[str, Any]], + execute_func: Callable[..., Awaitable[str]], ) -> None: - """注册记忆检索工具的便捷函数 + """注册记忆检索工具的便捷函数。 Args: - name: 工具名称 - description: 工具描述 - parameters: 参数定义列表 - execute_func: 执行函数 + name: 工具名称。 + description: 工具描述。 + parameters: 参数定义列表。 + execute_func: 执行函数。 """ tool = MemoryRetrievalTool(name, description, parameters, execute_func) _tool_registry.register_tool(tool) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 15ef0049..c603f4b7 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -17,14 +17,14 @@ from src.common.data_models.person_info_data_model import dump_group_cardname_re from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.services.llm_service import LLMServiceClient logger = get_logger("person_info") -relation_selection_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="relation_selection" +relation_selection_model = LLMServiceClient( + task_name="utils", request_type="relation_selection" ) @@ -578,7 +578,8 @@ class Person: <分类1><分类2><分类3>...... 如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) + generation_result = await relation_selection_model.generate_response(prompt) + response = generation_result.response # print(prompt) # print(response) category_list = extract_categories_from_response(response) @@ -600,7 +601,8 @@ class Person: 例如: <分类1><分类2><分类3>...... 如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) + generation_result = await relation_selection_model.generate_response(prompt) + response = generation_result.response # print(prompt) # print(response) category_list = extract_categories_from_response(response) @@ -634,7 +636,9 @@ class Person: class PersonInfoManager: def __init__(self): self.person_name_list = {} - self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") + self.qv_name_llm = LLMServiceClient( + task_name="utils", request_type="relation.qv_name" + ) try: with get_db_session() as _: pass @@ -737,7 +741,8 @@ class PersonInfoManager: "nickname": "昵称", "reason": "理由" }""" - response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt) + generation_result = await self.qv_name_llm.generate_response(qv_name_prompt) + response = generation_result.response # logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response) diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 33b54c64..1e1827cf 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -29,6 +29,19 @@ class _RuntimeComponentManagerProtocol(Protocol): def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ... + def _collect_api_reference_matches( + self, + caller_plugin_id: str, + normalized_api_name: str, + normalized_version: str, + ) -> tuple[List[tuple["PluginSupervisor", "APIEntry"]], List[tuple["PluginSupervisor", "APIEntry"]], bool]: ... + + def _collect_api_toggle_reference_matches( + self, + normalized_name: str, + normalized_version: str, + ) -> List[tuple["PluginSupervisor", "APIEntry"]]: ... + def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ... def _resolve_api_target( @@ -58,6 +71,73 @@ class _RuntimeComponentManagerProtocol(Protocol): class RuntimeComponentCapabilityMixin: + def _collect_api_reference_matches( + self: _RuntimeComponentManagerProtocol, + caller_plugin_id: str, + normalized_api_name: str, + normalized_version: str, + ) -> tuple[List[tuple["PluginSupervisor", "APIEntry"]], List[tuple["PluginSupervisor", "APIEntry"]], bool]: + """按 API 完整名或短名精确收集匹配项。 + + 该辅助方法用于兼容名字中本身包含 ``.`` 的 API。对于这类 API, + 不能简单按最后一个点号拆成 ``plugin_id.api_name``。 + + Args: + caller_plugin_id: 调用方插件 ID。 + normalized_api_name: 已规范化的 API 名称。 + normalized_version: 已规范化的版本号。 + + Returns: + tuple[List[tuple[PluginSupervisor, APIEntry]], List[tuple[PluginSupervisor, APIEntry]], bool]: + 依次为可见且启用的匹配项、可见但已禁用的匹配项、是否存在不可见匹配项。 + """ + + visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + hidden_match_exists = False + + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis( + version=normalized_version, + enabled_only=False, + ): + if entry.name != normalized_api_name and entry.full_name != normalized_api_name: + continue + if self._is_api_visible_to_plugin(entry, caller_plugin_id): + if entry.enabled: + visible_enabled_matches.append((supervisor, entry)) + else: + visible_disabled_matches.append((supervisor, entry)) + else: + hidden_match_exists = True + + return visible_enabled_matches, visible_disabled_matches, hidden_match_exists + + def _collect_api_toggle_reference_matches( + self: _RuntimeComponentManagerProtocol, + normalized_name: str, + normalized_version: str, + ) -> List[tuple["PluginSupervisor", "APIEntry"]]: + """按 API 完整名或短名精确收集启停操作匹配项。 + + Args: + normalized_name: 已规范化的 API 名称。 + normalized_version: 已规范化的版本号。 + + Returns: + List[tuple[PluginSupervisor, APIEntry]]: 匹配到的 API 条目列表。 + """ + + matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis( + version=normalized_version, + enabled_only=False, + ): + if entry.name == normalized_name or entry.full_name == normalized_name: + matches.append((supervisor, entry)) + return matches + @staticmethod def _normalize_component_type(component_type: str) -> str: """规范化组件类型名称。 @@ -69,7 +149,10 @@ class RuntimeComponentCapabilityMixin: str: 统一转为大写后的组件类型名。 """ - return str(component_type or "").strip().upper() + normalized_component_type = str(component_type or "").strip().upper() + if normalized_component_type == "ACTION": + return "TOOL" + return normalized_component_type @classmethod def _is_api_component_type(cls, component_type: str) -> bool: @@ -190,6 +273,20 @@ class RuntimeComponentCapabilityMixin: if not normalized_api_name: return None, None, "缺少必要参数 api_name" + exact_visible_enabled_matches, exact_visible_disabled_matches, exact_hidden_match_exists = ( + self._collect_api_reference_matches(caller_plugin_id, normalized_api_name, normalized_version) + ) + if len(exact_visible_enabled_matches) == 1: + return exact_visible_enabled_matches[0][0], exact_visible_enabled_matches[0][1], None + if len(exact_visible_enabled_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请显式指定 version" + if exact_visible_disabled_matches: + if len(exact_visible_disabled_matches) == 1: + return None, None, self._build_api_unavailable_error(exact_visible_disabled_matches[0][1]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version" + if exact_hidden_match_exists: + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + if "." in normalized_api_name: target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1) try: @@ -207,9 +304,7 @@ class RuntimeComponentCapabilityMixin: enabled_only=False, ) visible_enabled_entries = [ - entry - for entry in entries - if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled + entry for entry in entries if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled ] visible_disabled_entries = [ entry @@ -281,6 +376,12 @@ class RuntimeComponentCapabilityMixin: if not normalized_name: return None, None, "缺少必要参数 name" + exact_matches = self._collect_api_toggle_reference_matches(normalized_name, normalized_version) + if len(exact_matches) == 1: + return exact_matches[0][0], exact_matches[0][1], None + if len(exact_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_name},请显式指定 version" + if "." in normalized_name: plugin_id, api_name = normalized_name.rsplit(".", 1) try: diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index 9bb1755b..843b8ce0 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -1,33 +1,80 @@ -from typing import Any, Dict +from typing import Any, Dict, List from src.common.logger import get_logger from src.config.config import global_config -from src.llm_models.payload_content.tool_option import ToolCall logger = get_logger("plugin_runtime.integration") def _get_nested_config_value(source: Any, key: str, default: Any = None) -> Any: + """从嵌套对象或字典中读取配置值。 + + Args: + source: 配置对象或字典。 + key: 以点号分隔的路径。 + default: 未命中时返回的默认值。 + + Returns: + Any: 命中的值;读取失败时返回默认值。 + """ current = source try: for part in key.split("."): if isinstance(current, dict) and part in current: current = current[part] - elif hasattr(current, part): + continue + if hasattr(current, part): current = getattr(current, part) - else: - raise KeyError(part) + continue + raise KeyError(part) return current except Exception: return default +def _normalize_prompt_arg(prompt: Any) -> str | List[Dict[str, Any]]: + """校验并规范化插件传入的提示参数。 + + Args: + prompt: 原始提示参数。 + + Returns: + str | List[Dict[str, Any]]: 规范化后的提示输入。 + + Raises: + ValueError: 提示参数缺失或结构不受支持时抛出。 + """ + if isinstance(prompt, str): + if not prompt.strip(): + raise ValueError("缺少必要参数 prompt") + return prompt + if isinstance(prompt, list) and prompt: + for index, prompt_message in enumerate(prompt, start=1): + if not isinstance(prompt_message, dict): + raise ValueError(f"prompt 第 {index} 项必须为字典") + return prompt + raise ValueError("缺少必要参数 prompt") + + class RuntimeCoreCapabilityMixin: + """插件运行时的核心能力混入。""" + async def _cap_send_text(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送文本消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - text: str = args.get("text", "") - stream_id: str = args.get("stream_id", "") + text = str(args.get("text", "")) + stream_id = str(args.get("stream_id", "")) if not text or not stream_id: return {"success": False, "error": "缺少必要参数 text 或 stream_id"} @@ -35,20 +82,31 @@ class RuntimeCoreCapabilityMixin: result = await send_api.text_to_stream( text=text, stream_id=stream_id, - typing=args.get("typing", False), - set_reply=args.get("set_reply", False), - storage_message=args.get("storage_message", True), + typing=bool(args.get("typing", False)), + set_reply=bool(args.get("set_reply", False)), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.text] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.text] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_emoji(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送表情图片。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - emoji_base64: str = args.get("emoji_base64", "") - stream_id: str = args.get("stream_id", "") + emoji_base64 = str(args.get("emoji_base64", "")) + stream_id = str(args.get("stream_id", "")) if not emoji_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} @@ -56,18 +114,29 @@ class RuntimeCoreCapabilityMixin: result = await send_api.emoji_to_stream( emoji_base64=emoji_base64, stream_id=stream_id, - storage_message=args.get("storage_message", True), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.emoji] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.emoji] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_image(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送图片。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - image_base64: str = args.get("image_base64", "") - stream_id: str = args.get("stream_id", "") + image_base64 = str(args.get("image_base64", "")) + stream_id = str(args.get("stream_id", "")) if not image_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} @@ -75,18 +144,29 @@ class RuntimeCoreCapabilityMixin: result = await send_api.image_to_stream( image_base64=image_base64, stream_id=stream_id, - storage_message=args.get("storage_message", True), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.image] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.image] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_command(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送命令消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - command = args.get("command", "") - stream_id: str = args.get("stream_id", "") + command = str(args.get("command", "")) + stream_id = str(args.get("stream_id", "")) if not command or not stream_id: return {"success": False, "error": "缺少必要参数 command 或 stream_id"} @@ -95,22 +175,33 @@ class RuntimeCoreCapabilityMixin: message_type="command", content=command, stream_id=stream_id, - storage_message=args.get("storage_message", True), - display_message=args.get("display_message", ""), + storage_message=bool(args.get("storage_message", True)), + display_message=str(args.get("display_message", "")), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.command] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.command] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_send_custom(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """向指定流发送自定义消息。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 能力执行结果。 + """ + del plugin_id, capability from src.services import send_service as send_api - message_type: str = args.get("message_type", "") or args.get("custom_type", "") + message_type = str(args.get("message_type", "") or args.get("custom_type", "")) content = args.get("content") if content is None: content = args.get("data", "") - stream_id: str = args.get("stream_id", "") + stream_id = str(args.get("stream_id", "")) if not message_type or not stream_id: return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} @@ -119,114 +210,116 @@ class RuntimeCoreCapabilityMixin: message_type=message_type, content=content, stream_id=stream_id, - display_message=args.get("display_message", ""), - typing=args.get("typing", False), - storage_message=args.get("storage_message", True), + display_message=str(args.get("display_message", "")), + typing=bool(args.get("typing", False)), + storage_message=bool(args.get("storage_message", True)), ) return {"success": result} - except Exception as e: - logger.error(f"[cap.send.custom] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.send.custom] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_generate(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """执行无工具的 LLM 生成能力。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 标准化后的 LLM 响应结构。 + """ + del capability from src.services import llm_service as llm_api - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - model_name: str = args.get("model", "") or args.get("model_name", "") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") - try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, + prompt = _normalize_prompt_arg(args.get("prompt")) + task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", ""))) + result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name=task_name, + request_type=f"plugin.{plugin_id}", + prompt=prompt, + temperature=args.get("temperature"), + max_tokens=args.get("max_tokens"), + ) ) - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - } - except Exception as e: - logger.error(f"[cap.llm.generate] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + return result.to_capability_payload() + except Exception as exc: + logger.error(f"[cap.llm.generate] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_generate_with_tools(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """执行带工具的 LLM 生成能力。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 标准化后的 LLM 响应结构。 + """ + del capability from src.services import llm_service as llm_api - prompt: str = args.get("prompt", "") - if not prompt: - return {"success": False, "error": "缺少必要参数 prompt"} - - model_name: str = args.get("model", "") or args.get("model_name", "") tool_options = args.get("tools") or args.get("tool_options") - temperature = args.get("temperature") - max_tokens = args.get("max_tokens") + if tool_options is not None and not isinstance(tool_options, list): + return {"success": False, "error": "tools 必须为列表"} try: - models = llm_api.get_available_models() - if model_name and model_name in models: - model_config = models[model_name] - else: - if not models: - return {"success": False, "error": "没有可用的模型配置"} - model_config = next(iter(models.values())) - - success, response, reasoning, used_model, tool_calls = await llm_api.generate_with_model_with_tools( - prompt=prompt, - model_config=model_config, - tool_options=tool_options, - request_type=f"plugin.{plugin_id}", - temperature=temperature, - max_tokens=max_tokens, + prompt = _normalize_prompt_arg(args.get("prompt")) + task_name = llm_api.resolve_task_name(str(args.get("model", "") or args.get("model_name", ""))) + result = await llm_api.generate( + llm_api.LLMServiceRequest( + task_name=task_name, + request_type=f"plugin.{plugin_id}", + prompt=prompt, + tool_options=tool_options, + temperature=args.get("temperature"), + max_tokens=args.get("max_tokens"), + ) ) - serialized_tool_calls = None - if tool_calls: - serialized_tool_calls = [ - { - "id": tool_call.call_id, - "function": {"name": tool_call.func_name, "arguments": tool_call.args or {}}, - } - for tool_call in tool_calls - if isinstance(tool_call, ToolCall) - ] - return { - "success": success, - "response": response, - "reasoning": reasoning, - "model_name": used_model, - "tool_calls": serialized_tool_calls, - } - except Exception as e: - logger.error(f"[cap.llm.generate_with_tools] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + return result.to_capability_payload() + except Exception as exc: + logger.error(f"[cap.llm.generate_with_tools] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_llm_get_available_models(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """获取当前宿主可用的模型任务列表。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 可用模型列表。 + """ + del plugin_id, capability, args from src.services import llm_service as llm_api try: models = llm_api.get_available_models() return {"success": True, "models": list(models.keys())} - except Exception as e: - logger.error(f"[cap.llm.get_available_models] 执行失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} + except Exception as exc: + logger.error(f"[cap.llm.get_available_models] 执行失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} async def _cap_config_get(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - key: str = args.get("key", "") + """读取宿主全局配置中的单个字段。 + + Args: + plugin_id: 插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del plugin_id, capability + key = str(args.get("key", "")) default = args.get("default") if not key: return {"success": False, "value": None, "error": "缺少必要参数 key"} @@ -234,37 +327,57 @@ class RuntimeCoreCapabilityMixin: try: value = _get_nested_config_value(global_config, key, default) return {"success": True, "value": value} - except Exception as e: - return {"success": False, "value": None, "error": str(e)} + except Exception as exc: + return {"success": False, "value": None, "error": str(exc)} async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """读取指定插件的配置。 + + Args: + plugin_id: 当前插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del capability from src.plugin_runtime.component_query import component_query_service - plugin_name: str = args.get("plugin_name", plugin_id) - key: str = args.get("key", "") + plugin_name = str(args.get("plugin_name", plugin_id)) + key = str(args.get("key", "")) default = args.get("default") try: config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"} - if key: value = _get_nested_config_value(config, key, default) return {"success": True, "value": value} - return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": default, "error": str(e)} + except Exception as exc: + return {"success": False, "value": default, "error": str(exc)} async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + """读取指定插件的全部配置。 + + Args: + plugin_id: 当前插件标识。 + capability: 能力名称。 + args: 能力调用参数。 + + Returns: + Any: 配置读取结果。 + """ + del capability from src.plugin_runtime.component_query import component_query_service - plugin_name: str = args.get("plugin_name", plugin_id) + plugin_name = str(args.get("plugin_name", plugin_id)) try: config = component_query_service.get_plugin_config(plugin_name) if config is None: return {"success": True, "value": {}} return {"success": True, "value": config} - except Exception as e: - return {"success": False, "value": {}, "error": str(e)} + except Exception as exc: + return {"success": False, "value": {}, "error": str(exc)} diff --git a/src/plugin_runtime/component_query.py b/src/plugin_runtime/component_query.py index 7d23d202..e2ba7366 100644 --- a/src/plugin_runtime/component_query.py +++ b/src/plugin_runtime/component_query.py @@ -1,16 +1,23 @@ """插件运行时统一组件查询服务。 该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图, -供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。 +供 HFC、ToolExecutor 和运行时能力层查询与调用。 """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast from src.common.logger import get_logger +from src.core.tooling import ( + ToolExecutionContext, + ToolExecutionResult, + ToolInvocation, + ToolSpec, + build_tool_detailed_description, +) from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo -from src.llm_models.payload_content.tool_option import ToolParamType +from src.llm_models.payload_content.tool_option import normalize_tool_option if TYPE_CHECKING: from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry @@ -28,13 +35,6 @@ _HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = { ComponentType.COMMAND: "COMMAND", ComponentType.TOOL: "TOOL", } -_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "float": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "bool": ToolParamType.BOOLEAN, -} class ComponentQueryService: @@ -146,36 +146,25 @@ class ComponentQueryService: metadata = dict(entry.metadata) raw_action_parameters = metadata.get("action_parameters") action_parameters = ( - { - str(param_name): str(param_description) - for param_name, param_description in raw_action_parameters.items() - } + {str(param_name): str(param_description) for param_name, param_description in raw_action_parameters.items()} if isinstance(raw_action_parameters, dict) else {} ) action_require = [ - str(item) - for item in (metadata.get("action_require") or []) - if item is not None and str(item).strip() + str(item) for item in (metadata.get("action_require") or []) if item is not None and str(item).strip() ] associated_types = [ - str(item) - for item in (metadata.get("associated_types") or []) - if item is not None and str(item).strip() + str(item) for item in (metadata.get("associated_types") or []) if item is not None and str(item).strip() ] activation_keywords = [ - str(item) - for item in (metadata.get("activation_keywords") or []) - if item is not None and str(item).strip() + str(item) for item in (metadata.get("activation_keywords") or []) if item is not None and str(item).strip() ] return ActionInfo( name=entry.name, - component_type=ComponentType.ACTION, description=str(metadata.get("description", "") or ""), enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=metadata, action_parameters=action_parameters, action_require=action_require, associated_types=associated_types, @@ -202,72 +191,48 @@ class ComponentQueryService: metadata = dict(entry.metadata) return CommandInfo( name=entry.name, - component_type=ComponentType.COMMAND, description=str(metadata.get("description", "") or ""), enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=metadata, - command_pattern=str(metadata.get("command_pattern", "") or ""), ) @staticmethod - def _coerce_tool_param_type(raw_value: Any) -> ToolParamType: - """规范化工具参数类型。 - - Args: - raw_value: 原始工具参数类型值。 - - Returns: - ToolParamType: 规范化后的工具参数类型。 - """ - - normalized_value = str(raw_value or "").strip().lower() - return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING) - - @staticmethod - def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]: - """将运行时工具参数元数据转换为核心 ToolInfo 参数列表。 + def _build_tool_definition(entry: "ToolEntry") -> dict[str, Any]: + """将运行时 Tool 条目转换为原始工具定义字典。 Args: entry: 插件运行时中的 Tool 条目。 Returns: - list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。 + dict[str, Any]: 可交给 `normalize_tool_option()` 的原始工具定义。 """ + raw_definition: dict[str, Any] = { + "name": entry.name, + "description": entry.description, + } + if isinstance(entry.parameters_raw, dict) and entry.parameters_raw: + raw_definition["parameters_schema"] = entry.parameters_raw + return raw_definition + if isinstance(entry.parameters, list) and entry.parameters: + raw_definition["parameters"] = entry.parameters + return raw_definition + if isinstance(entry.parameters_raw, list) and entry.parameters_raw: + raw_definition["parameters"] = entry.parameters_raw + return raw_definition + return raw_definition - structured_parameters = entry.parameters if isinstance(entry.parameters, list) else [] - if not structured_parameters and isinstance(entry.parameters_raw, dict): - structured_parameters = [ - {"name": key, **value} - for key, value in entry.parameters_raw.items() - if isinstance(value, dict) - ] + @staticmethod + def _build_tool_parameters_schema(entry: "ToolEntry") -> dict[str, Any] | None: + """将运行时 Tool 条目转换为对象级参数 Schema。 - normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] - for parameter in structured_parameters: - if not isinstance(parameter, dict): - continue + Args: + entry: 插件运行时中的 Tool 条目。 - parameter_name = str(parameter.get("name", "") or "").strip() - if not parameter_name: - continue - - enum_values = parameter.get("enum") - normalized_enum_values = ( - [str(item) for item in enum_values if item is not None] - if isinstance(enum_values, list) - else None - ) - normalized_parameters.append( - ( - parameter_name, - ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")), - str(parameter.get("description", "") or ""), - bool(parameter.get("required", True)), - normalized_enum_values, - ) - ) - return normalized_parameters + Returns: + dict[str, Any] | None: 规范化后的对象级参数 Schema。 + """ + normalized_option = normalize_tool_option(ComponentQueryService._build_tool_definition(entry)) + return normalized_option.parameters_schema @staticmethod def _build_tool_info(entry: "ToolEntry") -> ToolInfo: @@ -282,13 +247,36 @@ class ComponentQueryService: return ToolInfo( name=entry.name, - component_type=ComponentType.TOOL, - description=entry.description, + description=entry.brief_description or entry.description, enabled=bool(entry.enabled), plugin_name=entry.plugin_id, - metadata=dict(entry.metadata), - tool_parameters=ComponentQueryService._build_tool_parameters(entry), - tool_description=entry.description, + parameters_schema=ComponentQueryService._build_tool_parameters_schema(entry), + ) + + @staticmethod + def _build_tool_spec(entry: "ToolEntry") -> ToolSpec: + """将运行时 Tool 条目转换为统一工具声明。 + + Args: + entry: 插件运行时中的 Tool 条目。 + + Returns: + ToolSpec: 统一工具声明。 + """ + + parameters_schema = ComponentQueryService._build_tool_parameters_schema(entry) + return ToolSpec( + name=entry.name, + brief_description=entry.brief_description or entry.description or f"工具 {entry.name}", + detailed_description=entry.detailed_description or build_tool_detailed_description(parameters_schema), + parameters_schema=parameters_schema, + provider_name=entry.plugin_id, + provider_type="plugin", + metadata={ + "plugin_id": entry.plugin_id, + "invoke_method": entry.invoke_method, + "legacy_component_type": entry.legacy_component_type, + }, ) @staticmethod @@ -478,9 +466,14 @@ class ComponentQueryService: message = kwargs.get("message") matched_groups = kwargs.get("matched_groups") plugin_config = kwargs.get("plugin_config") + message_info = getattr(message, "message_info", None) + group_info = getattr(message_info, "group_info", None) + user_info = getattr(message_info, "user_info", None) invoke_args: Dict[str, Any] = { "text": str(getattr(message, "processed_plain_text", "") or ""), "stream_id": str(getattr(message, "session_id", "") or ""), + "group_id": str(getattr(group_info, "group_id", "") or ""), + "user_id": str(getattr(user_info, "user_id", "") or ""), "matched_groups": matched_groups if isinstance(matched_groups, dict) else {}, } if isinstance(plugin_config, dict): @@ -515,7 +508,12 @@ class ComponentQueryService: return _executor @staticmethod - def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor: + def _build_tool_executor( + supervisor: "PluginSupervisor", + plugin_id: str, + component_name: str, + invoke_method: str = "plugin.invoke_tool", + ) -> ToolExecutor: """构造工具执行 RPC 闭包。 Args: @@ -539,7 +537,7 @@ class ComponentQueryService: try: response = await supervisor.invoke_plugin( - method="plugin.invoke_tool", + method=invoke_method, plugin_id=plugin_id, component_name=component_name, args=function_args, @@ -655,7 +653,162 @@ class ComponentQueryService: if matched_entry is None: return None supervisor, entry = matched_entry - return self._build_tool_executor(supervisor, entry.plugin_id, entry.name) + tool_entry = cast("ToolEntry", entry) + return self._build_tool_executor(supervisor, tool_entry.plugin_id, tool_entry.name, tool_entry.invoke_method) + + def get_llm_available_tool_specs(self) -> Dict[str, ToolSpec]: + """获取当前可供 LLM 使用的统一工具声明集合。 + + Returns: + Dict[str, ToolSpec]: 工具名到工具声明的映射。 + """ + + collected_specs: Dict[str, ToolSpec] = {} + for _supervisor, entry in self._iter_component_entries(ComponentType.TOOL): + if entry.name in collected_specs: + self._log_duplicate_component(ComponentType.TOOL, entry.name) + continue + collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type] + return collected_specs + + @staticmethod + def _build_tool_invocation_payload( + entry: "ToolEntry", + invocation: ToolInvocation, + context: Optional[ToolExecutionContext], + ) -> Dict[str, Any]: + """构造插件工具执行时发送给 Runner 的参数。 + + Args: + entry: 目标工具条目。 + invocation: 统一工具调用请求。 + context: 统一工具执行上下文。 + + Returns: + Dict[str, Any]: 发往 Runner 的参数字典。 + """ + + payload = dict(invocation.arguments) + if entry.invoke_method == "plugin.invoke_action": + stream_id = context.stream_id if context is not None else invocation.stream_id + reasoning = context.reasoning if context is not None else invocation.reasoning + payload = { + **payload, + "stream_id": stream_id, + "chat_id": stream_id, + "reasoning": reasoning, + "action_data": dict(invocation.arguments), + } + return payload + + @staticmethod + def _parse_tool_invoke_result( + entry: "ToolEntry", + result: Any, + ) -> ToolExecutionResult: + """将插件组件返回值转换为统一工具执行结果。 + + Args: + entry: 目标工具条目。 + result: 插件组件原始返回值。 + + Returns: + ToolExecutionResult: 统一执行结果。 + """ + + if isinstance(result, dict): + success = bool(result.get("success", True)) + content = str(result.get("content", result.get("message", "")) or "").strip() + error_message = "" + if not success: + error_message = str(result.get("error", result.get("message", "插件工具执行失败")) or "").strip() + return ToolExecutionResult( + tool_name=entry.name, + success=success, + content=content, + error_message=error_message, + structured_content=result, + metadata={"plugin_id": entry.plugin_id}, + ) + + if isinstance(result, (list, tuple)) and result: + if isinstance(result[0], bool): + success = bool(result[0]) + message = "" if len(result) < 2 or result[1] is None else str(result[1]).strip() + return ToolExecutionResult( + tool_name=entry.name, + success=success, + content=message if success else "", + error_message="" if success else message, + structured_content=list(result), + metadata={"plugin_id": entry.plugin_id}, + ) + + normalized_content = "" if result is None else str(result).strip() + return ToolExecutionResult( + tool_name=entry.name, + success=True, + content=normalized_content, + structured_content=result, + metadata={"plugin_id": entry.plugin_id}, + ) + + async def invoke_tool_as_tool( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """按统一工具语义执行插件工具。 + + Args: + invocation: 统一工具调用请求。 + context: 执行上下文。 + + Returns: + ToolExecutionResult: 统一工具执行结果。 + """ + + matched_entry = self._get_unique_component_entry(ComponentType.TOOL, invocation.tool_name) + if matched_entry is None: + return ToolExecutionResult( + tool_name=invocation.tool_name, + success=False, + error_message=f"未找到插件工具:{invocation.tool_name}", + ) + + supervisor, entry = matched_entry + tool_entry = cast("ToolEntry", entry) + invoke_payload = self._build_tool_invocation_payload(tool_entry, invocation, context) + + try: + response = await supervisor.invoke_plugin( + method=tool_entry.invoke_method, + plugin_id=tool_entry.plugin_id, + component_name=tool_entry.name, + args=invoke_payload, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"运行时工具 {tool_entry.plugin_id}.{tool_entry.name} 执行失败: {exc}", exc_info=True) + return ToolExecutionResult( + tool_name=tool_entry.name, + success=False, + error_message=str(exc), + metadata={"plugin_id": tool_entry.plugin_id}, + ) + + payload = response.payload if isinstance(response.payload, dict) else {} + transport_success = bool(payload.get("success", False)) + result = payload.get("result") + if not transport_success: + return ToolExecutionResult( + tool_name=tool_entry.name, + success=False, + error_message="" if result is None else str(result), + structured_content=result, + metadata={"plugin_id": tool_entry.plugin_id}, + ) + return self._parse_tool_invoke_result(tool_entry, result) def get_llm_available_tools(self) -> Dict[str, ToolInfo]: """获取当前可供 LLM 选择的工具集合。 diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 97fdca30..c91574e5 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -1,7 +1,7 @@ -"""Host-side ComponentRegistry +"""Host 侧组件注册表。 对齐旧系统 component_registry.py 的核心能力: -- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway) +- 按类型注册组件(action / command / tool / event_handler / hook_handler / message_gateway) - 命名空间 (plugin_id.component_name) - 命令正则匹配 - 组件启用/禁用 @@ -16,6 +16,7 @@ import contextlib import re from src.common.logger import get_logger +from src.core.tooling import build_tool_detailed_description logger = get_logger("plugin_runtime.host.component_registry") @@ -89,11 +90,81 @@ class ToolEntry(ComponentEntry): """Tool 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - self.description: str = metadata.get("description", "") + self.description: str = str(metadata.get("description", "") or "").strip() + self.brief_description: str = str( + metadata.get("brief_description", self.description) or self.description or f"工具 {name}" + ).strip() self.parameters: List[Dict[str, Any]] = metadata.get("parameters", []) - self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", []) + self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {}) + detailed_description = str(metadata.get("detailed_description", "") or "").strip() + self.detailed_description: str = detailed_description + self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip() + self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip() super().__init__(name, component_type, plugin_id, metadata) + if not self.detailed_description: + parameters_schema = self._get_parameters_schema() + self.detailed_description = build_tool_detailed_description(parameters_schema) + + def _get_parameters_schema(self) -> Dict[str, Any] | None: + """获取当前工具条目的对象级参数 Schema。 + + Returns: + Dict[str, Any] | None: 归一化后的参数 Schema。 + """ + + if isinstance(self.parameters_raw, dict) and self.parameters_raw: + if self.parameters_raw.get("type") == "object" or "properties" in self.parameters_raw: + return dict(self.parameters_raw) + + required_names: List[str] = [] + normalized_properties: Dict[str, Any] = {} + for property_name, property_schema in self.parameters_raw.items(): + if not isinstance(property_schema, dict): + continue + property_schema_copy = dict(property_schema) + if bool(property_schema_copy.pop("required", False)): + required_names.append(str(property_name)) + normalized_properties[str(property_name)] = property_schema_copy + + schema: Dict[str, Any] = { + "type": "object", + "properties": normalized_properties, + } + if required_names: + schema["required"] = required_names + return schema + + if isinstance(self.parameters, list) and self.parameters: + properties: Dict[str, Any] = {} + required_names: List[str] = [] + for parameter in self.parameters: + if not isinstance(parameter, dict): + continue + parameter_name = str(parameter.get("name", "") or "").strip() + if not parameter_name: + continue + if bool(parameter.get("required", False)): + required_names.append(parameter_name) + properties[parameter_name] = { + key: value + for key, value in parameter.items() + if key not in {"name", "required", "param_type"} + } + properties[parameter_name]["type"] = str( + parameter.get("type", parameter.get("param_type", "string")) or "string" + ) + + schema = { + "type": "object", + "properties": properties, + } + if required_names: + schema["required"] = required_names + return schema + + return None + class EventHandlerEntry(ComponentEntry): """EventHandler 组件条目""" @@ -106,14 +177,129 @@ class EventHandlerEntry(ComponentEntry): class HookHandlerEntry(ComponentEntry): - """WorkflowHandler 组件条目""" + """HookHandler 组件条目。""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - self.stage: str = metadata.get("stage", "") - self.priority: int = metadata.get("priority", 0) - self.blocking: bool = metadata.get("blocking", False) + self.hook: str = self._normalize_hook_name(metadata.get("hook", "")) + self.mode: str = self._normalize_mode(metadata.get("mode", "blocking")) + self.order: str = self._normalize_order(metadata.get("order", "normal")) + self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0)) + self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip")) super().__init__(name, component_type, plugin_id, metadata) + @staticmethod + def _normalize_error_policy(raw_value: Any) -> str: + """规范化 Hook 异常处理策略。 + + Args: + raw_value: 原始异常处理策略值。 + + Returns: + str: 规范化后的异常处理策略。 + + Raises: + ValueError: 当异常处理策略不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "skip" + if normalized_value not in {"abort", "skip", "log"}: + raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_hook_name(raw_value: Any) -> str: + """规范化命名 Hook 名称。 + + Args: + raw_value: 原始 Hook 名称。 + + Returns: + str: 去空白后的 Hook 名称。 + + Raises: + ValueError: 当 Hook 名称为空时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + if not (normalized_value := str(normalized_source or "").strip()): + raise ValueError("HookHandler 的 hook 名称不能为空") + return normalized_value + + @staticmethod + def _normalize_mode(raw_value: Any) -> str: + """规范化 Hook 处理模式。 + + Args: + raw_value: 原始模式值。 + + Returns: + str: 规范化后的模式。 + + Raises: + ValueError: 当模式不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "blocking" + if normalized_value not in {"blocking", "observe"}: + raise ValueError(f"HookHandler 模式不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_order(raw_value: Any) -> str: + """规范化 Hook 顺序槽位。 + + Args: + raw_value: 原始顺序值。 + + Returns: + str: 规范化后的顺序槽位。 + + Raises: + ValueError: 当顺序值不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "normal" + if normalized_value not in {"early", "normal", "late"}: + raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_timeout_ms(raw_value: Any) -> int: + """规范化 Hook 超时配置。 + + Args: + raw_value: 原始超时值。 + + Returns: + int: 规范化后的超时毫秒数。 + + Raises: + ValueError: 当超时值为负数或无法转换为整数时抛出。 + """ + + try: + timeout_ms = int(raw_value or 0) + except (TypeError, ValueError) as exc: + raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc + if timeout_ms < 0: + raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}") + return timeout_ms + + @property + def is_blocking(self) -> bool: + """返回当前 Hook 是否为阻塞模式。""" + + return self.mode == "blocking" + + @property + def is_observe(self) -> bool: + """返回当前 Hook 是否为观察模式。""" + + return self.mode == "observe" + class MessageGatewayEntry(ComponentEntry): """MessageGateway 组件条目""" @@ -167,7 +353,7 @@ class MessageGatewayEntry(ComponentEntry): class ComponentRegistry: - """Host-side 组件注册表 + """Host 侧组件注册表。 由 Supervisor 在收到 plugin.register_components 时调用。 供业务层查询可用组件、匹配命令、调度 action/event 等。 @@ -185,6 +371,86 @@ class ComponentRegistry: # 按插件索引 self._by_plugin: Dict[str, List[ComponentEntry]] = {} + @staticmethod + def _convert_action_metadata_to_tool_metadata( + name: str, + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """将旧 Action 元数据转换为统一 Tool 元数据。 + + Args: + name: 组件名称。 + metadata: Action 原始元数据。 + + Returns: + Dict[str, Any]: 转换后的 Tool 元数据。 + """ + + action_parameters = metadata.get("action_parameters") + parameters_schema: Dict[str, Any] | None = None + if isinstance(action_parameters, dict) and action_parameters: + properties: Dict[str, Any] = {} + for parameter_name, parameter_description in action_parameters.items(): + normalized_name = str(parameter_name or "").strip() + if not normalized_name: + continue + properties[normalized_name] = { + "type": "string", + "description": str(parameter_description or "").strip() or "兼容旧 Action 参数", + } + if properties: + parameters_schema = { + "type": "object", + "properties": properties, + } + + detailed_parts: List[str] = [] + if parameters_schema is not None: + parameter_description = build_tool_detailed_description(parameters_schema) + if parameter_description: + detailed_parts.append(parameter_description) + + action_require = [ + str(item).strip() + for item in (metadata.get("action_require") or []) + if str(item).strip() + ] + if action_require: + detailed_parts.append("使用建议:\n" + "\n".join(f"- {item}" for item in action_require)) + + associated_types = [ + str(item).strip() + for item in (metadata.get("associated_types") or []) + if str(item).strip() + ] + if associated_types: + detailed_parts.append(f"适用消息类型:{'、'.join(associated_types)}。") + + activation_type = str(metadata.get("activation_type", "always") or "always").strip() + activation_keywords = [ + str(item).strip() + for item in (metadata.get("activation_keywords") or []) + if str(item).strip() + ] + activation_lines = [f"兼容旧 Action 激活方式:{activation_type}。"] + if activation_keywords: + activation_lines.append(f"激活关键词:{'、'.join(activation_keywords)}。") + if str(metadata.get("action_prompt", "") or "").strip(): + activation_lines.append(f"原始 Action 提示语:{str(metadata['action_prompt']).strip()}。") + detailed_parts.append("\n".join(activation_lines)) + + brief_description = str(metadata.get("brief_description", metadata.get("description", "") or f"工具 {name}")).strip() + return { + **metadata, + "description": brief_description, + "brief_description": brief_description, + "detailed_description": "\n\n".join(part for part in detailed_parts if part).strip(), + "parameters_raw": parameters_schema or {}, + "invoke_method": "plugin.invoke_action", + "legacy_action": True, + "legacy_component_type": "ACTION", + } + @staticmethod def _normalize_component_type(component_type: str) -> ComponentTypes: """规范化组件类型输入。 @@ -223,18 +489,20 @@ class ComponentRegistry: """ try: normalized_type = self._normalize_component_type(component_type) + normalized_metadata = dict(metadata) if normalized_type == ComponentTypes.ACTION: - comp = ActionEntry(name, normalized_type.value, plugin_id, metadata) + normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata) + comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata) elif normalized_type == ComponentTypes.COMMAND: - comp = CommandEntry(name, normalized_type.value, plugin_id, metadata) + comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata) elif normalized_type == ComponentTypes.TOOL: - comp = ToolEntry(name, normalized_type.value, plugin_id, metadata) + comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata) elif normalized_type == ComponentTypes.EVENT_HANDLER: - comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata) + comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata) elif normalized_type == ComponentTypes.HOOK_HANDLER: - comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata) + comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata) elif normalized_type == ComponentTypes.MESSAGE_GATEWAY: - comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata) + comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata) else: raise ValueError(f"组件类型 {component_type} 不存在") except ValueError: @@ -454,16 +722,17 @@ class ComponentRegistry: return handlers def get_hook_handlers( - self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None + self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None ) -> List[HookHandlerEntry]: - """获取特定 hook 阶段的所有步骤,按 priority 降序。 + """获取订阅指定命名 Hook 的全部处理器。 Args: - stage: hook 名称 - enabled_only: 是否仅返回启用的组件 - session_id: 可选的会话ID,若提供则考虑会话禁用状态 + hook_name: 目标 Hook 名称。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID,若提供则考虑会话禁用状态。 + Returns: - handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序 + List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表。 """ handlers: List[HookHandlerEntry] = [] for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values(): @@ -471,11 +740,37 @@ class ComponentRegistry: continue if not isinstance(comp, HookHandlerEntry): continue - if comp.stage == stage: + if comp.hook == hook_name: handlers.append(comp) - handlers.sort(key=lambda c: c.priority, reverse=True) + handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name)) return handlers + @staticmethod + def _get_hook_mode_rank(mode: str) -> int: + """返回 Hook 模式的排序权重。 + + Args: + mode: Hook 模式字符串。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"blocking": 0, "observe": 1}.get(mode, 99) + + @staticmethod + def _get_hook_order_rank(order: str) -> int: + """返回 Hook 顺序槽位的排序权重。 + + Args: + order: Hook 顺序槽位字符串。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"early": 0, "normal": 1, "late": 2}.get(order, 99) + def get_message_gateway( self, plugin_id: str, @@ -566,8 +861,13 @@ class ComponentRegistry: Returns: stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 """ - stats: StatusDict = {"total": len(self._components)} # type: ignore - for comp_type, type_dict in self._by_type.items(): - stats[comp_type.value.lower()] = len(type_dict) - stats["plugins"] = len(self._by_plugin) - return stats + return StatusDict( + total=len(self._components), + action=len(self._by_type[ComponentTypes.ACTION]), + command=len(self._by_type[ComponentTypes.COMMAND]), + tool=len(self._by_type[ComponentTypes.TOOL]), + event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]), + hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]), + message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]), + plugins=len(self._by_plugin), + ) diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py index d5e88448..f2979f29 100644 --- a/src/plugin_runtime/host/hook_dispatcher.py +++ b/src/plugin_runtime/host/hook_dispatcher.py @@ -1,166 +1,670 @@ -""" -Hook Dispatch 系统 +"""命名 Hook 分发系统。 -插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。 -每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。 -在参数/返回值匹配的情况下允许修改参数/返回值。 +主程序可以在任意执行点触发一个命名 Hook,Host 会收集所有订阅该 Hook 的 +插件处理器,并按照固定的全局顺序调度执行。 -HookDispatcher 负责: -1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry) -2. 按 priority 排序,区分 blocking 和非 blocking 模式 -3. blocking 模式:依次同步调用,支持修改参数/提前终止 -4. 非 blocking 模式:异步调用,不阻塞主流程 -5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限 +排序规则如下: + +1. `blocking` 先于 `observe` +2. `early` 先于 `normal` 先于 `late` +3. 内置插件先于第三方插件 +4. `plugin_id` +5. `handler_name` + +其中: + +- `blocking` 处理器串行执行,可修改 `kwargs`,也可中止本次 Hook 调用。 +- `observe` 处理器后台并发执行,只允许旁路观察,不参与主流程控制。 """ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set + import asyncio -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING +import contextlib from src.common.logger import get_logger from src.config.config import global_config - if TYPE_CHECKING: + from .component_registry import HookHandlerEntry from .supervisor import PluginRunnerSupervisor - from .component_registry import ComponentRegistry, HookHandlerEntry logger = get_logger("plugin_runtime.host.hook_dispatcher") -@dataclass -class HookResult: - """单个 HookHandler 的执行结果""" +@dataclass(slots=True) +class HookSpec: + """命名 Hook 的静态规格定义。 + + Attributes: + name: Hook 的唯一名称。 + description: Hook 描述。 + default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。 + allow_blocking: 是否允许注册阻塞处理器。 + allow_observe: 是否允许注册观察处理器。 + allow_abort: 是否允许处理器中止当前 Hook 调用。 + allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。 + """ + + name: str + description: str = "" + default_timeout_ms: int = 0 + allow_blocking: bool = True + allow_observe: bool = True + allow_abort: bool = True + allow_kwargs_mutation: bool = True + + +@dataclass(slots=True) +class HookHandlerExecutionResult: + """单个 HookHandler 的执行结果。 + + Attributes: + handler_name: 完整处理器名称,格式通常为 `plugin_id.component_name`。 + plugin_id: 处理器所属插件 ID。 + success: 本次调用是否成功。 + action: 当前处理器要求的控制动作,仅支持 `continue` 或 `abort`。 + modified_kwargs: 处理器返回的修改后参数字典。 + custom_result: 处理器返回的附加结果。 + error_message: 失败时的错误描述。 + """ handler_name: str - success: bool = field(default=True) - continue_processing: bool = field(default=True) - modified_kwargs: Optional[Dict[str, Any]] = field(default=None) - custom_result: Any = field(default=None) + plugin_id: str + success: bool = True + action: str = "continue" + modified_kwargs: Optional[Dict[str, Any]] = None + custom_result: Any = None + error_message: str = "" + + +@dataclass(slots=True) +class HookDispatchResult: + """一次命名 Hook 调用的聚合结果。 + + Attributes: + hook_name: 本次调用的 Hook 名称。 + kwargs: 经阻塞处理器串行处理后的最终参数字典。 + aborted: 是否被某个处理器中止。 + stopped_by: 若被中止,记录触发中止的完整处理器名称。 + custom_results: 阻塞处理器返回的附加结果列表。 + errors: 本次调用中记录到的错误信息列表。 + """ + + hook_name: str + kwargs: Dict[str, Any] = field(default_factory=dict) + aborted: bool = False + stopped_by: Optional[str] = None + custom_results: List[Any] = field(default_factory=list) + errors: List[str] = field(default_factory=list) + + +@dataclass(slots=True) +class _HookInvocationTarget: + """内部使用的 Hook 调度目标。 + + Attributes: + supervisor: 负责该处理器的 Supervisor。 + entry: Hook 处理器条目。 + source_rank: 插件来源权重,内置插件为 `0`,第三方插件为 `1`。 + """ + + supervisor: "PluginRunnerSupervisor" + entry: "HookHandlerEntry" + source_rank: int class HookDispatcher: - """Host-side Hook 分发器 + """命名 Hook 分发器。""" - 由业务层调用 hook_dispatch(), - 内部通过 ComponentRegistry 查询 handler, - 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 - """ - - def __init__(self, component_registry: "ComponentRegistry") -> None: - """初始化 HookDispatcher + def __init__( + self, + supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None, + ) -> None: + """初始化 Hook 分发器。 Args: - component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler + supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()` + 时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。 """ - self._component_registry: "ComponentRegistry" = component_registry - self._background_tasks: Set[asyncio.Task] = set() + + self._background_tasks: Set[asyncio.Task[Any]] = set() + self._hook_specs: Dict[str, HookSpec] = {} + self._supervisors_provider = supervisors_provider async def stop(self) -> None: - """停止 HookDispatcher,取消所有未完成的后台任务""" + """停止分发器并取消所有未完成的观察任务。""" + for task in self._background_tasks: task.cancel() await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - async def hook_dispatch( - self, - stage: str, - supervisor: "PluginRunnerSupervisor", - **kwargs: Any, - ) -> Dict[str, Any]: - """分发 hook 到所有对应 handler 的便捷方法。 - - 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑, - 无需调用方手动构造 invoke_fn 闭包。 + def register_hook_spec(self, spec: HookSpec) -> None: + """注册单个命名 Hook 规格。 Args: - stage: hook 名称 - supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin - **kwargs: 关键字参数,会展开传递给 handler + spec: 需要注册的 Hook 规格。 + """ + + normalized_name = self._normalize_hook_name(spec.name) + self._hook_specs[normalized_name] = HookSpec( + name=normalized_name, + description=spec.description, + default_timeout_ms=max(int(spec.default_timeout_ms), 0), + allow_blocking=bool(spec.allow_blocking), + allow_observe=bool(spec.allow_observe), + allow_abort=bool(spec.allow_abort), + allow_kwargs_mutation=bool(spec.allow_kwargs_mutation), + ) + + def register_hook_specs(self, specs: Sequence[HookSpec]) -> None: + """批量注册命名 Hook 规格。 + + Args: + specs: 需要注册的 Hook 规格序列。 + """ + + for spec in specs: + self.register_hook_spec(spec) + + def get_hook_spec(self, hook_name: str) -> HookSpec: + """获取指定 Hook 的规格定义。 + + Args: + hook_name: Hook 名称。 Returns: - modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数 + HookSpec: 若未显式注册,则返回按系统默认值生成的运行时规格。 """ - handler_entries = self._component_registry.get_hook_handlers(stage) - if not handler_entries: - return kwargs - current_kwargs = kwargs.copy() - blocking_handlers: List["HookHandlerEntry"] = [] - non_blocking_handlers: List["HookHandlerEntry"] = [] + normalized_name = self._normalize_hook_name(hook_name) + if normalized_name in self._hook_specs: + return self._hook_specs[normalized_name] - # 分离 blocking 和非 blocking handler - for entry in handler_entries: - if entry.blocking: - blocking_handlers.append(entry) - else: - non_blocking_handlers.append(entry) + return HookSpec( + name=normalized_name, + default_timeout_ms=self._get_default_timeout_ms(), + ) - # 处理 blocking handlers(同步调用,支持修改参数/提前终止) - timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0 - for entry in blocking_handlers: - hook_args = {"stage": stage, **current_kwargs} - try: - # 应用超时控制 - result = await asyncio.wait_for( - self._invoke_handler(supervisor, entry, hook_args), - timeout=timeout, + async def invoke_hook( + self, + hook_name: str, + supervisors: Optional[Sequence["PluginRunnerSupervisor"]] = None, + **kwargs: Any, + ) -> HookDispatchResult: + """触发一次命名 Hook 调用。 + + Args: + hook_name: 本次触发的 Hook 名称。 + supervisors: 当前运行时中所有可参与分发的 Supervisor;留空时使用绑定的提供器。 + **kwargs: 传递给 Hook 处理器的关键字参数。 + + Returns: + HookDispatchResult: 聚合后的 Hook 调用结果。 + """ + + resolved_supervisors = list(supervisors) if supervisors is not None else list(self._resolve_supervisors()) + normalized_hook_name = self._normalize_hook_name(hook_name) + hook_spec = self.get_hook_spec(normalized_hook_name) + current_kwargs: Dict[str, Any] = dict(kwargs) + dispatch_result = HookDispatchResult(hook_name=normalized_hook_name, kwargs=dict(current_kwargs)) + invocation_targets = self._collect_invocation_targets(normalized_hook_name, resolved_supervisors) + + if not invocation_targets: + return dispatch_result + + for target in invocation_targets: + if target.entry.is_observe: + self._schedule_observe_handler( + hook_name=normalized_hook_name, + hook_spec=hook_spec, + target=target, + kwargs=current_kwargs, ) - except asyncio.TimeoutError: - logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过") - result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True) + continue - if result: - if result.modified_kwargs is not None: - current_kwargs = result.modified_kwargs - if not result.continue_processing: - logger.info(f"HookHandler {entry.full_name} 终止了后续处理") - break + if not hook_spec.allow_blocking: + error_message = ( + f"Hook {normalized_hook_name} 不允许 blocking 处理器," + f"已跳过 {target.entry.full_name}" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + continue - # 处理 non-blocking handlers(异步调用,不阻塞主流程) - for entry in non_blocking_handlers: - async_kwargs = current_kwargs.copy() - hook_args = {"stage": stage, **async_kwargs} - task = asyncio.create_task( - asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout) + execution_result = await self._invoke_handler( + hook_name=normalized_hook_name, + hook_spec=hook_spec, + target=target, + kwargs=current_kwargs, + ) + self._merge_blocking_result( + hook_spec=hook_spec, + target=target, + execution_result=execution_result, + dispatch_result=dispatch_result, ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return current_kwargs + current_kwargs = dict(dispatch_result.kwargs) + if dispatch_result.aborted: + break + + return dispatch_result + + def _resolve_supervisors(self) -> Sequence["PluginRunnerSupervisor"]: + """解析当前调用应使用的 Supervisor 列表。 + + Returns: + Sequence[PluginRunnerSupervisor]: 可参与本次 Hook 调度的 Supervisor 序列。 + + Raises: + ValueError: 当未传入 `supervisors` 且分发器也未绑定提供器时抛出。 + """ + + if self._supervisors_provider is None: + raise ValueError("当前 HookDispatcher 未绑定 supervisors_provider,请显式传入 supervisors") + return self._supervisors_provider() + + def _collect_invocation_targets( + self, + hook_name: str, + supervisors: Sequence["PluginRunnerSupervisor"], + ) -> List[_HookInvocationTarget]: + """收集并排序本次 Hook 调用的全部处理器目标。 + + Args: + hook_name: 目标 Hook 名称。 + supervisors: 当前参与调度的 Supervisor 序列。 + + Returns: + List[_HookInvocationTarget]: 已完成全局排序的处理器目标列表。 + """ + + invocation_targets: List[_HookInvocationTarget] = [] + for supervisor in supervisors: + source_rank = self._get_supervisor_source_rank(supervisor) + for entry in supervisor.component_registry.get_hook_handlers(hook_name): + invocation_targets.append( + _HookInvocationTarget( + supervisor=supervisor, + entry=entry, + source_rank=source_rank, + ) + ) + + invocation_targets.sort(key=self._build_sort_key) + return invocation_targets + + @staticmethod + def _build_sort_key(target: _HookInvocationTarget) -> tuple[int, int, int, str, str]: + """构造 Hook 处理器的全局排序键。 + + Args: + target: 待排序的处理器目标。 + + Returns: + tuple[int, int, int, str, str]: 全局排序键。 + """ + + return ( + HookDispatcher._get_mode_rank(target.entry.mode), + HookDispatcher._get_order_rank(target.entry.order), + target.source_rank, + target.entry.plugin_id, + target.entry.name, + ) + + @staticmethod + def _get_default_timeout_ms() -> int: + """读取系统级默认 Hook 超时。 + + Returns: + int: 默认超时毫秒数。 + """ + + timeout_seconds = float(global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0) + return max(int(timeout_seconds * 1000), 1) + + @staticmethod + def _get_mode_rank(mode: str) -> int: + """返回 Hook 模式的排序权重。 + + Args: + mode: Hook 模式。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"blocking": 0, "observe": 1}.get(mode, 99) + + @staticmethod + def _get_order_rank(order: str) -> int: + """返回 Hook 顺序槽位的排序权重。 + + Args: + order: Hook 顺序槽位。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"early": 0, "normal": 1, "late": 2}.get(order, 99) + + @staticmethod + def _get_supervisor_source_rank(supervisor: "PluginRunnerSupervisor") -> int: + """返回 Supervisor 的来源排序权重。 + + Args: + supervisor: 目标 Supervisor。 + + Returns: + int: 内置插件返回 `0`,第三方插件返回 `1`。 + """ + + return 0 if supervisor.group_name == "builtin" else 1 + + @staticmethod + def _normalize_hook_name(hook_name: str) -> str: + """规范化命名 Hook 名称。 + + Args: + hook_name: 原始 Hook 名称。 + + Returns: + str: 规范化后的 Hook 名称。 + + Raises: + ValueError: 当 Hook 名称为空时抛出。 + """ + + normalized_name = str(hook_name or "").strip() + if not normalized_name: + raise ValueError("Hook 名称不能为空") + return normalized_name + + def _resolve_timeout_ms(self, hook_spec: HookSpec, target: _HookInvocationTarget) -> int: + """计算单个处理器的实际超时。 + + Args: + hook_spec: 当前 Hook 的规格定义。 + target: 当前执行目标。 + + Returns: + int: 最终生效的超时毫秒数。 + """ + + if target.entry.timeout_ms > 0: + return target.entry.timeout_ms + if hook_spec.default_timeout_ms > 0: + return hook_spec.default_timeout_ms + return self._get_default_timeout_ms() async def _invoke_handler( self, - supervisor: "PluginRunnerSupervisor", - handler_entry: "HookHandlerEntry", - args: Dict[str, Any], - ) -> Optional[HookResult]: - """调用单个 handler 并收集结果。 + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> HookHandlerExecutionResult: + """执行单个 Hook 处理器。 Args: - supervisor: PluginRunnerSupervisor 实例 - handler_entry: HookHandlerEntry 实例 - args: 传递给 handler 的参数字典 - stage: hook 名称 + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前执行目标。 + kwargs: 当前参数字典。 Returns: - Optional[HookResult]: 执行结果,如果执行失败则返回 None + HookHandlerExecutionResult: 处理器执行结果。 """ - try: - resp_envelope = await supervisor.invoke_plugin( - "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args - ) - resp = resp_envelope.payload - result = HookResult( - handler_name=handler_entry.full_name, - success=resp.get("success", True), - continue_processing=resp.get("continue_processing", True), - modified_kwargs=resp.get("modified_kwargs"), - custom_result=resp.get("custom_result"), - ) - except Exception as e: - logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True) - result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) - return result + timeout_ms = self._resolve_timeout_ms(hook_spec, target) + request_args: Dict[str, Any] = {"hook_name": hook_name, **dict(kwargs)} + + try: + response_envelope = await asyncio.wait_for( + target.supervisor.invoke_plugin( + "plugin.invoke_hook", + target.entry.plugin_id, + target.entry.name, + request_args, + timeout_ms=timeout_ms, + ), + timeout=max(timeout_ms / 1000.0, 0.001), + ) + except asyncio.TimeoutError: + error_message = ( + f"HookHandler {target.entry.full_name} 执行超时,已超过 {timeout_ms}ms" + ) + logger.error(error_message) + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=False, + error_message=error_message, + ) + except Exception as exc: + error_message = f"HookHandler {target.entry.full_name} 执行失败: {exc}" + logger.error(error_message, exc_info=True) + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=False, + error_message=error_message, + ) + + response_payload = response_envelope.payload + if not isinstance(response_payload, dict): + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + custom_result=response_payload, + ) + + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=bool(response_payload.get("success", True)), + action=self._normalize_action(response_payload.get("action", "continue")), + modified_kwargs=self._extract_modified_kwargs(response_payload.get("modified_kwargs")), + custom_result=response_payload.get("custom_result"), + error_message=str(response_payload.get("error_message", "") or ""), + ) + + @staticmethod + def _extract_modified_kwargs(raw_value: Any) -> Optional[Dict[str, Any]]: + """提取并校验处理器返回的 `modified_kwargs`。 + + Args: + raw_value: 原始返回值。 + + Returns: + Optional[Dict[str, Any]]: 合法时返回字典,否则返回 `None`。 + """ + + if raw_value is None: + return None + if isinstance(raw_value, dict): + return dict(raw_value) + logger.warning("HookHandler 返回的 modified_kwargs 不是字典,已忽略") + return None + + @staticmethod + def _normalize_action(raw_value: Any) -> str: + """规范化处理器动作返回值。 + + Args: + raw_value: 原始动作值。 + + Returns: + str: 规范化后的动作值,仅支持 `continue` 或 `abort`。 + """ + + normalized_value = str(raw_value or "").strip().lower() or "continue" + if normalized_value not in {"continue", "abort"}: + logger.warning(f"未知的 Hook action: {raw_value},已按 continue 处理") + return "continue" + return normalized_value + + def _merge_blocking_result( + self, + hook_spec: HookSpec, + target: _HookInvocationTarget, + execution_result: HookHandlerExecutionResult, + dispatch_result: HookDispatchResult, + ) -> None: + """合并阻塞处理器结果到聚合结果。 + + Args: + hook_spec: 当前 Hook 规格。 + target: 当前执行目标。 + execution_result: 当前处理器执行结果。 + dispatch_result: 当前聚合结果对象。 + """ + + if execution_result.custom_result is not None: + dispatch_result.custom_results.append(execution_result.custom_result) + + if not execution_result.success: + error_message = execution_result.error_message or f"HookHandler {target.entry.full_name} 执行失败" + dispatch_result.errors.append(error_message) + self._apply_error_policy(target, hook_spec, dispatch_result, error_message) + return + + if execution_result.modified_kwargs is not None: + if hook_spec.allow_kwargs_mutation: + dispatch_result.kwargs = dict(execution_result.modified_kwargs) + else: + error_message = ( + f"Hook {dispatch_result.hook_name} 不允许修改 kwargs," + f"已忽略 {target.entry.full_name} 的 modified_kwargs" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + + if execution_result.action == "abort": + if hook_spec.allow_abort: + dispatch_result.aborted = True + dispatch_result.stopped_by = target.entry.full_name + logger.info(f"HookHandler {target.entry.full_name} 中止了 Hook {dispatch_result.hook_name}") + else: + error_message = ( + f"Hook {dispatch_result.hook_name} 不允许 abort," + f"已忽略 {target.entry.full_name} 的 abort 请求" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + + def _apply_error_policy( + self, + target: _HookInvocationTarget, + hook_spec: HookSpec, + dispatch_result: HookDispatchResult, + error_message: str, + ) -> None: + """根据错误策略处理阻塞处理器失败。 + + Args: + target: 触发错误的处理器目标。 + hook_spec: 当前 Hook 规格。 + dispatch_result: 当前聚合结果对象。 + error_message: 需要记录的错误描述。 + """ + + if target.entry.error_policy != "abort": + return + if not hook_spec.allow_abort: + logger.warning( + f"Hook {dispatch_result.hook_name} 禁止 abort," + f"已将 {target.entry.full_name} 的错误策略按 skip 处理" + ) + return + + dispatch_result.aborted = True + dispatch_result.stopped_by = target.entry.full_name + logger.warning( + f"HookHandler {target.entry.full_name} 因错误策略 abort " + f"中止了 Hook {dispatch_result.hook_name}: {error_message}" + ) + + def _schedule_observe_handler( + self, + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> None: + """后台调度观察型处理器。 + + Args: + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前观察型处理器目标。 + kwargs: 调用参数快照。 + """ + + if not hook_spec.allow_observe: + logger.warning(f"Hook {hook_name} 不允许 observe 处理器,已跳过 {target.entry.full_name}") + return + + task = asyncio.create_task( + self._run_observe_handler( + hook_name=hook_name, + hook_spec=hook_spec, + target=target, + kwargs=dict(kwargs), + ) + ) + self._background_tasks.add(task) + task.add_done_callback(self._handle_background_task_done) + + async def _run_observe_handler( + self, + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> None: + """执行观察型处理器并吞掉控制流副作用。 + + Args: + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前观察型处理器目标。 + kwargs: 调用参数快照。 + """ + + execution_result = await self._invoke_handler( + hook_name=hook_name, + hook_spec=hook_spec, + target=target, + kwargs=kwargs, + ) + + if not execution_result.success: + logger.warning( + f"观察型 HookHandler {target.entry.full_name} 执行失败: " + f"{execution_result.error_message or '未知错误'}" + ) + return + + if execution_result.modified_kwargs is not None: + logger.warning(f"观察型 HookHandler {target.entry.full_name} 返回了 modified_kwargs,已忽略") + if execution_result.action == "abort": + logger.warning(f"观察型 HookHandler {target.entry.full_name} 请求 abort,已忽略") + + def _handle_background_task_done(self, task: asyncio.Task[Any]) -> None: + """处理观察任务完成回调。 + + Args: + task: 已完成的后台任务。 + """ + + self._background_tasks.discard(task) + with contextlib.suppress(asyncio.CancelledError): + exception = task.exception() + if exception is not None: + logger.error(f"观察型 Hook 后台任务执行失败: {exception}") diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 08638d16..c94fcb3f 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -49,7 +49,7 @@ from .api_registry import APIRegistry from .capability_service import CapabilityService from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher -from .hook_dispatcher import HookDispatcher +from .hook_dispatcher import HookDispatchResult, HookDispatcher from .logger_bridge import RunnerLogBridge from .message_gateway import MessageGateway from .rpc_server import RPCServer @@ -80,6 +80,7 @@ class PluginRunnerSupervisor: def __init__( self, plugin_dirs: Optional[List[Path]] = None, + group_name: str = "third_party", socket_path: Optional[str] = None, health_check_interval_sec: Optional[float] = None, max_restart_attempts: Optional[int] = None, @@ -89,12 +90,14 @@ class PluginRunnerSupervisor: Args: plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 + group_name: 当前 Supervisor 所属运行时分组名称。 socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 health_check_interval_sec: 健康检查间隔,单位秒。 max_restart_attempts: 自动重启 Runner 的最大次数。 runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 """ runtime_config = global_config.plugin_runtime + self._group_name: str = str(group_name or "third_party").strip() or "third_party" self._plugin_dirs: List[Path] = plugin_dirs or [] self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 self._runner_spawn_timeout: float = ( @@ -108,7 +111,7 @@ class PluginRunnerSupervisor: self._api_registry = APIRegistry() self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) - self._hook_dispatcher = HookDispatcher(self._component_registry) + self._hook_dispatcher = HookDispatcher(lambda: [self]) self._message_gateway = MessageGateway(self._component_registry) self._log_bridge = RunnerLogBridge() @@ -133,6 +136,12 @@ class PluginRunnerSupervisor: """返回授权管理器。""" return self._authorization + @property + def group_name(self) -> str: + """返回当前 Supervisor 的运行时分组名称。""" + + return self._group_name + @property def capability_service(self) -> CapabilityService: """返回能力服务。""" @@ -243,17 +252,18 @@ class PluginRunnerSupervisor: """ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: - """分发 Hook 到已注册的 Hook 处理器。 + async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult: + """在当前 Supervisor 内触发一次命名 Hook 调用。 Args: - stage: Hook 阶段名称。 - **kwargs: 传递给 Hook 的关键字参数。 + hook_name: 本次触发的 Hook 名称。 + **kwargs: 传递给 Hook 处理器的关键字参数。 Returns: - Dict[str, Any]: 经 Hook 修改后的参数字典。 + HookDispatchResult: 聚合后的 Hook 调用结果。 """ - return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) + + return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs) async def send_message_to_external( self, diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index c34f5ef5..264c8ed2 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -3,8 +3,9 @@ 提供 PluginRuntimeManager 单例,负责: 1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程) 2. 将 EventType 桥接到运行时的 event dispatch -3. 在运行时的 ComponentRegistry 中查找命令 -4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 +3. 触发跨 Supervisor 的命名 Hook 调用 +4. 在运行时的 ComponentRegistry 中查找命令 +5. 提供统一的能力实现注册接口,使插件可以调用主程序功能 """ from pathlib import Path @@ -24,6 +25,7 @@ from src.plugin_runtime.capabilities import ( RuntimeDataCapabilityMixin, ) from src.plugin_runtime.capabilities.registry import register_capability_impls +from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils from src.plugin_runtime.runner.manifest_validator import ManifestValidator @@ -72,6 +74,7 @@ class PluginRuntimeManager( self._manifest_validator: ManifestValidator = ManifestValidator() self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload self._config_reload_callback_registered: bool = False + self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors) async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: """接收 Platform IO 审核后的入站消息并送入主消息链。 @@ -182,6 +185,7 @@ class PluginRuntimeManager( if builtin_dirs: self._builtin_supervisor = PluginSupervisor( plugin_dirs=builtin_dirs, + group_name="builtin", socket_path=builtin_socket, ) self._register_capability_impls(self._builtin_supervisor) @@ -189,6 +193,7 @@ class PluginRuntimeManager( if third_party_dirs: self._third_party_supervisor = PluginSupervisor( plugin_dirs=third_party_dirs, + group_name="third_party", socket_path=third_party_socket, ) self._register_capability_impls(self._third_party_supervisor) @@ -235,6 +240,7 @@ class PluginRuntimeManager( await platform_io_manager.stop() except Exception as platform_io_exc: logger.warning(f"Platform IO 停止失败: {platform_io_exc}") + await self._hook_dispatcher.stop() self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -274,6 +280,7 @@ class PluginRuntimeManager( else: logger.info("插件运行时已停止") finally: + await self._hook_dispatcher.stop() self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -284,11 +291,41 @@ class PluginRuntimeManager( """返回插件运行时是否处于启动状态。""" return self._started + @property + def hook_dispatcher(self) -> HookDispatcher: + """返回跨 Supervisor 的命名 Hook 分发器。""" + + return self._hook_dispatcher + + @property + def invoke_dispatcher(self) -> HookDispatcher: + """返回命名 Hook 分发器的兼容别名。""" + + return self._hook_dispatcher + @property def supervisors(self) -> List["PluginSupervisor"]: """获取所有活跃的 Supervisor""" return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] + def register_hook_spec(self, spec: HookSpec) -> None: + """注册单个命名 Hook 规格。 + + Args: + spec: 需要注册的 Hook 规格。 + """ + + self._hook_dispatcher.register_hook_spec(spec) + + def register_hook_specs(self, specs: Sequence[HookSpec]) -> None: + """批量注册命名 Hook 规格。 + + Args: + specs: 需要注册的 Hook 规格序列。 + """ + + self._hook_dispatcher.register_hook_specs(specs) + def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: """根据当前已注册插件构建全局依赖图。""" @@ -588,6 +625,19 @@ class PluginRuntimeManager( return True, modified + async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult: + """触发一次跨 Supervisor 的命名 Hook 调用。 + + Args: + hook_name: 本次触发的 Hook 名称。 + **kwargs: 传递给 Hook 处理器的关键字参数。 + + Returns: + HookDispatchResult: 聚合后的 Hook 调用结果。 + """ + + return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs) + # ─── 命令查找 ────────────────────────────────────────────── def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]: diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py index 6f42940f..03f2db4d 100644 --- a/src/plugin_runtime/runner/log_handler.py +++ b/src/plugin_runtime/runner/log_handler.py @@ -164,7 +164,7 @@ class RunnerIPCLogHandler(logging.Handler): return f"{event_text} {' '.join(extras)}".strip() return event_text - # format() 会处理 %s 参数替换和 exc_info 文本拼接。 + # format() 会处理占位参数替换和 exc_info 文本拼接。 return self.format(record) @staticmethod diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index d1ebc064..9de5d977 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -330,7 +330,6 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) - self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) @@ -1053,73 +1052,28 @@ class PluginRunner: ) except Exception as exc: logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) - return envelope.make_response(payload={"success": False, "continue_processing": True}) + return envelope.make_response( + payload={ + "success": False, + "action": "continue", + "error_message": str(exc), + } + ) if raw is None: - result = {"success": True, "continue_processing": True} + result = {"success": True, "action": "continue"} elif isinstance(raw, dict): result = { "success": True, - "continue_processing": raw.get("continue_processing", True), + "action": str(raw.get("action", "continue") or "continue").strip().lower() or "continue", "modified_kwargs": raw.get("modified_kwargs"), "custom_result": raw.get("custom_result"), } else: - result = {"success": True, "continue_processing": True, "custom_result": raw} + result = {"success": True, "action": "continue", "custom_result": raw} return envelope.make_response(payload=result) - async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: - """处理 WorkflowStep 调用请求 - - 与通用 invoke 不同,会将返回值规范化为 - {hook_result, modified_message, stage_output} 格式。 - """ - try: - invoke = InvokePayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - - plugin_id = envelope.plugin_id - meta = self._loader.get_plugin(plugin_id) - if meta is None: - return envelope.make_error_response( - ErrorCode.E_PLUGIN_NOT_FOUND.value, - f"插件 {plugin_id} 未加载", - ) - - component_name = invoke.component_name - handler_method = self._resolve_component_handler(meta, component_name) - - if handler_method is None or not callable(handler_method): - return envelope.make_error_response( - ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"插件 {plugin_id} 无组件: {component_name}", - ) - - try: - raw = ( - await handler_method(**invoke.args) - if inspect.iscoroutinefunction(handler_method) - else handler_method(**invoke.args) - ) - - # 规范化返回值 - if isinstance(raw, str): - result = {"hook_result": raw} - elif isinstance(raw, dict): - result = raw - result.setdefault("hook_result", "continue") - else: - result = {"hook_result": "continue"} - - resp_payload = InvokeResultPayload(success=True, result=result) - return envelope.make_response(payload=resp_payload.model_dump()) - except Exception as e: - logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True) - resp_payload = InvokeResultPayload(success=False, result=str(e)) - return envelope.make_response(payload=resp_payload.model_dump()) - async def _handle_health(self, envelope: Envelope) -> Envelope: """处理健康检查""" uptime_ms = int((time.monotonic() - self._start_time) * 1000) diff --git a/src/plugin_runtime/tool_provider.py b/src/plugin_runtime/tool_provider.py new file mode 100644 index 00000000..84bed06e --- /dev/null +++ b/src/plugin_runtime/tool_provider.py @@ -0,0 +1,48 @@ +"""插件运行时工具 Provider。""" + +from __future__ import annotations + +from typing import Optional + +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec + +from .component_query import component_query_service + + +class PluginToolProvider(ToolProvider): + """将插件 Tool 与兼容旧 Action 暴露为统一工具 Provider。""" + + provider_name = "plugin_runtime" + provider_type = "plugin" + + async def list_tools(self) -> list[ToolSpec]: + """列出插件运行时当前可用的工具声明。""" + + return list(component_query_service.get_llm_available_tool_specs().values()) + + async def invoke( + self, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, + ) -> ToolExecutionResult: + """执行插件工具或兼容旧 Action 的工具调用。 + + Args: + invocation: 工具调用请求。 + context: 执行上下文。 + + Returns: + ToolExecutionResult: 工具执行结果。 + """ + + return await component_query_service.invoke_tool_as_tool( + invocation=invocation, + context=context, + ) + + async def close(self) -> None: + """关闭 Provider。 + + 插件运行时工具 Provider 不持有独立资源。 + """ + diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index aa2da795..00e7578c 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -6,7 +6,7 @@ from maibot_sdk import Command, MaiBotPlugin -_VALID_COMPONENT_TYPES = ("action", "command", "event_handler") +_VALID_COMPONENT_TYPES = ("tool", "command", "event_handler") HELP_ALL = ( "管理命令帮助\n" @@ -37,7 +37,7 @@ HELP_COMPONENT = ( "/pm component enable local 本聊天启用组件\n" "/pm component disable global 全局禁用组件\n" "/pm component disable local 本聊天禁用组件\n" - " - 可选项: action, command, event_handler\n" + " - 可选项: tool, command, event_handler\n" ) diff --git a/src/services/database_service.py b/src/services/database_service.py index 5b8b716f..5e41f2c6 100644 --- a/src/services/database_service.py +++ b/src/services/database_service.py @@ -4,14 +4,14 @@ import json import time import traceback from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, cast from sqlalchemy import delete, func, select from sqlmodel import SQLModel from src.chat.message_receive.chat_manager import BotChatSession from src.common.database.database import get_db_session -from src.common.database.database_model import ActionRecord +from src.common.database.database_model import ToolRecord from src.common.logger import get_logger logger = get_logger("database_service") @@ -65,7 +65,7 @@ async def db_save( record = None if key_field and key_value is not None: key_column = _get_model_field(model_class, key_field) - record = session.exec(select(model_class).where(key_column == key_value)).first() + record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first() if record is None: record = model_class(**data) @@ -99,7 +99,7 @@ async def db_get( statement = _apply_order_by(statement, model_class, order_by) if limit: statement = statement.limit(limit) - results = session.exec(statement).all() + results = session.exec(cast(Any, statement)).all() data = [_to_dict(item) for item in results] if single_result: return data[0] if data else None @@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: statement = select(model_class) if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) - records = session.exec(statement).all() + records = session.exec(cast(Any, statement)).all() for record in records: for field_name, value in data.items(): _get_model_field(model_class, field_name) @@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] statement = select(func.count()).select_from(model_class) if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) - result = session.exec(statement).one() + result = session.exec(cast(Any, statement)).one() return int(result or 0) except Exception as e: logger.error(f"[DatabaseService] 统计数据库记录出错: {e}") @@ -157,6 +157,39 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] return 0 +async def store_tool_info( + chat_stream: BotChatSession, + builtin_prompt: Optional[str] = None, + display_prompt: str = "", + tool_id: str = "", + tool_data: Optional[dict[str, Any]] = None, + tool_name: str = "", + tool_reasoning: str = "", +) -> Optional[dict[str, Any]]: + try: + record_data = { + "tool_id": tool_id or str(int(time.time() * 1000000)), + "timestamp": datetime.now(), + "session_id": chat_stream.session_id, + "tool_name": tool_name, + "tool_data": json.dumps(tool_data or {}, ensure_ascii=False), + "tool_reasoning": tool_reasoning, + "tool_builtin_prompt": builtin_prompt, + "tool_display_prompt": display_prompt, + } + + saved_record = await db_save(ToolRecord, data=record_data, key_field="tool_id", key_value=record_data["tool_id"]) + if saved_record: + logger.debug(f"[DatabaseService] 成功存储工具信息: {tool_name} (ID: {record_data['tool_id']})") + else: + logger.error(f"[DatabaseService] 存储工具信息失败: {tool_name}") + return saved_record + except Exception as e: + logger.error(f"[DatabaseService] 存储工具信息时发生错误: {e}") + traceback.print_exc() + return None + + async def store_action_info( chat_stream: BotChatSession, builtin_prompt: Optional[str] = None, @@ -166,27 +199,13 @@ async def store_action_info( action_name: str = "", action_reasoning: str = "", ) -> Optional[dict[str, Any]]: - try: - record_data = { - "action_id": thinking_id or str(int(time.time() * 1000000)), - "timestamp": datetime.now(), - "session_id": chat_stream.session_id, - "action_name": action_name, - "action_data": json.dumps(action_data or {}, ensure_ascii=False), - "action_reasoning": action_reasoning, - "action_builtin_prompt": builtin_prompt, - "action_display_prompt": display_prompt, - } - - saved_record = await db_save( - ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"] - ) - if saved_record: - logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") - else: - logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}") - return saved_record - except Exception as e: - logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}") - traceback.print_exc() - return None + """兼容旧接口,内部转发到 ``store_tool_info``。""" + return await store_tool_info( + chat_stream=chat_stream, + builtin_prompt=builtin_prompt, + display_prompt=display_prompt, + tool_id=thinking_id, + tool_data=action_data, + tool_name=action_name, + tool_reasoning=action_reasoning, + ) diff --git a/src/services/llm_service.py b/src/services/llm_service.py index 2927b5c1..de116507 100644 --- a/src/services/llm_service.py +++ b/src/services/llm_service.py @@ -1,191 +1,492 @@ -"""LLM 服务模块 +"""LLM 服务层。 -提供与 LLM 模型交互的核心功能。 +该模块负责在宿主侧收口统一的 LLM 服务请求模型,并将其转发到 +`src.llm_models` 中的底层请求调度器。 """ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple +import json + +from src.common.data_models.llm_service_data_models import ( + LLMAudioTranscriptionResult, + LLMEmbeddingResult, + LLMGenerationOptions, + LLMImageOptions, + LLMResponseResult, + LLMServiceRequest, + LLMServiceResult, + MessageFactory, + PromptInput, + PromptMessage, +) from src.common.logger import get_logger from src.config.config import config_manager from src.config.model_configs import TaskConfig from src.llm_models.model_client.base_client import BaseClient -from src.llm_models.payload_content.message import Message +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType from src.llm_models.payload_content.tool_option import ToolCall -from src.llm_models.utils_model import LLMRequest +from src.llm_models.utils_model import LLMOrchestrator logger = get_logger("llm_service") +class LLMServiceClient: + """面向上层模块的 LLM 服务对象式门面。 -async def _generate_response( - model_config: TaskConfig, - request_type: str, - prompt: Optional[str] = None, - message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, - tool_options: Optional[List[Dict[str, Any]]] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[str, str, str, List[ToolCall] | None]: - llm_request = LLMRequest(model_set=model_config, request_type=request_type) + 当前推荐优先使用以下正式接口: + - `generate_response` + - `generate_response_with_messages` + - `generate_response_for_image` + - `transcribe_audio` + - `embed_text` + """ - if message_factory is not None: - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async( - message_factory=message_factory, - tools=tool_options, - temperature=temperature, - max_tokens=max_tokens, + def __init__(self, task_name: str, request_type: str = "") -> None: + """初始化 LLM 服务门面。 + + Args: + task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 + request_type: 当前请求的业务类型标识。 + """ + self.task_name = resolve_task_name(task_name) + self.request_type = request_type + self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type) + + @staticmethod + def _normalize_generation_options(options: LLMGenerationOptions | None = None) -> LLMGenerationOptions: + """规范化文本生成选项。 + + Args: + options: 原始生成选项。 + + Returns: + LLMGenerationOptions: 可直接用于执行请求的完整选项对象。 + """ + if options is None: + return LLMGenerationOptions() + return options + + @staticmethod + def _normalize_image_options(options: LLMImageOptions | None = None) -> LLMImageOptions: + """规范化图像理解选项。 + + Args: + options: 原始图像理解选项。 + + Returns: + LLMImageOptions: 可直接用于执行请求的完整选项对象。 + """ + if options is None: + return LLMImageOptions() + return options + + async def generate_response( + self, + prompt: str, + options: LLMGenerationOptions | None = None, + ) -> LLMResponseResult: + """生成单轮文本响应。 + + Args: + prompt: 文本提示词。 + options: 文本生成选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_generation_options(options) + return await self._orchestrator.generate_response_async( + prompt=prompt, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + tools=active_options.tool_options, + response_format=active_options.response_format, + raise_when_empty=active_options.raise_when_empty, + interrupt_flag=active_options.interrupt_flag, ) - return response, reasoning_content, model_name, tool_call - if prompt is None: - raise ValueError("prompt 与 message_factory 不能同时为空") + async def generate_response_with_messages( + self, + message_factory: MessageFactory, + options: LLMGenerationOptions | None = None, + ) -> LLMResponseResult: + """基于消息工厂生成响应。 - response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( - prompt, - tools=tool_options, - temperature=temperature, - max_tokens=max_tokens, - ) - return response, reasoning_content, model_name, tool_call + Args: + message_factory: 消息工厂,会根据客户端能力构建消息列表。 + options: 文本生成选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_generation_options(options) + return await self._orchestrator.generate_response_with_message_async( + message_factory=message_factory, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + tools=active_options.tool_options, + response_format=active_options.response_format, + raise_when_empty=active_options.raise_when_empty, + interrupt_flag=active_options.interrupt_flag, + ) + + async def generate_response_for_image( + self, + prompt: str, + image_base64: str, + image_format: str, + options: LLMImageOptions | None = None, + ) -> LLMResponseResult: + """为图像内容生成响应。 + + Args: + prompt: 文本提示词。 + image_base64: 图像的 Base64 编码字符串。 + image_format: 图像格式,例如 ``png``、``jpeg``。 + options: 图像理解选项。 + + Returns: + LLMResponseResult: 统一文本生成结果。 + """ + active_options = self._normalize_image_options(options) + return await self._orchestrator.generate_response_for_image( + prompt=prompt, + image_base64=image_base64, + image_format=image_format, + temperature=active_options.temperature, + max_tokens=active_options.max_tokens, + interrupt_flag=active_options.interrupt_flag, + ) + + async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult: + """执行音频转写请求。 + + Args: + voice_base64: 音频的 Base64 编码字符串。 + + Returns: + LLMAudioTranscriptionResult: 音频转写结果对象。 + """ + return await self._orchestrator.generate_response_for_voice(voice_base64) + + async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult: + """生成文本嵌入向量。 + + Args: + embedding_input: 待编码的文本。 + + Returns: + LLMEmbeddingResult: 向量生成结果对象。 + """ + return await self._orchestrator.get_embedding(embedding_input) def get_available_models() -> Dict[str, TaskConfig]: - """获取所有可用的模型配置 + """获取所有可用模型配置。 Returns: - Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置 + Dict[str, TaskConfig]: 以模型任务名为键的配置映射。 """ try: models = config_manager.get_model_config().model_task_config - attrs = dir(models) - rets: Dict[str, TaskConfig] = {} - for attr in attrs: - if not attr.startswith("__"): - try: - value = getattr(models, attr) - if not callable(value) and isinstance(value, TaskConfig): - rets[attr] = value - except Exception as e: - logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}") - continue - return rets - - except Exception as e: - logger.error(f"[LLMService] 获取可用模型失败: {e}") + available_models: Dict[str, TaskConfig] = {} + for attr_name in dir(models): + if attr_name.startswith("__"): + continue + try: + attr_value = getattr(models, attr_name) + except Exception as exc: + logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}") + continue + if not callable(attr_value) and isinstance(attr_value, TaskConfig): + available_models[attr_name] = attr_value + return available_models + except Exception as exc: + logger.error(f"[LLMService] 获取可用模型失败: {exc}") return {} -async def generate_with_model( - prompt: str, - model_config: TaskConfig, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str]: - """使用指定模型生成内容 +def resolve_task_name(task_name: str = "") -> str: + """根据名称解析任务配置名。 Args: - prompt: 提示词 - model_config: 模型配置(从 get_available_models 获取的模型配置) - request_type: 请求类型标识 + task_name: 目标任务配置名;为空时返回首个可用任务名。 Returns: - Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + str: 解析得到的任务配置名。 + + Raises: + RuntimeError: 当前没有任何可用模型配置。 + ValueError: 指定名称不存在时抛出。 """ - try: - logger.debug(f"[LLMService] 完整提示词: {prompt}") - response, reasoning_content, model_name, _ = await _generate_response( - model_config=model_config, - request_type=request_type, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - ) - return True, response, reasoning_content, model_name - - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "" + models = get_available_models() + if not models: + raise RuntimeError("没有可用的模型配置") + normalized_task_name = task_name.strip() + if not normalized_task_name: + return next(iter(models.keys())) + if normalized_task_name not in models: + raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置") + return normalized_task_name -async def generate_with_model_with_tools( - prompt: str, - model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: - """使用指定模型和工具生成内容 +def _normalize_role(role_name: str) -> RoleType: + """将原始角色字符串转换为内部角色枚举。 Args: - prompt: 提示词 - model_config: 模型配置(从 get_available_models 获取的模型配置) - tool_options: 工具选项列表 - request_type: 请求类型标识 - temperature: 温度参数 - max_tokens: 最大token数 + role_name: 原始角色名称。 Returns: - Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表) + RoleType: 规范化后的角色枚举。 + + Raises: + ValueError: 角色类型不受支持时抛出。 """ + normalized_role_name = role_name.strip().lower() try: - model_name_list = model_config.model_list - logger.info(f"使用模型{model_name_list}生成内容") - logger.debug(f"完整提示词: {prompt}") - - response, reasoning_content, model_name, tool_call = await _generate_response( - model_config=model_config, - request_type=request_type, - prompt=prompt, - tool_options=tool_options, - temperature=temperature, - max_tokens=max_tokens, - ) - return True, response, reasoning_content, model_name, tool_call - - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "", None + return RoleType(normalized_role_name) + except ValueError as exc: + raise ValueError(f"不支持的消息角色: {role_name}") from exc -async def generate_with_model_with_tools_by_message_factory( - message_factory: Callable[[BaseClient], List[Message]], - model_config: TaskConfig, - tool_options: List[Dict[str, Any]] | None = None, - request_type: str = "plugin.generate", - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, -) -> Tuple[bool, str, str, str, List[ToolCall] | None]: - """使用指定模型和工具生成内容(通过消息工厂构建消息列表) +def _parse_data_url_image(image_url: str) -> Tuple[str, str]: + """解析 Data URL 形式的图片内容。 Args: - message_factory: 消息工厂函数 - model_config: 模型配置 - tool_options: 工具选项列表 - request_type: 请求类型标识 - temperature: 温度参数 - max_tokens: 最大token数 + image_url: 图片 URL。 Returns: - Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表) + Tuple[str, str]: `(图片格式, Base64 数据)`。 + + Raises: + ValueError: 输入不是受支持的 Data URL 时抛出。 """ - try: - model_name_list = model_config.model_list - logger.info(f"使用模型 {model_name_list} 生成内容") + if not image_url.startswith("data:image/") or ";base64," not in image_url: + raise ValueError("仅支持 Data URL 形式的图片输入") + prefix, image_base64 = image_url.split(";base64,", maxsplit=1) + image_format = prefix.removeprefix("data:image/") + if not image_format or not image_base64: + raise ValueError("图片 Data URL 不完整") + return image_format, image_base64 - response, reasoning_content, model_name, tool_call = await _generate_response( - model_config=model_config, - request_type=request_type, - message_factory=message_factory, - tool_options=tool_options, - temperature=temperature, - max_tokens=max_tokens, + +def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None: + """将原始消息内容追加到内部消息构建器。 + + Args: + message_builder: 目标消息构建器。 + content: 原始消息内容。 + + Raises: + ValueError: 消息内容结构不受支持时抛出。 + """ + if isinstance(content, str): + message_builder.add_text_content(content) + return + + content_items: List[Any] + if isinstance(content, list): + content_items = content + elif isinstance(content, dict): + content_items = [content] + else: + raise ValueError("消息内容必须为字符串、字典或列表") + + for content_item in content_items: + if isinstance(content_item, str): + message_builder.add_text_content(content_item) + continue + if not isinstance(content_item, dict): + raise ValueError("消息内容列表中仅支持字符串或字典片段") + + part_type = str(content_item.get("type", "text")).strip().lower() + if part_type == "text": + text_content = content_item.get("text") + if not isinstance(text_content, str): + raise ValueError("文本片段缺少 `text` 字段") + message_builder.add_text_content(text_content) + continue + + if part_type in {"image", "image_url", "input_image"}: + image_url = content_item.get("image_url") + if isinstance(image_url, dict): + image_url = image_url.get("url") + if isinstance(image_url, str): + image_format, image_base64 = _parse_data_url_image(image_url) + message_builder.add_image_content(image_format=image_format, image_base64=image_base64) + continue + + image_format = content_item.get("image_format") + image_base64 = content_item.get("image_base64") + if isinstance(image_format, str) and isinstance(image_base64, str): + message_builder.add_image_content(image_format=image_format, image_base64=image_base64) + continue + raise ValueError("图片片段缺少可识别的图片数据") + + raise ValueError(f"不支持的消息片段类型: {part_type}") + + +def _normalize_tool_arguments(arguments: Any) -> Dict[str, Any] | None: + """将原始工具参数规范化为字典。 + + Args: + arguments: 原始工具参数。 + + Returns: + Dict[str, Any] | None: 规范化后的参数字典。 + """ + if arguments is None: + return None + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + stripped_arguments = arguments.strip() + if not stripped_arguments: + return {} + try: + parsed_arguments = json.loads(stripped_arguments) + except json.JSONDecodeError: + return {"raw_arguments": arguments} + if isinstance(parsed_arguments, dict): + return parsed_arguments + return {"value": parsed_arguments} + return {"value": arguments} + + +def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None: + """从原始消息中提取工具调用列表。 + + Args: + raw_tool_calls: 原始工具调用结构。 + + Returns: + List[ToolCall] | None: 规范化后的工具调用列表。 + + Raises: + ValueError: 工具调用结构缺失必要字段时抛出。 + """ + if raw_tool_calls is None: + return None + if not isinstance(raw_tool_calls, list): + raise ValueError("`tool_calls` 必须为列表") + + tool_calls: List[ToolCall] = [] + for raw_tool_call in raw_tool_calls: + if not isinstance(raw_tool_call, dict): + raise ValueError("工具调用项必须为字典") + + function_info = raw_tool_call.get("function") + if isinstance(function_info, dict): + func_name = function_info.get("name") + arguments = function_info.get("arguments") + else: + func_name = raw_tool_call.get("name") or raw_tool_call.get("func_name") + arguments = raw_tool_call.get("arguments") or raw_tool_call.get("args") + + call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id") + if not isinstance(call_id, str) or not isinstance(func_name, str): + raise ValueError("工具调用缺少 `id` 或函数名称") + + tool_calls.append( + ToolCall( + call_id=call_id, + func_name=func_name, + args=_normalize_tool_arguments(arguments), + ) ) - return True, response, reasoning_content, model_name, tool_call - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"[LLMService] {error_msg}") - return False, error_msg, "", "", None + return tool_calls or None + + +def _build_message_from_dict(raw_message: PromptMessage) -> Message: + """将原始消息字典转换为内部消息对象。 + + Args: + raw_message: 原始消息字典。 + + Returns: + Message: 规范化后的消息对象。 + + Raises: + ValueError: 原始消息结构不合法时抛出。 + """ + raw_role = raw_message.get("role") + if not isinstance(raw_role, str): + raise ValueError("消息缺少字符串类型的 `role` 字段") + + role = _normalize_role(raw_role) + message_builder = MessageBuilder().set_role(role) + + tool_calls = _build_tool_calls(raw_message.get("tool_calls")) + if tool_calls is not None: + message_builder.set_tool_calls(tool_calls) + + tool_call_id = raw_message.get("tool_call_id") + if isinstance(tool_call_id, str) and role == RoleType.Tool: + message_builder.set_tool_call_id(tool_call_id) + + if "content" in raw_message and raw_message["content"] not in (None, "", []): + _append_content_parts(message_builder, raw_message["content"]) + + return message_builder.build() + + +def _build_prompt_message_factory(prompt: PromptInput) -> MessageFactory: + """将统一提示输入转换为消息工厂。 + + Args: + prompt: 原始提示输入。 + + Returns: + MessageFactory: 惰性构建消息列表的工厂函数。 + """ + if isinstance(prompt, str): + def build_messages(_: BaseClient) -> List[Message]: + """构建单条用户消息。""" + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + return [message_builder.build()] + + return build_messages + + def build_messages(_: BaseClient) -> List[Message]: + """构建多消息对话输入。""" + return [_build_message_from_dict(raw_message) for raw_message in prompt] + + return build_messages + + +async def generate(request: LLMServiceRequest) -> LLMServiceResult: + """执行统一的 LLM 服务请求。 + + Args: + request: 服务层统一请求对象。 + + Returns: + LLMServiceResult: 统一响应对象;失败时 `success=False`。 + """ + llm_client = LLMServiceClient(task_name=request.task_name, request_type=request.request_type) + if request.message_factory is not None: + active_message_factory = request.message_factory + else: + prompt = request.prompt + if prompt is None: + raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个") + active_message_factory = _build_prompt_message_factory(prompt) + + try: + generation_result = await llm_client.generate_response_with_messages( + message_factory=active_message_factory, + options=LLMGenerationOptions( + temperature=request.temperature, + max_tokens=request.max_tokens, + tool_options=request.tool_options, + response_format=request.response_format, + interrupt_flag=request.interrupt_flag, + ), + ) + return LLMServiceResult.from_response_result(generation_result) + except Exception as exc: + error_message = f"生成内容时出错: {exc}" + logger.error(f"[LLMService] {error_message}") + return LLMServiceResult.from_error(error_message, str(exc)) diff --git a/src/services/message_service.py b/src/services/message_service.py index d918b177..5291bf04 100644 --- a/src/services/message_service.py +++ b/src/services/message_service.py @@ -7,9 +7,9 @@ from typing import List, Optional, Tuple from sqlmodel import col, select from src.chat.message_receive.message import SessionMessage -from src.common.data_models.action_record_data_model import MaiActionRecord +from src.common.data_models.tool_record_data_model import MaiToolRecord from src.common.database.database import get_db_session -from src.common.database.database_model import ActionRecord, Images, ImageType +from src.common.database.database_model import Images, ImageType, ToolRecord from src.common.message_repository import count_messages, find_messages from src.common.utils.math_utils import translate_timestamp_to_human_readable from src.common.utils.utils_action import ActionUtils @@ -238,18 +238,18 @@ def get_actions_by_timestamp_with_chat( timestamp_start: float, timestamp_end: float, limit: Optional[int] = None, -) -> List[MaiActionRecord]: +) -> List[MaiToolRecord]: with get_db_session() as session: statement = ( - select(ActionRecord) - .where(col(ActionRecord.session_id) == chat_id) - .where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start)) - .where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end)) - .order_by(col(ActionRecord.timestamp)) + select(ToolRecord) + .where(col(ToolRecord.session_id) == chat_id) + .where(col(ToolRecord.timestamp) >= datetime.fromtimestamp(timestamp_start)) + .where(col(ToolRecord.timestamp) <= datetime.fromtimestamp(timestamp_end)) + .order_by(col(ToolRecord.timestamp)) ) if limit is not None: statement = statement.limit(limit) - return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()] + return [MaiToolRecord.from_db_instance(item) for item in session.exec(statement).all()] def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str: diff --git a/src/services/send_service.py b/src/services/send_service.py index 134fb15e..d7f17563 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -281,6 +281,26 @@ def _build_processed_plain_text(message: SessionMessage) -> str: return " ".join(part for part in processed_parts if part) +def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160) -> str: + """构造出站消息的日志预览文本。 + + Args: + message: 待发送的内部消息对象。 + max_length: 预览文本最大长度。 + + Returns: + str: 适用于日志展示的消息摘要。 + """ + preview_text = (message.processed_plain_text or message.display_message or "").strip() + if not preview_text: + preview_text = f"[{_describe_message_sequence(message.raw_message)}]" + + normalized_preview = " ".join(preview_text.split()) + if len(normalized_preview) <= max_length: + return normalized_preview + return f"{normalized_preview[:max_length]}..." + + def _build_outbound_session_message( message_sequence: MessageSequence, stream_id: str, @@ -424,11 +444,7 @@ def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" for receipt in delivery_batch.failed_receipts ) or "未命中任何发送路由" - logger.warning( - "[SendService] Platform IO 发送失败: platform=%s %s", - delivery_batch.route_key.platform, - failed_details, - ) + logger.warning(f"[SendService] Platform IO 发送失败: platform={delivery_batch.route_key.platform} {failed_details}") async def _send_via_platform_io( @@ -493,9 +509,9 @@ async def _send_via_platform_io( for receipt in delivery_batch.sent_receipts ] logger.info( - "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)", - route_key.platform, - ", ".join(successful_driver_ids), + f"[SendService] 已通过 Platform IO 将消息发往平台 '{route_key.platform}' " + f"(drivers: {', '.join(successful_driver_ids)}) " + f"message={_build_outbound_log_preview(message)}" ) return True diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 2f67aca5..fad701ba 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -13,6 +13,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query from src.common.logger import get_logger from src.config.config import CONFIG_DIR +from src.config.model_configs import APIProvider +from src.llm_models.openai_compat import build_openai_compatible_client_config, normalize_openai_base_url from src.webui.dependencies import require_auth from src.webui.utils.network_security import validate_public_url @@ -35,8 +37,8 @@ MODEL_FETCHER_CONFIG = { def _normalize_url(url: str) -> str: - """规范化 URL(去掉尾部斜杠)""" - return url.rstrip("/") if url else "" + """规范化 URL(去掉尾部斜杠)。""" + return normalize_openai_base_url(url) if url else "" def _parse_openai_response(data: Dict) -> List[Dict]: @@ -89,19 +91,30 @@ async def _fetch_models_from_provider( endpoint: str, parser: str, client_type: str = "openai", + auth_type: str = "bearer", + auth_header_name: str = "Authorization", + auth_header_prefix: str = "Bearer", + auth_query_name: str = "api_key", + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, str]] = None, ) -> List[Dict]: - """ - 从提供商 API 获取模型列表 + """从提供商 API 获取模型列表。 Args: - base_url: 提供商的基础 URL - api_key: API 密钥 - endpoint: 获取模型列表的端点 - parser: 响应解析器类型 ('openai' | 'gemini') - client_type: 客户端类型 ('openai' | 'gemini') + base_url: 提供商的基础 URL。 + api_key: API 密钥。 + endpoint: 获取模型列表的端点。 + parser: 响应解析器类型。 + client_type: 客户端类型。 + auth_type: OpenAI 兼容接口的鉴权方式。 + auth_header_name: Header 鉴权时使用的请求头名称。 + auth_header_prefix: Header 鉴权时使用的请求头前缀。 + auth_query_name: Query 鉴权时使用的查询参数名称。 + default_headers: 默认附带的请求头。 + default_query: 默认附带的查询参数。 Returns: - 模型列表 + List[Dict]: 解析后的模型列表。 """ try: base_url = validate_public_url(_normalize_url(base_url)) @@ -118,8 +131,21 @@ async def _fetch_models_from_provider( # Gemini 使用 URL 参数传递 API Key params["key"] = api_key else: - # OpenAI 兼容格式使用 Authorization 头 - headers["Authorization"] = f"Bearer {api_key}" + provider = APIProvider( + name="webui-openai-compatible-fetcher", + base_url=base_url, + api_key=api_key, + client_type="openai", + auth_type=auth_type, + auth_header_name=auth_header_name, + auth_header_prefix=auth_header_prefix, + auth_query_name=auth_query_name, + default_headers=default_headers or {}, + default_query=default_query or {}, + ) + client_config = build_openai_compatible_client_config(provider) + headers.update(client_config.default_headers) + params.update(client_config.default_query) try: async with httpx.AsyncClient(timeout=30.0) as client: @@ -186,10 +212,9 @@ async def get_provider_models( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), ): - """ - 获取指定提供商的可用模型列表 + """获取指定提供商的可用模型列表。 - 通过提供商名称查找配置,然后请求对应的模型列表端点 + 通过提供商名称查找配置,然后请求对应的模型列表端点。 """ # 获取提供商配置 provider_config = _get_provider_config(provider_name) @@ -205,13 +230,21 @@ async def get_provider_models( if not api_key: raise HTTPException(status_code=400, detail="提供商配置缺少 api_key") + resolved_endpoint = provider_config.get("model_list_endpoint", endpoint) if endpoint == "/models" else endpoint + # 获取模型列表 models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, - endpoint=endpoint, + endpoint=resolved_endpoint, parser=parser, client_type=client_type, + auth_type=provider_config.get("auth_type", "bearer"), + auth_header_name=provider_config.get("auth_header_name", "Authorization"), + auth_header_prefix=provider_config.get("auth_header_prefix", "Bearer"), + auth_query_name=provider_config.get("auth_query_name", "api_key"), + default_headers=provider_config.get("default_headers", {}), + default_query=provider_config.get("default_query", {}), ) return { @@ -229,16 +262,22 @@ async def get_models_by_url( parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"), endpoint: str = Query("/models", description="获取模型列表的端点"), client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), + auth_type: str = Query("bearer", description="鉴权方式 (bearer | header | query | none)"), + auth_header_name: str = Query("Authorization", description="Header 鉴权名称"), + auth_header_prefix: str = Query("Bearer", description="Header 鉴权前缀"), + auth_query_name: str = Query("api_key", description="Query 鉴权参数名"), ): - """ - 通过 URL 直接获取模型列表(用于自定义提供商) - """ + """通过 URL 直接获取模型列表。""" models = await _fetch_models_from_provider( base_url=base_url, api_key=api_key, endpoint=endpoint, parser=parser, client_type=client_type, + auth_type=auth_type, + auth_header_name=auth_header_name, + auth_header_prefix=auth_header_prefix, + auth_query_name=auth_query_name, ) return {