Merge remote-tracking branch 'upstream/r-dev' into sync/pr-1564-upstream-20260331
# Conflicts: # src/chat/brain_chat/PFC/conversation.py # src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py # src/chat/knowledge/lpmm_ops.py
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -278,6 +278,8 @@ logs
|
||||
.vscode
|
||||
|
||||
/config/*
|
||||
config/mcp_config.json
|
||||
!config/mcp_config.json.template
|
||||
config/old/bot_config_20250405_212257.toml
|
||||
temp/
|
||||
|
||||
|
||||
535
code_scripts/migrate_expression_jargon_db.py
Normal file
535
code_scripts/migrate_expression_jargon_db.py
Normal file
@@ -0,0 +1,535 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from sys import path as sys_path
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import Session, SQLModel, create_engine, delete
|
||||
|
||||
ROOT_PATH = Path(__file__).resolve().parent.parent
|
||||
if str(ROOT_PATH) not in sys_path:
|
||||
sys_path.insert(0, str(ROOT_PATH))
|
||||
|
||||
from src.common.database.database_model import Expression, Jargon, ModifiedBy
|
||||
|
||||
|
||||
def build_argument_parser() -> ArgumentParser:
|
||||
"""构建命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description="将旧版 expression/jargon 数据迁移到新版 expressions/jargons 数据库。"
|
||||
)
|
||||
parser.add_argument("--source-db", dest="source_db", help="旧版 SQLite 数据库路径")
|
||||
parser.add_argument("--target-db", dest="target_db", help="新版 SQLite 数据库路径")
|
||||
parser.add_argument(
|
||||
"--clear-target",
|
||||
dest="clear_target",
|
||||
action="store_true",
|
||||
help="迁移前清空目标库中的 expressions 和 jargons 表",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def prompt_path(prompt_text: str, current_value: Optional[str] = None) -> Path:
|
||||
"""读取数据库路径输入。"""
|
||||
while True:
|
||||
suffix = f" [{current_value}]" if current_value else ""
|
||||
raw_text = input(f"{prompt_text}{suffix}: ").strip()
|
||||
value = raw_text or current_value or ""
|
||||
if not value:
|
||||
print("路径不能为空,请重新输入。")
|
||||
continue
|
||||
return Path(value).expanduser().resolve()
|
||||
|
||||
|
||||
def prompt_yes_no(prompt_text: str, default: bool = False) -> bool:
|
||||
"""读取是否确认输入。"""
|
||||
default_hint = "Y/n" if default else "y/N"
|
||||
raw_text = input(f"{prompt_text} [{default_hint}]: ").strip().lower()
|
||||
if not raw_text:
|
||||
return default
|
||||
return raw_text in {"y", "yes"}
|
||||
|
||||
|
||||
def ensure_sqlite_file(path: Path, should_exist: bool) -> None:
|
||||
"""校验 SQLite 文件路径。"""
|
||||
if should_exist and not path.is_file():
|
||||
raise FileNotFoundError(f"数据库文件不存在:{path}")
|
||||
if not should_exist:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def connect_sqlite(path: Path) -> sqlite3.Connection:
|
||||
"""创建 SQLite 连接。"""
|
||||
connection = sqlite3.connect(path)
|
||||
connection.row_factory = sqlite3.Row
|
||||
return connection
|
||||
|
||||
|
||||
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:
|
||||
"""检查表是否存在。"""
|
||||
result = connection.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? LIMIT 1",
|
||||
(table_name,),
|
||||
).fetchone()
|
||||
return result is not None
|
||||
|
||||
|
||||
def resolve_source_table_name(connection: sqlite3.Connection, candidates: list[str]) -> str:
|
||||
"""从候选表名中解析实际存在的表名。"""
|
||||
for table_name in candidates:
|
||||
if table_exists(connection, table_name):
|
||||
return table_name
|
||||
raise ValueError(f"未找到候选表:{', '.join(candidates)}")
|
||||
|
||||
|
||||
def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[str]:
|
||||
"""获取表字段名集合。"""
|
||||
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
|
||||
return {str(row["name"]) for row in rows}
|
||||
|
||||
|
||||
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
|
||||
"""获取表字段是否允许 NULL 的映射。"""
|
||||
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
|
||||
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
|
||||
|
||||
|
||||
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
|
||||
"""读取整张表的数据。"""
|
||||
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
|
||||
|
||||
|
||||
def normalize_optional_text(raw_value: Any) -> Optional[str]:
|
||||
"""标准化可空文本字段。"""
|
||||
if raw_value is None:
|
||||
return None
|
||||
return str(raw_value)
|
||||
|
||||
|
||||
def ensure_nullable_compatibility(
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
row_id: Any,
|
||||
value: Any,
|
||||
nullable_map: dict[str, bool],
|
||||
) -> None:
|
||||
"""检查待迁移值是否与目标表可空约束兼容。"""
|
||||
if value is None and not nullable_map.get(column_name, True):
|
||||
raise ValueError(
|
||||
f"目标表 {table_name}.{column_name} 不允许 NULL,但源记录 id={row_id} 的该字段为 NULL。"
|
||||
)
|
||||
|
||||
|
||||
def normalize_string_list(raw_value: Any) -> list[str]:
|
||||
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
|
||||
if raw_value is None:
|
||||
return []
|
||||
if isinstance(raw_value, list):
|
||||
return [str(item).strip() for item in raw_value if str(item).strip()]
|
||||
if isinstance(raw_value, str):
|
||||
raw_text = raw_value.strip()
|
||||
if not raw_text:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(raw_text)
|
||||
except json.JSONDecodeError:
|
||||
return [raw_text]
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
if isinstance(parsed, str):
|
||||
parsed_text = parsed.strip()
|
||||
return [parsed_text] if parsed_text else []
|
||||
if parsed is None:
|
||||
return []
|
||||
return [str(parsed).strip()]
|
||||
return [str(raw_value).strip()]
|
||||
|
||||
|
||||
def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
|
||||
"""标准化审核来源字段。"""
|
||||
if raw_value is None:
|
||||
return None
|
||||
|
||||
normalized_raw_value = raw_value
|
||||
if isinstance(raw_value, str):
|
||||
raw_text = raw_value.strip()
|
||||
if raw_text.startswith('"') and raw_text.endswith('"'):
|
||||
try:
|
||||
normalized_raw_value = json.loads(raw_text)
|
||||
except json.JSONDecodeError:
|
||||
normalized_raw_value = raw_text
|
||||
else:
|
||||
normalized_raw_value = raw_text
|
||||
|
||||
value = str(normalized_raw_value).strip().lower()
|
||||
if value in {"", "none", "null"}:
|
||||
return None
|
||||
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
|
||||
return ModifiedBy.AI
|
||||
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
|
||||
return ModifiedBy.USER
|
||||
return None
|
||||
|
||||
|
||||
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
|
||||
"""解析可空布尔值,兼容整数和字符串。"""
|
||||
if raw_value is None:
|
||||
return None
|
||||
if isinstance(raw_value, bool):
|
||||
return raw_value
|
||||
if isinstance(raw_value, int):
|
||||
return bool(raw_value)
|
||||
if isinstance(raw_value, float):
|
||||
return bool(int(raw_value))
|
||||
|
||||
value = str(raw_value).strip().lower()
|
||||
if value in {"", "none", "null"}:
|
||||
return None
|
||||
if value in {"1", "true", "t", "yes", "y"}:
|
||||
return True
|
||||
if value in {"0", "false", "f", "no", "n"}:
|
||||
return False
|
||||
raise ValueError(f"无法解析布尔值:{raw_value}")
|
||||
|
||||
|
||||
def parse_bool(raw_value: Any, default: bool = False) -> bool:
|
||||
"""解析非空布尔值。"""
|
||||
parsed = parse_optional_bool(raw_value)
|
||||
return default if parsed is None else parsed
|
||||
|
||||
|
||||
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
|
||||
"""将旧库中的 Unix 时间戳转换为 datetime。"""
|
||||
if raw_value is None or raw_value == "":
|
||||
return datetime.now() if fallback_now else None
|
||||
if isinstance(raw_value, datetime):
|
||||
return raw_value
|
||||
try:
|
||||
return datetime.fromtimestamp(float(raw_value))
|
||||
except (TypeError, ValueError, OSError, OverflowError):
|
||||
return datetime.now() if fallback_now else None
|
||||
|
||||
|
||||
def build_session_id_dict(raw_chat_id: Any, fallback_count: int) -> str:
|
||||
"""将旧版 jargon.chat_id 转换为新版 session_id_dict。"""
|
||||
if raw_chat_id is None:
|
||||
return json.dumps({}, ensure_ascii=False)
|
||||
|
||||
if isinstance(raw_chat_id, str):
|
||||
raw_text = raw_chat_id.strip()
|
||||
else:
|
||||
raw_text = str(raw_chat_id).strip()
|
||||
|
||||
if not raw_text:
|
||||
return json.dumps({}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw_text)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({raw_text: max(fallback_count, 1)}, ensure_ascii=False)
|
||||
|
||||
if isinstance(parsed, str):
|
||||
parsed_text = parsed.strip()
|
||||
session_counts = {parsed_text: max(fallback_count, 1)} if parsed_text else {}
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return json.dumps({}, ensure_ascii=False)
|
||||
|
||||
session_counts: dict[str, int] = {}
|
||||
for item in parsed:
|
||||
if not isinstance(item, list) or not item:
|
||||
continue
|
||||
session_id = str(item[0]).strip()
|
||||
if not session_id:
|
||||
continue
|
||||
item_count = 1
|
||||
if len(item) > 1:
|
||||
try:
|
||||
item_count = int(item[1])
|
||||
except (TypeError, ValueError):
|
||||
item_count = 1
|
||||
session_counts[session_id] = max(item_count, 1)
|
||||
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_target_engine(target_db_path: Path):
|
||||
"""创建目标数据库引擎。"""
|
||||
return create_engine(
|
||||
f"sqlite:///{target_db_path.as_posix()}",
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
def clear_target_tables(session: Session) -> None:
|
||||
"""清空目标表。"""
|
||||
session.exec(delete(Expression))
|
||||
session.exec(delete(Jargon))
|
||||
|
||||
|
||||
def migrate_expressions(
|
||||
old_rows: Iterable[sqlite3.Row],
|
||||
target_session: Session,
|
||||
expression_columns: set[str],
|
||||
) -> int:
|
||||
"""迁移 expression 数据。"""
|
||||
migrated_count = 0
|
||||
modified_by_ai_count = 0
|
||||
modified_by_user_count = 0
|
||||
modified_by_null_count = 0
|
||||
unknown_modified_by_values: dict[str, int] = {}
|
||||
for row in old_rows:
|
||||
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
|
||||
last_active_time = timestamp_to_datetime(
|
||||
row["last_active_time"] if "last_active_time" in expression_columns else None,
|
||||
True,
|
||||
)
|
||||
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
|
||||
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
|
||||
modified_by = normalize_modified_by(raw_modified_by)
|
||||
if modified_by == ModifiedBy.AI:
|
||||
modified_by_ai_count += 1
|
||||
elif modified_by == ModifiedBy.USER:
|
||||
modified_by_user_count += 1
|
||||
else:
|
||||
modified_by_null_count += 1
|
||||
if raw_modified_by not in (None, "", "null", "NULL", "None"):
|
||||
unknown_key = str(raw_modified_by)
|
||||
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
|
||||
|
||||
target_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO expressions (
|
||||
id,
|
||||
situation,
|
||||
style,
|
||||
content_list,
|
||||
count,
|
||||
last_active_time,
|
||||
create_time,
|
||||
session_id,
|
||||
checked,
|
||||
rejected,
|
||||
modified_by
|
||||
) VALUES (
|
||||
:id,
|
||||
:situation,
|
||||
:style,
|
||||
:content_list,
|
||||
:count,
|
||||
:last_active_time,
|
||||
:create_time,
|
||||
:session_id,
|
||||
:checked,
|
||||
:rejected,
|
||||
:modified_by
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": int(row["id"]) if row["id"] is not None else None,
|
||||
"situation": str(row["situation"]).strip(),
|
||||
"style": str(row["style"]).strip(),
|
||||
"content_list": json.dumps(content_list, ensure_ascii=False),
|
||||
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
|
||||
"last_active_time": last_active_time or datetime.now(),
|
||||
"create_time": create_time or datetime.now(),
|
||||
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
|
||||
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
|
||||
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
|
||||
"modified_by": modified_by.name if modified_by is not None else None,
|
||||
},
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
print(
|
||||
"Expression modified_by 迁移统计:"
|
||||
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
|
||||
)
|
||||
if unknown_modified_by_values:
|
||||
preview_items = list(unknown_modified_by_values.items())[:10]
|
||||
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
|
||||
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
|
||||
return migrated_count
|
||||
|
||||
|
||||
def migrate_jargons(
|
||||
old_rows: Iterable[sqlite3.Row],
|
||||
target_session: Session,
|
||||
jargon_columns: set[str],
|
||||
jargon_nullable_map: dict[str, bool],
|
||||
) -> int:
|
||||
"""迁移 jargon 数据。"""
|
||||
migrated_count = 0
|
||||
coerced_meaning_null_count = 0
|
||||
for row in old_rows:
|
||||
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
|
||||
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
|
||||
raw_content_list = normalize_string_list(raw_content_value)
|
||||
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
|
||||
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
|
||||
inference_content_key = (
|
||||
"inference_content_only"
|
||||
if "inference_content_only" in jargon_columns
|
||||
else "inference_with_content_only"
|
||||
if "inference_with_content_only" in jargon_columns
|
||||
else None
|
||||
)
|
||||
|
||||
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
|
||||
|
||||
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
|
||||
meaning_value = ""
|
||||
coerced_meaning_null_count += 1
|
||||
|
||||
# 显式执行 SQL,避免 ORM 在 None 场景下回填模型默认值。
|
||||
target_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO jargons (
|
||||
id,
|
||||
content,
|
||||
raw_content,
|
||||
meaning,
|
||||
session_id_dict,
|
||||
count,
|
||||
is_jargon,
|
||||
is_complete,
|
||||
is_global,
|
||||
last_inference_count,
|
||||
inference_with_context,
|
||||
inference_with_content_only
|
||||
) VALUES (
|
||||
:id,
|
||||
:content,
|
||||
:raw_content,
|
||||
:meaning,
|
||||
:session_id_dict,
|
||||
:count,
|
||||
:is_jargon,
|
||||
:is_complete,
|
||||
:is_global,
|
||||
:last_inference_count,
|
||||
:inference_with_context,
|
||||
:inference_with_content_only
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": int(row["id"]) if row["id"] is not None else None,
|
||||
"content": str(row["content"]).strip(),
|
||||
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
|
||||
"meaning": meaning_value,
|
||||
"session_id_dict": build_session_id_dict(
|
||||
row["chat_id"] if "chat_id" in jargon_columns else None,
|
||||
fallback_count=count,
|
||||
),
|
||||
"count": count,
|
||||
"is_jargon": is_jargon_value,
|
||||
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
|
||||
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
|
||||
"last_inference_count": (
|
||||
int(row["last_inference_count"])
|
||||
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
|
||||
else 0
|
||||
),
|
||||
"inference_with_context": (
|
||||
str(row["inference_with_context"])
|
||||
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
|
||||
else None
|
||||
),
|
||||
"inference_with_content_only": (
|
||||
str(row[inference_content_key])
|
||||
if inference_content_key and row[inference_content_key] is not None
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
if coerced_meaning_null_count > 0:
|
||||
print(
|
||||
f"警告:目标表 jargons.meaning 不允许 NULL,已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
|
||||
)
|
||||
return migrated_count
|
||||
|
||||
|
||||
def confirm_target_replacement(target_db_path: Path, clear_target: bool) -> bool:
|
||||
"""确认是否写入目标数据库。"""
|
||||
if clear_target:
|
||||
return prompt_yes_no(f"将清空目标库中的 expressions/jargons 后再迁移,确认继续吗?\n目标库:{target_db_path}")
|
||||
return prompt_yes_no(f"将写入目标库,若主键冲突会导致迁移失败,确认继续吗?\n目标库:{target_db_path}")
|
||||
|
||||
|
||||
def parse_arguments() -> Namespace:
|
||||
"""解析参数。"""
|
||||
return build_argument_parser().parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""脚本入口。"""
|
||||
args = parse_arguments()
|
||||
|
||||
print("旧版 expression/jargon -> 新版 expressions/jargons 迁移工具")
|
||||
source_db_path = prompt_path("请输入旧版数据库路径", args.source_db)
|
||||
target_db_path = prompt_path("请输入新版数据库路径", args.target_db)
|
||||
clear_target = args.clear_target or prompt_yes_no("迁移前是否清空目标库中的 expressions 和 jargons 表?", False)
|
||||
|
||||
if source_db_path == target_db_path:
|
||||
raise ValueError("旧版数据库路径和新版数据库路径不能相同。")
|
||||
|
||||
ensure_sqlite_file(source_db_path, should_exist=True)
|
||||
ensure_sqlite_file(target_db_path, should_exist=False)
|
||||
|
||||
print(f"旧库:{source_db_path}")
|
||||
print(f"新库:{target_db_path}")
|
||||
print(f"清空目标表:{'是' if clear_target else '否'}")
|
||||
|
||||
if not confirm_target_replacement(target_db_path, clear_target):
|
||||
print("已取消迁移。")
|
||||
return
|
||||
|
||||
source_connection = connect_sqlite(source_db_path)
|
||||
try:
|
||||
expression_table_name = resolve_source_table_name(source_connection, ["expression", "expressions"])
|
||||
jargon_table_name = resolve_source_table_name(source_connection, ["jargon", "jargons"])
|
||||
expression_columns = get_table_columns(source_connection, expression_table_name)
|
||||
jargon_columns = get_table_columns(source_connection, jargon_table_name)
|
||||
expression_rows = load_rows(source_connection, expression_table_name)
|
||||
jargon_rows = load_rows(source_connection, jargon_table_name)
|
||||
finally:
|
||||
source_connection.close()
|
||||
|
||||
target_engine = create_target_engine(target_db_path)
|
||||
SQLModel.metadata.create_all(target_engine)
|
||||
|
||||
target_sqlite_connection = connect_sqlite(target_db_path)
|
||||
try:
|
||||
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
|
||||
finally:
|
||||
target_sqlite_connection.close()
|
||||
|
||||
with Session(target_engine) as target_session:
|
||||
if clear_target:
|
||||
clear_target_tables(target_session)
|
||||
target_session.commit()
|
||||
|
||||
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
|
||||
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
|
||||
target_session.commit()
|
||||
|
||||
print("迁移完成。")
|
||||
print(f"已迁移 expression 记录:{expression_count}")
|
||||
print(f"已迁移 jargon 记录:{jargon_count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,7 +27,7 @@
|
||||
"startup.main_error": "Main process encountered an exception: {error}",
|
||||
"startup.opensource_free_notice": " This project is completely free and open-source software, released under the GPL-3.0 license",
|
||||
"startup.opensource_group": " Official group chat: ",
|
||||
"startup.opensource_group_value": "1006149251",
|
||||
"startup.opensource_group_value": "766798517",
|
||||
"startup.opensource_repo": " Official repository: ",
|
||||
"startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot",
|
||||
"startup.opensource_resale_warning": " Reselling this software as a \"product\" or concealing its open-source nature violates the license!",
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"startup.main_error": "メインプロセスで例外が発生しました: {error}",
|
||||
"startup.opensource_free_notice": " 本プロジェクトは完全無料のオープンソースソフトウェアであり、GPL-3.0 ライセンスのもとで公開されています",
|
||||
"startup.opensource_group": " 公式グループ: ",
|
||||
"startup.opensource_group_value": "1006149251",
|
||||
"startup.opensource_group_value": "766798517",
|
||||
"startup.opensource_repo": " 公式リポジトリ: ",
|
||||
"startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot",
|
||||
"startup.opensource_resale_warning": " 本ソフトウェアを「商品」として転売したり、オープンソースであることを隠すことはライセンス違反です!",
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"startup.main_error": "메인 프로세스에서 예외 발생: {error}",
|
||||
"startup.opensource_free_notice": " 본 프로젝트는 완전 무료 오픈소스 소프트웨어이며, GPL-3.0 라이선스로 배포됩니다",
|
||||
"startup.opensource_group": " 공식 그룹: ",
|
||||
"startup.opensource_group_value": "1006149251",
|
||||
"startup.opensource_group_value": "766798517",
|
||||
"startup.opensource_repo": " 공식 저장소: ",
|
||||
"startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot",
|
||||
"startup.opensource_resale_warning": " 본 소프트웨어를 '상품'으로 재판매하거나 오픈소스임을 숨기는 행위는 라이선스 위반입니다!",
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"startup.main_error": "主程序发生异常: {error}",
|
||||
"startup.opensource_free_notice": " 本项目是完全免费的开源软件,基于 GPL-3.0 协议发布",
|
||||
"startup.opensource_group": " 官方群聊: ",
|
||||
"startup.opensource_group_value": "1006149251",
|
||||
"startup.opensource_group_value": "766798517",
|
||||
"startup.opensource_repo": " 官方仓库: ",
|
||||
"startup.opensource_repo_value": "https://github.com/MaiM-with-u/MaiBot",
|
||||
"startup.opensource_resale_warning": " 将本软件作为「商品」倒卖、隐瞒开源性质均违反协议!",
|
||||
|
||||
887
mai_knowledge/knowledge.json
Normal file
887
mai_knowledge/knowledge.json
Normal file
@@ -0,0 +1,887 @@
|
||||
{
|
||||
"1": [
|
||||
{
|
||||
"id": "know_1_1774770946.623486",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:55:46.623486"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771765.051286",
|
||||
"content": "性别为女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.051286"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771851.333504",
|
||||
"content": "用户是I人(内向型人格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.333504"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771894.517183",
|
||||
"content": "用户名为小千,被他人称为“宝宝”,结合语境推测为女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.517183"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771923.859455",
|
||||
"content": "小千是I人(内向型人格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.859455"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774771993.479732",
|
||||
"content": "小千是女性",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.479732"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774772079.496335",
|
||||
"content": "用户名为小千,被他人称为“宝宝”,推测为女性或处于亲密社交语境中(注:性别非明确陈述,但基于昵称高频使用及语境,高置信度归纳为女性或女性化称呼偏好,若严格遵循“明确表达”则此项存疑。鉴于指令要求“高置信度可归纳”,且群内互动模式符合典型女性向昵称习惯,此处提取为倾向性事实)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.496335"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774773435.68612",
|
||||
"content": "用户名为小千",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.686120"
|
||||
},
|
||||
{
|
||||
"id": "know_1_1774773676.69252",
|
||||
"content": "用户自称猫娘(二次元人设)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.692520"
|
||||
}
|
||||
],
|
||||
"2": [
|
||||
{
|
||||
"id": "know_2_1774768612.298128",
|
||||
"content": "性格自信,常以“真理在我这边”自居",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.298128"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774768645.029561",
|
||||
"content": "性格自信且带有自嘲精神,喜欢用轻松调侃的方式应对他人评价",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:17:25.029561"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771068.355999",
|
||||
"content": "喜欢用夸张、幽默或古风修辞表达观点",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.355999"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771397.764996",
|
||||
"content": "性格幽默,喜欢使用夸张比喻和古风表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.764996"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771471.03367",
|
||||
"content": "幽默风趣,喜欢使用夸张比喻和玩梗",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.033670"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771765.052285",
|
||||
"content": "性格不孤僻,社交圈较广",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.052285"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771851.33601",
|
||||
"content": "用户表现出社恐倾向,喜欢回避社交互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.336010"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771894.520185",
|
||||
"content": "性格偏向内向(I人),有社恐倾向,喜欢回避社交压力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.520185"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771958.585244",
|
||||
"content": "小千是内向型人格(I人)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.585244"
|
||||
},
|
||||
{
|
||||
"id": "know_2_1774771993.481732",
|
||||
"content": "小千性格内向(I人)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.481732"
|
||||
}
|
||||
],
|
||||
"3": [
|
||||
{
|
||||
"id": "know_3_1774773676.695521",
|
||||
"content": "喜欢冰淇淋",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.695521"
|
||||
}
|
||||
],
|
||||
"4": [],
|
||||
"5": [],
|
||||
"6": [
|
||||
{
|
||||
"id": "know_6_1774768486.451792",
|
||||
"content": "正在搭建 RAG 测试集",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:14:46.451792"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774768517.122405",
|
||||
"content": "熟悉 NapCat、RAG 等技术工具及互联网梗文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.122405"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774769406.247087",
|
||||
"content": "喜欢动漫风格插画",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.247087"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770487.207364",
|
||||
"content": "关注显卡硬件参数(如显存、型号)及深度学习/炼丹应用",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.207364"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770487.209372",
|
||||
"content": "对游戏光影效果感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.209372"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770603.063873",
|
||||
"content": "喜欢玩《我的世界》和VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:03.063873"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770654.654349",
|
||||
"content": "关注显卡硬件参数(如4090、48G显存、5090)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.654349"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770654.655356",
|
||||
"content": "使用VRChat进行社交娱乐",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.655356"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.287947",
|
||||
"content": "关注显卡硬件(如4090、3050)及AI炼丹技术",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.287947"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.289944",
|
||||
"content": "玩《我的世界》并配置光影效果",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.289944"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774770734.291944",
|
||||
"content": "计划游玩VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:52:14.291944"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771033.111011",
|
||||
"content": "喜欢玩VRChat",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:13.111011"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771068.358999",
|
||||
"content": "关注VRChat等虚拟现实游戏及硬件性能话题",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.358999"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771233.980219",
|
||||
"content": "使用VRChat(VRC)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:00:33.980219"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771397.766996",
|
||||
"content": "对VRChat(VRC)及虚拟形象社交感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.766996"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771471.03567",
|
||||
"content": "对VRChat等虚拟社交游戏感兴趣",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.035670"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771894.521183",
|
||||
"content": "熟悉二次元文化、动漫角色及互联网流行梗(Meme)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.521183"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771923.861534",
|
||||
"content": "小千玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.861534"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771958.587243",
|
||||
"content": "回声者_Echoderd喜欢玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.587243"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774771993.483732",
|
||||
"content": "小千喜欢二次元文化及动漫游戏圈梗",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.483732"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772079.499335",
|
||||
"content": "熟悉并喜爱二次元文化、动漫角色及互联网梗图(如阴间美学、病娇系、黑长直萌妹等风格)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.499335"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772112.716455",
|
||||
"content": "小千关注CS:GO游戏及中考备考话题",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:12.716455"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772154.873237",
|
||||
"content": "用户玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.873237"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772186.438797",
|
||||
"content": "玩CS:GO游戏",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:16:26.438797"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774772730.867535",
|
||||
"content": "熟悉《我的青春恋爱物语果然有问题》及二次元表情包文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:25:30.867535"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773338.849271",
|
||||
"content": "熟悉《原神》等二次元游戏及网络梗文化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.849271"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773371.406209",
|
||||
"content": "关注高分屏字体显示效果",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.406209"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773401.48921",
|
||||
"content": "熟悉电脑显示技术(如高分屏字体选择)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:41.489210"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773435.688119",
|
||||
"content": "关注高分屏显示效果与字体选择(无衬线/衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.688119"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773608.256103",
|
||||
"content": "关注屏幕字体与分辨率(无衬线/有衬线)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:40:08.256103"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773645.671546",
|
||||
"content": "关注屏幕分辨率与字体显示效果(高分屏/无衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:40:45.671546"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773676.698035",
|
||||
"content": "关注字体设计(无衬线体)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.698035"
|
||||
},
|
||||
{
|
||||
"id": "know_6_1774773740.83822",
|
||||
"content": "喜欢二次元文化及 VTuber 风格内容",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:42:20.838220"
|
||||
}
|
||||
],
|
||||
"7": [
|
||||
{
|
||||
"id": "know_7_1774768517.120403",
|
||||
"content": "从事 RAG 测试集搭建或相关技术工作",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.120403"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774768573.741823",
|
||||
"content": "从事 RAG(检索增强生成)测试集搭建相关工作",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:13.741823"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774770603.062873",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:03.062873"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771471.036668",
|
||||
"content": "正在备战中考的学生",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.036668"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771923.862535",
|
||||
"content": "小千正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:03.862535"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774771958.588749",
|
||||
"content": "回声者_Echoderd正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:12:38.588749"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774772112.714455",
|
||||
"content": "小千使用AI模型进行对话",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:12.714455"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774772154.870238",
|
||||
"content": "用户正在备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.870238"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773185.194069",
|
||||
"content": "使用 NapCat 框架",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:33:05.194069"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773338.851275",
|
||||
"content": "使用 NapCat 框架,具备技术平台认知能力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.851275"
|
||||
},
|
||||
{
|
||||
"id": "know_7_1774773371.403696",
|
||||
"content": "熟悉 NapCat 框架",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.403696"
|
||||
}
|
||||
],
|
||||
"8": [
|
||||
{
|
||||
"id": "know_8_1774770946.624486",
|
||||
"content": "日常逛游戏地图",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:55:46.624486"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771397.769034",
|
||||
"content": "备考中考期间仍保持日常游戏娱乐习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:03:17.769034"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771851.338018",
|
||||
"content": "用户有备考中考的学习任务",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:10:51.338018"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771894.523189",
|
||||
"content": "备考中(备战中考)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:11:34.523189"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774771993.484733",
|
||||
"content": "小千有打CS:GO的游戏习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:13:13.484733"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774772079.501334",
|
||||
"content": "有在高压环境下(如中考前)进行游戏娱乐(CS:GO)的习惯,自称或认同“摆烂”的生活态度",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.501334"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774772154.875743",
|
||||
"content": "用户在备考期间有打游戏摸鱼的习惯",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:15:54.875743"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774773435.690121",
|
||||
"content": "习惯使用表情包表达情绪或进行网络互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.690121"
|
||||
},
|
||||
{
|
||||
"id": "know_8_1774773676.701034",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:41:16.701034"
|
||||
}
|
||||
],
|
||||
"9": [],
|
||||
"10": [
|
||||
{
|
||||
"id": "know_10_1774768486.452792",
|
||||
"content": "沟通风格带有调侃和自信,习惯用反问句表达观点",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:14:46.452792"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768517.121403",
|
||||
"content": "沟通风格带有较强的好胜心和防御性,习惯用反问和调侃回应质疑",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:15:17.121403"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768573.742824",
|
||||
"content": "沟通风格幽默,擅长使用逻辑闭环和反问句式进行辩论或调侃",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:13.742824"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768612.299126",
|
||||
"content": "沟通风格幽默风趣,擅长使用网络梗和表情包互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.299126"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768612.299845",
|
||||
"content": "偶尔会文绉绉地表达(自称“文青病犯了”),但能迅速切换回口语化",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:16:52.299845"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774768645.028561",
|
||||
"content": "沟通风格幽默风趣,偶尔会文青病发作使用古风表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:17:25.028561"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774769406.249584",
|
||||
"content": "沟通中常使用文言文或半文言表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.249584"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774769406.251097",
|
||||
"content": "习惯用反问句和夸张语气进行互动",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:30:06.251097"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774770487.211056",
|
||||
"content": "沟通风格幽默,常使用网络梗和夸张表达",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:48:07.211056"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774771471.038677",
|
||||
"content": "沟通风格轻松随意,善于接话和调侃",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:04:31.038677"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774771765.053285",
|
||||
"content": "沟通风格活泼,喜欢使用语气词和表情符号撒娇",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:09:25.053285"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774772079.503333",
|
||||
"content": "沟通风格幽默调侃,擅长用反话(如“烦到了”)和夸张修辞(如“耳朵起茧子”、“要报警了”)表达情绪",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:14:39.503333"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773338.853274",
|
||||
"content": "沟通风格幽默风趣,擅长玩梗与自嘲",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:35:38.853274"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773371.408719",
|
||||
"content": "喜欢用幽默调侃的方式回应他人",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:11.408719"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773401.491209",
|
||||
"content": "沟通风格幽默风趣,擅长玩梗和角色扮演",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:36:41.491209"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773435.693121",
|
||||
"content": "沟通风格幽默、喜欢玩梗和自嘲,擅长接话茬",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:37:15.693121"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773532.488374",
|
||||
"content": "沟通风格幽默,喜欢使用网络梗和表情包活跃气氛",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:38:52.488374"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773532.490959",
|
||||
"content": "在争论中倾向于据理力争,并自嘲或调侃对方阅读理解能力",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:38:52.490959"
|
||||
},
|
||||
{
|
||||
"id": "know_10_1774773569.709356",
|
||||
"content": "喜欢用幽默、夸张和自嘲的方式活跃气氛",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:39:29.709356"
|
||||
}
|
||||
],
|
||||
"11": [
|
||||
{
|
||||
"id": "know_11_1774771068.360999",
|
||||
"content": "乐于接受并学习新的技术技巧(如加速器用法)",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:57:48.360999"
|
||||
}
|
||||
],
|
||||
"12": [
|
||||
{
|
||||
"id": "know_12_1774770654.657355",
|
||||
"content": "面对网络延迟问题倾向于寻找加速器解决方案",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T15:50:54.657355"
|
||||
},
|
||||
{
|
||||
"id": "know_12_1774773185.196068",
|
||||
"content": "备战中考",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:33:05.196068"
|
||||
},
|
||||
{
|
||||
"id": "know_12_1774773740.836223",
|
||||
"content": "面对压力或冲突时,倾向于通过撒娇、耍赖和寻求盟友支持来应对",
|
||||
"metadata": {
|
||||
"session_id": "628336b082552269377e9d0648e26c60",
|
||||
"source": "maisaka_learning"
|
||||
},
|
||||
"created_at": "2026-03-29T16:42:20.836223"
|
||||
}
|
||||
]
|
||||
}
|
||||
30
plugins/MaiBot_MCPBridgePlugin/.gitignore
vendored
30
plugins/MaiBot_MCPBridgePlugin/.gitignore
vendored
@@ -1,30 +0,0 @@
|
||||
# 运行时配置(包含用户敏感信息)
|
||||
config.toml
|
||||
|
||||
# 备份文件
|
||||
*.backup.*
|
||||
*.bak
|
||||
|
||||
# 日志
|
||||
logs/
|
||||
*.log
|
||||
*.jsonl
|
||||
|
||||
# Python 缓存
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
|
||||
# 本地测试脚本(仓库不提交)
|
||||
test_*.py
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
@@ -1,24 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
本文件记录 `MaiBot_MCPBridgePlugin` 的用户可感知变更。
|
||||
|
||||
## 2.0.0
|
||||
|
||||
- 配置入口统一:MCP 服务器仅使用 Claude Desktop `mcpServers` JSON(`servers.claude_config_json`)
|
||||
- 兼容迁移:自动识别旧版 `servers.list` 并迁移为 `mcpServers`(需在 WebUI 保存一次固化)
|
||||
- 保持功能不变:保留 Workflow(硬流程/工具链)与 ReAct(软流程)双轨制能力
|
||||
- 精简实现:移除旧的 WebUI 导入导出/快速添加服务器实现与 `tomlkit` 依赖
|
||||
- 易用性:完善 Workflow 变量替换(支持数组下标与 bracket 写法),并优化 WebUI 配置区顺序
|
||||
|
||||
## 1.9.0
|
||||
|
||||
- 双轨制架构:ReAct(软流程)+ Workflow(硬流程/工具链)
|
||||
|
||||
## 1.8.0
|
||||
|
||||
- Workflow(工具链):多工具顺序执行、变量替换、自定义 Workflow 并注册为组合工具
|
||||
|
||||
## 1.7.0
|
||||
|
||||
- 断路器模式、状态刷新、工具搜索等易用性增强
|
||||
|
||||
@@ -1,356 +0,0 @@
|
||||
# MCP 桥接插件开发文档
|
||||
|
||||
本文档面向开发者,介绍插件的架构设计、核心模块和扩展方式。
|
||||
|
||||
## 架构概览
|
||||
|
||||
```
|
||||
MaiBot_MCPBridgePlugin/
|
||||
├── plugin.py # 主插件文件,包含所有核心逻辑
|
||||
├── mcp_client.py # MCP 客户端封装
|
||||
├── tool_chain.py # 工具链(Workflow)模块
|
||||
├── core/
|
||||
│ └── claude_config.py # Claude Desktop mcpServers 解析/迁移
|
||||
├── config.toml # 运行时配置
|
||||
└── _manifest.json # 插件元数据
|
||||
```
|
||||
|
||||
## 核心模块
|
||||
|
||||
### 1. MCP 客户端 (`mcp_client.py`)
|
||||
|
||||
封装了与 MCP 服务器的通信逻辑。
|
||||
|
||||
```python
|
||||
from .mcp_client import mcp_manager, MCPServerConfig, TransportType
|
||||
|
||||
# 添加服务器
|
||||
config = MCPServerConfig(
|
||||
name="my-server",
|
||||
transport=TransportType.STREAMABLE_HTTP,
|
||||
url="https://mcp.example.com/mcp"
|
||||
)
|
||||
await mcp_manager.add_server(config)
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_manager.call_tool("server_tool_name", {"param": "value"})
|
||||
if result.success:
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
**支持的传输类型:**
|
||||
- `STDIO`: 本地进程通信
|
||||
- `SSE`: Server-Sent Events
|
||||
- `HTTP`: HTTP 请求
|
||||
- `STREAMABLE_HTTP`: 流式 HTTP(推荐)
|
||||
|
||||
### 2. 工具注册系统
|
||||
|
||||
MCP 工具通过动态类创建注册到 MaiBot:
|
||||
|
||||
```python
|
||||
# 创建工具代理类
|
||||
class MCPToolProxy(BaseTool):
|
||||
name = "mcp_server_tool"
|
||||
description = "工具描述"
|
||||
parameters = [("param", ToolParamType.STRING, "参数描述", True, None)]
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args):
|
||||
result = await mcp_manager.call_tool(self._mcp_tool_key, function_args)
|
||||
return {"name": self.name, "content": result.content}
|
||||
```
|
||||
|
||||
### 3. 工具链模块 (`tool_chain.py`)
|
||||
|
||||
实现 Workflow 硬流程,支持多工具顺序执行。
|
||||
|
||||
```python
|
||||
from .tool_chain import ToolChainDefinition, ToolChainStep, tool_chain_manager
|
||||
|
||||
# 定义工具链
|
||||
chain = ToolChainDefinition(
|
||||
name="search_and_detail",
|
||||
description="搜索并获取详情",
|
||||
input_params={"query": "搜索关键词"},
|
||||
steps=[
|
||||
ToolChainStep(
|
||||
tool_name="mcp_server_search",
|
||||
args_template={"keyword": "${input.query}"},
|
||||
output_key="search_result"
|
||||
),
|
||||
ToolChainStep(
|
||||
tool_name="mcp_server_detail",
|
||||
args_template={"id": "${prev}"}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 注册并执行
|
||||
tool_chain_manager.add_chain(chain)
|
||||
result = await tool_chain_manager.execute_chain("search_and_detail", {"query": "test"})
|
||||
```
|
||||
|
||||
**变量替换语法:**
|
||||
- `${input.参数名}`: 用户输入
|
||||
- `${step.输出键}`: 指定步骤的输出
|
||||
- `${prev}`: 上一步输出
|
||||
- `${prev.字段}`: 上一步输出(JSON)的字段
|
||||
- `${step.geo.return.0.location}` / `${step.geo.return[0].location}`: 数组下标访问
|
||||
- `${step.geo['return'][0]['location']}`: bracket 写法(最通用)
|
||||
|
||||
## 双轨制架构
|
||||
|
||||
### ReAct 软流程
|
||||
|
||||
将 MCP 工具注册到 MaiBot 的记忆检索 ReAct 系统,LLM 自主决策调用。
|
||||
|
||||
```python
|
||||
def _register_tools_to_react(self) -> int:
|
||||
from src.memory_system.retrieval_tools import register_memory_retrieval_tool
|
||||
|
||||
def make_execute_func(tool_key: str):
|
||||
async def execute_func(**kwargs) -> str:
|
||||
result = await mcp_manager.call_tool(tool_key, kwargs)
|
||||
return result.content if result.success else f"失败: {result.error}"
|
||||
return execute_func
|
||||
|
||||
register_memory_retrieval_tool(
|
||||
name="mcp_tool_name",
|
||||
description="工具描述",
|
||||
parameters=[{"name": "param", "type": "string", "required": True}],
|
||||
execute_func=make_execute_func("tool_key")
|
||||
)
|
||||
```
|
||||
|
||||
### Workflow 硬流程
|
||||
|
||||
用户预定义的固定执行流程,注册为组合工具。
|
||||
|
||||
```python
|
||||
def _register_tool_chains(self) -> None:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
for chain_name, chain in tool_chain_manager.get_enabled_chains().items():
|
||||
info, tool_class = tool_chain_registry.register_chain(chain)
|
||||
info.plugin_name = self.plugin_name
|
||||
component_registry.register_component(info, tool_class)
|
||||
```
|
||||
|
||||
## 配置系统
|
||||
|
||||
### MCP 服务器配置(Claude Desktop 规范)
|
||||
|
||||
插件只接受 Claude Desktop 的 `mcpServers` JSON(见 `core/claude_config.py`)。配置入口统一为:
|
||||
|
||||
- WebUI/配置文件:`[servers].claude_config_json`
|
||||
- 命令:`/mcp import`(合并 `mcpServers`)与 `/mcp export`(导出当前 `mcpServers`)
|
||||
|
||||
兼容迁移:
|
||||
- 若检测到旧版 `servers.list`,会自动迁移为 `servers.claude_config_json`(仅迁移到内存配置,需 WebUI 保存一次固化)。
|
||||
|
||||
### WebUI 配置 Schema
|
||||
|
||||
使用 `ConfigField` 定义 WebUI 配置项:
|
||||
|
||||
```python
|
||||
config_schema = {
|
||||
"section_name": {
|
||||
"field_name": ConfigField(
|
||||
type=str, # 类型: str, bool, int, float
|
||||
default="default_value", # 默认值
|
||||
description="字段描述",
|
||||
label="显示标签",
|
||||
input_type="textarea", # 输入类型: text, textarea, password
|
||||
rows=5, # textarea 行数
|
||||
disabled=True, # 只读
|
||||
choices=["a", "b"], # 下拉选项
|
||||
hint="提示信息",
|
||||
order=1, # 排序
|
||||
),
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 配置读取
|
||||
|
||||
```python
|
||||
# 在组件中读取配置
|
||||
value = self.get_config("section.key", default="fallback")
|
||||
|
||||
# 在插件类中读取
|
||||
value = self.config.get("section", {}).get("key", "default")
|
||||
```
|
||||
|
||||
## 事件处理
|
||||
|
||||
### 启动事件
|
||||
|
||||
```python
|
||||
class MCPStartupHandler(BaseEventHandler):
|
||||
event_type = EventType.ON_START
|
||||
handler_name = "mcp_startup"
|
||||
|
||||
async def execute(self, message):
|
||||
global _plugin_instance
|
||||
if _plugin_instance:
|
||||
await _plugin_instance._async_connect_servers()
|
||||
return (True, True, None, None, None)
|
||||
```
|
||||
|
||||
### 停止事件
|
||||
|
||||
```python
|
||||
class MCPStopHandler(BaseEventHandler):
|
||||
event_type = EventType.ON_STOP
|
||||
handler_name = "mcp_stop"
|
||||
|
||||
async def execute(self, message):
|
||||
await mcp_manager.shutdown()
|
||||
return (True, True, None, None, None)
|
||||
```
|
||||
|
||||
## 命令系统
|
||||
|
||||
```python
|
||||
class MCPStatusCommand(BaseCommand):
|
||||
command_name = "mcp_status"
|
||||
command_pattern = r"^/mcp(?:\s+(?P<action>\S+))?(?:\s+(?P<arg>.+))?$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
action = self.matched_groups.get("action", "")
|
||||
arg = self.matched_groups.get("arg", "")
|
||||
|
||||
if action == "tools":
|
||||
await self.send_text("工具列表...")
|
||||
elif action == "reconnect":
|
||||
await self._handle_reconnect(arg)
|
||||
|
||||
return (True, None, True) # (成功, 消息, 拦截)
|
||||
```
|
||||
|
||||
## 高级功能
|
||||
|
||||
### 调用追踪
|
||||
|
||||
```python
|
||||
from plugin import tool_call_tracer, ToolCallRecord
|
||||
|
||||
# 记录调用
|
||||
record = ToolCallRecord(
|
||||
call_id="xxx",
|
||||
timestamp=time.time(),
|
||||
tool_name="tool",
|
||||
server_name="server",
|
||||
arguments={"key": "value"},
|
||||
success=True,
|
||||
duration_ms=100.0
|
||||
)
|
||||
tool_call_tracer.record(record)
|
||||
|
||||
# 查询记录
|
||||
recent = tool_call_tracer.get_recent(10)
|
||||
by_tool = tool_call_tracer.get_by_tool("tool_name")
|
||||
```
|
||||
|
||||
### 调用缓存
|
||||
|
||||
```python
|
||||
from plugin import tool_call_cache
|
||||
|
||||
# 配置缓存
|
||||
tool_call_cache.configure(
|
||||
enabled=True,
|
||||
ttl=300, # 秒
|
||||
max_entries=200,
|
||||
exclude_tools="mcp_*_time_*" # 排除模式
|
||||
)
|
||||
|
||||
# 使用缓存
|
||||
cached = tool_call_cache.get("tool_name", {"param": "value"})
|
||||
if cached is None:
|
||||
result = await call_tool(...)
|
||||
tool_call_cache.set("tool_name", {"param": "value"}, result)
|
||||
```
|
||||
|
||||
### 权限控制
|
||||
|
||||
```python
|
||||
from plugin import permission_checker
|
||||
|
||||
# 配置权限
|
||||
permission_checker.configure(
|
||||
enabled=True,
|
||||
default_mode="allow_all", # 或 "deny_all"
|
||||
rules_json='[{"tool": "mcp_*_delete_*", "denied": ["qq:123:group"]}]',
|
||||
quick_deny_groups="123456789",
|
||||
quick_allow_users="111111111"
|
||||
)
|
||||
|
||||
# 检查权限
|
||||
allowed = permission_checker.check(
|
||||
tool_name="mcp_server_delete",
|
||||
chat_id="123456",
|
||||
user_id="789",
|
||||
is_group=True
|
||||
)
|
||||
```
|
||||
|
||||
### 断路器模式
|
||||
|
||||
MCP 客户端内置断路器,故障服务器快速失败:
|
||||
|
||||
- 连续失败 N 次后熔断
|
||||
- 熔断期间直接返回错误
|
||||
- 定期尝试恢复
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的传输类型
|
||||
|
||||
1. 在 `mcp_client.py` 中添加 `TransportType` 枚举值
|
||||
2. 实现对应的连接逻辑
|
||||
3. 更新 `_create_transport()` 方法
|
||||
|
||||
### 添加新的工具类型
|
||||
|
||||
1. 继承 `BaseTool` 创建新类
|
||||
2. 在 `get_plugin_components()` 中注册
|
||||
3. 实现 `execute()` 方法
|
||||
|
||||
### 添加新的命令
|
||||
|
||||
1. 在 `MCPStatusCommand.execute()` 中添加新的 action 分支
|
||||
2. 或创建新的 `BaseCommand` 子类
|
||||
|
||||
## 调试技巧
|
||||
|
||||
### 日志级别
|
||||
|
||||
```python
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("mcp_bridge_plugin")
|
||||
|
||||
logger.debug("详细调试信息")
|
||||
logger.info("一般信息")
|
||||
logger.warning("警告")
|
||||
logger.error("错误")
|
||||
```
|
||||
|
||||
### 常用调试命令
|
||||
|
||||
```bash
|
||||
/mcp # 查看状态
|
||||
/mcp tools # 查看工具列表
|
||||
/mcp trace # 查看调用记录
|
||||
/mcp cache # 查看缓存状态
|
||||
/mcp chain # 查看工具链
|
||||
```
|
||||
|
||||
## 更新日志
|
||||
|
||||
见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md`
|
||||
|
||||
## 开发约定
|
||||
|
||||
- 本仓库不提交测试脚本/临时复现文件;如需本地验证,可自行在工作区创建未跟踪文件(建议放到 `.local/` 并加入 `.gitignore`)。
|
||||
@@ -1,357 +0,0 @@
|
||||
# MCP 桥接插件
|
||||
|
||||
将 [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。
|
||||
|
||||
<img width="3012" height="1794" alt="image" src="https://github.com/user-attachments/assets/ece56404-301a-4abf-b16d-87bd430fc977" />
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 安装
|
||||
|
||||
```bash
|
||||
# 克隆到 MaiBot 插件目录
|
||||
cd /path/to/MaiBot/plugins
|
||||
git clone https://github.com/CharTyr/MaiBot_MCPBridgePlugin.git MCPBridgePlugin
|
||||
|
||||
# 安装依赖
|
||||
pip install mcp
|
||||
|
||||
# 复制配置文件
|
||||
cd MCPBridgePlugin
|
||||
cp config.example.toml config.toml
|
||||
```
|
||||
|
||||
### 2. 添加服务器
|
||||
|
||||
编辑 `config.toml`,在 `[servers]` 的 `claude_config_json` 中填写 Claude Desktop 的 `mcpServers` JSON:
|
||||
|
||||
```toml
|
||||
[servers]
|
||||
claude_config_json = '''
|
||||
{
|
||||
"mcpServers": {
|
||||
"time": { "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" },
|
||||
"my-server": { "transport": "streamable_http", "url": "https://mcp.xxx.com/mcp", "headers": { "Authorization": "Bearer 你的密钥" } },
|
||||
"fetch": { "command": "uvx", "args": ["mcp-server-fetch"] }
|
||||
}
|
||||
}
|
||||
'''
|
||||
```
|
||||
|
||||
### 3. 启动
|
||||
|
||||
重启 MaiBot,或发送 `/mcp reconnect`
|
||||
|
||||
---
|
||||
|
||||
## 📚 去哪找 MCP 服务器?
|
||||
|
||||
| 平台 | 说明 |
|
||||
|------|------|
|
||||
| [mcp.modelscope.cn](https://mcp.modelscope.cn/) | 魔搭 ModelScope,免费推荐 |
|
||||
| [smithery.ai](https://smithery.ai/) | MCP 服务器注册中心 |
|
||||
| [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) | 官方服务器列表 |
|
||||
|
||||
---
|
||||
|
||||
## 💡 常用命令
|
||||
|
||||
| 命令 | 说明 |
|
||||
|------|------|
|
||||
| `/mcp` | 查看连接状态 |
|
||||
| `/mcp tools` | 查看可用工具 |
|
||||
| `/mcp reconnect` | 重连服务器 |
|
||||
| `/mcp trace` | 查看调用记录 |
|
||||
| `/mcp cache` | 查看缓存状态 |
|
||||
| `/mcp perm` | 查看权限配置 |
|
||||
| `/mcp import <json>` | 🆕 导入 Claude Desktop 配置 |
|
||||
| `/mcp export` | 🆕 导出配置 |
|
||||
| `/mcp search <关键词>` | 🆕 搜索工具 |
|
||||
| `/mcp chain` | 🆕 查看工具链 |
|
||||
| `/mcp chain <名称>` | 🆕 查看工具链详情 |
|
||||
| `/mcp chain test <名称> <参数>` | 🆕 测试执行工具链 |
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 核心功能
|
||||
- 🔌 多服务器同时连接
|
||||
- 📡 支持 stdio / SSE / HTTP / Streamable HTTP
|
||||
- 🔄 自动重试、心跳检测、断线重连
|
||||
- 🖥️ WebUI 完整配置支持
|
||||
|
||||
### 双轨制架构
|
||||
- 🔄 **ReAct(软流程)**:LLM 自主决策,多轮动态调用 MCP 工具(适合探索式场景)
|
||||
- 🔗 **Workflow(硬流程/工具链)**:用户预定义步骤顺序与参数传递(适合可控可复用场景)
|
||||
|
||||
### 高级功能
|
||||
- 📦 Resources 支持(实验性)
|
||||
- 📝 Prompts 支持(实验性)
|
||||
- 🔄 结果后处理(LLM 摘要提炼)
|
||||
- 🔍 调用追踪 / 🗄️ 调用缓存 / 🔐 权限控制 / 🚫 工具禁用
|
||||
|
||||
### 更新日志
|
||||
- 见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md`
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 服务器配置
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"server_name": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://..."
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `mcpServers.<name>` | 服务器名称(唯一) |
|
||||
| `enabled` | 是否启用(可选,默认 true) |
|
||||
| `transport` | `stdio` / `sse` / `http` / `streamable_http` |
|
||||
| `url` | 远程服务器地址 |
|
||||
| `headers` | 🆕 鉴权头(如 `{"Authorization": "Bearer xxx"}`) |
|
||||
| `command` / `args` | 本地服务器启动命令 |
|
||||
|
||||
### 权限控制
|
||||
|
||||
**快捷配置(推荐):**
|
||||
```toml
|
||||
[permissions]
|
||||
perm_enabled = true
|
||||
quick_deny_groups = "123456789" # 禁用的群号
|
||||
quick_allow_users = "111111111" # 管理员白名单
|
||||
```
|
||||
|
||||
**高级规则:**
|
||||
```json
|
||||
[{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}]
|
||||
```
|
||||
|
||||
### 工具禁用
|
||||
|
||||
```toml
|
||||
[tools]
|
||||
disabled_tools = '''
|
||||
mcp_filesystem_delete_file
|
||||
mcp_filesystem_write_file
|
||||
'''
|
||||
```
|
||||
|
||||
### 调用缓存
|
||||
|
||||
```toml
|
||||
[settings]
|
||||
cache_enabled = true
|
||||
cache_ttl = 300
|
||||
cache_exclude_tools = "mcp_*_time_*"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
**Q: 工具没有注册?**
|
||||
- 检查 `enabled = true`
|
||||
- 检查 MaiBot 日志错误信息
|
||||
- 确认 `pip install mcp`
|
||||
|
||||
**Q: JSON 格式报错?**
|
||||
- 多行 JSON 用 `'''` 三引号包裹
|
||||
- 使用英文双引号 `"`
|
||||
|
||||
**Q: 如何手动重连?**
|
||||
- `/mcp reconnect` 或 `/mcp reconnect 服务器名`
|
||||
|
||||
---
|
||||
|
||||
## 📥 配置导入导出(Claude mcpServers)
|
||||
|
||||
### 从 Claude Desktop 导入
|
||||
|
||||
如果你已有 Claude Desktop 的 MCP 配置,可以直接导入:
|
||||
|
||||
```
|
||||
/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]},"fetch":{"command":"uvx","args":["mcp-server-fetch"]}}}
|
||||
```
|
||||
|
||||
支持的格式:
|
||||
- Claude Desktop 格式(`mcpServers` 对象)
|
||||
- 兼容旧版:MaiBot servers 列表数组(将自动迁移为 `mcpServers`)
|
||||
|
||||
### 导出配置
|
||||
|
||||
```
|
||||
/mcp export # 导出为 Claude Desktop 格式(默认)
|
||||
/mcp export claude # 导出为 Claude Desktop 格式
|
||||
```
|
||||
|
||||
### 注意事项
|
||||
- 导入时会自动跳过同名服务器
|
||||
- 导入后需要发送 `/mcp reconnect` 使配置生效
|
||||
- 支持 stdio、sse、http、streamable_http 全部传输类型
|
||||
|
||||
---
|
||||
|
||||
## 🔗 Workflow(硬流程/工具链)
|
||||
|
||||
工具链允许你将多个 MCP 工具按顺序执行,后续工具可以使用前序工具的输出作为输入。
|
||||
|
||||
### 1 分钟上手(推荐 WebUI)
|
||||
1. 先完成 MCP 服务器配置并 `/mcp reconnect`
|
||||
2. 发送 `/mcp tools`,复制你要用的工具名
|
||||
3. 打开 WebUI → 「Workflow(硬流程/工具链)」→ 用“快速添加”表单填入:
|
||||
- 名称/描述
|
||||
- 输入参数(每行 `参数名=描述`)
|
||||
- 执行步骤(每行 `工具名|参数JSON|输出键`)
|
||||
4. 在“确认添加”中输入 `ADD` 并保存
|
||||
|
||||
### 快速添加工具链(推荐)
|
||||
|
||||
在 WebUI 的「工具链」配置区,使用表单快速添加:
|
||||
|
||||
1. **名称**: 填写工具链名称(英文,如 `search_and_detail`)
|
||||
2. **描述**: 填写工具链用途(供 LLM 理解何时使用)
|
||||
3. **输入参数**: 每行一个,格式 `参数名=描述`
|
||||
```
|
||||
query=搜索关键词
|
||||
max_results=最大结果数
|
||||
```
|
||||
4. **执行步骤**: 每行一个,格式 `工具名|参数JSON|输出键`
|
||||
```
|
||||
mcp_server_search|{"keyword":"${input.query}"}|search_result
|
||||
mcp_server_detail|{"id":"${prev}"}|
|
||||
```
|
||||
5. **确认添加**: 输入 `ADD` 并保存
|
||||
|
||||
### JSON 配置方式
|
||||
|
||||
也可以直接在「工具链列表」中编写 JSON:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "search_and_detail",
|
||||
"description": "先搜索模组,再获取详情",
|
||||
"input_params": {
|
||||
"query": "搜索关键词"
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"tool_name": "mcp_mcmod_search_mod",
|
||||
"args_template": {"keyword": "${input.query}", "limit": 1},
|
||||
"output_key": "search_result",
|
||||
"description": "搜索模组"
|
||||
},
|
||||
{
|
||||
"tool_name": "mcp_mcmod_get_mod_detail",
|
||||
"args_template": {"mod_id": "${prev}"},
|
||||
"description": "获取详情"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### 变量替换
|
||||
|
||||
| 变量格式 | 说明 |
|
||||
|---------|------|
|
||||
| `${input.参数名}` | 用户输入的参数 |
|
||||
| `${step.输出键}` | 某个步骤的输出(通过 `output_key` 指定) |
|
||||
| `${prev}` | 上一步的输出 |
|
||||
| `${prev.字段}` | 上一步输出(JSON)的某个字段 |
|
||||
| `${step.geo.return.0.location}` | 数组下标访问(dot) |
|
||||
| `${step.geo.return[0].location}` | 数组下标访问([]) |
|
||||
| `${step.geo['return'][0]['location']}` | bracket 写法(最通用) |
|
||||
|
||||
### 工具链字段说明
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `name` | 工具链名称,将生成 `chain_xxx` 工具 |
|
||||
| `description` | 描述,供 LLM 理解何时使用 |
|
||||
| `input_params` | 输入参数定义 `{参数名: 描述}` |
|
||||
| `steps` | 执行步骤数组 |
|
||||
| `steps[].tool_name` | 要调用的工具名 |
|
||||
| `steps[].args_template` | 参数模板,支持变量替换 |
|
||||
| `steps[].output_key` | 输出存储键名(可选) |
|
||||
| `steps[].optional` | 是否可选,失败时继续执行(默认 false) |
|
||||
|
||||
### 命令
|
||||
|
||||
```bash
|
||||
/mcp chain # 查看所有工具链
|
||||
/mcp chain list # 列出工具链
|
||||
/mcp chain <名称> # 查看详情
|
||||
/mcp chain test <名称> {"query": "JEI"} # 测试执行
|
||||
/mcp chain reload # 重新加载配置
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 双轨制架构
|
||||
|
||||
MCP 桥接插件支持两种工具调用模式,可根据场景选择:
|
||||
|
||||
### ReAct 软流程
|
||||
|
||||
LLM 自主决策的多轮工具调用模式,适合复杂、不确定的场景。
|
||||
|
||||
**工作原理:**
|
||||
1. 用户提问 → LLM 分析需要什么信息
|
||||
2. LLM 选择调用工具 → 获取结果
|
||||
3. LLM 观察结果 → 决定是否需要更多信息
|
||||
4. 重复 2-3 直到信息足够 → 生成最终回答
|
||||
|
||||
**启用方式:**
|
||||
在 WebUI「ReAct (软流程)」配置区启用,MCP 工具将自动注册到 MaiBot 的记忆检索 ReAct 系统。
|
||||
|
||||
**适用场景:**
|
||||
- 复杂问题需要多步推理
|
||||
- 不确定需要调用哪些工具
|
||||
- 需要根据中间结果动态调整
|
||||
|
||||
### Workflow 硬流程
|
||||
|
||||
用户预定义的工作流,固定执行顺序,适合可靠、可控的场景。
|
||||
|
||||
**工作原理:**
|
||||
1. 用户定义步骤顺序和参数传递
|
||||
2. 按顺序执行每个步骤
|
||||
3. 后续步骤可使用前序步骤的输出
|
||||
4. 返回最终结果
|
||||
|
||||
**适用场景:**
|
||||
- 流程固定、可预测
|
||||
- 需要可靠、可重复的执行
|
||||
- 希望精确控制工具调用顺序
|
||||
|
||||
### 对比
|
||||
|
||||
| 特性 | ReAct 软流程 | Workflow 硬流程 |
|
||||
|------|-------------|----------------|
|
||||
| 决策者 | LLM 自主决策 | 用户预定义 |
|
||||
| 灵活性 | 高,动态调整 | 低,固定流程 |
|
||||
| 可预测性 | 低 | 高 |
|
||||
| 适用场景 | 复杂、探索性任务 | 固定、重复性任务 |
|
||||
| 配置方式 | 启用即可 | 需要定义步骤 |
|
||||
|
||||
---
|
||||
|
||||
## 📋 依赖
|
||||
|
||||
- MaiBot >= 0.11.6
|
||||
- Python >= 3.10
|
||||
- mcp >= 1.0.0
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
AGPL-3.0
|
||||
@@ -1,44 +0,0 @@
|
||||
"""
|
||||
MCP 桥接插件
|
||||
将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
|
||||
|
||||
v1.1.0 新增功能:
|
||||
- 心跳检测和自动重连
|
||||
- 调用统计(次数、成功率、耗时)
|
||||
- 更好的错误处理
|
||||
|
||||
v1.2.0 新增功能:
|
||||
- Resources 支持(资源读取)
|
||||
- Prompts 支持(提示模板)
|
||||
"""
|
||||
|
||||
from .plugin import MCPBridgePlugin, mcp_tool_registry, MCPStartupHandler, MCPStopHandler
|
||||
from .mcp_client import (
|
||||
mcp_manager,
|
||||
MCPClientManager,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
MCPCallResult,
|
||||
MCPToolInfo,
|
||||
MCPResourceInfo,
|
||||
MCPPromptInfo,
|
||||
ToolCallStats,
|
||||
ServerStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MCPBridgePlugin",
|
||||
"mcp_tool_registry",
|
||||
"mcp_manager",
|
||||
"MCPClientManager",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"MCPCallResult",
|
||||
"MCPToolInfo",
|
||||
"MCPResourceInfo",
|
||||
"MCPPromptInfo",
|
||||
"ToolCallStats",
|
||||
"ServerStats",
|
||||
"MCPStartupHandler",
|
||||
"MCPStopHandler",
|
||||
]
|
||||
@@ -1,42 +0,0 @@
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"name": "MCP桥接插件",
|
||||
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。",
|
||||
"author": {
|
||||
"name": "CharTyr",
|
||||
"url": "https://github.com/CharTyr"
|
||||
},
|
||||
"license": "AGPL-3.0",
|
||||
"urls": {
|
||||
"repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
|
||||
"issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues"
|
||||
},
|
||||
"host_application": {
|
||||
"min_version": "0.11.6",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [
|
||||
{
|
||||
"type": "python_package",
|
||||
"name": "mcp",
|
||||
"version_spec": ">=0.0.0"
|
||||
}
|
||||
],
|
||||
"capabilities": [
|
||||
"send.text"
|
||||
],
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "chartyr.mcpbridge-plugin"
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
# MCP桥接插件 - 配置文件示例
|
||||
# 将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
|
||||
#
|
||||
# 使用方法:复制此文件为 config.toml,然后根据需要修改配置
|
||||
#
|
||||
# ============================================================
|
||||
# 🎯 快速开始(三步)
|
||||
# ============================================================
|
||||
# 1. 在下方 [servers] 添加 MCP 服务器配置
|
||||
# 2. 将 enabled 改为 true 启用服务器
|
||||
# 3. 重启 MaiBot 或发送 /mcp reconnect
|
||||
#
|
||||
# ============================================================
|
||||
# 📚 去哪找 MCP 服务器?
|
||||
# ============================================================
|
||||
#
|
||||
# 【远程服务(推荐新手)】
|
||||
# - ModelScope: https://mcp.modelscope.cn/ (免费,推荐)
|
||||
# - Smithery: https://smithery.ai/
|
||||
# - Glama: https://glama.ai/mcp/servers
|
||||
#
|
||||
# 【本地服务(需要 npx 或 uvx)】
|
||||
# - 官方列表: https://github.com/modelcontextprotocol/servers
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
# ============================================================
|
||||
# 🔌 MCP 服务器配置
|
||||
# ============================================================
|
||||
#
|
||||
# ⚠️ 重要:配置格式(Claude Desktop 规范)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 统一使用 Claude Desktop 的 mcpServers JSON。
|
||||
#
|
||||
# claude_config_json 的内容应为 JSON 对象:
|
||||
# {
|
||||
# "mcpServers": {
|
||||
# "server_name": { ...server config... },
|
||||
# "another": { ... }
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# 每个服务器支持字段:
|
||||
# transport - 传输方式: "stdio" / "sse" / "http" / "streamable_http"(可选)
|
||||
# url - 服务器地址(sse/http/streamable_http 模式)
|
||||
# command - 启动命令(stdio 模式,如 "npx" / "uvx")
|
||||
# args - 命令参数数组(stdio 模式)
|
||||
# env - 环境变量对象(stdio 模式,可选)
|
||||
# headers - 鉴权头(可选,如 {"Authorization": "Bearer xxx"})
|
||||
# enabled - 是否启用(可选,默认 true)
|
||||
# post_process - 服务器级别后处理配置(可选)
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[servers]
|
||||
claude_config_json = '''
|
||||
{
|
||||
"mcpServers": {
|
||||
"time-mcp-server": {
|
||||
"enabled": false,
|
||||
"transport": "streamable_http",
|
||||
"url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time"
|
||||
},
|
||||
"my-auth-server": {
|
||||
"enabled": false,
|
||||
"transport": "streamable_http",
|
||||
"url": "https://mcp.api-inference.modelscope.net/xxxxxx/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ms-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
}
|
||||
},
|
||||
"fetch-local": {
|
||||
"enabled": false,
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"]
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# 插件基本信息
|
||||
# ============================================================
|
||||
[plugin]
|
||||
name = "mcp_bridge_plugin"
|
||||
version = "2.0.0"
|
||||
config_version = "2.0.0"
|
||||
enabled = false # 默认禁用,在 WebUI 中启用
|
||||
|
||||
# ============================================================
|
||||
# Workflow(硬流程/工具链)
|
||||
# ============================================================
|
||||
#
|
||||
# 作用:把多个工具按顺序执行;后续步骤可引用前序输出。
|
||||
#
|
||||
# ✅ 推荐配置方式:WebUI「Workflow(硬流程/工具链)」里用“快速添加”表单。
|
||||
# ✅ 也可以直接写 chains_list(JSON 数组)。
|
||||
#
|
||||
# 变量替换:
|
||||
# ${input.xxx} - 用户输入
|
||||
# ${step.<output_key>} - 指定步骤输出(需设置 output_key)
|
||||
# ${prev} - 上一步输出
|
||||
# ${prev.字段} - 上一步输出(JSON)的字段
|
||||
# ${step.geo.return.0.location} - 数组/下标访问(dot)
|
||||
# ${step.geo.return[0].location} - 数组/下标访问([])
|
||||
# ${step.geo['return'][0]['location']} - bracket 写法
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[tool_chains]
|
||||
chains_enabled = true
|
||||
|
||||
chains_list = '''
|
||||
[
|
||||
{
|
||||
"name": "search_and_detail",
|
||||
"description": "先搜索,再根据结果获取详情",
|
||||
"input_params": { "query": "搜索关键词" },
|
||||
"steps": [
|
||||
{ "tool_name": "把这里替换成你的搜索工具名", "args_template": { "keyword": "${input.query}" }, "output_key": "search" },
|
||||
{ "tool_name": "把这里替换成你的详情工具名", "args_template": { "id": "${prev}" } }
|
||||
]
|
||||
}
|
||||
]
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# ReAct(软流程)
|
||||
# ============================================================
|
||||
#
|
||||
# 作用:把 MCP 工具注册到 MaiBot 的 ReAct 系统,LLM 可自主多轮调用。
|
||||
#
|
||||
# 注意:ReAct 适合“探索式/不确定”场景;Workflow 适合“固定/可控”场景。
|
||||
#
|
||||
# ============================================================
|
||||
|
||||
[react]
|
||||
react_enabled = false
|
||||
filter_mode = "whitelist" # whitelist / blacklist
|
||||
tool_filter = "" # 每行一个工具名,支持通配符 *
|
||||
|
||||
# ============================================================
|
||||
# 全局设置(高级设置建议保持默认)
|
||||
# ============================================================
|
||||
[settings]
|
||||
# 🏷️ 工具前缀 - 用于区分 MCP 工具和原生工具
|
||||
tool_prefix = "mcp"
|
||||
|
||||
# ⏱️ 连接超时(秒)
|
||||
connect_timeout = 30.0
|
||||
|
||||
# ⏱️ 调用超时(秒)
|
||||
call_timeout = 60.0
|
||||
|
||||
# 🔄 自动连接 - 启动时自动连接所有已启用的服务器
|
||||
auto_connect = true
|
||||
|
||||
# 🔁 重试次数 - 连接失败时的重试次数
|
||||
retry_attempts = 3
|
||||
|
||||
# ⏳ 重试间隔(秒)
|
||||
retry_interval = 5.0
|
||||
|
||||
# 💓 心跳检测 - 定期检测服务器连接状态
|
||||
heartbeat_enabled = true
|
||||
|
||||
# 💓 心跳间隔(秒)- 建议 30-120 秒
|
||||
heartbeat_interval = 60.0
|
||||
|
||||
# 🔄 自动重连 - 检测到断开时自动尝试重连
|
||||
auto_reconnect = true
|
||||
|
||||
# 🔄 最大重连次数 - 连续重连失败后暂停重连
|
||||
max_reconnect_attempts = 3
|
||||
|
||||
# ============================================================
|
||||
# 高级功能(实验性)
|
||||
# ============================================================
|
||||
# 📦 启用 Resources - 允许读取 MCP 服务器提供的资源
|
||||
enable_resources = false
|
||||
|
||||
# 📝 启用 Prompts - 允许使用 MCP 服务器提供的提示模板
|
||||
enable_prompts = false
|
||||
|
||||
# ============================================================
|
||||
# 结果后处理功能
|
||||
# ============================================================
|
||||
# 当 MCP 工具返回的内容过长时,使用 LLM 对结果进行摘要提炼
|
||||
|
||||
# 🔄 启用结果后处理
|
||||
post_process_enabled = false
|
||||
|
||||
# 📏 后处理阈值(字符数)- 结果长度超过此值才触发后处理
|
||||
post_process_threshold = 500
|
||||
|
||||
# 🔢 后处理输出限制 - LLM 摘要输出的最大 token 数
|
||||
post_process_max_tokens = 500
|
||||
|
||||
# 🤖 后处理模型(可选)- 留空则使用 utils 模型组
|
||||
post_process_model = ""
|
||||
|
||||
# 🧠 后处理提示词模板
|
||||
post_process_prompt = '''用户问题:{query}
|
||||
|
||||
工具返回内容:
|
||||
{result}
|
||||
|
||||
请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:'''
|
||||
|
||||
# ============================================================
|
||||
# 调用链路追踪
|
||||
# ============================================================
|
||||
# 记录工具调用详情,便于调试和分析
|
||||
|
||||
# 🔍 启用调用追踪
|
||||
trace_enabled = true
|
||||
|
||||
# 📊 追踪记录上限 - 内存中保留的最大记录数
|
||||
trace_max_records = 50
|
||||
|
||||
# 📝 追踪日志文件 - 是否将追踪记录写入日志文件
|
||||
# 启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl
|
||||
trace_log_enabled = false
|
||||
|
||||
# ============================================================
|
||||
# 工具调用缓存
|
||||
# ============================================================
|
||||
# 缓存相同参数的调用结果,减少重复请求
|
||||
|
||||
# 🗄️ 启用调用缓存
|
||||
cache_enabled = false
|
||||
|
||||
# ⏱️ 缓存有效期(秒)
|
||||
cache_ttl = 300
|
||||
|
||||
# 📦 最大缓存条目 - 超出后 LRU 淘汰
|
||||
cache_max_entries = 200
|
||||
|
||||
# 🚫 缓存排除列表 - 即不缓存的工具(每行一个,支持通配符 *)
|
||||
# 时间类、随机类工具建议排除
|
||||
cache_exclude_tools = '''
|
||||
mcp_*_time_*
|
||||
mcp_*_random_*
|
||||
'''
|
||||
|
||||
# ============================================================
|
||||
# 工具管理
|
||||
# ============================================================
|
||||
[tools]
|
||||
# 📋 工具清单(只读)- 启动后自动生成
|
||||
tool_list = "(启动后自动生成)"
|
||||
|
||||
# 🚫 禁用工具列表 - 要禁用的工具名(每行一个)
|
||||
# 从上方工具清单复制工具名,禁用后该工具不会被 LLM 调用
|
||||
# 示例:
|
||||
# disabled_tools = '''
|
||||
# mcp_filesystem_delete_file
|
||||
# mcp_filesystem_write_file
|
||||
# '''
|
||||
disabled_tools = ""
|
||||
|
||||
# ============================================================
|
||||
# 权限控制
|
||||
# ============================================================
|
||||
[permissions]
|
||||
# 🔐 启用权限控制 - 按群/用户限制工具使用
|
||||
perm_enabled = false
|
||||
|
||||
# 📋 默认模式
|
||||
# allow_all: 未配置规则的工具默认允许
|
||||
# deny_all: 未配置规则的工具默认禁止
|
||||
perm_default_mode = "allow_all"
|
||||
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 🚀 快捷配置(推荐新手使用)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
|
||||
# 🚫 禁用群列表 - 这些群无法使用任何 MCP 工具(每行一个群号)
|
||||
# 示例:
|
||||
# quick_deny_groups = '''
|
||||
# 123456789
|
||||
# 987654321
|
||||
# '''
|
||||
quick_deny_groups = ""
|
||||
|
||||
# ✅ 管理员白名单 - 这些用户始终可以使用所有工具(每行一个QQ号)
|
||||
# 示例:
|
||||
# quick_allow_users = '''
|
||||
# 111111111
|
||||
# '''
|
||||
quick_allow_users = ""
|
||||
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 📜 高级权限规则(可选,针对特定工具配置)
|
||||
# ────────────────────────────────────────────────────────────
|
||||
# 格式: qq:ID:group/private/user,工具名支持通配符 *
|
||||
# 示例:
|
||||
# perm_rules = '''
|
||||
# [
|
||||
# {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||
# ]
|
||||
# '''
|
||||
perm_rules = "[]"
|
||||
|
||||
# ============================================================
|
||||
# 状态显示(只读)
|
||||
# ============================================================
|
||||
[status]
|
||||
connection_status = "未初始化"
|
||||
@@ -1 +0,0 @@
|
||||
"""Core helpers for MCP Bridge Plugin."""
|
||||
@@ -1,169 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
class ClaudeConfigError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
Transport = Literal["stdio", "sse", "http", "streamable_http"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClaudeMcpServer:
|
||||
name: str
|
||||
transport: Transport
|
||||
command: str = ""
|
||||
args: List[str] = field(default_factory=list)
|
||||
env: Dict[str, str] = field(default_factory=dict)
|
||||
url: str = ""
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
def _normalize_transport(value: Optional[str]) -> Transport:
|
||||
if not value:
|
||||
return "streamable_http"
|
||||
v = value.strip().lower().replace("-", "_")
|
||||
if v in ("streamable_http", "streamablehttp", "streamable"):
|
||||
return "streamable_http"
|
||||
if v in ("http",):
|
||||
return "http"
|
||||
if v in ("sse",):
|
||||
return "sse"
|
||||
if v in ("stdio",):
|
||||
return "stdio"
|
||||
raise ClaudeConfigError(f"unsupported transport: {value}")
|
||||
|
||||
|
||||
def _coerce_str_list(value: Any, field_name: str) -> List[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
raise ClaudeConfigError(f"{field_name} must be a list")
|
||||
|
||||
|
||||
def _coerce_str_dict(value: Any, field_name: str) -> Dict[str, str]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return {str(k): str(v) for k, v in value.items()}
|
||||
raise ClaudeConfigError(f"{field_name} must be an object")
|
||||
|
||||
|
||||
def parse_claude_mcp_config(config_json: str) -> List[ClaudeMcpServer]:
|
||||
"""Parse Claude Desktop style MCP config JSON.
|
||||
|
||||
Supported:
|
||||
- Full object: {"mcpServers": {...}}
|
||||
- Direct mapping: {...} treated as mcpServers
|
||||
"""
|
||||
text = (config_json or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ClaudeConfigError(f"invalid JSON: {e}") from e
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ClaudeConfigError("config must be a JSON object")
|
||||
|
||||
servers_obj = data.get("mcpServers", data)
|
||||
if not isinstance(servers_obj, dict):
|
||||
raise ClaudeConfigError("mcpServers must be an object")
|
||||
|
||||
servers: List[ClaudeMcpServer] = []
|
||||
for name, raw in servers_obj.items():
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ClaudeConfigError("server name must be a non-empty string")
|
||||
if not isinstance(raw, dict):
|
||||
raise ClaudeConfigError(f"server '{name}' must be an object")
|
||||
|
||||
enabled = bool(raw.get("enabled", True))
|
||||
command = str(raw.get("command", "") or "")
|
||||
url = str(raw.get("url", "") or "")
|
||||
args = _coerce_str_list(raw.get("args"), "args")
|
||||
env = _coerce_str_dict(raw.get("env"), "env")
|
||||
headers = _coerce_str_dict(raw.get("headers"), "headers")
|
||||
|
||||
transport_hint = raw.get("transport", raw.get("type"))
|
||||
|
||||
if command:
|
||||
transport: Transport = "stdio"
|
||||
elif url:
|
||||
try:
|
||||
transport = _normalize_transport(str(transport_hint) if transport_hint is not None else None)
|
||||
except ClaudeConfigError:
|
||||
transport = "streamable_http"
|
||||
else:
|
||||
raise ClaudeConfigError(f"server '{name}' must have either 'command' or 'url'")
|
||||
|
||||
servers.append(
|
||||
ClaudeMcpServer(
|
||||
name=name,
|
||||
transport=transport,
|
||||
command=command,
|
||||
args=args,
|
||||
env=env,
|
||||
url=url,
|
||||
headers=headers,
|
||||
enabled=enabled,
|
||||
)
|
||||
)
|
||||
|
||||
return servers
|
||||
|
||||
|
||||
def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
|
||||
"""Convert legacy v1.x servers list (JSON array) to Claude mcpServers JSON.
|
||||
|
||||
Legacy item schema:
|
||||
{"name","enabled","transport","url","headers","command","args","env"}
|
||||
"""
|
||||
text = (servers_list_json or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
if not isinstance(data, list):
|
||||
return ""
|
||||
|
||||
mcp_servers: Dict[str, Any] = {}
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name", "") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
enabled = bool(item.get("enabled", True))
|
||||
transport = str(item.get("transport", "") or "").strip().lower().replace("-", "_")
|
||||
|
||||
if transport == "stdio" or item.get("command"):
|
||||
entry: Dict[str, Any] = {
|
||||
"enabled": enabled,
|
||||
"command": item.get("command", "") or "",
|
||||
"args": item.get("args", []) or [],
|
||||
}
|
||||
if item.get("env"):
|
||||
entry["env"] = item.get("env")
|
||||
mcp_servers[name] = entry
|
||||
continue
|
||||
|
||||
entry = {"enabled": enabled, "url": item.get("url", "") or ""}
|
||||
if item.get("headers"):
|
||||
entry["headers"] = item.get("headers")
|
||||
if transport:
|
||||
entry["transport"] = transport
|
||||
mcp_servers[name] = entry
|
||||
|
||||
if not mcp_servers:
|
||||
return ""
|
||||
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,2 +0,0 @@
|
||||
# MCP 桥接插件依赖
|
||||
mcp>=1.0.0
|
||||
@@ -1,584 +0,0 @@
|
||||
"""
|
||||
MCP Workflow 模块 v1.9.0
|
||||
支持用户自定义工作流(硬流程),将多个 MCP 工具按顺序执行
|
||||
|
||||
双轨制架构:
|
||||
- 软流程 (ReAct): LLM 自主决策,动态多轮调用工具,灵活但不可预测
|
||||
- 硬流程 (Workflow): 用户预定义的工作流,固定流程,可靠可控
|
||||
|
||||
功能:
|
||||
- Workflow 定义和管理
|
||||
- 顺序执行多个工具(硬流程)
|
||||
- 支持变量替换(使用前序工具的输出)
|
||||
- 自动注册为组合工具供 LLM 调用
|
||||
- 与 ReAct 软流程互补,用户可选择合适的执行方式
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_tool_chain")
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mcp_tool_chain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainStep:
|
||||
"""工具链步骤"""
|
||||
|
||||
tool_name: str # 要调用的工具名(如 mcp_server_tool)
|
||||
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
|
||||
output_key: str = "" # 输出存储的键名,供后续步骤引用
|
||||
description: str = "" # 步骤描述
|
||||
optional: bool = False # 是否可选(失败时继续执行)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": self.tool_name,
|
||||
"args_template": self.args_template,
|
||||
"output_key": self.output_key,
|
||||
"description": self.description,
|
||||
"optional": self.optional,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
args_template=data.get("args_template", {}),
|
||||
output_key=data.get("output_key", ""),
|
||||
description=data.get("description", ""),
|
||||
optional=data.get("optional", False),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolChainDefinition:
|
||||
"""工具链定义"""
|
||||
|
||||
name: str # 工具链名称(将作为组合工具的名称)
|
||||
description: str # 工具链描述(供 LLM 理解)
|
||||
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
|
||||
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"steps": [step.to_dict() for step in self.steps],
|
||||
"input_params": self.input_params,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
|
||||
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
|
||||
return cls(
|
||||
name=data.get("name", ""),
|
||||
description=data.get("description", ""),
|
||||
steps=steps,
|
||||
input_params=data.get("input_params", {}),
|
||||
enabled=data.get("enabled", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainExecutionResult:
|
||||
"""工具链执行结果"""
|
||||
|
||||
success: bool
|
||||
final_output: str # 最终输出(最后一个步骤的结果)
|
||||
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
|
||||
error: str = ""
|
||||
total_duration_ms: float = 0.0
|
||||
|
||||
def to_summary(self) -> str:
|
||||
"""生成执行摘要"""
|
||||
lines = []
|
||||
for i, step in enumerate(self.step_results):
|
||||
status = "✅" if step.get("success") else "❌"
|
||||
tool = step.get("tool_name", "unknown")
|
||||
duration = step.get("duration_ms", 0)
|
||||
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
|
||||
if not step.get("success") and step.get("error"):
|
||||
lines.append(f" 错误: {step['error'][:50]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ToolChainExecutor:
|
||||
"""工具链执行器"""
|
||||
|
||||
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
|
||||
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
|
||||
|
||||
def __init__(self, mcp_manager):
|
||||
self._mcp_manager = mcp_manager
|
||||
|
||||
def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
|
||||
"""解析工具名,返回有效的 tool_key
|
||||
|
||||
支持:
|
||||
- 直接使用 tool_key(如 mcp_server_tool)
|
||||
- 使用注册后的工具名(会自动转换 - 和 . 为 _)
|
||||
"""
|
||||
all_tools = self._mcp_manager.all_tools
|
||||
|
||||
# 直接匹配
|
||||
if tool_name in all_tools:
|
||||
return tool_name
|
||||
|
||||
# 尝试转换后匹配(用户可能使用了注册后的名称)
|
||||
normalized = tool_name.replace("-", "_").replace(".", "_")
|
||||
if normalized in all_tools:
|
||||
return normalized
|
||||
|
||||
# 尝试查找包含该名称的工具
|
||||
for key in all_tools.keys():
|
||||
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
chain: ToolChainDefinition,
|
||||
input_args: Dict[str, Any],
|
||||
) -> ChainExecutionResult:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
input_args: 用户输入的参数
|
||||
|
||||
Returns:
|
||||
ChainExecutionResult: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
step_results = []
|
||||
context = {
|
||||
"input": input_args or {}, # 用户输入,确保不为 None
|
||||
"step": {}, # 各步骤输出,按 output_key 存储
|
||||
"prev": "", # 上一步的输出
|
||||
}
|
||||
|
||||
final_output = ""
|
||||
|
||||
# 验证必需的输入参数
|
||||
missing_params = []
|
||||
for param_name in chain.input_params.keys():
|
||||
if param_name not in context["input"]:
|
||||
missing_params.append(param_name)
|
||||
|
||||
if missing_params:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"缺少必需参数: {', '.join(missing_params)}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
step_start = time.time()
|
||||
step_result = {
|
||||
"step_index": i,
|
||||
"tool_name": step.tool_name,
|
||||
"success": False,
|
||||
"output": "",
|
||||
"error": "",
|
||||
"duration_ms": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# 替换参数中的变量
|
||||
resolved_args = self._resolve_args(step.args_template, context)
|
||||
step_result["resolved_args"] = resolved_args
|
||||
|
||||
# 解析工具名
|
||||
tool_key = self._resolve_tool_key(step.tool_name)
|
||||
if not tool_key:
|
||||
step_result["error"] = f"工具 {step.tool_name} 不存在"
|
||||
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
continue
|
||||
|
||||
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
|
||||
|
||||
# 调用工具
|
||||
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
|
||||
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
|
||||
if result.success:
|
||||
step_result["success"] = True
|
||||
# 确保 content 不为 None
|
||||
content = result.content if result.content is not None else ""
|
||||
step_result["output"] = content
|
||||
|
||||
# 更新上下文
|
||||
context["prev"] = content
|
||||
if step.output_key:
|
||||
context["step"][step.output_key] = content
|
||||
|
||||
final_output = content
|
||||
content_preview = content[:100] if content else "(空)"
|
||||
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
|
||||
else:
|
||||
step_result["error"] = result.error or "未知错误"
|
||||
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
step_duration = (time.time() - step_start) * 1000
|
||||
step_result["duration_ms"] = step_duration
|
||||
step_result["error"] = str(e)
|
||||
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
|
||||
|
||||
if not step.optional:
|
||||
step_results.append(step_result)
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
step_results=step_results,
|
||||
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
|
||||
total_duration_ms=(time.time() - start_time) * 1000,
|
||||
)
|
||||
|
||||
step_results.append(step_result)
|
||||
|
||||
total_duration = (time.time() - start_time) * 1000
|
||||
|
||||
return ChainExecutionResult(
|
||||
success=True,
|
||||
final_output=final_output,
|
||||
step_results=step_results,
|
||||
total_duration_ms=total_duration,
|
||||
)
|
||||
|
||||
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解析参数模板,替换变量
|
||||
|
||||
支持的变量格式:
|
||||
- ${input.param_name}: 用户输入的参数
|
||||
- ${step.output_key}: 某个步骤的输出
|
||||
- ${prev}: 上一步的输出
|
||||
- ${prev.field}: 上一步输出(JSON)的某个字段
|
||||
"""
|
||||
resolved = {}
|
||||
|
||||
for key, value in args_template.items():
|
||||
if isinstance(value, str):
|
||||
resolved[key] = self._substitute_vars(value, context)
|
||||
elif isinstance(value, dict):
|
||||
resolved[key] = self._resolve_args(value, context)
|
||||
elif isinstance(value, list):
|
||||
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
|
||||
else:
|
||||
resolved[key] = value
|
||||
|
||||
return resolved
|
||||
|
||||
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""替换字符串中的变量"""
|
||||
|
||||
def replacer(match):
|
||||
var_path = match.group(1)
|
||||
return self._get_var_value(var_path, context)
|
||||
|
||||
return self.VAR_PATTERN.sub(replacer, template)
|
||||
|
||||
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id"
|
||||
context: 上下文
|
||||
"""
|
||||
parts = self._parse_var_path(var_path)
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
# 获取根对象
|
||||
root = parts[0]
|
||||
if root not in context:
|
||||
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
|
||||
return ""
|
||||
|
||||
value = context[root]
|
||||
|
||||
# 遍历路径
|
||||
for part in parts[1:]:
|
||||
if isinstance(value, str):
|
||||
parsed = self._try_parse_json(value)
|
||||
if parsed is not None:
|
||||
value = parsed
|
||||
|
||||
if isinstance(value, dict):
|
||||
value = value.get(part, "")
|
||||
elif isinstance(value, list):
|
||||
if part.isdigit():
|
||||
idx = int(part)
|
||||
value = value[idx] if 0 <= idx < len(value) else ""
|
||||
else:
|
||||
value = ""
|
||||
else:
|
||||
value = ""
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if value is None:
|
||||
return ""
|
||||
if value == "":
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
def _try_parse_json(self, value: str) -> Optional[Any]:
|
||||
"""尝试将字符串解析为 JSON 对象,失败则返回 None。"""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _parse_var_path(self, var_path: str) -> List[str]:
|
||||
"""解析变量路径,支持点号与下标写法。
|
||||
|
||||
支持:
|
||||
- step.geo.return.0.location
|
||||
- step.geo.return[0].location
|
||||
- step.geo['return'][0]['location']
|
||||
"""
|
||||
if not var_path:
|
||||
return []
|
||||
|
||||
tokens: List[str] = []
|
||||
buf: List[str] = []
|
||||
in_bracket = False
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
|
||||
def flush_buf() -> None:
|
||||
if buf:
|
||||
token = "".join(buf).strip()
|
||||
if token:
|
||||
tokens.append(token)
|
||||
buf.clear()
|
||||
|
||||
i = 0
|
||||
while i < len(var_path):
|
||||
ch = var_path[i]
|
||||
|
||||
if not in_bracket and ch == ".":
|
||||
flush_buf()
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if not in_bracket and ch == "[":
|
||||
flush_buf()
|
||||
in_bracket = True
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and not in_quote and ch == "]":
|
||||
flush_buf()
|
||||
in_bracket = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and ch in ("'", '"'):
|
||||
if not in_quote:
|
||||
in_quote = True
|
||||
quote_char = ch
|
||||
i += 1
|
||||
continue
|
||||
if quote_char == ch:
|
||||
in_quote = False
|
||||
quote_char = ""
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_bracket and not in_quote:
|
||||
if ch.isspace():
|
||||
i += 1
|
||||
continue
|
||||
if ch == ",":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
buf.append(ch)
|
||||
i += 1
|
||||
|
||||
flush_buf()
|
||||
|
||||
if in_bracket or in_quote:
|
||||
return [p for p in var_path.split(".") if p]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class ToolChainManager:
|
||||
"""工具链管理器"""
|
||||
|
||||
_instance: Optional["ToolChainManager"] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._chains: Dict[str, ToolChainDefinition] = {}
|
||||
self._executor: Optional[ToolChainExecutor] = None
|
||||
|
||||
def set_executor(self, mcp_manager) -> None:
|
||||
"""设置执行器"""
|
||||
self._executor = ToolChainExecutor(mcp_manager)
|
||||
|
||||
def add_chain(self, chain: ToolChainDefinition) -> bool:
|
||||
"""添加工具链"""
|
||||
if not chain.name:
|
||||
logger.error("工具链名称不能为空")
|
||||
return False
|
||||
|
||||
if chain.name in self._chains:
|
||||
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
|
||||
return True
|
||||
|
||||
def remove_chain(self, name: str) -> bool:
|
||||
"""移除工具链"""
|
||||
if name in self._chains:
|
||||
del self._chains[name]
|
||||
logger.info(f"已移除工具链: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
|
||||
"""获取工具链"""
|
||||
return self._chains.get(name)
|
||||
|
||||
def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有工具链"""
|
||||
return self._chains.copy()
|
||||
|
||||
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
|
||||
"""获取所有启用的工具链"""
|
||||
return {name: chain for name, chain in self._chains.items() if chain.enabled}
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
input_args: Dict[str, Any],
|
||||
) -> ChainExecutionResult:
|
||||
"""执行工具链"""
|
||||
chain = self._chains.get(chain_name)
|
||||
if not chain:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"工具链 {chain_name} 不存在",
|
||||
)
|
||||
|
||||
if not chain.enabled:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error=f"工具链 {chain_name} 已禁用",
|
||||
)
|
||||
|
||||
if not self._executor:
|
||||
return ChainExecutionResult(
|
||||
success=False,
|
||||
final_output="",
|
||||
error="工具链执行器未初始化",
|
||||
)
|
||||
|
||||
return await self._executor.execute(chain, input_args)
|
||||
|
||||
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
|
||||
"""从 JSON 字符串加载工具链配置
|
||||
|
||||
Returns:
|
||||
(成功加载数量, 错误列表)
|
||||
"""
|
||||
errors = []
|
||||
loaded = 0
|
||||
|
||||
try:
|
||||
data = json.loads(json_str) if json_str.strip() else []
|
||||
except json.JSONDecodeError as e:
|
||||
return 0, [f"JSON 解析失败: {e}"]
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for i, item in enumerate(data):
|
||||
try:
|
||||
chain = ToolChainDefinition.from_dict(item)
|
||||
if not chain.name:
|
||||
errors.append(f"第 {i + 1} 个工具链缺少名称")
|
||||
continue
|
||||
if not chain.steps:
|
||||
errors.append(f"工具链 {chain.name} 没有步骤")
|
||||
continue
|
||||
|
||||
self.add_chain(chain)
|
||||
loaded += 1
|
||||
except Exception as e:
|
||||
errors.append(f"第 {i + 1} 个工具链解析失败: {e}")
|
||||
|
||||
return loaded, errors
|
||||
|
||||
def export_to_json(self, pretty: bool = True) -> str:
|
||||
"""导出所有工具链为 JSON"""
|
||||
chains_data = [chain.to_dict() for chain in self._chains.values()]
|
||||
if pretty:
|
||||
return json.dumps(chains_data, ensure_ascii=False, indent=2)
|
||||
return json.dumps(chains_data, ensure_ascii=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有工具链"""
|
||||
self._chains.clear()
|
||||
|
||||
|
||||
# 全局工具链管理器实例
|
||||
tool_chain_manager = ToolChainManager()
|
||||
@@ -1,22 +0,0 @@
|
||||
你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析:
|
||||
|
||||
【时间感知分析】
|
||||
1. 对话持续时长:当前对话已经进行了多久
|
||||
2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何
|
||||
3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适
|
||||
4. 时间相关洞察:
|
||||
- 用户是否可能正在忙(回复变慢)
|
||||
- 用户是否正在积极对话(回复很快)
|
||||
- 当前时段(深夜/早晨/工作时间等)是否适合继续聊
|
||||
- 对话是否已经持续太久,用户可能需要休息
|
||||
- 是否应该主动结束对话
|
||||
|
||||
【自我反思分析】
|
||||
1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论
|
||||
2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论
|
||||
3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断
|
||||
|
||||
要求:
|
||||
- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半
|
||||
- 重点关注对话节奏的变化趋势和助手自身的人设一致性
|
||||
- 直接输出分析结果,不要有格式标题或分段标记
|
||||
@@ -1,22 +0,0 @@
|
||||
你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析:
|
||||
|
||||
【时间感知分析】
|
||||
1. 对话持续时长:当前对话已经进行了多久
|
||||
2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何
|
||||
3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适
|
||||
4. 时间相关洞察:
|
||||
- 用户是否可能正在忙(回复变慢)
|
||||
- 用户是否正在积极对话(回复很快)
|
||||
- 当前时段(深夜/早晨/工作时间等)是否适合继续聊
|
||||
- 对话是否已经持续太久,用户可能需要休息
|
||||
- 是否应该主动结束对话
|
||||
|
||||
【自我反思分析】
|
||||
1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论
|
||||
2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论
|
||||
3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断
|
||||
|
||||
要求:
|
||||
- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半
|
||||
- 重点关注对话节奏的变化趋势和助手自身的人设一致性
|
||||
- 直接输出分析结果,不要有格式标题或分段标记
|
||||
@@ -1,5 +0,0 @@
|
||||
{action_name}
|
||||
动作描述:{action_description}
|
||||
使用条件{parallel_text}:
|
||||
{action_require}
|
||||
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}}
|
||||
@@ -1 +0,0 @@
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
@@ -1 +0,0 @@
|
||||
正在群里聊天
|
||||
@@ -1 +0,0 @@
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
@@ -1 +0,0 @@
|
||||
和{sender_name}聊天
|
||||
@@ -1,10 +0,0 @@
|
||||
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的知识获取指令
|
||||
|
||||
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
|
||||
@@ -1,5 +1,5 @@
|
||||
你的任务是分析聊天和聊天中的互动情况。
|
||||
你需要关注 {bot_name}(AI) 与不同用户的对话来为选择正确的动作和行为提供建议
|
||||
你需要关注 {bot_name}(AI) 与不同用户的对话来为选择正确的动作和行为以及搜集信息提供建议
|
||||
|
||||
【参考信息】
|
||||
{bot_name}的人设:{identity}
|
||||
@@ -8,28 +8,28 @@
|
||||
你需要根据提供的参考信息,当前场景和输出规则来进行分析
|
||||
在当前场景中,用户正在与AI麦麦进行聊天互动,你的任务不是生成对用户可见的发言,而是进行分析来指导AI进行回复。
|
||||
“分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。
|
||||
你需要先搜集能够帮助{bot_name}回复的信息,然后再给出回复意见
|
||||
|
||||
|
||||
你可以使用这些工具:
|
||||
- wait(seconds) - 暂时停止对话,等待(seconds)秒,把话语权交给用户,等待对方新的发言。
|
||||
- stop() - 结束对话,不进行任何回复,直到对方有新消息。
|
||||
- reply():当你判断现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。
|
||||
- no_reply():当你判断现在不应该发言,应该继续内部思考时调用。这个工具不会做任何外部行为,只会继续下一轮循环。
|
||||
{file_tools_section}
|
||||
- stop() - 当你判断{bot_name}现在不应该发言,结束对话,不进行任何回复,直到对方有新消息。
|
||||
- reply():当你判断{bot_name}现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。
|
||||
- query_jargon():当你认为某些词的含义不明确,或用户询问某些词的含义,需要进行查询
|
||||
- 其他定义的工具,你可以视情况合适使用
|
||||
|
||||
工具使用规则:
|
||||
1.如果麦麦已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用wait或者stop进行等待
|
||||
1.如果{bot_name}已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用wait或者stop进行等待
|
||||
2.如果用户有新发言,但是你评估用户还有后续发言尚未发送,可以适当等待让用户说完
|
||||
3.在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,可以不使用stop或者wait
|
||||
4.如果你想指导麦麦直接发言,可以不使用任何工具
|
||||
4.你需要控制自己发言的频率,如果用户一对一聊天,可以以均匀地频率发言,如果用户较多,不要每句都回复,控制回复频率。当你决定暂时不发言,可以使用wait暂时等待一定时间或者stop等待新消息
|
||||
5.如果存在用户的疑问,或者对某些概念的不确定,你可以使用工具来搜集信息或者查询含义,你可以使用多个工具
|
||||
|
||||
你的输出规则:
|
||||
你的分析规则:
|
||||
1. 默认直接输出你当前的最新分析,不要重复之前的分析内容。
|
||||
2. 最新分析应尽量具体,贴近上下文,不要空泛重复。
|
||||
3. 如果你认为现在更适合等待用户补充,可以调用 `wait(seconds)`。
|
||||
4. 如果你认为应当结束当前对话,不回复任何内容,可以调用 `stop()`。
|
||||
5. 只有在确实需要等待或停止时才调用工具,否则优先直接输出分析想法。
|
||||
6. 如果你刚刚做了工具调用,下一轮应结合工具结果继续输出新的分析。
|
||||
7. 分析应服务于后续决策,而不是机械复述用户内容。
|
||||
3. 如果你刚刚做了工具调用,下一轮应结合工具结果继续输出新的分析。
|
||||
4. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题。
|
||||
5. 如果你上一轮没有发言,需要重新进行分析,输出新的分析内容,不要重复上一轮的分析内容
|
||||
|
||||
现在,请你输出你的分析:
|
||||
现在,请你输出你对{bot_name}发言的分析,你必须先输出文本内容的分析,然后再进行工具调用:
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
你是一个认知感知分析模块。你的任务是根据对话上下文,分析对话中用户的:
|
||||
1. 核心意图(如:寻求帮助、纯粹聊天、请求任务、发泄情绪、获取信息、表达观点等)
|
||||
2. 认知状态(如:明确具体、模糊试探、犹豫不决、困惑迷茫、思路清晰、逻辑混乱等)
|
||||
3. 隐含目的(如:解决问题、获得安慰、打发时间、寻求认同、交换想法、表达自我等)
|
||||
|
||||
要求:
|
||||
- 只分析用户(对话中 role=user 的内容),不要分析助手自己
|
||||
- 根据用户最新发言重点分析,同时结合上下文理解深层动机
|
||||
- 输出简洁(2-4 句话),不要太长
|
||||
- 如果信息太少无法判断,就说信息不足,给出初步印象
|
||||
- 直接输出分析结果,不要有格式标题
|
||||
@@ -1,11 +0,0 @@
|
||||
你是一个情绪感知分析模块。你的任务是根据对话上下文,分析对话中用户的:
|
||||
1. 当前情绪状态(如:开心、沮丧、焦虑、平静、兴奋、愤怒等)
|
||||
2. 言语态度(如:友好、冷淡、热情、敷衍、试探、认真、调侃等)
|
||||
3. 潜在的情感需求(如:需要倾听、需要鼓励、想要倾诉、只是闲聊等)
|
||||
|
||||
要求:
|
||||
- 只分析用户(对话中 role=user 的内容),不要分析助手自己
|
||||
- 根据用户最新发言重点分析,同时结合上下文理解变化趋势
|
||||
- 输出简洁(2-4 句话),不要太长
|
||||
- 如果信息太少无法判断,就说信息不足,给出初步印象
|
||||
- 直接输出分析结果,不要有格式标题
|
||||
@@ -1,7 +1,10 @@
|
||||
你的名字是麦麦,也有人叫你牢麦,麦叠,你是一个大二在读女大学生,现在正在上网和群友聊天,有点攻击性;
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,
|
||||
尽量简短一些。
|
||||
请注意把握聊天内容,不要回复的太有条理。
|
||||
你的风格平淡但不失讽刺,不过分兴奋,很简短。可以参考贴吧,知乎和微博的回复风格。很平淡和白话,不浮夸不长篇大论,b站评论风格,但一定注意不要过分修辞和复杂句。
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
|
||||
{time_block}
|
||||
|
||||
{identity}
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复,
|
||||
尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
@@ -1,22 +0,0 @@
|
||||
你是一个对话节奏与时间感知分析模块,同时负责自我反思。你的任务是根据对话上下文和系统提供的时间戳信息,分析:
|
||||
|
||||
【时间感知分析】
|
||||
1. 对话持续时长:当前对话已经进行了多久
|
||||
2. 回复间隔:用户上次发言距今多久、用户的平均回复速度如何
|
||||
3. 建议等待时长:结合对话内容和时间规律,建议下次等待多少秒比较合适
|
||||
4. 时间相关洞察:
|
||||
- 用户是否可能正在忙(回复变慢)
|
||||
- 用户是否正在积极对话(回复很快)
|
||||
- 当前时段(深夜/早晨/工作时间等)是否适合继续聊
|
||||
- 对话是否已经持续太久,用户可能需要休息
|
||||
- 是否应该主动结束对话
|
||||
|
||||
【自我反思分析】
|
||||
1. 人设一致性:是否符合设定的人格特质、说话风格是否一致、是否有不符合身份的言论
|
||||
2. 回复合理性:是否有逻辑漏洞、是否回应了用户的核心诉求、是否有过当或不当言论
|
||||
3. 认知局限性:是否对某些情况理解不足、是否缺乏必要信息、是否做出了过度推断
|
||||
|
||||
要求:
|
||||
- 输出简洁(4-6 句话),时间感知分析和自我反思分析各占一半
|
||||
- 重点关注对话节奏的变化趋势和助手自身的人设一致性
|
||||
- 直接输出分析结果,不要有格式标题或分段标记
|
||||
@@ -1,14 +0,0 @@
|
||||
{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。
|
||||
{identity}
|
||||
{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
@@ -1,18 +0,0 @@
|
||||
{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
{reply_target_block}。
|
||||
{planner_reasoning}
|
||||
{identity}
|
||||
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,
|
||||
尽量简短一些。{keywords_reaction_prompt}
|
||||
请注意把握聊天内容,不要回复的太有条理。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。
|
||||
现在,你说:
|
||||
@@ -1,11 +0,0 @@
|
||||
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的工具使用指令
|
||||
你可以选择多个动作
|
||||
|
||||
If you need to use tools, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
|
||||
@@ -19,7 +19,8 @@ dependencies = [
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"maim-message>=0.6.2",
|
||||
"maibot-plugin-sdk>=2.0.0",
|
||||
"maibot-plugin-sdk>=2.1.0",
|
||||
"mcp",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
|
||||
924
pytests/common_test/test_database_migration_foundation.py
Normal file
924
pytests/common_test/test_database_migration_foundation.py
Normal file
@@ -0,0 +1,924 @@
|
||||
"""数据库迁移基础设施测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
import json
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import (
|
||||
BaseSchemaVersionDetector,
|
||||
BaseMigrationProgressReporter,
|
||||
DatabaseSchemaSnapshot,
|
||||
DatabaseMigrationBootstrapper,
|
||||
DatabaseMigrationState,
|
||||
DatabaseMigrationManager,
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationRegistry,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionResolver,
|
||||
SchemaVersionSource,
|
||||
SQLiteSchemaInspector,
|
||||
SQLiteUserVersionStore,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
create_database_migration_bootstrapper,
|
||||
)
|
||||
|
||||
|
||||
class FixedVersionDetector(BaseSchemaVersionDetector):
|
||||
"""测试用固定版本探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回测试探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 探测器名称。
|
||||
"""
|
||||
return "fixed_version_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据测试表是否存在返回固定版本。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
||||
"""
|
||||
if snapshot.has_table("legacy_records"):
|
||||
return 2
|
||||
return None
|
||||
|
||||
|
||||
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""测试用迁移进度上报器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试用进度上报器。"""
|
||||
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""记录打开事件。"""
|
||||
self.events.append(("open", None, None, None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""记录关闭事件。"""
|
||||
self.events.append(("close", None, None, None))
|
||||
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""记录启动事件。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
del table_unit_name, record_unit_name
|
||||
self.events.append(("start", total_records, total_tables, description))
|
||||
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""记录推进事件。
|
||||
|
||||
Args:
|
||||
records: 推进的记录数。
|
||||
completed_tables: 已完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
self.events.append(("advance", records, completed_tables, item_name))
|
||||
|
||||
|
||||
def _create_sqlite_engine(database_file: Path) -> Engine:
|
||||
"""创建测试用 SQLite 引擎。
|
||||
|
||||
Args:
|
||||
database_file: 测试数据库文件路径。
|
||||
|
||||
Returns:
|
||||
Engine: SQLite 引擎实例。
|
||||
"""
|
||||
return create_engine(
|
||||
f"sqlite:///{database_file}",
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
def _create_current_schema(connection: Connection) -> None:
|
||||
"""创建当前最新版本的数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(connection)
|
||||
|
||||
|
||||
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
||||
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE chat_streams (
|
||||
id INTEGER PRIMARY KEY,
|
||||
stream_id TEXT NOT NULL,
|
||||
create_time REAL NOT NULL,
|
||||
last_active_time REAL NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
group_id TEXT,
|
||||
group_name TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY,
|
||||
message_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_info_platform TEXT,
|
||||
user_id TEXT,
|
||||
user_nickname TEXT,
|
||||
chat_info_group_id TEXT,
|
||||
chat_info_group_name TEXT,
|
||||
is_mentioned INTEGER,
|
||||
is_at INTEGER,
|
||||
processed_plain_text TEXT,
|
||||
display_message TEXT,
|
||||
is_emoji INTEGER,
|
||||
is_picid INTEGER,
|
||||
is_command INTEGER,
|
||||
is_notify INTEGER,
|
||||
additional_config TEXT,
|
||||
priority_mode TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE action_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
action_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
action_reasoning TEXT,
|
||||
action_name TEXT NOT NULL,
|
||||
action_data TEXT,
|
||||
action_prompt_display TEXT,
|
||||
chat_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE expression (
|
||||
id INTEGER PRIMARY KEY,
|
||||
situation TEXT NOT NULL,
|
||||
style TEXT NOT NULL,
|
||||
content_list TEXT,
|
||||
count INTEGER,
|
||||
last_active_time REAL NOT NULL,
|
||||
chat_id TEXT,
|
||||
create_date REAL,
|
||||
checked INTEGER,
|
||||
rejected INTEGER,
|
||||
modified_by TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE jargon (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
raw_content TEXT,
|
||||
meaning TEXT,
|
||||
chat_id TEXT,
|
||||
is_global INTEGER,
|
||||
count INTEGER,
|
||||
is_jargon INTEGER,
|
||||
last_inference_count INTEGER,
|
||||
is_complete INTEGER,
|
||||
inference_with_context TEXT,
|
||||
inference_content_only TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO chat_streams (
|
||||
id,
|
||||
stream_id,
|
||||
create_time,
|
||||
last_active_time,
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
group_name
|
||||
) VALUES (
|
||||
1,
|
||||
'session-1',
|
||||
1710000000.0,
|
||||
1710000300.0,
|
||||
'qq',
|
||||
'user-1',
|
||||
'group-1',
|
||||
'测试群'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
message_id,
|
||||
time,
|
||||
chat_id,
|
||||
chat_info_platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
chat_info_group_id,
|
||||
chat_info_group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
processed_plain_text,
|
||||
display_message,
|
||||
is_emoji,
|
||||
is_picid,
|
||||
is_command,
|
||||
is_notify,
|
||||
additional_config,
|
||||
priority_mode
|
||||
) VALUES (
|
||||
1,
|
||||
'msg-1',
|
||||
1710000010.0,
|
||||
'session-1',
|
||||
'qq',
|
||||
'user-1',
|
||||
'测试用户',
|
||||
'group-1',
|
||||
'测试群',
|
||||
1,
|
||||
0,
|
||||
'你好',
|
||||
'你好呀',
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
'{"source":"legacy"}',
|
||||
'high'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO action_records (
|
||||
id,
|
||||
action_id,
|
||||
time,
|
||||
action_reasoning,
|
||||
action_name,
|
||||
action_data,
|
||||
action_prompt_display,
|
||||
chat_id
|
||||
) VALUES (
|
||||
1,
|
||||
'action-1',
|
||||
1710000020.0,
|
||||
'需要调用工具',
|
||||
'search',
|
||||
'{"query":"MaiBot"}',
|
||||
'执行搜索',
|
||||
'session-1'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO expression (
|
||||
id,
|
||||
situation,
|
||||
style,
|
||||
content_list,
|
||||
count,
|
||||
last_active_time,
|
||||
chat_id,
|
||||
create_date,
|
||||
checked,
|
||||
rejected,
|
||||
modified_by
|
||||
) VALUES (
|
||||
1,
|
||||
'打招呼',
|
||||
'可爱',
|
||||
'["你好呀","早上好"]',
|
||||
3,
|
||||
1710000030.0,
|
||||
'session-1',
|
||||
1710000040.0,
|
||||
1,
|
||||
0,
|
||||
'ai'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO jargon (
|
||||
id,
|
||||
content,
|
||||
raw_content,
|
||||
meaning,
|
||||
chat_id,
|
||||
is_global,
|
||||
count,
|
||||
is_jargon,
|
||||
last_inference_count,
|
||||
is_complete,
|
||||
inference_with_context,
|
||||
inference_content_only
|
||||
) VALUES (
|
||||
1,
|
||||
'上分',
|
||||
'["上分"]',
|
||||
'提高排名',
|
||||
'session-1',
|
||||
0,
|
||||
5,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
'{"guess":"context"}',
|
||||
'{"guess":"content"}'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
||||
"""应支持读取与写入 SQLite ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
||||
version_store = SQLiteUserVersionStore()
|
||||
|
||||
with engine.begin() as connection:
|
||||
assert version_store.read_version(connection) == 0
|
||||
version_store.write_version(connection, 7)
|
||||
|
||||
with engine.connect() as connection:
|
||||
assert version_store.read_version(connection) == 7
|
||||
|
||||
|
||||
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
||||
"""应能提取 SQLite 数据库的表与列结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
||||
inspector = SQLiteSchemaInspector()
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE legacy_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = inspector.inspect(connection)
|
||||
|
||||
assert snapshot.has_table("legacy_records")
|
||||
assert snapshot.has_column("legacy_records", "payload")
|
||||
assert not snapshot.has_column("legacy_records", "missing_column")
|
||||
table_schema = snapshot.get_table("legacy_records")
|
||||
|
||||
assert table_schema is not None
|
||||
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
||||
|
||||
|
||||
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
||||
"""空数据库应被解析为版本 ``0``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
||||
resolver = SchemaVersionResolver()
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 0
|
||||
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
||||
assert resolved_version.snapshot is not None
|
||||
assert resolved_version.snapshot.is_empty()
|
||||
|
||||
|
||||
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
||||
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
||||
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 2
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "fixed_version_detector"
|
||||
|
||||
|
||||
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
||||
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
||||
executed_steps: List[str] = []
|
||||
|
||||
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 0 -> 1。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
||||
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 1 -> 2。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
||||
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=0,
|
||||
version_to=1,
|
||||
name="create_sample_records",
|
||||
description="创建示例表。",
|
||||
handler=migrate_0_to_1,
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="add_sample_email",
|
||||
description="为示例表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
),
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 2
|
||||
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
version_store = SQLiteUserVersionStore()
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = version_store.read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("sample_records")
|
||||
assert snapshot.has_column("sample_records", "email")
|
||||
|
||||
|
||||
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
||||
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
|
||||
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="progress_step",
|
||||
description="测试进度上报。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=registry,
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
assert reporter_instances[0].events == [
|
||||
("open", None, None, None),
|
||||
("start", 30, 3, "总迁移进度"),
|
||||
("advance", 10, 1, "chat_sessions"),
|
||||
("advance", 10, 1, "mai_messages"),
|
||||
("advance", 10, 1, "tool_records"),
|
||||
("close", None, None, None),
|
||||
]
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "latest_schema_detector"
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
||||
"""空库在建表完成后应回写最新 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
||||
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
||||
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
||||
execution_marks: List[str] = []
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试桥接器迁移步骤 ``1 -> 2``。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
||||
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
||||
)
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="bootstrap_add_email",
|
||||
description="为桥接器测试表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
bootstrapper = DatabaseMigrationBootstrapper(
|
||||
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
||||
latest_schema_version=2,
|
||||
)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert migration_state.resolved_version.version == 2
|
||||
assert migration_state.target_version == 2
|
||||
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("bootstrap_records")
|
||||
assert snapshot.has_column("bootstrap_records", "email")
|
||||
|
||||
|
||||
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
message_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, processed_plain_text, additional_config, raw_content
|
||||
FROM mai_messages
|
||||
WHERE message_id = 'msg-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
action_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, action_name, action_display_prompt
|
||||
FROM action_records
|
||||
WHERE action_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
tool_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, tool_name, tool_display_prompt
|
||||
FROM tool_records
|
||||
WHERE tool_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
expression_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, content_list, modified_by
|
||||
FROM expressions
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
jargon_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id_dict, raw_content, inference_with_content_only
|
||||
FROM jargons
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
assert snapshot.has_table("__legacy_v1_messages")
|
||||
assert snapshot.has_table("chat_sessions")
|
||||
assert snapshot.has_table("mai_messages")
|
||||
assert snapshot.has_table("tool_records")
|
||||
|
||||
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
||||
additional_config = json.loads(message_row["additional_config"])
|
||||
expression_content_list = json.loads(expression_row["content_list"])
|
||||
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
||||
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
||||
|
||||
assert message_row["session_id"] == "session-1"
|
||||
assert message_row["processed_plain_text"] == "你好"
|
||||
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
||||
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
||||
assert action_row["session_id"] == "session-1"
|
||||
assert action_row["action_name"] == "search"
|
||||
assert action_row["action_display_prompt"] == "执行搜索"
|
||||
assert tool_row["session_id"] == "session-1"
|
||||
assert tool_row["tool_name"] == "search"
|
||||
assert tool_row["tool_display_prompt"] == "执行搜索"
|
||||
assert expression_row["session_id"] == "session-1"
|
||||
assert expression_row["modified_by"] == "AI"
|
||||
assert expression_content_list == ["你好呀", "早上好"]
|
||||
assert jargon_session_id_dict == {"session-1": 5}
|
||||
assert jargon_raw_content == ["上分"]
|
||||
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
||||
|
||||
|
||||
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
||||
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=build_default_migration_registry(),
|
||||
resolver=build_default_schema_version_resolver(),
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
reporter_events = reporter_instances[0].events
|
||||
|
||||
assert reporter_events[0] == ("open", None, None, None)
|
||||
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
|
||||
assert reporter_events[-1] == ("close", None, None, None)
|
||||
assert reporter_events.count(("advance", 1, 0, None)) == 6
|
||||
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
|
||||
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
|
||||
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
|
||||
|
||||
|
||||
def test_initialize_database_calls_bootstrapper_before_create_all(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
||||
call_order: List[str] = []
|
||||
|
||||
def _fake_prepare_database() -> DatabaseMigrationState:
|
||||
"""返回测试用迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
||||
"""
|
||||
call_order.append("prepare_database")
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
plan=MigrationPlan(
|
||||
current_version=EMPTY_SCHEMA_VERSION,
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
steps=[],
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_create_all(bind) -> None:
|
||||
"""记录建表调用。
|
||||
|
||||
Args:
|
||||
bind: 传入的数据库绑定对象。
|
||||
"""
|
||||
del bind
|
||||
call_order.append("create_all")
|
||||
|
||||
def _fake_migrate_action_records() -> None:
|
||||
"""记录轻量补迁移调用。"""
|
||||
call_order.append("migrate_action_records")
|
||||
|
||||
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
||||
"""记录迁移收尾调用。
|
||||
|
||||
Args:
|
||||
migration_state: 当前数据库迁移状态。
|
||||
"""
|
||||
del migration_state
|
||||
call_order.append("finalize_database")
|
||||
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False)
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
||||
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
||||
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
|
||||
|
||||
database_module.initialize_database()
|
||||
|
||||
assert call_order == [
|
||||
"prepare_database",
|
||||
"create_all",
|
||||
"migrate_action_records",
|
||||
"finalize_database",
|
||||
]
|
||||
55
pytests/test_maisaka_message_adapter.py
Normal file
55
pytests/test_maisaka_message_adapter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def test_build_message_returns_session_message_with_maisaka_metadata() -> None:
|
||||
timestamp = datetime.now()
|
||||
tool_call = ToolCall(
|
||||
call_id="call-1",
|
||||
func_name="reply",
|
||||
args={"message_id": "msg-1"},
|
||||
)
|
||||
raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")])
|
||||
|
||||
message = build_message(
|
||||
role="assistant",
|
||||
content="展示消息内容",
|
||||
message_kind="perception",
|
||||
source="assistant",
|
||||
tool_call_id="call-1",
|
||||
tool_calls=[tool_call],
|
||||
timestamp=timestamp,
|
||||
message_id="maisaka-msg-1",
|
||||
raw_message=raw_message,
|
||||
display_text="展示消息内容",
|
||||
)
|
||||
|
||||
assert isinstance(message, SessionMessage)
|
||||
assert message.initialized is True
|
||||
assert message.message_id == "maisaka-msg-1"
|
||||
assert message.timestamp == timestamp
|
||||
assert message.processed_plain_text == "展示消息内容"
|
||||
assert message.display_message == "展示消息内容"
|
||||
assert message.raw_message is raw_message
|
||||
|
||||
assert get_message_role(message) == "assistant"
|
||||
assert get_message_kind(message) == "perception"
|
||||
assert get_tool_call_id(message) == "call-1"
|
||||
|
||||
restored_tool_calls = get_tool_calls(message)
|
||||
assert len(restored_tool_calls) == 1
|
||||
assert restored_tool_calls[0].call_id == "call-1"
|
||||
assert restored_tool_calls[0].func_name == "reply"
|
||||
assert restored_tool_calls[0].args == {"message_id": "msg-1"}
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -1831,395 +1832,445 @@ class TestMaiMessages:
|
||||
assert msg.llm_response_content == "new response"
|
||||
|
||||
|
||||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||||
class _FakeHookSupervisor:
|
||||
"""用于 Hook 分发测试的简化 Supervisor。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_name: str,
|
||||
component_registry: Any,
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]],
|
||||
call_log: List[tuple[str, str]],
|
||||
) -> None:
|
||||
"""初始化测试用 Supervisor。
|
||||
|
||||
Args:
|
||||
group_name: 运行时分组名称。
|
||||
component_registry: 组件注册表实例。
|
||||
handlers: 处理器映射,键为 `plugin_id.component_name`。
|
||||
call_log: 记录调用顺序的列表。
|
||||
"""
|
||||
|
||||
self._group_name = group_name
|
||||
self.component_registry = component_registry
|
||||
self._handlers = handlers
|
||||
self._call_log = call_log
|
||||
|
||||
@property
|
||||
def group_name(self) -> str:
|
||||
"""返回当前测试 Supervisor 的分组名称。"""
|
||||
|
||||
return self._group_name
|
||||
|
||||
async def invoke_plugin(
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> SimpleNamespace:
|
||||
"""模拟调用插件组件。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
component_name: 目标组件名称。
|
||||
args: 调用参数。
|
||||
timeout_ms: 超时配置,测试中仅用于保持接口一致。
|
||||
|
||||
Returns:
|
||||
SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。
|
||||
"""
|
||||
|
||||
del method
|
||||
del timeout_ms
|
||||
|
||||
full_name = f"{plugin_id}.{component_name}"
|
||||
handler = self._handlers[full_name]
|
||||
self._call_log.append((plugin_id, component_name))
|
||||
result = handler(dict(args or {}))
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return SimpleNamespace(payload=result)
|
||||
|
||||
|
||||
class TestWorkflowExecutor:
|
||||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||||
# ─── HookDispatcher 测试 ────────────────────────────────
|
||||
|
||||
|
||||
class TestHookDispatcher:
|
||||
"""命名 Hook 分发器测试。"""
|
||||
|
||||
@staticmethod
|
||||
def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
|
||||
"""导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest 的 monkeypatch 工具。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。
|
||||
"""
|
||||
|
||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatcher
|
||||
|
||||
return ComponentRegistry, HookDispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline_completes(self):
|
||||
"""无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""未注册处理器时应直接返回原始参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
executor = WorkflowExecutor(reg)
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "continue"}
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, [])
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert result.return_message == "workflow completed"
|
||||
assert len(ctx.timings) == 6 # 6 stages
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.hook_name == "heart_fc.cycle_start"
|
||||
assert result.kwargs == {"session_id": "s-1"}
|
||||
assert result.aborted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_hook_modifies_message(self):
|
||||
"""blocking hook 可以修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""blocking 处理器可以修改参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"upper",
|
||||
"workflow_step",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
msg = args.get("message", {})
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
|
||||
}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "hello"},
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{
|
||||
"p1.upper": lambda args: {
|
||||
"success": True,
|
||||
"action": "continue",
|
||||
"modified_kwargs": {
|
||||
"session_id": args["session_id"],
|
||||
"text": str(args["text"]).upper(),
|
||||
},
|
||||
}
|
||||
},
|
||||
[],
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "HELLO"
|
||||
assert len(ctx.modification_log) == 1
|
||||
assert ctx.modification_log[0].stage == "pre_process"
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello")
|
||||
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert result.kwargs["text"] == "HELLO"
|
||||
assert result.aborted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_stops_pipeline(self):
|
||||
"""HookResult.ABORT 立即终止 pipeline"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""blocking 处理器的 abort 应阻止后续 blocking 处理器执行。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"blocker",
|
||||
"workflow_step",
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"stopper",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "abort"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "aborted"
|
||||
assert result.stopped_at == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_stage(self):
|
||||
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# high-priority hook 返回 skip_stage
|
||||
reg.register_component(
|
||||
"skipper",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
},
|
||||
)
|
||||
# low-priority hook 不应被执行
|
||||
reg.register_component(
|
||||
"checker",
|
||||
"workflow_step",
|
||||
registry.register_component(
|
||||
"after_stop",
|
||||
"HOOK_HANDLER",
|
||||
"p2",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
call_log: List[tuple[str, str]] = []
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{
|
||||
"p1.stopper": lambda args: {"success": True, "action": "abort"},
|
||||
"p2.after_stop": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1")
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "skipper":
|
||||
return {"hook_result": "skip_stage"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "completed"
|
||||
# 只有 skipper 被调用,checker 被跳过
|
||||
assert call_log == ["skipper"]
|
||||
assert result.aborted is True
|
||||
assert result.stopped_by == "p1.stopper"
|
||||
assert call_log == [("p1", "stopper")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_filter(self):
|
||||
"""filter 条件不匹配时跳过 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_observe_handler_runs_in_background_without_mutation(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""observe 处理器应后台执行且不能影响主流程参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"only_dm",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"filter": {"chat_type": "direct"},
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
# 不匹配 filter —— hook 不应被调用
|
||||
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
|
||||
assert not call_log
|
||||
|
||||
# 匹配 filter —— hook 应被调用
|
||||
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
|
||||
assert call_log == ["only_dm"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_skip(self):
|
||||
"""error_policy=skip 时跳过失败的 hook 继续执行"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"failer",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
"error_policy": "skip",
|
||||
},
|
||||
)
|
||||
reg.register_component(
|
||||
"ok_step",
|
||||
"workflow_step",
|
||||
"p2",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "failer":
|
||||
raise RuntimeError("boom")
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "completed"
|
||||
assert "failer" in call_log
|
||||
assert "ok_step" in call_log
|
||||
assert any("boom" in e for e in ctx.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_abort(self):
|
||||
"""error_policy=abort(默认)时 pipeline 失败"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"failer",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
# error_policy defaults to "abort"
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
raise RuntimeError("fatal")
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "failed"
|
||||
assert result.stopped_at == "ingress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_hooks_concurrent(self):
|
||||
"""non-blocking hook 并发执行,不修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
for i in range(3):
|
||||
reg.register_component(
|
||||
f"nb_{i}",
|
||||
"workflow_step",
|
||||
f"p{i}",
|
||||
{
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||||
|
||||
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||
# non-blocking 的 modified_message 被忽略
|
||||
assert final_msg["plain_text"] == "original"
|
||||
# 给异步 task 时间完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert result.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_tasks_are_retained_until_completion(self):
|
||||
"""execute 返回后,non-blocking task 仍应保持强引用直到执行完成。"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"observer",
|
||||
"workflow_step",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟耗时观察型处理器。"""
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
started.set()
|
||||
await release.wait()
|
||||
return {"hook_result": "continue"}
|
||||
return {
|
||||
"success": True,
|
||||
"action": "abort",
|
||||
"modified_kwargs": {"session_id": "changed"},
|
||||
"custom_result": args["session_id"],
|
||||
}
|
||||
|
||||
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{"p1.observer": observe_handler},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "original"
|
||||
assert result.aborted is False
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert started.is_set()
|
||||
assert len(executor._background_tasks) == 1
|
||||
assert len(dispatcher._background_tasks) == 1
|
||||
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert not executor._background_tasks
|
||||
assert call_log == [("p1", "observer")]
|
||||
assert not dispatcher._background_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_routing(self):
|
||||
"""PLAN 阶段内置命令路由"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""全局排序应先看 order,再看内置/第三方来源。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"help",
|
||||
"command",
|
||||
"p1",
|
||||
{
|
||||
"command_pattern": r"^/help",
|
||||
},
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
builtin_registry = ComponentRegistry()
|
||||
third_registry = ComponentRegistry()
|
||||
builtin_registry.register_component(
|
||||
"builtin_early",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
builtin_registry.register_component(
|
||||
"builtin_normal",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
third_registry.register_component(
|
||||
"third_early",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
third_registry.register_component(
|
||||
"third_normal",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "help":
|
||||
return {"output": "帮助信息"}
|
||||
return {"hook_result": "continue"}
|
||||
call_log: List[tuple[str, str]] = []
|
||||
dispatcher = HookDispatcher()
|
||||
builtin_supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
builtin_registry,
|
||||
{
|
||||
"b1.builtin_early": lambda args: {"success": True, "action": "continue"},
|
||||
"b1.builtin_normal": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
third_supervisor = _FakeHookSupervisor(
|
||||
"third_party",
|
||||
third_registry,
|
||||
{
|
||||
"t1.third_early": lambda args: {"success": True, "action": "continue"},
|
||||
"t1.third_normal": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
|
||||
assert result.status == "completed"
|
||||
assert ctx.matched_command == "p1.help"
|
||||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||||
assert cmd_result is not None
|
||||
assert cmd_result["output"] == "帮助信息"
|
||||
await dispatcher.invoke_hook(
|
||||
"heart_fc.cycle_start",
|
||||
[third_supervisor, builtin_supervisor],
|
||||
cycle_id="c-1",
|
||||
)
|
||||
|
||||
assert call_log == [
|
||||
("b1", "builtin_early"),
|
||||
("t1", "third_early"),
|
||||
("b1", "builtin_normal"),
|
||||
("t1", "third_normal"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_outputs(self):
|
||||
"""stage_outputs 数据在阶段间传递"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""error_policy=abort 时应中止本次 Hook 调用。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# ingress 阶段写入数据
|
||||
reg.register_component(
|
||||
"writer",
|
||||
"workflow_step",
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"failer",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
"error_policy": "abort",
|
||||
},
|
||||
)
|
||||
# pre_process 阶段读取数据
|
||||
reg.register_component(
|
||||
"reader",
|
||||
"workflow_step",
|
||||
"p2",
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""抛出异常以触发 abort 策略。"""
|
||||
|
||||
del args
|
||||
raise RuntimeError("boom")
|
||||
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.aborted is True
|
||||
assert result.stopped_by == "p1.failer"
|
||||
assert any("boom" in error for error in result.errors)
|
||||
assert call_log == [("p1", "failer")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""处理器超时应被记录为错误并继续。"""
|
||||
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"slow",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
"timeout_ms": 10,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "writer":
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"stage_output": {"parsed_intent": "greeting"},
|
||||
}
|
||||
if comp_name == "reader":
|
||||
# 验证 stage_outputs 被传递过来
|
||||
outputs = args.get("stage_outputs", {})
|
||||
ingress_data = outputs.get("ingress", {})
|
||||
assert ingress_data.get("parsed_intent") == "greeting"
|
||||
return {"hook_result": "continue"}
|
||||
return {"hook_result": "continue"}
|
||||
async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟超时处理器。"""
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
|
||||
assert result.status == "completed"
|
||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||
del args
|
||||
await asyncio.sleep(0.05)
|
||||
return {"success": True, "action": "continue"}
|
||||
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.aborted is False
|
||||
assert any("超时" in error for error in result.errors)
|
||||
assert call_log == [("p1", "slow")]
|
||||
|
||||
|
||||
class TestPluginRuntimeHookEntry:
|
||||
"""PluginRuntimeManager 命名 Hook 入口测试。"""
|
||||
|
||||
@staticmethod
|
||||
def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
|
||||
"""导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest 的 monkeypatch 工具。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。
|
||||
"""
|
||||
|
||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
|
||||
return ComponentRegistry, PluginRuntimeManager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_invoke_hook_dispatches_across_supervisors(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。"""
|
||||
|
||||
ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
|
||||
|
||||
builtin_registry = ComponentRegistry()
|
||||
builtin_registry.register_component(
|
||||
"builtin_guard",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
third_registry = ComponentRegistry()
|
||||
third_registry.register_component(
|
||||
"observer",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
|
||||
)
|
||||
|
||||
call_log: List[tuple[str, str]] = []
|
||||
manager = PluginRuntimeManager()
|
||||
manager._started = True
|
||||
manager._builtin_supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
builtin_registry,
|
||||
{"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}},
|
||||
call_log,
|
||||
)
|
||||
manager._third_party_supervisor = _FakeHookSupervisor(
|
||||
"third_party",
|
||||
third_registry,
|
||||
{"t1.observer": lambda args: {"success": True, "action": "continue"}},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert manager.invoke_dispatcher is manager.hook_dispatcher
|
||||
assert result.aborted is False
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert ("b1", "builtin_guard") in call_log
|
||||
|
||||
|
||||
class TestRPCServer:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
||||
import pytest
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.services import send_service
|
||||
|
||||
|
||||
@@ -13,42 +14,18 @@ class _FakePlatformIOManager:
|
||||
"""用于测试的 Platform IO 管理器假对象。"""
|
||||
|
||||
def __init__(self, delivery_batch: Any) -> None:
|
||||
"""初始化假 Platform IO 管理器。
|
||||
|
||||
Args:
|
||||
delivery_batch: 发送时返回的批量回执。
|
||||
"""
|
||||
self._delivery_batch = delivery_batch
|
||||
self.ensure_calls = 0
|
||||
self.sent_messages: List[Dict[str, Any]] = []
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""记录发送管线准备调用次数。"""
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
"""根据消息构造假的路由键。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
|
||||
Returns:
|
||||
Any: 简化后的路由键对象。
|
||||
"""
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
"""记录发送请求并返回预设回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
Any: 预设的批量发送回执。
|
||||
"""
|
||||
self.sent_messages.append(
|
||||
{
|
||||
"message": message,
|
||||
@@ -59,12 +36,7 @@ class _FakePlatformIOManager:
|
||||
return self._delivery_batch
|
||||
|
||||
|
||||
def _build_target_stream() -> BotChatSession:
|
||||
"""构造一个最小可用的目标会话对象。
|
||||
|
||||
Returns:
|
||||
BotChatSession: 测试用会话对象。
|
||||
"""
|
||||
def _build_private_stream() -> BotChatSession:
|
||||
return BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
@@ -73,14 +45,21 @@ def _build_target_stream() -> BotChatSession:
|
||||
)
|
||||
|
||||
|
||||
def _build_group_stream() -> BotChatSession:
|
||||
return BotChatSession(
|
||||
session_id="group-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id="target-group",
|
||||
)
|
||||
|
||||
|
||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
|
||||
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
||||
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream())
|
||||
|
||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
||||
@@ -88,7 +67,6 @@ def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""send service 应将发送职责统一交给 Platform IO。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
@@ -104,7 +82,7 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
@@ -123,7 +101,6 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=False,
|
||||
@@ -144,7 +121,7 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
|
||||
@@ -152,3 +129,63 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
|
||||
assert result is False
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_outbound_message_preserves_bot_sender_and_receiver_user(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
outbound_message = send_service._build_outbound_session_message(
|
||||
message_sequence=MessageSequence(components=[TextComponent(text="你好")]),
|
||||
stream_id="test-session",
|
||||
display_message="你好",
|
||||
)
|
||||
|
||||
assert outbound_message is not None
|
||||
maim_message = await outbound_message.to_maim_message()
|
||||
|
||||
assert maim_message.message_info.user_info is not None
|
||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.group_info is None
|
||||
assert maim_message.message_info.sender_info is not None
|
||||
assert maim_message.message_info.sender_info.user_info is not None
|
||||
assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.receiver_info is not None
|
||||
assert maim_message.message_info.receiver_info.user_info is not None
|
||||
assert maim_message.message_info.receiver_info.user_info.user_id == "target-user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_outbound_message_preserves_bot_sender_and_target_group(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_group_stream() if stream_id == "group-session" else None,
|
||||
)
|
||||
|
||||
outbound_message = send_service._build_outbound_session_message(
|
||||
message_sequence=MessageSequence(components=[TextComponent(text="大家好")]),
|
||||
stream_id="group-session",
|
||||
display_message="大家好",
|
||||
)
|
||||
|
||||
assert outbound_message is not None
|
||||
maim_message = await outbound_message.to_maim_message()
|
||||
|
||||
assert maim_message.message_info.user_info is not None
|
||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.group_info is not None
|
||||
assert maim_message.message_info.group_info.group_id == "target-group"
|
||||
assert maim_message.message_info.receiver_info is not None
|
||||
assert maim_message.message_info.receiver_info.group_info is not None
|
||||
assert maim_message.message_info.receiver_info.group_info.group_id == "target-group"
|
||||
|
||||
@@ -10,6 +10,7 @@ jieba>=0.42.1
|
||||
json-repair>=0.47.6
|
||||
maim-message>=0.6.2
|
||||
maibot-plugin-sdk>=1.2.3,<2.0.0
|
||||
mcp
|
||||
msgpack>=1.1.2
|
||||
numpy>=2.2.6
|
||||
openai>=1.95.0
|
||||
@@ -30,4 +31,4 @@ structlog>=25.4.0
|
||||
tomlkit>=0.13.3
|
||||
typing-extensions
|
||||
uvicorn>=0.35.0
|
||||
watchfiles>=1.1.1
|
||||
watchfiles>=1.1.1
|
||||
|
||||
@@ -189,7 +189,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
logger.info('已处理"%s"', doc_item.get("passage", ""))
|
||||
logger.info(f'已处理"{doc_item.get("passage", "")}"')
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
|
||||
@@ -110,7 +110,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
这里不重复解析子参数,而是直接调用各脚本的 main(),
|
||||
让子脚本保留原有的交互/参数行为。
|
||||
"""
|
||||
logger.info("开始执行操作: %s", action)
|
||||
logger.info(f"开始执行操作: {action}")
|
||||
|
||||
extra_args = extra_args or []
|
||||
|
||||
@@ -162,14 +162,14 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
_warn_if_lpmm_disabled()
|
||||
_with_overridden_argv(extra_args, refresh_lpmm_knowledge_main)
|
||||
else:
|
||||
logger.error("未知操作: %s", action)
|
||||
logger.error(f"未知操作: {action}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断当前操作(Ctrl+C)")
|
||||
except SystemExit:
|
||||
# 子脚本里大量使用 sys.exit,直接透传即可
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error("执行操作 %s 时发生未捕获异常: %s", action, exc)
|
||||
logger.error(f"执行操作 {action} 时发生未捕获异常: {exc}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -442,7 +442,7 @@ def _run_embedding_helper() -> None:
|
||||
try:
|
||||
test_path.rename(archive_path)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error("归档 embedding_model_test.json 失败: %s", exc)
|
||||
logger.error(f"归档 embedding_model_test.json 失败: {exc}")
|
||||
print("[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
|
||||
return
|
||||
|
||||
|
||||
@@ -1,499 +0,0 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
|
||||
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以回复,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识或记忆,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
listening: 倾听对方发言,当你认为对方话才说到一半,发言明显未结束时选择
|
||||
direct_reply: 直接回复对方
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
|
||||
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊,刚刚你已经回复了对方,请根据以下【所有信息】审慎且灵活的决策下一步行动,可以继续发送新消息,可以等待,可以倾听,可以调取知识,甚至可以屏蔽对方:
|
||||
|
||||
【当前对话目标】
|
||||
{goals_str}
|
||||
{knowledge_info_str}
|
||||
|
||||
【最近行动历史概要】
|
||||
{action_history_summary}
|
||||
【上一次行动的详细情况和结果】
|
||||
{last_action_context}
|
||||
【时间和超时提示】
|
||||
{time_since_last_bot_message_info}{timeout_context}
|
||||
【最近的对话记录】(包括你已成功发送的消息 和 新收到的消息)
|
||||
{chat_history_text}
|
||||
|
||||
------
|
||||
可选行动类型以及解释:
|
||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择,对方若提到你不太认识的人名或实体也可以尝试选择
|
||||
wait: 暂时不说话,留给对方交互空间,等待对方回复(尤其是在你刚发言后、或上次发言因重复、发言过多被拒时、或不确定做什么时,这是不错的选择)
|
||||
listening: 倾听对方发言(虽然你刚发过言,但如果对方立刻回复且明显话没说完,可以选择这个)
|
||||
send_new_message: 发送一条新消息继续对话,允许适当的追问、补充、深入话题,或开启相关新话题。**但是避免在因重复被拒后立即使用,也不要在对方没有回复的情况下过多的“消息轰炸”或重复发言**
|
||||
rethink_goal: 思考一个对话目标,当你觉得目前对话需要目标,或当前目标不再适用,或话题卡住时选择。注意私聊的环境是灵活的,有可能需要经常选择
|
||||
end_conversation: 结束对话,对方长时间没回复或者当你觉得对话告一段落时可以选择
|
||||
block_and_ignore: 更加极端的结束对话方式,直接结束对话并在一段时间内无视对方所有发言(屏蔽),当对话让你感到十分不适,或你遭到各类骚扰时选择
|
||||
|
||||
请以JSON格式输出你的决策:
|
||||
{{
|
||||
"action": "选择的行动类型 (必须是上面列表中的一个)",
|
||||
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
# 新增:Prompt(3): 决定是否在结束对话前发送告别语
|
||||
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
|
||||
|
||||
【你们之前的聊天记录】
|
||||
{chat_history_text}
|
||||
|
||||
你觉得你们的对话已经完整结束了吗?有时候,在对话自然结束后再说点什么可能会有点奇怪,但有时也可能需要一条简短的消息来圆满结束。
|
||||
如果觉得确实有必要再发一条简短、自然、符合你人设的告别消息(比如 "好,下次再聊~" 或 "嗯,先这样吧"),就输出 "yes"。
|
||||
如果觉得当前状态下直接结束对话更好,没有必要再发消息,就输出 "no"。
|
||||
|
||||
请以 JSON 格式输出你的选择:
|
||||
{{
|
||||
"say_bye": "yes/no",
|
||||
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
|
||||
}}
|
||||
|
||||
注意:请严格按照 JSON 格式输出,不要包含任何其他内容。"""
|
||||
|
||||
|
||||
# ActionPlanner 类定义,顶格
|
||||
class ActionPlanner:
|
||||
"""行动规划器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo,
|
||||
last_successful_reply_action: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
observation_info: 决策信息
|
||||
conversation_info: 对话信息
|
||||
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 或 'send_new_message' 或 None)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
chat_history = getattr(observation_info, "chat_history", None)
|
||||
if chat_history and len(chat_history) > 0:
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
msg = chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
|
||||
msg_time = msg.get("time")
|
||||
if sender_id == bot_id and msg_time:
|
||||
time_diff = time.time() - msg_time
|
||||
if time_diff < 60.0:
|
||||
time_since_last_bot_message_info = (
|
||||
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
timeout_context = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
last_goal_dict = conversation_info.goal_list[-1]
|
||||
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
|
||||
last_goal_text = last_goal_dict["goal"]
|
||||
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
|
||||
try:
|
||||
timeout_minutes_text = last_goal_text.split(",")[0].replace("你等待了", "")
|
||||
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
except Exception:
|
||||
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
|
||||
)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
try:
|
||||
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
|
||||
if not goals_str:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
|
||||
)
|
||||
goals_str = "- 获取对话目标时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
|
||||
goals_str = "- 构建对话目标时出错。\n"
|
||||
|
||||
# --- 知识信息字符串构建开始 ---
|
||||
knowledge_info_str = "【已获取的相关知识和记忆】\n"
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识,防止 Prompt 过长
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字,避免太长
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
|
||||
)
|
||||
else:
|
||||
# 处理列表里不是字典的异常情况
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge: # 如果 knowledge_list 存在但为空
|
||||
knowledge_info_str += "- 暂无相关知识和记忆。\n"
|
||||
|
||||
else:
|
||||
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
|
||||
knowledge_info_str += "- 暂无相关知识记忆。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
# --- 知识信息字符串构建结束 ---
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str or "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
|
||||
)
|
||||
chat_history_text = "获取聊天记录时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
|
||||
chat_history_text = "处理聊天记录时出错。\n"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
|
||||
# (这部分逻辑不变)
|
||||
action_history_summary = "你最近执行的行动历史:\n"
|
||||
last_action_context = "关于你【上一次尝试】的行动:\n"
|
||||
action_history_list = []
|
||||
try:
|
||||
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
|
||||
action_history_list = conversation_info.done_action[-5:]
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
|
||||
|
||||
if not action_history_list:
|
||||
action_history_summary += "- 还没有执行过行动。\n"
|
||||
last_action_context += "- 这是你规划的第一个行动。\n"
|
||||
else:
|
||||
for i, action_data in enumerate(action_history_list):
|
||||
action_type = "未知"
|
||||
plan_reason = "未知"
|
||||
status = "未知"
|
||||
final_reason = ""
|
||||
action_time = ""
|
||||
|
||||
if isinstance(action_data, dict):
|
||||
action_type = action_data.get("action", "未知")
|
||||
plan_reason = action_data.get("plan_reason", "未知规划原因")
|
||||
status = action_data.get("status", "未知")
|
||||
final_reason = action_data.get("final_reason", "")
|
||||
action_time = action_data.get("time", "")
|
||||
elif isinstance(action_data, tuple):
|
||||
# 假设旧格式兼容
|
||||
if len(action_data) > 0:
|
||||
action_type = action_data[0]
|
||||
if len(action_data) > 1:
|
||||
plan_reason = action_data[1] # 可能是规划原因或最终原因
|
||||
if len(action_data) > 2:
|
||||
status = action_data[2]
|
||||
if status == "recall" and len(action_data) > 3:
|
||||
final_reason = action_data[3]
|
||||
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
|
||||
plan_reason = "成功发送" # 简化显示
|
||||
|
||||
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
|
||||
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
|
||||
action_history_summary += summary_line + "\n"
|
||||
|
||||
if i == len(action_history_list) - 1:
|
||||
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
|
||||
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
|
||||
if status == "done":
|
||||
last_action_context += "- 该行动已【成功执行】。\n"
|
||||
# 记录这次成功的行动类型,供下次决策
|
||||
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
|
||||
elif status == "recall":
|
||||
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
|
||||
if final_reason:
|
||||
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}”\n"
|
||||
else:
|
||||
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
|
||||
# self.last_successful_action_type = None # 行动失败,清除记录
|
||||
else:
|
||||
last_action_context += f"- 该行动当前状态: {status}\n"
|
||||
# self.last_successful_action_type = None # 非完成状态,清除记录
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
|
||||
prompt_template = PROMPT_FOLLOW_UP
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
|
||||
else:
|
||||
prompt_template = PROMPT_INITIAL_REPLY
|
||||
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
|
||||
action_history_summary=action_history_summary,
|
||||
last_action_context=last_action_context,
|
||||
time_since_last_bot_message_info=time_since_last_bot_message_info,
|
||||
timeout_context=timeout_context,
|
||||
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
|
||||
|
||||
# --- 初始行动规划解析 ---
|
||||
success, initial_result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"action",
|
||||
"reason",
|
||||
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因,默认等待"},
|
||||
)
|
||||
|
||||
initial_action = initial_result.get("action", "wait")
|
||||
initial_reason = initial_result.get("reason", "LLM未提供原因,默认等待")
|
||||
|
||||
# 检查是否需要进行结束对话决策 ---
|
||||
if initial_action == "end_conversation":
|
||||
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
|
||||
|
||||
# 使用新的 PROMPT_END_DECISION
|
||||
end_decision_prompt = PROMPT_END_DECISION.format(
|
||||
persona_text=persona_text, # 复用之前的 persona_text
|
||||
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
|
||||
)
|
||||
try:
|
||||
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
|
||||
|
||||
# 解析结束决策的JSON
|
||||
end_success, end_result = get_items_from_json(
|
||||
end_content,
|
||||
self.private_name,
|
||||
"say_bye",
|
||||
"reason",
|
||||
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误,默认不告别"},
|
||||
required_types={"say_bye": str, "reason": str}, # 明确类型
|
||||
)
|
||||
|
||||
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
|
||||
end_decision_reason = end_result.get("reason", "未提供原因")
|
||||
|
||||
if end_success and say_bye_decision == "yes":
|
||||
# 决定要告别,返回新的 'say_goodbye' 动作
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
|
||||
final_action = "say_goodbye"
|
||||
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
|
||||
return final_action, final_reason
|
||||
else:
|
||||
# 决定不告别 (包括解析失败或明确说no)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
|
||||
)
|
||||
# 返回原始的 'end_conversation' 动作
|
||||
final_action = "end_conversation"
|
||||
final_reason = initial_reason # 保持原始的结束理由
|
||||
return final_action, final_reason
|
||||
|
||||
except Exception as end_e:
|
||||
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
|
||||
# 出错时,默认执行原始的结束对话
|
||||
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
|
||||
return "end_conversation", initial_reason # 返回原始动作和原因
|
||||
|
||||
else:
|
||||
action = initial_action
|
||||
reason = initial_reason
|
||||
|
||||
# 验证action类型 (保持不变)
|
||||
valid_actions = [
|
||||
"direct_reply",
|
||||
"send_new_message",
|
||||
"fetch_knowledge",
|
||||
"wait",
|
||||
"listening",
|
||||
"rethink_goal",
|
||||
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
|
||||
"block_and_ignore",
|
||||
"say_goodbye", # 也要验证这个新动作
|
||||
]
|
||||
if action not in valid_actions:
|
||||
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
|
||||
reason = f"(原始行动'{action}'无效,已强制改为wait) {reason}"
|
||||
action = "wait"
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
|
||||
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
|
||||
return action, reason
|
||||
|
||||
except Exception as e:
|
||||
# 外层异常处理保持不变
|
||||
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
|
||||
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"
|
||||
@@ -1,394 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_logger
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from maim_message import UserInfo
|
||||
from src.config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_observer")
|
||||
|
||||
|
||||
def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||
"""Convert Peewee Message model to dict for PFC compatibility
|
||||
|
||||
Args:
|
||||
message: Peewee Messages model instance
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Message dictionary
|
||||
"""
|
||||
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"time": message_timestamp,
|
||||
"chat_id": message.session_id,
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"display_message": message.display_message,
|
||||
"is_mentioned": message.is_mentioned,
|
||||
"is_command": message.is_command,
|
||||
# Add user_info dict for compatibility with existing code
|
||||
"user_info": {
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
"""聊天状态观察器"""
|
||||
|
||||
# 类级别的实例管理
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
ChatObserver: 观察器实例
|
||||
"""
|
||||
if stream_id not in cls._instances:
|
||||
cls._instances[stream_id] = cls(stream_id, private_name)
|
||||
return cls._instances[stream_id]
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化观察器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.last_check_time = None
|
||||
self.last_bot_speak_time = None
|
||||
self.last_user_speak_time = None
|
||||
if stream_id in self._instances:
|
||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._update_event = asyncio.Event() # 触发更新的事件
|
||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||
|
||||
# 通知管理器
|
||||
self.notification_manager = NotificationManager()
|
||||
|
||||
# 冷场检查配置
|
||||
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||
self.last_cold_chat_check: float = time.time()
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
async def check(self) -> bool:
|
||||
"""检查距离上一次观察之后是否有了新消息
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
last_check_time = self.last_check_time or 0.0
|
||||
last_check_dt = datetime.fromtimestamp(last_check_time)
|
||||
with get_db_session() as session:
|
||||
statement = select(Messages).where(
|
||||
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
|
||||
)
|
||||
new_message_exists = session.exec(statement).first() is not None
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
self.last_check_time = time.time()
|
||||
|
||||
return new_message_exists
|
||||
|
||||
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||
"""添加消息到历史记录并发送通知
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
try:
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer", target="observation_info", message=message
|
||||
)
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
|
||||
async def _check_cold_chat(self):
|
||||
"""检查是否处于冷场状态并发送通知"""
|
||||
current_time = time.time()
|
||||
|
||||
# 每10秒检查一次冷场状态
|
||||
if current_time - self.last_cold_chat_check < 10:
|
||||
return
|
||||
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = (
|
||||
True
|
||||
if self.last_message_time is None
|
||||
else (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||
)
|
||||
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
def new_message_after(self, time_point: float) -> bool:
|
||||
"""判断是否在指定时间点后有新消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
|
||||
return False
|
||||
|
||||
has_new = self.last_message_time > time_point
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
|
||||
)
|
||||
return has_new
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
last_message_time = self.last_message_time or 0.0
|
||||
last_message_dt = datetime.fromtimestamp(last_message_time)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
)
|
||||
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
self.last_message_time = new_messages[-1]["time"]
|
||||
|
||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
time_point: 时间戳
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
time_point_dt = datetime.fromtimestamp(time_point)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
.limit(5)
|
||||
)
|
||||
messages = list(session.exec(statement).all())
|
||||
messages.reverse()
|
||||
new_messages = [_message_to_dict(msg) for msg in messages]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
self._update_complete.clear() # 重置完成事件
|
||||
|
||||
# 获取新消息
|
||||
new_messages = await self._fetch_new_messages()
|
||||
|
||||
if new_messages:
|
||||
# 处理新消息
|
||||
for message in new_messages:
|
||||
await self._add_message_to_history(message)
|
||||
|
||||
# 设置完成事件
|
||||
self._update_complete.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self._update_complete.set() # 即使出错也要设置完成事件
|
||||
|
||||
def trigger_update(self):
|
||||
"""触发一次立即更新"""
|
||||
self._update_event.set()
|
||||
|
||||
async def wait_for_update(self, timeout: float = 5.0) -> bool:
|
||||
"""等待更新完成
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功完成更新(False表示超时)
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
|
||||
return False
|
||||
|
||||
def start(self):
|
||||
"""启动观察器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._update_loop())
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
|
||||
|
||||
def stop(self):
|
||||
"""停止观察器"""
|
||||
self._running = False
|
||||
self._update_event.set() # 设置事件以解除等待
|
||||
self._update_complete.set() # 设置完成事件以解除等待
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
|
||||
|
||||
async def process_chat_history(self, messages: list):
|
||||
"""处理聊天历史
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
"""
|
||||
self.update_check_time()
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.bot.qq_account:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
|
||||
continue
|
||||
|
||||
def update_check_time(self):
|
||||
"""更新查看时间"""
|
||||
self.last_check_time = time.time()
|
||||
|
||||
def update_bot_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新机器人说话时间"""
|
||||
self.last_bot_speak_time = speak_time or time.time()
|
||||
|
||||
def update_user_speak_time(self, speak_time: Optional[float] = None):
|
||||
"""更新用户说话时间"""
|
||||
self.last_user_speak_time = speak_time or time.time()
|
||||
|
||||
def get_time_info(self) -> str:
|
||||
"""获取时间信息文本"""
|
||||
current_time = time.time()
|
||||
time_info = ""
|
||||
|
||||
if self.last_bot_speak_time:
|
||||
bot_speak_ago = current_time - self.last_bot_speak_time
|
||||
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}秒"
|
||||
|
||||
if self.last_user_speak_time:
|
||||
user_speak_ago = current_time - self.last_user_speak_time
|
||||
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}秒"
|
||||
|
||||
return time_info
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
|
||||
Args:
|
||||
limit: 获取的最大消息数量,默认50
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 缓存的消息历史列表
|
||||
"""
|
||||
return self.message_cache[-limit:]
|
||||
|
||||
def get_last_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一条消息
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None
|
||||
"""
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[-1]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
@@ -1,290 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, List, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
ACTIVE_CHAT = auto() # 活跃状态
|
||||
BOT_SPEAKING = auto() # 机器人正在说话
|
||||
USER_SPEAKING = auto() # 用户正在说话
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
BOT_SPEAKING = auto() # 机器人说话通知
|
||||
USER_SPEAKING = auto() # 用户说话通知
|
||||
MESSAGE_DELETED = auto() # 消息删除通知
|
||||
USER_JOINED = auto() # 用户加入通知
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
last_speaker: Optional[str] = None
|
||||
message_count: int = 0
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
target: str # 接收者标识
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 按接收者和通知类型存储处理器
|
||||
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||
self._active_states: Set[NotificationType] = set()
|
||||
self._notification_history: List[Notification] = []
|
||||
|
||||
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注册通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识(例如:"pfc")
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
if target not in self._handlers:
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
self._handlers[target][notification_type] = []
|
||||
# print(self._handlers[target][notification_type])
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
# print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
|
||||
Args:
|
||||
target: 接收者标识
|
||||
notification_type: 通知类型
|
||||
handler: 要注销的处理器实例
|
||||
"""
|
||||
if target in self._handlers and notification_type in self._handlers[target]:
|
||||
handlers = self._handlers[target][notification_type]
|
||||
if handler in handlers:
|
||||
handlers.remove(handler)
|
||||
# 如果该类型的处理器列表为空,删除该类型
|
||||
if not handlers:
|
||||
del self._handlers[target][notification_type]
|
||||
# 如果该目标没有任何处理器,删除该目标
|
||||
if not self._handlers[target]:
|
||||
del self._handlers[target]
|
||||
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
if notification.is_active:
|
||||
self._active_states.add(notification.type)
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
"""获取当前活跃的状态"""
|
||||
return self._active_states.copy()
|
||||
|
||||
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
sender: 过滤特定发送者的通知
|
||||
target: 过滤特定接收者的通知
|
||||
limit: 限制返回数量
|
||||
"""
|
||||
history = self._notification_history
|
||||
|
||||
if sender:
|
||||
history = [n for n in history if n.sender == sender]
|
||||
if target:
|
||||
history = [n for n in history if n.target == target]
|
||||
|
||||
if limit is not None:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
return Notification(
|
||||
type=NotificationType.NEW_MESSAGE,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.COLD_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
type=NotificationType.ACTIVE_CHAT,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_state = ChatState.NORMAL
|
||||
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||
self.state_history: list[ChatStateInfo] = []
|
||||
|
||||
def update_state(self, new_state: ChatState, **kwargs):
|
||||
"""更新聊天状态
|
||||
|
||||
Args:
|
||||
new_state: 新的状态
|
||||
**kwargs: 其他状态信息
|
||||
"""
|
||||
self.current_state = new_state
|
||||
self.state_info.state = new_state
|
||||
|
||||
# 更新其他状态信息
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.state_info, key):
|
||||
setattr(self.state_info, key, value)
|
||||
|
||||
# 记录状态历史
|
||||
self.state_history.append(self.state_info)
|
||||
|
||||
def get_current_state_info(self) -> ChatStateInfo:
|
||||
"""获取当前状态信息"""
|
||||
return self.state_info
|
||||
|
||||
def get_state_history(self) -> list[ChatStateInfo]:
|
||||
"""获取状态历史"""
|
||||
return self.state_history
|
||||
|
||||
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||
"""判断是否处于冷场状态
|
||||
|
||||
Args:
|
||||
threshold: 冷场阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否冷场
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return True
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) > threshold
|
||||
|
||||
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||
"""判断是否处于活跃状态
|
||||
|
||||
Args:
|
||||
threshold: 活跃阈值(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否活跃
|
||||
"""
|
||||
if not self.state_info.last_message_time:
|
||||
return False
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
@@ -1,722 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
# from src.config.config import global_config
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
from src.common.logger import get_logger
|
||||
from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from maim_message import UserInfo
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
import traceback
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""对话类,负责管理单个对话的状态和行为"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
"""初始化对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.state = ConversationState.INIT
|
||||
self.should_continue = False
|
||||
self.ignore_until_timestamp: Optional[float] = None
|
||||
|
||||
# 回复相关
|
||||
self.generated_reply = ""
|
||||
|
||||
async def _initialize(self):
|
||||
"""初始化实例,注册所有组件"""
|
||||
|
||||
try:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo(self.private_name)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将 SessionMessage 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
initial_messages_dict: list[dict] = []
|
||||
for msg in initial_messages:
|
||||
user_info = msg.message_info.user_info
|
||||
msg_dict = {
|
||||
"message_id": msg.message_id,
|
||||
"time": msg.timestamp.timestamp(),
|
||||
"chat_id": msg.session_id,
|
||||
"processed_plain_text": msg.processed_plain_text,
|
||||
"display_message": msg.display_message,
|
||||
"is_mentioned": msg.is_mentioned,
|
||||
"is_command": msg.is_command,
|
||||
"user_info": {
|
||||
"user_id": user_info.user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"user_cardname": user_info.user_cardname,
|
||||
"platform": msg.platform,
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages_dict
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages_dict)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg_dict: dict = initial_messages_dict[-1]
|
||||
self.observation_info.last_message_time = last_msg_dict.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg_dict.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg_dict.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages_dict)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
if self.observation_info.last_message_time:
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg_dict # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
except Exception as load_err:
|
||||
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
|
||||
# 出错也要继续,只是没有历史记录而已
|
||||
# 组件准备完成,启动该论对话
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
|
||||
asyncio.create_task(self._plan_and_action_loop())
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
while self.should_continue:
|
||||
# 忽略逻辑
|
||||
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
|
||||
await asyncio.sleep(30)
|
||||
continue
|
||||
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
|
||||
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
|
||||
self.ignore_until_timestamp = None
|
||||
self.should_continue = False
|
||||
continue
|
||||
try:
|
||||
# --- 在规划前记录当前新消息数量 ---
|
||||
initial_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
|
||||
)
|
||||
|
||||
# --- 调用 Action Planner ---
|
||||
# 传递 self.conversation_info.last_successful_reply_action
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
|
||||
)
|
||||
|
||||
# --- 规划后检查是否有 *更多* 新消息到达 ---
|
||||
current_new_message_count = 0
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
current_new_message_count = self.observation_info.new_messages_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
|
||||
)
|
||||
|
||||
if current_new_message_count > initial_new_message_count + 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
|
||||
)
|
||||
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 包含 send_new_message
|
||||
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
|
||||
)
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
else:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
|
||||
)
|
||||
|
||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||
|
||||
# 检查是否需要结束对话 (逻辑不变)
|
||||
goal_ended = False
|
||||
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
|
||||
for goal_item in self.conversation_info.goal_list:
|
||||
if isinstance(goal_item, dict):
|
||||
current_goal = goal_item.get("goal")
|
||||
|
||||
if current_goal == "结束对话":
|
||||
goal_ended = True
|
||||
break
|
||||
|
||||
if goal_ended:
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
|
||||
|
||||
except Exception as loop_err:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.should_continue:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
|
||||
|
||||
def _check_new_messages_after_planning(self):
|
||||
"""检查在规划后是否有新消息"""
|
||||
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
|
||||
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
|
||||
)
|
||||
return False # 或者根据需要抛出错误
|
||||
|
||||
if self.observation_info.new_messages_count > 2:
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
|
||||
)
|
||||
# 如果有新消息,也应该重置上次回复状态
|
||||
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage:
|
||||
"""将消息字典转换为MaiMessage对象"""
|
||||
from datetime import datetime as dt
|
||||
from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
try:
|
||||
user_info_dict = msg_dict.get("user_info", {})
|
||||
user_info = MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
|
||||
msg = MaiMessage(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"),
|
||||
timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())),
|
||||
)
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.platform = user_info_dict.get("platform", "")
|
||||
msg.session_id = self.stream_id
|
||||
msg.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.initialized = True
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史 (逻辑不变)
|
||||
current_action_record = {
|
||||
"action": action,
|
||||
"plan_reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
# 确保 done_action 列表存在
|
||||
if not hasattr(conversation_info, "done_action"):
|
||||
conversation_info.done_action = []
|
||||
conversation_info.done_action.append(current_action_record)
|
||||
action_index = len(conversation_info.done_action) - 1
|
||||
|
||||
action_successful = False # 用于标记动作是否成功完成
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
|
||||
# send_new_message 失败后执行 wait
|
||||
if action == "send_new_message":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复 (调用 generate 时传入 action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="send_new_message"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复 (逻辑不变)
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 send_new_message
|
||||
self.conversation_info.last_successful_reply_action = "send_new_message"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 追问失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 追问失败,下次用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "direct_reply":
|
||||
max_reply_attempts = 3
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
check_reason = "未进行尝试"
|
||||
final_reply_to_send = ""
|
||||
|
||||
while reply_attempt_count < max_reply_attempts and not is_suitable:
|
||||
reply_attempt_count += 1
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
|
||||
)
|
||||
self.state = ConversationState.GENERATING
|
||||
|
||||
# 1. 生成回复
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="direct_reply"
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
|
||||
)
|
||||
|
||||
# 2. 检查回复
|
||||
self.state = ConversationState.CHECKING
|
||||
try:
|
||||
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
|
||||
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1,
|
||||
)
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply
|
||||
break
|
||||
elif need_replan:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
|
||||
)
|
||||
break
|
||||
except Exception as check_err:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
|
||||
)
|
||||
check_reason = f"第 {reply_attempt_count} 次检查过程出错: {check_err}"
|
||||
break
|
||||
|
||||
# 循环结束,处理最终结果
|
||||
if is_suitable:
|
||||
# 检查是否有新消息
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
|
||||
)
|
||||
return # 直接返回,重新规划
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send
|
||||
# --- 在这里调用 _send_reply ---
|
||||
await self._send_reply() # <--- 调用恢复后的函数
|
||||
|
||||
# 更新状态: 标记上次成功是 direct_reply
|
||||
self.conversation_info.last_successful_reply_action = "direct_reply"
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
elif need_replan:
|
||||
# 打回动作决策
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
|
||||
)
|
||||
|
||||
else:
|
||||
# 首次回复失败
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
|
||||
)
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
|
||||
)
|
||||
# 重置状态: 首次回复失败,下次还是用初始 prompt
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
|
||||
# 执行 Wait 操作 (保持原有逻辑)
|
||||
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
|
||||
self.state = ConversationState.WAITING
|
||||
await self.waiter.wait(self.conversation_info)
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
conversation_info.done_action.append(wait_action_record)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
knowledge_query = reason
|
||||
try:
|
||||
# 检查 knowledge_fetcher 是否存在
|
||||
if not hasattr(self, "knowledge_fetcher"):
|
||||
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
|
||||
raise AttributeError("KnowledgeFetcher not initialized")
|
||||
|
||||
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
|
||||
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
|
||||
if knowledge:
|
||||
# 确保 knowledge_list 存在
|
||||
if not hasattr(conversation_info, "knowledge_list"):
|
||||
conversation_info.knowledge_list = []
|
||||
conversation_info.knowledge_list.append(
|
||||
{"query": knowledge_query, "knowledge": knowledge, "source": source}
|
||||
)
|
||||
action_successful = True
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "rethink_goal":
|
||||
self.state = ConversationState.RETHINKING
|
||||
try:
|
||||
# 检查 goal_analyzer 是否存在
|
||||
if not hasattr(self, "goal_analyzer"):
|
||||
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
|
||||
raise AttributeError("GoalAnalyzer not initialized")
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
action_successful = True
|
||||
except Exception as rethink_err:
|
||||
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
await self.waiter.wait_listening(conversation_info)
|
||||
action_successful = True # Listening 完成就算成功
|
||||
except Exception as listen_err:
|
||||
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
elif action == "say_goodbye":
|
||||
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
|
||||
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
|
||||
try:
|
||||
# 1. 生成告别语 (使用 'say_goodbye' action_type)
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info, conversation_info, action_type="say_goodbye"
|
||||
)
|
||||
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
|
||||
|
||||
# 2. 直接发送告别语 (不经过检查)
|
||||
if self.generated_reply: # 确保生成了内容
|
||||
await self._send_reply() # 调用发送方法
|
||||
# 发送成功后,标记动作成功
|
||||
action_successful = True
|
||||
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": "未能生成告别语内容"}
|
||||
)
|
||||
|
||||
# 3. 无论是否发送成功,都准备结束对话
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
|
||||
|
||||
except Exception as goodbye_err:
|
||||
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
# 即使出错,也结束对话
|
||||
self.should_continue = False
|
||||
action_successful = False # 标记动作失败
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
|
||||
)
|
||||
|
||||
elif action == "end_conversation":
|
||||
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
|
||||
self.should_continue = False
|
||||
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
|
||||
action_successful = True # 标记这个指令本身是成功的
|
||||
|
||||
elif action == "block_and_ignore":
|
||||
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
|
||||
ignore_duration_seconds = 10 * 60
|
||||
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
|
||||
)
|
||||
self.state = ConversationState.IGNORED
|
||||
action_successful = True # 标记动作成功
|
||||
|
||||
else: # 对应 'wait' 动作
|
||||
self.state = ConversationState.WAITING
|
||||
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
|
||||
try:
|
||||
# 检查 waiter 是否存在
|
||||
if not hasattr(self, "waiter"):
|
||||
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
|
||||
raise AttributeError("Waiter not initialized")
|
||||
_timeout_occurred = await self.waiter.wait(self.conversation_info)
|
||||
action_successful = True # Wait 完成就算成功
|
||||
except Exception as wait_err:
|
||||
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
|
||||
)
|
||||
self.conversation_info.last_successful_reply_action = None # 重置状态
|
||||
|
||||
# --- 更新 Action History 状态 ---
|
||||
# 只有当动作本身成功时,才更新状态为 done
|
||||
if action_successful:
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
|
||||
if action not in ["direct_reply", "send_new_message"]:
|
||||
self.conversation_info.last_successful_reply_action = None
|
||||
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
|
||||
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
|
||||
|
||||
async def _send_reply(self):
|
||||
"""发送回复"""
|
||||
if not self.generated_reply:
|
||||
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
|
||||
return
|
||||
|
||||
try:
|
||||
_current_time = time.time()
|
||||
reply_content = self.generated_reply
|
||||
|
||||
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
|
||||
if not hasattr(self, "direct_sender") or not self.direct_sender:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
|
||||
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
|
||||
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
|
||||
# self.chat_observer.trigger_update()
|
||||
# if not await self.chat_observer.wait_for_update():
|
||||
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
|
||||
|
||||
self.state = ConversationState.ANALYZING # 更新状态
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
|
||||
self.state = ConversationState.ANALYZING
|
||||
|
||||
async def _send_timeout_message(self):
|
||||
"""发送超时结束消息"""
|
||||
try:
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")
|
||||
@@ -1,10 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action: list = []
|
||||
self.goal_list: list = []
|
||||
self.knowledge_list: list = []
|
||||
self.memory_list: list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
@@ -1,61 +0,0 @@
|
||||
"""PFC 侧消息发送封装。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.services import send_service as send_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器。"""
|
||||
|
||||
def __init__(self, private_name: str) -> None:
|
||||
"""初始化直接消息发送器。
|
||||
|
||||
Args:
|
||||
private_name: 当前私聊实例的名称。
|
||||
"""
|
||||
self.private_name = private_name
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: BotChatSession,
|
||||
content: str,
|
||||
reply_to_message: Optional[MaiMessage] = None,
|
||||
) -> None:
|
||||
"""发送文本消息到聊天流。
|
||||
|
||||
Args:
|
||||
chat_stream: 目标聊天会话。
|
||||
content: 待发送的文本内容。
|
||||
reply_to_message: 可选的引用回复锚点消息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当消息发送失败时抛出。
|
||||
"""
|
||||
try:
|
||||
sent = await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=chat_stream.session_id,
|
||||
set_reply=reply_to_message is not None,
|
||||
reply_message=reply_to_message,
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
if sent:
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
return
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
raise RuntimeError("消息发送失败")
|
||||
except Exception as exc:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}")
|
||||
raise
|
||||
@@ -1,429 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo as MaiUserInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
||||
|
||||
def dict_to_session_message(msg_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""Convert PFC dict format to SessionMessage object.
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
Returns:
|
||||
SessionMessage object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
timestamp = msg_dict.get("time", 0.0)
|
||||
platform = user_info_dict.get("platform", "")
|
||||
message = SessionMessage(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
timestamp=datetime.fromtimestamp(timestamp),
|
||||
platform=platform,
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
)
|
||||
message.session_id = msg_dict.get("chat_id", "")
|
||||
message.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
message.display_message = msg_dict.get("display_message", "")
|
||||
message.is_mentioned = msg_dict.get("is_mentioned", False)
|
||||
message.is_command = msg_dict.get("is_command", False)
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
"""ObservationInfo的通知处理器"""
|
||||
|
||||
def __init__(self, observation_info: "ObservationInfo", private_name: str):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
observation_info: 要更新的ObservationInfo实例
|
||||
private_name: 私聊对象的名称,用于日志记录
|
||||
"""
|
||||
self.observation_info = observation_info
|
||||
# 将 private_name 存储在 handler 实例中
|
||||
self.private_name = private_name
|
||||
|
||||
async def handle_notification(self, notification: Notification): # 添加类型提示
|
||||
# 获取通知类型和数据
|
||||
notification_type = notification.type
|
||||
data = notification.data
|
||||
|
||||
try: # 添加错误处理块
|
||||
if notification_type == NotificationType.NEW_MESSAGE:
|
||||
# 处理新消息通知
|
||||
# logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释
|
||||
message_id = data.get("message_id")
|
||||
processed_plain_text = data.get("processed_plain_text")
|
||||
detailed_plain_text = data.get("detailed_plain_text")
|
||||
user_info_dict = data.get("user_info") # 先获取字典
|
||||
time_value = data.get("time")
|
||||
|
||||
# 确保 user_info 是字典类型再创建 UserInfo 对象
|
||||
user_info = None
|
||||
if isinstance(user_info_dict, dict):
|
||||
try:
|
||||
user_info = UserInfo.from_dict(user_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}"
|
||||
)
|
||||
# 可以选择在这里返回或记录错误,避免后续代码出错
|
||||
return
|
||||
elif user_info_dict is not None:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}"
|
||||
)
|
||||
# 根据需要处理非字典情况,这里暂时返回
|
||||
return
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
"detailed_plain_text": detailed_plain_text,
|
||||
"user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理
|
||||
"time": time_value,
|
||||
}
|
||||
# 传递 UserInfo 对象(如果成功创建)或原始字典
|
||||
await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象
|
||||
|
||||
elif notification_type == NotificationType.COLD_CHAT:
|
||||
# 处理冷场通知
|
||||
is_cold = data.get("is_cold", False)
|
||||
await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用
|
||||
|
||||
elif notification_type == NotificationType.ACTIVE_CHAT:
|
||||
# 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理)
|
||||
is_active = data.get("is_active", False)
|
||||
self.observation_info.is_cold = not is_active
|
||||
|
||||
elif notification_type == NotificationType.BOT_SPEAKING:
|
||||
# 处理机器人说话通知 (按需实现)
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_bot_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.USER_SPEAKING:
|
||||
# 处理用户说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_user_speak_time = time.time()
|
||||
|
||||
elif notification_type == NotificationType.MESSAGE_DELETED:
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
# 从 unprocessed_messages 中移除被删除的消息
|
||||
original_count = len(self.observation_info.unprocessed_messages)
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
if len(self.observation_info.unprocessed_messages) < original_count:
|
||||
logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})")
|
||||
|
||||
elif notification_type == NotificationType.USER_JOINED:
|
||||
# 处理用户加入通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.add(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.USER_LEFT:
|
||||
# 处理用户离开通知 (如果适用私聊场景)
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.discard(str(user_id)) # 确保是字符串
|
||||
|
||||
elif notification_type == NotificationType.ERROR:
|
||||
# 处理错误通知
|
||||
error_msg = data.get("error", "未提供错误信息")
|
||||
logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细堆栈信息
|
||||
|
||||
|
||||
# @dataclass <-- 这个,不需要了(递黄瓜)
|
||||
class ObservationInfo:
|
||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)"""
|
||||
|
||||
# 类型提示保留,可用于文档和静态分析
|
||||
private_name: str
|
||||
chat_history: List[Dict[str, Any]]
|
||||
chat_history_str: str
|
||||
unprocessed_messages: List[Dict[str, Any]]
|
||||
active_users: Set[str]
|
||||
last_bot_speak_time: Optional[float]
|
||||
last_user_speak_time: Optional[float]
|
||||
last_message_time: Optional[float]
|
||||
last_message_id: Optional[str]
|
||||
last_message_content: str
|
||||
last_message_sender: Optional[str]
|
||||
bot_id: Optional[str]
|
||||
chat_history_count: int
|
||||
new_messages_count: int
|
||||
cold_chat_start_time: Optional[float]
|
||||
cold_chat_duration: float
|
||||
is_typing: bool
|
||||
is_cold_chat: bool
|
||||
changed: bool
|
||||
chat_observer: Optional[ChatObserver]
|
||||
handler: Optional[ObservationInfoHandler]
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
"""
|
||||
手动初始化 ObservationInfo 的所有实例变量。
|
||||
"""
|
||||
|
||||
# 接收的参数
|
||||
self.private_name: str = private_name
|
||||
|
||||
# data_list
|
||||
self.chat_history: List[Dict[str, Any]] = []
|
||||
self.chat_history_str: str = ""
|
||||
self.unprocessed_messages: List[Dict[str, Any]] = []
|
||||
self.active_users: Set[str] = set()
|
||||
|
||||
# data
|
||||
self.last_bot_speak_time: Optional[float] = None
|
||||
self.last_user_speak_time: Optional[float] = None
|
||||
self.last_message_time: Optional[float] = None
|
||||
self.last_message_id: Optional[str] = None
|
||||
self.last_message_content: str = ""
|
||||
self.last_message_sender: Optional[str] = None
|
||||
self.bot_id: Optional[str] = None
|
||||
self.chat_history_count: int = 0
|
||||
self.new_messages_count: int = 0
|
||||
self.cold_chat_start_time: Optional[float] = None
|
||||
self.cold_chat_duration: float = 0.0
|
||||
|
||||
# state
|
||||
self.is_typing: bool = False
|
||||
self.is_cold_chat: bool = False
|
||||
self.changed: bool = False
|
||||
|
||||
# 关联对象
|
||||
self.chat_observer: Optional[ChatObserver] = None
|
||||
|
||||
self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name)
|
||||
|
||||
def bind_to_chat_observer(self, chat_observer: ChatObserver):
|
||||
"""绑定到指定的chat_observer
|
||||
|
||||
Args:
|
||||
chat_observer: 要绑定的 ChatObserver 实例
|
||||
"""
|
||||
if self.chat_observer:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver")
|
||||
return
|
||||
|
||||
self.chat_observer = chat_observer
|
||||
try:
|
||||
if not self.handler: # 确保 handler 已经被创建
|
||||
logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!")
|
||||
self.chat_observer = None # 重置,防止后续错误
|
||||
return
|
||||
|
||||
# 注册关心的通知类型
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 可以根据需要注册更多通知类型
|
||||
# self.chat_observer.notification_manager.register_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}")
|
||||
self.chat_observer = None # 绑定失败,重置
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if (
|
||||
self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler
|
||||
): # 增加 handler 检查
|
||||
try:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
# 如果注册了其他类型,也要在这里注销
|
||||
# self.chat_observer.notification_manager.unregister_handler(
|
||||
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
|
||||
# )
|
||||
logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}")
|
||||
finally: # 确保 chat_observer 被重置
|
||||
self.chat_observer = None
|
||||
else:
|
||||
logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置")
|
||||
|
||||
# 修改:update_from_message 接收 UserInfo 对象
|
||||
async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]):
|
||||
"""从消息更新信息
|
||||
|
||||
Args:
|
||||
message: 消息数据字典
|
||||
user_info: 解析后的 UserInfo 对象 (可能为 None)
|
||||
"""
|
||||
message_time = message.get("time")
|
||||
message_id = message.get("message_id")
|
||||
processed_text = message.get("processed_plain_text", "")
|
||||
|
||||
# 只有在新消息到达时才更新 last_message 相关信息
|
||||
if message_time and message_time > (self.last_message_time or 0):
|
||||
self.last_message_time = message_time
|
||||
self.last_message_id = message_id
|
||||
self.last_message_content = processed_text
|
||||
# 重置冷场计时器
|
||||
self.is_cold_chat = False
|
||||
self.cold_chat_start_time = None
|
||||
self.cold_chat_duration = 0.0
|
||||
|
||||
if user_info:
|
||||
sender_id = str(user_info.user_id) # 确保是字符串
|
||||
self.last_message_sender = sender_id
|
||||
# 更新发言时间
|
||||
if sender_id == self.bot_id:
|
||||
self.last_bot_speak_time = message_time
|
||||
else:
|
||||
self.last_user_speak_time = message_time
|
||||
self.active_users.add(sender_id) # 用户发言则认为其活跃
|
||||
else:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}"
|
||||
)
|
||||
self.last_message_sender = None # 发送者未知
|
||||
|
||||
# 将原始消息字典添加到未处理列表
|
||||
self.unprocessed_messages.append(message)
|
||||
self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}")
|
||||
self.update_changed() # 标记状态已改变
|
||||
else:
|
||||
# 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告
|
||||
pass
|
||||
# logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}")
|
||||
|
||||
def update_changed(self):
|
||||
"""标记状态已改变,并重置标记"""
|
||||
# logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)")
|
||||
self.changed = True
|
||||
|
||||
async def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||
"""更新冷场状态
|
||||
|
||||
Args:
|
||||
is_cold: 是否处于冷场状态
|
||||
current_time: 当前时间戳
|
||||
"""
|
||||
if is_cold != self.is_cold_chat: # 仅在状态变化时更新
|
||||
self.is_cold_chat = is_cold
|
||||
if is_cold:
|
||||
# 进入冷场状态
|
||||
self.cold_chat_start_time = (
|
||||
self.last_message_time or current_time
|
||||
) # 从最后消息时间开始算,或从当前时间开始
|
||||
logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}")
|
||||
else:
|
||||
# 结束冷场状态
|
||||
if self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f} 秒")
|
||||
self.cold_chat_start_time = None # 重置开始时间
|
||||
self.update_changed() # 状态变化,标记改变
|
||||
|
||||
# 即使状态没变,如果是冷场状态,也更新持续时间
|
||||
if self.is_cold_chat and self.cold_chat_start_time:
|
||||
self.cold_chat_duration = current_time - self.cold_chat_start_time
|
||||
|
||||
def get_active_duration(self) -> float:
|
||||
"""获取当前活跃时长 (距离最后一条消息的时间)
|
||||
|
||||
Returns:
|
||||
float: 最后一条消息到现在的时长(秒)
|
||||
"""
|
||||
if not self.last_message_time:
|
||||
return 0.0
|
||||
return time.time() - self.last_message_time
|
||||
|
||||
def get_user_response_time(self) -> Optional[float]:
|
||||
"""获取用户最后响应时间 (距离用户最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None
|
||||
"""
|
||||
if not self.last_user_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_user_speak_time
|
||||
|
||||
def get_bot_response_time(self) -> Optional[float]:
|
||||
"""获取机器人最后响应时间 (距离机器人最后发言的时间)
|
||||
|
||||
Returns:
|
||||
Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None
|
||||
"""
|
||||
if not self.last_bot_speak_time:
|
||||
return None
|
||||
return time.time() - self.last_bot_speak_time
|
||||
|
||||
async def clear_unprocessed_messages(self):
|
||||
"""将未处理消息移入历史记录,并更新相关状态"""
|
||||
if not self.unprocessed_messages:
|
||||
return # 没有未处理消息,直接返回
|
||||
|
||||
# logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...")
|
||||
# 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长)
|
||||
max_history_len = 100 # 示例:最多保留100条历史记录
|
||||
self.chat_history.extend(self.unprocessed_messages)
|
||||
if len(self.chat_history) > max_history_len:
|
||||
self.chat_history = self.chat_history[-max_history_len:]
|
||||
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in history_slice_for_str]
|
||||
self.chat_history_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}")
|
||||
self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示
|
||||
|
||||
# 清空未处理消息列表和计数
|
||||
# cleared_count = len(self.unprocessed_messages)
|
||||
self.unprocessed_messages.clear()
|
||||
self.new_messages_count = 0
|
||||
# self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断
|
||||
|
||||
self.chat_history_count = len(self.chat_history) # 更新历史记录总数
|
||||
# logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。")
|
||||
|
||||
self.update_changed() # 状态改变
|
||||
@@ -1,361 +0,0 @@
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="conversation_goal")
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.nick_name = global_config.bot.alias_names
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
|
||||
# 多目标存储结构
|
||||
self.goals = [] # 存储多个目标
|
||||
self.max_goals = 3 # 同时保持的最大目标数量
|
||||
self.current_goal_and_reason = None
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
|
||||
"""分析对话历史并设定目标
|
||||
|
||||
Args:
|
||||
conversation_info: 对话信息
|
||||
observation_info: 观察信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: (目标, 方法, 原因)
|
||||
"""
|
||||
# 构建对话目标
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
goals_str += goal_str
|
||||
else:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# 构建action历史文本
|
||||
action_history_list = conversation_info.done_action
|
||||
action_history_text = "你之前做的事情是:"
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
{action_history_text}
|
||||
当前对话目标:
|
||||
{goals_str}
|
||||
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请分析当前对话并确定最适合的对话目标。你可以:
|
||||
1. 保持现有目标不变
|
||||
2. 修改现有目标
|
||||
3. 添加新目标
|
||||
4. 删除不再相关的目标
|
||||
5. 如果你想结束对话,请设置一个目标,目标goal为"结束对话",原因reasoning为你希望结束对话
|
||||
|
||||
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
|
||||
1. goal: 对话目标(简短的一句话)
|
||||
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
{{
|
||||
"goal": "回答用户关于Python编程的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}},
|
||||
{{
|
||||
"goal": "回答用户关于python安装的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}}
|
||||
]"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
|
||||
content = ""
|
||||
|
||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal",
|
||||
"reasoning",
|
||||
required_types={"goal": str, "reasoning": str},
|
||||
allow_array=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 判断结果是单个字典还是字典列表
|
||||
if isinstance(result, list):
|
||||
# 清空现有目标列表并添加新目标
|
||||
conversation_info.goal_list = []
|
||||
for item in result:
|
||||
conversation_info.goal_list.append(item)
|
||||
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
|
||||
else:
|
||||
# 单个目标的情况
|
||||
conversation_info.goal_list.append(result)
|
||||
goal_value = result.get("goal", "")
|
||||
reasoning_value = result.get("reasoning", "")
|
||||
return goal_value, "", reasoning_value
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return "", "", ""
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
Args:
|
||||
new_goal: 新的目标
|
||||
method: 实现目标的方法
|
||||
reasoning: 目标的原因
|
||||
"""
|
||||
# 检查新目标是否与现有目标相似
|
||||
for i, (existing_goal, _, _) in enumerate(self.goals):
|
||||
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
# 更新现有目标
|
||||
self.goals[i] = (new_goal, method, reasoning)
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
self.goals.insert(0, self.goals.pop(i))
|
||||
return
|
||||
|
||||
# 添加新目标到列表前面
|
||||
self.goals.insert(0, (new_goal, method, reasoning))
|
||||
|
||||
# 限制目标数量
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
|
||||
"""
|
||||
return self.goals.copy()
|
||||
|
||||
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取除了当前主要目标外的其他备选目标
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 备选目标列表
|
||||
"""
|
||||
if len(self.goals) <= 1:
|
||||
return []
|
||||
return self.goals[1:].copy()
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
session_messages = [dict_to_session_message(m) for m in messages]
|
||||
chat_history_text = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
# ===> Persona 文本构建结束 <===
|
||||
|
||||
# --- 修改 Prompt 字符串,使用 persona_text ---
|
||||
prompt = f"""{persona_text}。现在你在参与一场QQ聊天,
|
||||
当前对话目标:{goal}
|
||||
产生该对话目标的原因:{reasoning}
|
||||
|
||||
请分析以下聊天记录,并根据你的性格特征评估该目标是否已经达到,或者你是否希望停止该次对话。
|
||||
聊天记录:
|
||||
{chat_history_text}
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. goal_achieved: 对话目标是否已经达到(true/false)
|
||||
2. stop_conversation: 是否希望停止该次对话(true/false)
|
||||
3. reason: 为什么希望停止该次对话(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"goal_achieved": true,
|
||||
"stop_conversation": false,
|
||||
"reason": "虽然目标已达成,但对话仍然有继续的价值"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
|
||||
|
||||
# 尝试解析JSON
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
self.private_name,
|
||||
"goal_achieved",
|
||||
"stop_conversation",
|
||||
"reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON")
|
||||
return False, False, "解析结果失败"
|
||||
|
||||
goal_achieved = result["goal_achieved"]
|
||||
stop_conversation = result["stop_conversation"]
|
||||
reason = result["reason"]
|
||||
|
||||
return goal_achieved, stop_conversation, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}")
|
||||
return False, False, f"分析出错: {str(e)}"
|
||||
|
||||
|
||||
# 先注释掉,万一以后出问题了还能开回来(((
|
||||
# class DirectMessageSender:
|
||||
# """直接发送消息到平台的发送器"""
|
||||
|
||||
# def __init__(self, private_name: str):
|
||||
# self.logger = get_module_logger("direct_sender")
|
||||
# self.storage = MessageStorage()
|
||||
# self.private_name = private_name
|
||||
|
||||
# async def send_via_ws(self, message: MessageSending) -> None:
|
||||
# try:
|
||||
# await global_api.send_message(message)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
|
||||
# async def send_message(
|
||||
# self,
|
||||
# chat_stream: ChatStream,
|
||||
# content: str,
|
||||
# reply_to_message: Optional[Message] = None,
|
||||
# ) -> None:
|
||||
# """直接发送消息到平台
|
||||
|
||||
# Args:
|
||||
# chat_stream: 聊天流
|
||||
# content: 消息内容
|
||||
# reply_to_message: 要回复的消息
|
||||
# """
|
||||
# # 构建消息对象
|
||||
# message_segment = Seg(type="text", data=content)
|
||||
# bot_user_info = UserInfo(
|
||||
# user_id=global_config.BOT_QQ,
|
||||
# user_nickname=global_config.BOT_NICKNAME,
|
||||
# platform=chat_stream.platform,
|
||||
# )
|
||||
|
||||
# message = MessageSending(
|
||||
# message_id=f"dm{round(time.time(), 2)}",
|
||||
# chat_stream=chat_stream,
|
||||
# bot_user_info=bot_user_info,
|
||||
# sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
# message_segment=message_segment,
|
||||
# reply=reply_to_message,
|
||||
# is_head=True,
|
||||
# is_emoji=False,
|
||||
# thinking_start_time=time.time(),
|
||||
# )
|
||||
|
||||
# # 处理消息
|
||||
# await message.process()
|
||||
|
||||
# _message_json = message.to_dict()
|
||||
|
||||
# # 发送消息
|
||||
# try:
|
||||
# await self.send_via_ws(message)
|
||||
# await self.storage.store_message(message, chat_stream)
|
||||
# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
|
||||
@@ -1,115 +0,0 @@
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from src.common.logger import get_logger
|
||||
from .conversation import Conversation
|
||||
import traceback
|
||||
|
||||
logger = get_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
"""PFC对话管理器,负责管理所有对话实例"""
|
||||
|
||||
# 单例模式
|
||||
_instance = None
|
||||
|
||||
# 会话实例管理
|
||||
_instances: Dict[str, Conversation] = {}
|
||||
_initializing: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "PFCManager":
|
||||
"""获取管理器单例
|
||||
|
||||
Returns:
|
||||
PFCManager: 管理器实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = PFCManager()
|
||||
return cls._instance
|
||||
|
||||
async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]:
|
||||
"""获取或创建对话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
private_name: 私聊名称
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 对话实例,创建失败则返回None
|
||||
"""
|
||||
# 检查是否已经有实例
|
||||
if stream_id in self._initializing and self._initializing[stream_id]:
|
||||
logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}")
|
||||
return None
|
||||
|
||||
if stream_id in self._instances and self._instances[stream_id].should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return self._instances[stream_id]
|
||||
if stream_id in self._instances:
|
||||
instance = self._instances[stream_id]
|
||||
if (
|
||||
hasattr(instance, "ignore_until_timestamp")
|
||||
and instance.ignore_until_timestamp
|
||||
and time.time() < instance.ignore_until_timestamp
|
||||
):
|
||||
logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}")
|
||||
# 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵?
|
||||
# 还是返回 None 吧喵。
|
||||
return None
|
||||
|
||||
# 检查 should_continue 状态
|
||||
if instance.should_continue:
|
||||
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
|
||||
return instance
|
||||
# else: 实例存在但不应继续
|
||||
try:
|
||||
# 创建新实例
|
||||
logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}")
|
||||
self._initializing[stream_id] = True
|
||||
# 创建实例
|
||||
conversation_instance = Conversation(stream_id, private_name)
|
||||
self._instances[stream_id] = conversation_instance
|
||||
|
||||
# 启动实例初始化
|
||||
await self._initialize_conversation(conversation_instance)
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
return conversation_instance
|
||||
|
||||
async def _initialize_conversation(self, conversation: Conversation):
|
||||
"""初始化会话实例
|
||||
|
||||
Args:
|
||||
conversation: 要初始化的会话实例
|
||||
"""
|
||||
stream_id = conversation.stream_id
|
||||
private_name = conversation.private_name
|
||||
|
||||
try:
|
||||
logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}")
|
||||
# 启动初始化流程
|
||||
await conversation._initialize()
|
||||
|
||||
# 标记初始化完成
|
||||
self._initializing[stream_id] = False
|
||||
|
||||
logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}")
|
||||
logger.error(f"[私聊][{private_name}]{traceback.format_exc()}")
|
||||
# 清理失败的初始化
|
||||
|
||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||
"""获取已存在的会话实例
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Optional[Conversation]: 会话实例,不存在则返回None
|
||||
"""
|
||||
return self._instances.get(stream_id)
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""对话状态"""
|
||||
|
||||
INIT = "初始化"
|
||||
RETHINKING = "重新思考"
|
||||
ANALYZING = "分析历史"
|
||||
PLANNING = "规划目标"
|
||||
GENERATING = "生成回复"
|
||||
CHECKING = "检查回复"
|
||||
SENDING = "发送消息"
|
||||
FETCHING = "获取知识"
|
||||
WAITING = "等待"
|
||||
LISTENING = "倾听"
|
||||
ENDED = "结束"
|
||||
JUDGING = "判断"
|
||||
IGNORED = "屏蔽"
|
||||
|
||||
|
||||
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
||||
@@ -1,127 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
content: str,
|
||||
private_name: str,
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
allow_array: bool = True,
|
||||
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
Args:
|
||||
content: 包含JSON的文本
|
||||
private_name: 私聊名称
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
allow_array: 是否允许解析JSON数组
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
|
||||
# 设置默认值
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
# 首先尝试解析为JSON数组
|
||||
if allow_array:
|
||||
try:
|
||||
# 尝试找到文本中的JSON数组
|
||||
array_pattern = r"\[[\s\S]*\]"
|
||||
array_match = re.search(array_pattern, content)
|
||||
if array_match:
|
||||
array_content = array_match.group()
|
||||
json_array = json.loads(array_content)
|
||||
|
||||
# 确认是数组类型
|
||||
if isinstance(json_array, list):
|
||||
# 验证数组中的每个项目是否包含所有必需字段
|
||||
valid_items = []
|
||||
for item in json_array:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有所有必需字段
|
||||
if all(field in item for field in items):
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
type_valid = True
|
||||
for field, expected_type in required_types.items():
|
||||
if field in item and not isinstance(item[field], expected_type):
|
||||
type_valid = False
|
||||
break
|
||||
|
||||
if not type_valid:
|
||||
continue
|
||||
|
||||
# 验证字符串字段不为空
|
||||
string_valid = True
|
||||
for field in items:
|
||||
if isinstance(item[field], str) and not item[field].strip():
|
||||
string_valid = False
|
||||
break
|
||||
|
||||
if not string_valid:
|
||||
continue
|
||||
|
||||
valid_items.append(item)
|
||||
|
||||
if valid_items:
|
||||
return True, valid_items
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"[私聊][{private_name}]JSON数组解析失败,尝试解析单个JSON对象")
|
||||
except Exception as e:
|
||||
logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}")
|
||||
|
||||
# 尝试解析JSON对象
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
json_data = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败")
|
||||
return False, result
|
||||
else:
|
||||
logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON")
|
||||
return False, result
|
||||
|
||||
# 提取字段
|
||||
for item in items:
|
||||
if item in json_data:
|
||||
result[item] = json_data[item]
|
||||
|
||||
# 验证必需字段
|
||||
if not all(item in result for item in items):
|
||||
logger.error(f"[私聊][{private_name}]JSON缺少必要字段,实际内容: {json_data}")
|
||||
return False, result
|
||||
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
for field, expected_type in required_types.items():
|
||||
if field in result and not isinstance(result[field], expected_type):
|
||||
logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型")
|
||||
return False, result
|
||||
|
||||
# 验证字符串字段不为空
|
||||
for field in items:
|
||||
if isinstance(result[field], str) and not result[field].strip():
|
||||
logger.error(f"[私聊][{private_name}]{field} 不能为空")
|
||||
return False, result
|
||||
|
||||
return True, result
|
||||
@@ -1,198 +0,0 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
"""回复检查器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reply_check")
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
Args:
|
||||
reply: 生成的回复
|
||||
goal: 对话目标
|
||||
chat_history: 对话历史记录
|
||||
chat_history_text: 对话历史记录文本
|
||||
retry_count: 当前重试次数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
# 不再从 observer 获取,直接使用传入的 chat_history
|
||||
# messages = self.chat_observer.get_cached_messages(limit=20)
|
||||
try:
|
||||
# 筛选出最近由 Bot 自己发送的消息
|
||||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.bot.qq_account):
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
# 进行比较
|
||||
if bot_messages:
|
||||
# 可以用简单比较,或者更复杂的相似度库 (如 difflib)
|
||||
# 简单比较:是否完全相同
|
||||
if reply == bot_messages[0]: # 和最近一条完全一样
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待",
|
||||
True,
|
||||
) # 不合适,需要返回至决策层
|
||||
# 2. 相似度检查 (如果精确匹配未通过)
|
||||
import difflib # 导入 difflib 库
|
||||
|
||||
# 计算编辑距离相似度,ratio() 返回 0 到 1 之间的浮点数
|
||||
similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio()
|
||||
logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}")
|
||||
|
||||
# 设置一个相似度阈值
|
||||
similarity_threshold = 0.9
|
||||
if similarity_ratio > similarity_threshold:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息
|
||||
|
||||
prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适:
|
||||
|
||||
当前对话目标:{goal}
|
||||
最新的对话记录:
|
||||
{chat_history_text}
|
||||
|
||||
待检查的消息:
|
||||
{reply}
|
||||
|
||||
请结合聊天记录检查以下几点:
|
||||
1. 这条消息是否依然符合当前对话目标和实现方式
|
||||
2. 这条消息是否与最新的对话记录保持一致性
|
||||
3. 是否存在重复发言,或重复表达同质内容(尤其是只是换一种方式表达了相同的含义)
|
||||
4. 这条消息是否包含违规内容(例如血腥暴力,政治敏感等)
|
||||
5. 这条消息是否以发送者的角度发言(不要让发送者自己回复自己的消息)
|
||||
6. 这条消息是否通俗易懂
|
||||
7. 这条消息是否有些多余,例如在对方没有回复的情况下,依然连续多次“消息轰炸”(尤其是已经连续发送3条信息的情况,这很可能不合理,需要着重判断)
|
||||
8. 这条消息是否使用了完全没必要的修辞
|
||||
9. 这条消息是否逻辑通顺
|
||||
10. 这条消息是否太过冗长了(通常私聊的每条消息长度在20字以内,除非特殊情况)
|
||||
11. 在连续多次发送消息的情况下,这条消息是否衔接自然,会不会显得奇怪(例如连续两条消息中部分内容重叠)
|
||||
|
||||
请以JSON格式输出,包含以下字段:
|
||||
1. suitable: 是否合适 (true/false)
|
||||
2. reason: 原因说明
|
||||
3. need_replan: 是否需要重新决策 (true/false),当你认为此时已经不适合发消息,需要规划其它行动时,设为true
|
||||
|
||||
输出格式示例:
|
||||
{{
|
||||
"suitable": true,
|
||||
"reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体",
|
||||
"need_replan": false
|
||||
}}
|
||||
|
||||
注意:请严格按照JSON格式输出,不要包含任何其他内容。"""
|
||||
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
|
||||
|
||||
# 清理内容,尝试提取JSON部分
|
||||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result: dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result: dict = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
else:
|
||||
# 如果找不到JSON,从文本中判断
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
reason = content[:100] if content else "无法解析响应"
|
||||
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
|
||||
return is_suitable, reason, need_replan
|
||||
|
||||
# 验证JSON字段
|
||||
suitable = result.get("suitable", None)
|
||||
reason = result.get("reason", "未提供原因")
|
||||
need_replan = result.get("need_replan", False)
|
||||
|
||||
# 如果suitable字段是字符串,转换为布尔值
|
||||
if isinstance(suitable, str):
|
||||
suitable = suitable.lower() == "true"
|
||||
|
||||
# 如果suitable字段不存在或不是布尔值,从reason中判断
|
||||
if suitable is None:
|
||||
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
|
||||
|
||||
# 如果不合适且未达到最大重试次数,返回需要重试
|
||||
if not suitable and retry_count < self.max_retries:
|
||||
return False, reason, False
|
||||
|
||||
# 如果不合适且已达到最大重试次数,返回需要重新规划
|
||||
if not suitable and retry_count >= self.max_retries:
|
||||
return False, f"多次重试后仍不合适: {reason}", True
|
||||
|
||||
return suitable, reason, need_replan
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}")
|
||||
# 如果出错且已达到最大重试次数,建议重新规划
|
||||
if retry_count >= self.max_retries:
|
||||
return False, "多次检查失败,建议重新规划", True
|
||||
return False, f"检查过程出错,建议重试: {str(e)}", False
|
||||
@@ -1,242 +0,0 @@
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
logger = get_logger("reply_generator")
|
||||
|
||||
# --- 定义 Prompt 模板 ---
|
||||
|
||||
# Prompt for direct_reply (首次回复)
|
||||
PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊,请根据以下信息生成一条回复:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,回复对方。该回复应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 自然、得体,结合聊天记录逻辑合理,且没有重复表达同质内容
|
||||
|
||||
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
可以回复得自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for send_new_message (追问/补充)
|
||||
PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊,**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息:
|
||||
|
||||
当前对话目标:{goals_str}
|
||||
|
||||
{knowledge_info_str}
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
|
||||
请根据上述信息,结合聊天记录,继续发一条新消息(例如对之前消息的补充,深入话题,或追问等等)。该消息应该:
|
||||
1. 符合对话目标,以"你"的角度发言(不要自己与自己对话!)
|
||||
2. 符合你的性格特征和身份细节
|
||||
3. 通俗易懂,自然流畅,像正常聊天一样,简短(通常20字以内,除非特殊情况)
|
||||
4. 可以适当利用相关知识,但不要生硬引用
|
||||
5. 跟之前你发的消息自然的衔接,逻辑合理,且没有重复表达同质内容或部分重叠内容
|
||||
|
||||
请注意把握聊天内容,不用太有条理,可以有个性。请分清"你"和对方说的话,不要把"你"说的话当做对方说的话,这是你自己说的话。
|
||||
这条消息可以自然随意自然一些,就像真人一样,注意把握聊天内容,整体风格可以平和、简短,不要刻意突出自身学科背景,不要说你说过的话,可以简短,多简短都可以,但是避免冗长。
|
||||
请你注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出消息内容。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
|
||||
请直接输出回复内容,不需要任何额外格式。"""
|
||||
|
||||
# Prompt for say_goodbye (告别语生成)
|
||||
PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。
|
||||
|
||||
最近的聊天记录:
|
||||
{chat_history_text}
|
||||
|
||||
请根据上述信息,结合聊天记录,构思一条**简短、自然、符合你人设**的最后的消息。
|
||||
这条消息应该:
|
||||
1. 从你自己的角度发言。
|
||||
2. 符合你的性格特征和身份细节。
|
||||
3. 通俗易懂,自然流畅,通常很简短。
|
||||
4. 自然地为这场对话画上句号,避免开启新话题或显得冗长、刻意。
|
||||
|
||||
请像真人一样随意自然,**简洁是关键**。
|
||||
不要输出多余内容(包括前后缀、冒号、引号、括号、表情包、at或@等)。
|
||||
|
||||
请直接输出最终的告别消息内容,不需要任何额外格式。"""
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
"""回复生成器"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 generate 方法签名,增加 action_type 参数
|
||||
async def generate(
|
||||
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str
|
||||
) -> str:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
observation_info: 观察信息
|
||||
conversation_info: 对话信息
|
||||
action_type: 当前执行的动作类型 ('direct_reply' 或 'send_new_message')
|
||||
|
||||
Returns:
|
||||
str: 生成的回复
|
||||
"""
|
||||
# 构建提示词
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}"
|
||||
)
|
||||
|
||||
# --- 构建通用 Prompt 参数 ---
|
||||
# (这部分逻辑基本不变)
|
||||
|
||||
# 构建对话目标 (goals_str)
|
||||
goals_str = ""
|
||||
if conversation_info.goal_list:
|
||||
for goal_reason in conversation_info.goal_list:
|
||||
if isinstance(goal_reason, dict):
|
||||
goal = goal_reason.get("goal", "目标内容缺失")
|
||||
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||
else:
|
||||
goal = str(goal_reason)
|
||||
reasoning = "没有明确原因"
|
||||
|
||||
goal = str(goal) if goal is not None else "目标内容缺失"
|
||||
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
|
||||
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
|
||||
else:
|
||||
goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况
|
||||
|
||||
# --- 新增:构建知识信息字符串 ---
|
||||
knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考
|
||||
try:
|
||||
# 检查 conversation_info 是否有 knowledge_list 并且不为空
|
||||
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
|
||||
# 最多只显示最近的 5 条知识
|
||||
recent_knowledge = conversation_info.knowledge_list[-5:]
|
||||
for i, knowledge_item in enumerate(recent_knowledge):
|
||||
if isinstance(knowledge_item, dict):
|
||||
query = knowledge_item.get("query", "未知查询")
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = f"{knowledge[:2000]}..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
else:
|
||||
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
|
||||
|
||||
if not recent_knowledge:
|
||||
knowledge_info_str += "- 暂无。\n" # 更简洁的提示
|
||||
|
||||
else:
|
||||
knowledge_info_str += "- 暂无。\n"
|
||||
except AttributeError:
|
||||
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
|
||||
knowledge_info_str += "- 获取知识列表时出错。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
|
||||
knowledge_info_str += "- 处理知识列表时出错。\n"
|
||||
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
elif not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。"
|
||||
|
||||
# 构建 Persona 文本 (persona_text)
|
||||
persona_text = f"你的名字是{self.name},{self.personality_info}。"
|
||||
|
||||
# --- 选择 Prompt ---
|
||||
if action_type == "send_new_message":
|
||||
prompt_template = PROMPT_SEND_NEW_MESSAGE
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)")
|
||||
elif action_type == "say_goodbye": # 处理告别动作
|
||||
prompt_template = PROMPT_FAREWELL
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)")
|
||||
else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型)
|
||||
prompt_template = PROMPT_DIRECT_REPLY
|
||||
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)")
|
||||
|
||||
# --- 格式化最终的 Prompt ---
|
||||
prompt = prompt_template.format(
|
||||
persona_text=persona_text,
|
||||
goals_str=goals_str,
|
||||
chat_history_text=chat_history_text,
|
||||
knowledge_info_str=knowledge_info_str,
|
||||
)
|
||||
|
||||
# --- 调用 LLM 生成 ---
|
||||
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
|
||||
try:
|
||||
content, _ = await self.llm.generate_response_async(prompt)
|
||||
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
|
||||
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}")
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
# check_reply 方法保持不变
|
||||
async def check_reply(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
(此方法逻辑保持不变)
|
||||
"""
|
||||
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)
|
||||
@@ -1,79 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
# from src.individuality.individuality import Individuality # 不再需要
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("waiter")
|
||||
|
||||
# --- 在这里设定你想要的超时时间(秒) ---
|
||||
# 例如: 120 秒 = 2 分钟
|
||||
DESIRED_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
class Waiter:
|
||||
"""等待处理类"""
|
||||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||
|
||||
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""等待用户新消息或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么",
|
||||
"reasoning": "对方很久没有回复你的消息了",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]等待中..."
|
||||
) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出
|
||||
|
||||
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
||||
"""倾听用户发言或超时"""
|
||||
wait_start_time = time.time()
|
||||
logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
|
||||
|
||||
while True:
|
||||
# 检查是否有新消息
|
||||
if self.chat_observer.new_message_after(wait_start_time):
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息")
|
||||
return False # 返回 False 表示不是超时
|
||||
|
||||
# 检查是否超时
|
||||
elapsed_time = time.time() - wait_start_time
|
||||
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
|
||||
logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
|
||||
wait_goal = {
|
||||
# 保持 goal 文本一致
|
||||
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么",
|
||||
"reasoning": "对方话说一半消失了,很久没有回复",
|
||||
}
|
||||
conversation_info.goal_list.append(wait_goal)
|
||||
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
|
||||
return True # 返回 True 表示超时
|
||||
|
||||
await asyncio.sleep(5) # 每 5 秒检查一次
|
||||
logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉
|
||||
@@ -1,800 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.heartFC_utils import CycleDetail
|
||||
from src.person_info.person_info import Person
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.services import (
|
||||
generator_service as generator_api,
|
||||
send_service as send_api,
|
||||
message_service as message_api,
|
||||
database_service as database_api,
|
||||
)
|
||||
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
"action_result": {
|
||||
"action_type": "error",
|
||||
"action_data": {},
|
||||
"reasoning": "循环处理失败",
|
||||
},
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": False,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
|
||||
logger = get_logger("bc") # Logger Name Changed
|
||||
|
||||
|
||||
class BrainChatting:
|
||||
"""
|
||||
管理一个连续的私聊Brain Chat循环
|
||||
用于在特定聊天流中生成回复。
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
"""
|
||||
BrainChatting 初始化函数
|
||||
|
||||
参数:
|
||||
chat_id: 聊天流唯一标识符(如stream_id)
|
||||
on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数
|
||||
performance_version: 性能记录版本号,用于区分不同启动版本
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = session_id # 聊天流ID
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore[assignment]
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(self.stream_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
self._expression_learner = ExpressionLearner(self.stream_id)
|
||||
self._jargon_miner = JargonMiner(self.stream_id, session_name=self.log_prefix.strip("[]"))
|
||||
self._min_messages_for_extraction = 30
|
||||
self._min_extraction_interval = 60
|
||||
self._last_extraction_time = 0.0
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = BrainPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
|
||||
# 循环控制内部状态
|
||||
self.running: bool = False
|
||||
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
|
||||
self._new_message_event = asyncio.Event() # 新消息事件,用于打断 wait
|
||||
|
||||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
# 最近一次是否成功进行了 reply,用于选择 BrainPlanner 的 Prompt
|
||||
self._last_successful_reply: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
# 如果循环已经激活,直接返回
|
||||
if self.running:
|
||||
logger.debug(f"{self.log_prefix} BrainChatting 已激活,无需重复启动")
|
||||
return
|
||||
|
||||
try:
|
||||
# 标记为活动状态,防止重复启动
|
||||
self.running = True
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
||||
logger.info(f"{self.log_prefix} BrainChatting 启动完成")
|
||||
|
||||
except Exception as e:
|
||||
# 启动失败时重置状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
logger.error(f"{self.log_prefix} BrainChatting 启动失败: {e}")
|
||||
raise
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.log_prefix} BrainChatting: 脱离了聊天(异常): {exception}")
|
||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 脱离了聊天 (外部停止)")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} BrainChatting: 结束了聊天")
|
||||
|
||||
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
def end_cycle(self, loop_info, cycle_timers):
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
self.history_loop.append(self._current_cycle_detail)
|
||||
self._current_cycle_detail.timers = cycle_timers
|
||||
self._current_cycle_detail.end_time = time.time()
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _trigger_expression_learning(self, messages: List[SessionMessage]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
self._expression_learner.add_messages(messages)
|
||||
if time.time() - self._last_extraction_time < self._min_extraction_interval:
|
||||
return
|
||||
if self._expression_learner.get_cache_size() < self._min_messages_for_extraction:
|
||||
return
|
||||
if not self._enable_expression_learning:
|
||||
return
|
||||
|
||||
self._last_extraction_time = time.time()
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
await self._expression_learner.learn(jargon_miner)
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 表达学习失败: {exc}", exc_info=True)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
# 获取最新消息(用于上下文,但不影响是否调用 observe)
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并触发事件以打断正在进行的 wait
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
self._new_message_event.set() # 触发新消息事件,打断 wait
|
||||
|
||||
# 总是执行一次思考迭代(不管有没有新消息)
|
||||
# wait 动作会在其内部等待,不需要在这里处理
|
||||
should_continue = await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
if not should_continue:
|
||||
# 选择了 complete_talk,返回 False 表示需要等待新消息
|
||||
return False
|
||||
|
||||
# 继续下一次迭代(除非选择了 complete_talk)
|
||||
# 短暂等待后再继续,避免过于频繁的循环
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return True
|
||||
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set: MessageSequence,
|
||||
action_message: SessionMessage,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.platform
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.message_info.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=action_prompt_display,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List[SessionMessage]] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
if recent_messages_list:
|
||||
asyncio.create_task(self._trigger_expression_learning(recent_messages_list))
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
# 第一步:动作检查
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||
|
||||
# 一次思考迭代:Think - Act - Observe
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.action_planner.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
prompt_info = await self.action_planner.build_planner_prompt(
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=available_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
prompt_key="brain_planner",
|
||||
)
|
||||
_event_msg = build_event_message(
|
||||
EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.session_id
|
||||
)
|
||||
continue_flag, modified_message = await event_bus.emit(EventType.ON_PLAN, _event_msg)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||
|
||||
# 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||
)
|
||||
for action in action_to_use_info
|
||||
]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["reply_text"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["reply_text"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
# 更新观察时间标记
|
||||
self.action_planner.last_obs_time_mark = time.time()
|
||||
|
||||
# 如果选择了 complete_talk,标记为完成,不再继续迭代
|
||||
if has_complete_talk:
|
||||
logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成")
|
||||
|
||||
# 构建循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": action_to_use_info,
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
_reply_text = action_reply_text
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 以停止 _loopbody 的循环
|
||||
# 否则返回 True,让 _loopbody 继续下一次迭代
|
||||
should_continue = not has_complete_talk
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
# 如果选择了 complete_talk,返回 False 停止循环
|
||||
# 否则返回 True,继续下一次思考迭代
|
||||
return should_continue
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while self.running:
|
||||
# 主循环
|
||||
success = await self._loopbody()
|
||||
if not success:
|
||||
# 选择了 complete,等待新消息
|
||||
logger.info(f"{self.log_prefix} 选择了 complete,等待新消息...")
|
||||
await self._wait_for_new_message()
|
||||
# 有新消息后继续循环
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
# 设置了关闭标志位后被取消是正常流程
|
||||
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
|
||||
except Exception:
|
||||
logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
||||
print(traceback.format_exc())
|
||||
await asyncio.sleep(3)
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
|
||||
|
||||
async def _wait_for_new_message(self):
|
||||
"""等待新消息到达"""
|
||||
last_check_time = self.last_read_time
|
||||
check_interval = 1.0 # 每秒检查一次
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
while self.running:
|
||||
# 检查是否有新消息
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=last_check_time,
|
||||
end_time=time.time(),
|
||||
limit=20,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=False,
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
# 如果有新消息,更新 last_read_time 并返回
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
|
||||
return
|
||||
|
||||
# 等待新消息事件或超时后再次检查
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=check_interval)
|
||||
# 事件被触发,说明有新消息
|
||||
logger.info(f"{self.log_prefix} 检测到新消息事件,恢复循环")
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
# 超时后继续检查
|
||||
continue
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: Optional[SessionMessage] = None,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
try:
|
||||
# 使用工厂创建动作处理器实例
|
||||
try:
|
||||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
action_reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
# BaseAction 定义了异步方法 execute() 作为统一执行入口
|
||||
# 这里调用 execute() 以兼容所有 Action 实现
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
return success, action_text, command
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set: MessageSequence,
|
||||
message_data: SessionMessage,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.session_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
for component in reply_set.components:
|
||||
if not isinstance(component, TextComponent):
|
||||
continue
|
||||
data = component.text
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.session_id,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
async def _execute_action(
|
||||
self,
|
||||
action_planner_info: ActionPlannerInfo,
|
||||
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||
thinking_id: str,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
cycle_timers: Dict[str, float],
|
||||
):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
if action_planner_info.action_type == "complete_talk":
|
||||
# 直接处理complete_talk逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择完成对话"
|
||||
logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}")
|
||||
|
||||
# 存储complete_talk信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="complete_talk",
|
||||
)
|
||||
return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned_uw: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
if stripped_item := item.strip():
|
||||
cleaned_uw.append(stripped_item)
|
||||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
unknown_words=unknown_words,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
action_message=action_planner_info.action_message, # type: ignore
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
# 标记这次循环已经成功进行了回复
|
||||
self._last_successful_reply = True
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 内建 wait / listening:不通过插件系统,直接在这里处理
|
||||
if action_planner_info.action_type in ["wait", "listening"]:
|
||||
reason = action_planner_info.reasoning or ""
|
||||
action_data = action_planner_info.action_data or {}
|
||||
|
||||
if action_planner_info.action_type == "wait":
|
||||
# 获取等待时间(必填)
|
||||
wait_seconds = action_data.get("wait_seconds")
|
||||
if wait_seconds is None:
|
||||
logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
else:
|
||||
try:
|
||||
wait_seconds = float(wait_seconds)
|
||||
if wait_seconds < 0:
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
|
||||
wait_seconds = 5
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒(可被新消息打断)")
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason or f"等待 {wait_seconds} 秒",
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="wait",
|
||||
)
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常完成
|
||||
pass
|
||||
|
||||
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "wait",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# listening 已合并到 wait,如果遇到则转换为 wait(向后兼容)
|
||||
elif action_planner_info.action_type == "listening":
|
||||
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换")
|
||||
# 使用默认等待时间
|
||||
wait_seconds = 3
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒(可被新消息打断)"
|
||||
)
|
||||
|
||||
# 清除事件状态,准备等待新消息
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 记录动作信息
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
display_prompt=reason or f"倾听并等待 {wait_seconds} 秒",
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason, "wait_seconds": wait_seconds},
|
||||
action_name="listening",
|
||||
)
|
||||
|
||||
# 等待指定时间,但可被新消息打断
|
||||
try:
|
||||
await asyncio.wait_for(self._new_message_event.wait(), timeout=wait_seconds)
|
||||
# 如果事件被触发,说明有新消息到达
|
||||
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常完成
|
||||
pass
|
||||
|
||||
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
|
||||
|
||||
# 这些动作本身不产生文本回复
|
||||
self._last_successful_reply = False
|
||||
return {
|
||||
"action_type": "listening",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": "",
|
||||
}
|
||||
|
||||
# 其余动作:走原有插件 Action 体系
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
action_planner_info.reasoning or "",
|
||||
action_planner_info.action_data or {},
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
)
|
||||
# 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt
|
||||
if success and action_planner_info.action_type != "reply":
|
||||
self._last_successful_reply = False
|
||||
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -1,620 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
get_messages_before_time_in_chat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class BrainPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
# 计划日志记录
|
||||
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional["SessionMessage"]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item[0] == message_id:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "complete_talk")
|
||||
logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非complete_talk动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
|
||||
# 选择最新消息作为target_message
|
||||
target_message = message_id_list[-1][1]
|
||||
else:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
# 内部保留动作(不依赖插件系统)
|
||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||
)
|
||||
|
||||
# 将 listening 转换为 wait(向后兼容)
|
||||
if action == "listening":
|
||||
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换")
|
||||
action = "wait"
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "complete_talk"
|
||||
logger.warning(f"{self.log_prefix}动作已转换为 complete_talk")
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
)
|
||||
)
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "SessionMessage"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 提及/被@ 的处理由心流或统一判定模块驱动;Planner 不再做硬编码强制回复
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
|
||||
prompt_key = "brain_planner"
|
||||
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
prompt_key=prompt_key,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 记录和展示计划日志
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
prompt_key: str = "brain_planner",
|
||||
) -> tuple[str, List[Tuple[str, "SessionMessage"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
# 获取最近执行过的动作
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 600,
|
||||
timestamp_end=time.time(),
|
||||
limit=6,
|
||||
)
|
||||
actions_before_now_block = ActionUtils.build_readable_action_records(actions_before_now)
|
||||
if actions_before_now_block:
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
else:
|
||||
actions_before_now_block = ""
|
||||
|
||||
chat_context_description: str = ""
|
||||
if chat_target_info:
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = (
|
||||
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
|
||||
)
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
||||
# 其他信息
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也可以叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 获取主规划器模板并填充
|
||||
planner_prompt_template = prompt_manager.get_prompt(prompt_key)
|
||||
planner_prompt_template.add_context("time_block", time_block)
|
||||
planner_prompt_template.add_context("chat_context_description", chat_context_description)
|
||||
planner_prompt_template.add_context("chat_content_block", chat_content_block)
|
||||
planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block)
|
||||
planner_prompt_template.add_context("action_options_text", action_options_block)
|
||||
planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block)
|
||||
planner_prompt_template.add_context("name_block", name_block)
|
||||
planner_prompt_template.add_context("interest", interest)
|
||||
planner_prompt_template.add_context("plan_style", global_config.experimental.private_plan_style)
|
||||
prompt = await prompt_manager.render_prompt(planner_prompt_template)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
|
||||
"""
|
||||
获取 Planner 需要的必要信息
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
return is_group_chat, chat_target_info, current_available_actions
|
||||
|
||||
def _filter_actions_by_activation_type(
|
||||
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
|
||||
) -> Dict[str, ActionInfo]:
|
||||
"""根据激活类型过滤动作"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in available_actions.items():
|
||||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if action_info.activation_keywords:
|
||||
for keyword in action_info.activation_keywords:
|
||||
if keyword in chat_content_block:
|
||||
filtered_actions[action_name] = action_info
|
||||
break
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
|
||||
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
# sourcery skip: use-join
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数文本
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in action_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = ""
|
||||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt_template = prompt_manager.get_prompt("brain_action")
|
||||
using_action_prompt_template.add_context("action_name", action_name)
|
||||
using_action_prompt_template.add_context("action_description", action_info.description)
|
||||
using_action_prompt_template.add_context("action_parameters", param_text)
|
||||
using_action_prompt_template.add_context("action_require", require_text)
|
||||
using_action_prompt = await prompt_manager.render_prompt(using_action_prompt_template)
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
||||
return action_options_block
|
||||
|
||||
async def _execute_main_planner(
|
||||
self,
|
||||
prompt: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
extracted_reasoning = ""
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
return (
|
||||
extracted_reasoning,
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=extracted_reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
if json_objects:
|
||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
for i, json_obj in enumerate(json_objects):
|
||||
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||
logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}")
|
||||
actions.extend(parsed_actions)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_complete_talk(extracted_reasoning, available_actions)
|
||||
|
||||
# 添加循环开始时间到所有动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""创建complete_talk"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="complete_talk",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
"""添加计划日志"""
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
# 先尝试将整个块作为一个JSON对象或数组(适用于多行JSON)
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果整个块解析失败,尝试按行分割(适用于多个单行JSON对象)
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 单行解析失败,继续下一行
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -17,8 +18,9 @@ from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
|
||||
@@ -38,8 +40,10 @@ def _ensure_directories() -> None:
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
emoji_manager_vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMServiceClient(
|
||||
task_name="utils", request_type="emoji"
|
||||
)
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
@@ -48,11 +52,13 @@ class EmojiManager:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化表情包管理器。"""
|
||||
_ensure_directories()
|
||||
|
||||
self._emoji_num: int = 0
|
||||
self.emojis: list[MaiEmoji] = []
|
||||
self.emojis: List[MaiEmoji] = []
|
||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
self._reload_callback_registered: bool = False
|
||||
|
||||
config_manager.register_reload_callback(self.reload_runtime_config)
|
||||
@@ -75,7 +81,11 @@ class EmojiManager:
|
||||
logger.info("[关闭] Emoji 模块已注销配置热重载回调")
|
||||
|
||||
async def get_emoji_description(
|
||||
self, *, emoji_bytes: Optional[bytes] = None, emoji_hash: Optional[str] = None
|
||||
self,
|
||||
*,
|
||||
emoji_bytes: Optional[bytes] = None,
|
||||
emoji_hash: Optional[str] = None,
|
||||
wait_for_build: bool = True,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
"""
|
||||
根据表情包哈希获取表情包描述和情感列表的封装方法
|
||||
@@ -83,6 +93,7 @@ class EmojiManager:
|
||||
Args:
|
||||
emoji_bytes (Optional[bytes]): 表情包的字节数据,如果提供了字节数据但数据库中没有找到对应记录,则会尝试构建表情包描述
|
||||
emoji_hash (Optional[str]): 表情包的哈希值,如果提供了哈希值则优先使用哈希值查找表情包描述
|
||||
wait_for_build (bool): 未命中缓存时是否同步等待描述构建完成
|
||||
Returns:
|
||||
return (Optional[Tuple[str, List[str]]]): 如果找到对应的表情包,则返回包含描述和情感标签的元组;若没找到,则尝试构建表情包描述并返回,如果构建失败则返回 None
|
||||
Raises:
|
||||
@@ -110,27 +121,88 @@ class EmojiManager:
|
||||
# 如果提供了字节数据但数据库中没有找到,尝试构建
|
||||
if not emoji_bytes:
|
||||
return None
|
||||
if not wait_for_build:
|
||||
self._schedule_description_build(emoji_hash, emoji_bytes)
|
||||
return None
|
||||
|
||||
# 找不到尝试构建
|
||||
return await self._build_and_cache_emoji_description(emoji_hash, emoji_bytes)
|
||||
|
||||
def _schedule_description_build(self, emoji_hash: str, emoji_bytes: bytes) -> None:
|
||||
"""调度表情包描述后台构建任务。
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包哈希值。
|
||||
emoji_bytes: 表情包字节数据。
|
||||
"""
|
||||
if emoji_hash in self._pending_description_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._build_description_in_background(emoji_hash, emoji_bytes))
|
||||
self._pending_description_tasks[emoji_hash] = task
|
||||
task.add_done_callback(lambda finished_task: self._finalize_description_build(emoji_hash, finished_task))
|
||||
|
||||
async def _build_description_in_background(self, emoji_hash: str, emoji_bytes: bytes) -> None:
|
||||
"""在后台构建并缓存表情包描述。
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包哈希值。
|
||||
emoji_bytes: 表情包字节数据。
|
||||
"""
|
||||
try:
|
||||
logger.info(f"表情包描述后台构建已开始,哈希值: {emoji_hash}")
|
||||
await self._build_and_cache_emoji_description(emoji_hash, emoji_bytes)
|
||||
logger.info(f"表情包描述后台构建完成,哈希值: {emoji_hash}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"表情包描述后台构建失败,哈希值: {emoji_hash},错误: {exc}")
|
||||
|
||||
def _finalize_description_build(self, emoji_hash: str, task: asyncio.Task[None]) -> None:
|
||||
"""回收表情包描述后台构建任务。
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包哈希值。
|
||||
task: 已完成的后台任务。
|
||||
"""
|
||||
self._pending_description_tasks.pop(emoji_hash, None)
|
||||
try:
|
||||
task.result()
|
||||
except Exception as exc:
|
||||
logger.debug(f"表情包描述后台任务结束时捕获异常,哈希值: {emoji_hash},错误: {exc}")
|
||||
|
||||
async def _build_and_cache_emoji_description(
|
||||
self,
|
||||
emoji_hash: str,
|
||||
emoji_bytes: bytes,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
"""构建并缓存表情包描述与情感标签。
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包哈希值。
|
||||
emoji_bytes: 表情包字节数据。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, List[str]]]: 构建成功时返回描述和情感标签,否则返回 ``None``。
|
||||
"""
|
||||
logger.info(f"未找到哈希值为 {emoji_hash} 的表情包与其描述,尝试构建描述")
|
||||
full_path = EMOJI_DIR / f"{emoji_hash}.png"
|
||||
try:
|
||||
full_path.write_bytes(emoji_bytes)
|
||||
new_emoji = MaiEmoji(full_path=full_path, image_bytes=emoji_bytes)
|
||||
await new_emoji.calculate_hash_format()
|
||||
except Exception as e:
|
||||
logger.error(f"缓存表情包文件时出错: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"缓存表情包文件时出错: {exc}")
|
||||
raise exc
|
||||
|
||||
success_desc, new_emoji = await self.build_emoji_description(new_emoji)
|
||||
if not success_desc:
|
||||
logger.error("构建表情包描述失败")
|
||||
return None
|
||||
|
||||
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
|
||||
if not success_emotion:
|
||||
logger.error("构建表情包情感标签失败")
|
||||
return None
|
||||
|
||||
# 缓存结果到数据库
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
image_record = new_emoji.to_db_instance()
|
||||
@@ -139,8 +211,8 @@ class EmojiManager:
|
||||
image_record.register_time = datetime.now()
|
||||
image_record.no_file_flag = True
|
||||
session.add(image_record)
|
||||
except Exception as e:
|
||||
logger.error(f"缓存表情包描述时出错: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"缓存表情包描述时出错: {exc}")
|
||||
return new_emoji.description, new_emoji.emotion or []
|
||||
|
||||
def load_emojis_from_db(self) -> None:
|
||||
@@ -461,9 +533,11 @@ class EmojiManager:
|
||||
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
|
||||
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
|
||||
|
||||
decision, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emoji_replace_prompt, temperature=0.8, max_tokens=600
|
||||
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emoji_replace_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.8, max_tokens=600),
|
||||
)
|
||||
decision = decision_result.response
|
||||
logger.info(f"[决策] 结果: {decision}")
|
||||
|
||||
# 解析决策结果
|
||||
@@ -515,33 +589,56 @@ class EmojiManager:
|
||||
image_bytes = target_emoji.image_bytes or await asyncio.to_thread(
|
||||
target_emoji.read_image_bytes, target_emoji.full_path
|
||||
)
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
try:
|
||||
if image_format == "gif":
|
||||
try:
|
||||
image_bytes = await asyncio.to_thread(ImageUtils.gif_2_static_image, image_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}")
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
"jpg",
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.5),
|
||||
)
|
||||
description = description_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"[构建描述] 调用视觉模型生成表情包描述时出错: {e}")
|
||||
return False, target_emoji
|
||||
|
||||
if image_format == "gif":
|
||||
try:
|
||||
image_bytes = await asyncio.to_thread(ImageUtils.gif_2_static_image, image_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}")
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.5
|
||||
)
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.5
|
||||
)
|
||||
if not description:
|
||||
logger.warning(f"[构建描述] 视觉模型返回空描述,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
# 表情包审查
|
||||
if global_config.emoji.content_filtration:
|
||||
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
|
||||
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
|
||||
llm_response, _ = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt, image_base64, image_format, temperature=0.3
|
||||
)
|
||||
try:
|
||||
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
|
||||
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
|
||||
filtration_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.3),
|
||||
)
|
||||
llm_response = filtration_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"[表情包审查] 调用视觉模型审查表情包时出错: {e}")
|
||||
return False, target_emoji
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
@@ -567,9 +664,19 @@ class EmojiManager:
|
||||
emotion_prompt_template.add_context("description", target_emoji.description)
|
||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||
# 调用LLM生成情感标签
|
||||
emotion_result, _ = await emoji_manager_emotion_judge_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=200
|
||||
)
|
||||
try:
|
||||
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emotion_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
|
||||
)
|
||||
emotion_result = emotion_generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"[构建情感标签] 调用模型生成情感标签时出错: {e}")
|
||||
return False, target_emoji
|
||||
|
||||
if not emotion_result:
|
||||
logger.warning(f"[构建情感标签] 情感标签结果为空,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
# 解析情感标签结果
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
@@ -651,7 +758,12 @@ class EmojiManager:
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
if await self.register_emoji_by_filename(emoji_file):
|
||||
try:
|
||||
register_success = await self.register_emoji_by_filename(emoji_file)
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
|
||||
register_success = False
|
||||
if register_success:
|
||||
break # 每次只注册一个表情包
|
||||
try:
|
||||
emoji_file.unlink()
|
||||
|
||||
@@ -198,7 +198,6 @@ class HeartFChatting:
|
||||
"""判定和生成回复"""
|
||||
asyncio.create_task(self._trigger_expression_learning(self.message_cache))
|
||||
# TODO: 完成反思器之后的逻辑
|
||||
start_time = time.time()
|
||||
current_cycle_detail = self._start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
@@ -207,10 +206,7 @@ class HeartFChatting:
|
||||
# TODO: 动作执行逻辑
|
||||
|
||||
cycle_detail = self._end_cycle(current_cycle_detail)
|
||||
if wait_time := global_config.chat.planner_smooth - (time.time() - start_time) > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
|
||||
return True
|
||||
|
||||
def _handle_loop_completion(self, task: asyncio.Task):
|
||||
|
||||
@@ -1,47 +1,50 @@
|
||||
from typing import Dict
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
# from src.chat.brain_chat.brain_chat import BrainChatting
|
||||
from src.common.logger import get_logger
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
# TODO: 恢复PFC,现在暂时禁用
|
||||
class HeartflowManager:
|
||||
"""主心流协调器,负责初始化并协调聊天,控制聊天属性"""
|
||||
"""管理 session 级别的 Maisaka 心流实例。"""
|
||||
|
||||
def __init__(self):
|
||||
# self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
||||
self.heartflow_chat_list: Dict[str, HeartFChatting] = {}
|
||||
def __init__(self) -> None:
|
||||
self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {}
|
||||
self._chat_create_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting:
|
||||
"""获取或创建指定会话对应的 Maisaka runtime。"""
|
||||
try:
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
return chat
|
||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||
if not chat_session:
|
||||
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
|
||||
# new_chat = (
|
||||
# HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
|
||||
# )
|
||||
new_chat = HeartFChatting(session_id=session_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[session_id] = new_chat
|
||||
return new_chat
|
||||
except Exception as e:
|
||||
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def adjust_talk_frequency(self, session_id: str, frequency: float):
|
||||
"""调整指定聊天流的说话频率"""
|
||||
create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock())
|
||||
async with create_lock:
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
return chat
|
||||
|
||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||
if not chat_session:
|
||||
raise ValueError(f"未找到 session_id={session_id} 对应的聊天流")
|
||||
|
||||
new_chat = MaisakaHeartFlowChatting(session_id=session_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[session_id] = new_chat
|
||||
return new_chat
|
||||
except Exception as exc:
|
||||
logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def adjust_talk_frequency(self, session_id: str, frequency: float) -> None:
|
||||
"""调整指定聊天流的说话频率。"""
|
||||
chat = self.heartflow_chat_list.get(session_id)
|
||||
if chat and isinstance(chat, HeartFChatting):
|
||||
if chat:
|
||||
chat.adjust_talk_frequency(frequency)
|
||||
logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}")
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
@@ -11,8 +13,9 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.data_models.image_data_model import MaiImage
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMImageOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -23,21 +26,30 @@ IMAGE_DIR = DATA_DIR / "images"
|
||||
logger = get_logger("image")
|
||||
|
||||
|
||||
def _ensure_image_dir_exists():
|
||||
def _ensure_image_dir_exists() -> None:
|
||||
"""确保图片缓存目录存在。"""
|
||||
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
vlm = LLMServiceClient(task_name="vlm", request_type="image")
|
||||
|
||||
|
||||
class ImageManager:
|
||||
def __init__(self):
|
||||
"""图片描述管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化图片管理器。"""
|
||||
_ensure_image_dir_exists()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
logger.info("图片管理器初始化完成")
|
||||
|
||||
async def get_image_description(
|
||||
self, *, image_hash: Optional[str] = None, image_bytes: Optional[bytes] = None
|
||||
self,
|
||||
*,
|
||||
image_hash: Optional[str] = None,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
wait_for_build: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
获取图片描述的封装方法
|
||||
@@ -49,6 +61,7 @@ class ImageManager:
|
||||
Args:
|
||||
image_hash (Optional[str]): 图片的哈希值,如果提供则优先使用该
|
||||
image_bytes (Optional[bytes]): 图片的字节数据,如果提供则在数据库中找不到哈希值时使用该数据生成描述
|
||||
wait_for_build (bool): 未命中缓存时是否同步等待描述构建完成
|
||||
Returns:
|
||||
return (str): 图片描述,如果发生错误或无法生成描述则返回空字符串
|
||||
Raises:
|
||||
@@ -73,6 +86,9 @@ class ImageManager:
|
||||
if not image_bytes:
|
||||
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
|
||||
return ""
|
||||
if not wait_for_build:
|
||||
self._schedule_description_build(hash_str, image_bytes)
|
||||
return ""
|
||||
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
|
||||
try:
|
||||
image = await self.save_image_and_process(image_bytes)
|
||||
@@ -81,6 +97,47 @@ class ImageManager:
|
||||
logger.error(f"生成图片描述时发生错误: {e}")
|
||||
return ""
|
||||
|
||||
def _schedule_description_build(self, image_hash: str, image_bytes: bytes) -> None:
|
||||
"""调度图片描述后台构建任务。
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值。
|
||||
image_bytes: 图片字节数据。
|
||||
"""
|
||||
if image_hash in self._pending_description_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._build_description_in_background(image_hash, image_bytes))
|
||||
self._pending_description_tasks[image_hash] = task
|
||||
task.add_done_callback(lambda finished_task: self._finalize_description_build(image_hash, finished_task))
|
||||
|
||||
async def _build_description_in_background(self, image_hash: str, image_bytes: bytes) -> None:
|
||||
"""在后台构建并缓存图片描述。
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值。
|
||||
image_bytes: 图片字节数据。
|
||||
"""
|
||||
try:
|
||||
logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}")
|
||||
await self.save_image_and_process(image_bytes)
|
||||
logger.info(f"图片描述后台构建完成,哈希值: {image_hash}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}")
|
||||
|
||||
def _finalize_description_build(self, image_hash: str, task: asyncio.Task[None]) -> None:
|
||||
"""回收图片描述后台构建任务。
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值。
|
||||
task: 已完成的后台任务。
|
||||
"""
|
||||
self._pending_description_tasks.pop(image_hash, None)
|
||||
try:
|
||||
task.result()
|
||||
except Exception as exc:
|
||||
logger.debug(f"图片描述后台任务结束时捕获异常,哈希值: {image_hash},错误: {exc}")
|
||||
|
||||
def get_image_from_db(self, image_hash: str) -> Optional[MaiImage]:
|
||||
"""
|
||||
从数据库中根据图片哈希值获取图片记录
|
||||
@@ -260,7 +317,13 @@ class ImageManager:
|
||||
prompt = global_config.personality.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
description, _ = await vlm.generate_response_for_image(prompt, image_base64, image_format, 0.4)
|
||||
generation_result = await vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
options=LLMImageOptions(temperature=0.4),
|
||||
)
|
||||
description = generation_result.response
|
||||
if not description:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
return description or ""
|
||||
|
||||
@@ -139,14 +139,14 @@ class EmbeddingStore:
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
@@ -195,13 +195,12 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
@@ -209,7 +208,8 @@ class EmbeddingStore:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from . import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
from . import INVALID_ENTITY
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str) -> List[str] | List[List[str]] | Dict[str, object]:
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
"""从文本中提取 JSON 数据。
|
||||
|
||||
Args:
|
||||
text: 原始模型输出文本。
|
||||
|
||||
Returns:
|
||||
List[str] | List[List[str]] | Dict[str, object]: 修复并解析后的 JSON 结果。
|
||||
"""
|
||||
if text is None:
|
||||
logger.error("输入文本为None")
|
||||
return []
|
||||
@@ -46,20 +55,30 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMServiceClient, paragraph: str) -> List[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
"""对单段文本执行实体提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取实体的原始段落文本。
|
||||
|
||||
Returns:
|
||||
List[str]: 提取出的实体列表。
|
||||
"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(entity_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(entity_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"LLM返回的原始响应: {response}")
|
||||
@@ -92,8 +111,21 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
def _rdf_triple_extract(
|
||||
llm_req: LLMServiceClient,
|
||||
paragraph: str,
|
||||
entities: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""对单段文本执行 RDF 三元组提取。
|
||||
|
||||
Args:
|
||||
llm_req: LLM 服务门面实例。
|
||||
paragraph: 待提取的原始段落文本。
|
||||
entities: 已识别出的实体列表。
|
||||
|
||||
Returns:
|
||||
List[List[str]]: 提取出的三元组列表。
|
||||
"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
@@ -102,11 +134,13 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
|
||||
response, _ = future.result()
|
||||
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response(rdf_extract_context), loop)
|
||||
generation_result = future.result()
|
||||
response = generation_result.response
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
|
||||
generation_result = asyncio.run(llm_req.generate_response(rdf_extract_context))
|
||||
response = generation_result.response
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||
@@ -140,8 +174,21 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
llm_client_for_ner: LLMServiceClient,
|
||||
llm_client_for_rdf: LLMServiceClient,
|
||||
paragraph: str,
|
||||
) -> Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]:
|
||||
"""从文本中提取实体与三元组信息。
|
||||
|
||||
Args:
|
||||
llm_client_for_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_client_for_rdf: RDF 三元组提取使用的 LLM 服务门面。
|
||||
paragraph: 原始段落文本。
|
||||
|
||||
Returns:
|
||||
Union[Tuple[None, None], Tuple[List[str], List[List[str]]]]: 成功时返回
|
||||
``(实体列表, 三元组列表)``,失败时返回 ``(None, None)``。
|
||||
"""
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -176,17 +223,30 @@ def info_extract_from_str(
|
||||
|
||||
|
||||
class IEProcess:
|
||||
"""
|
||||
信息抽取处理器类,提供更方便的批次处理接口。
|
||||
"""
|
||||
"""信息抽取处理器。"""
|
||||
|
||||
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_ner: LLMServiceClient,
|
||||
llm_rdf: LLMServiceClient | None = None,
|
||||
) -> None:
|
||||
"""初始化信息抽取处理器。
|
||||
|
||||
Args:
|
||||
llm_ner: 实体提取使用的 LLM 服务门面。
|
||||
llm_rdf: RDF 三元组提取使用的 LLM 服务门面;为空时复用 `llm_ner`。
|
||||
"""
|
||||
self.llm_ner = llm_ner
|
||||
self.llm_rdf = llm_rdf or llm_ner
|
||||
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
|
||||
"""
|
||||
异步处理多个段落。
|
||||
async def process_paragraphs(self, paragraphs: List[str]) -> List[Dict[str, object]]:
|
||||
"""异步处理多个段落。
|
||||
|
||||
Args:
|
||||
paragraphs: 待处理的段落列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 每个成功段落对应的抽取结果。
|
||||
"""
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import os
|
||||
@@ -10,9 +11,9 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
@@ -29,36 +30,20 @@ logger = get_logger("chat")
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化聊天机器人入口。"""
|
||||
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
# self.pfc_manager = PFCManager.get_instance() # PFC管理器 # TODO: PFC恢复
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
async def _ensure_started(self) -> None:
|
||||
"""确保所有后台任务已启动。"""
|
||||
if not self._started:
|
||||
logger.debug("确保ChatBot所有任务已启动")
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _create_pfc_chat(self, message: SessionMessage):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = message.session_id
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
@@ -175,11 +160,12 @@ class ChatBot:
|
||||
recalled: Dict[str, Any] = {}
|
||||
recalled_id = None
|
||||
|
||||
if getattr(seg, "type", None) == "notify" and isinstance(getattr(seg, "data", None), dict):
|
||||
sub_type = seg.data.get("sub_type")
|
||||
scene = seg.data.get("scene")
|
||||
msg_id = seg.data.get("message_id")
|
||||
recalled = seg.data.get("recalled_user_info") or {}
|
||||
seg_data = getattr(seg, "data", None)
|
||||
if getattr(seg, "type", None) == "notify" and isinstance(seg_data, dict):
|
||||
sub_type = seg_data.get("sub_type")
|
||||
scene = seg_data.get("scene")
|
||||
msg_id = seg_data.get("message_id")
|
||||
recalled = seg_data.get("recalled_user_info") or {}
|
||||
if isinstance(recalled, dict):
|
||||
recalled_id = recalled.get("user_id")
|
||||
|
||||
@@ -301,7 +287,13 @@ class ChatBot:
|
||||
# pass
|
||||
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
|
||||
await message.process()
|
||||
if global_config.maisaka.direct_image_input:
|
||||
message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined]
|
||||
# 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。
|
||||
await message.process(
|
||||
enable_heavy_media_analysis=False,
|
||||
enable_voice_transcription=False,
|
||||
)
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
@@ -335,14 +327,13 @@ class ChatBot:
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
# 注意:命令返回的 response 当前只用于日志记录和流程判断,
|
||||
# 不会在这里自动作为回复消息发送回会话。
|
||||
# is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
# 命令处理 - 使用新插件系统检查并处理命令。
|
||||
# 命令处理器内部自行决定是否回复消息,这里只负责流程分发与拦截。
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# # 如果是命令且不需要继续处理,则直接返回
|
||||
# if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
# return
|
||||
# 如果是命令且不需要继续处理,则直接返回,避免落入 HeartFlow / MaiSaka。
|
||||
if is_command and await self._handle_command_processing_result(message, cmd_result, continue_process):
|
||||
return
|
||||
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
# if not continue_flag:
|
||||
@@ -362,31 +353,12 @@ class ChatBot:
|
||||
# else:
|
||||
# template_group_name = None
|
||||
|
||||
# async def preprocess():
|
||||
# # 根据聊天类型路由消息
|
||||
# if group_info is None:
|
||||
# # 私聊消息 -> PFC系统
|
||||
# logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
# await MessageStorage.store_message(message, chat)
|
||||
# await self._create_pfc_chat(message)
|
||||
# else:
|
||||
# # 群聊消息 -> HeartFlow系统
|
||||
# logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
# await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
# if template_group_name:
|
||||
# async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
# await preprocess()
|
||||
# else:
|
||||
# await preprocess()
|
||||
async def preprocess():
|
||||
if group_info is None:
|
||||
# logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
# MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||
# await self._create_pfc_chat(message)
|
||||
logger.critical("暂时禁用私聊")
|
||||
logger.debug("[私聊]检测到私聊消息,路由到 Maisaka")
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
else:
|
||||
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
logger.debug("[群聊]检测到群聊消息,路由到 Maisaka")
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
await preprocess()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from asyncio import Task
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import List, Dict, Tuple, Sequence
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -27,14 +28,36 @@ logger = get_logger("chat_message")
|
||||
|
||||
|
||||
class MsgIDMapping:
|
||||
def __init__(self):
|
||||
self.mapping: Dict[str, Tuple[str | Task, UserInfo]] = {}
|
||||
"""回复消息内容缓存。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化消息 ID 到内容的映射缓存。"""
|
||||
self.mapping: Dict[str, Tuple[str | Task[str], UserInfo]] = {}
|
||||
|
||||
|
||||
class SessionMessage(MaiMessage):
|
||||
async def process(self):
|
||||
"""处理消息内容,识别消息内容并转化为文本(会修改消息组件属性)"""
|
||||
tasks = [self.process_single_component(component, MsgIDMapping()) for component in self.raw_message.components]
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
enable_voice_transcription: bool = True,
|
||||
) -> None:
|
||||
"""处理消息内容并转化为纯文本。
|
||||
|
||||
Args:
|
||||
enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。
|
||||
enable_voice_transcription: 是否同步执行语音转写。
|
||||
"""
|
||||
id_content_map = MsgIDMapping()
|
||||
tasks = [
|
||||
self.process_single_component(
|
||||
component,
|
||||
id_content_map,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
enable_voice_transcription=enable_voice_transcription,
|
||||
)
|
||||
for component in self.raw_message.components
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
processed_texts: List[str] = []
|
||||
for result in results:
|
||||
@@ -45,50 +68,116 @@ class SessionMessage(MaiMessage):
|
||||
self.processed_plain_text = " ".join(processed_texts)
|
||||
|
||||
async def process_single_component(
|
||||
self, component: StandardMessageComponents, id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
self,
|
||||
component: StandardMessageComponents,
|
||||
id_content_map: MsgIDMapping,
|
||||
recursion_depth: int = 0,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
enable_voice_transcription: bool = True,
|
||||
) -> str:
|
||||
"""按照类型处理单个消息组件,返回处理后的文本内容(会修改消息组件属性)"""
|
||||
"""按类型处理单个消息组件。
|
||||
|
||||
Args:
|
||||
component: 待处理的消息组件。
|
||||
id_content_map: 回复消息解析缓存。
|
||||
recursion_depth: 当前递归深度。
|
||||
enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。
|
||||
enable_voice_transcription: 是否同步执行语音转写。
|
||||
|
||||
Returns:
|
||||
str: 组件对应的文本表示。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return component.text
|
||||
elif isinstance(component, ImageComponent):
|
||||
return await self.process_image_component(component)
|
||||
return await self.process_image_component(
|
||||
component,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
)
|
||||
elif isinstance(component, EmojiComponent):
|
||||
return await self.process_emoji_component(component)
|
||||
return await self.process_emoji_component(
|
||||
component,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
)
|
||||
elif isinstance(component, AtComponent):
|
||||
return await self.process_at_component(component)
|
||||
elif isinstance(component, VoiceComponent):
|
||||
return await self.process_voice_component(component)
|
||||
return await self.process_voice_component(
|
||||
component,
|
||||
enable_voice_transcription=enable_voice_transcription,
|
||||
)
|
||||
elif isinstance(component, ReplyComponent):
|
||||
return await self.process_reply_component(component, id_content_map)
|
||||
elif isinstance(component, ForwardNodeComponent):
|
||||
return await self.process_forward_component(component, id_content_map, recursion_depth=recursion_depth + 1)
|
||||
return await self.process_forward_component(
|
||||
component,
|
||||
id_content_map,
|
||||
recursion_depth=recursion_depth + 1,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
enable_voice_transcription=enable_voice_transcription,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"暂时不支持的消息组件类型: {type(component)}")
|
||||
|
||||
async def process_image_component(self, component: ImageComponent) -> str:
|
||||
async def process_image_component(
|
||||
self,
|
||||
component: ImageComponent,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
) -> str:
|
||||
"""处理图片组件。
|
||||
|
||||
Args:
|
||||
component: 图片组件。
|
||||
enable_heavy_media_analysis: 是否同步执行图片描述生成。
|
||||
|
||||
Returns:
|
||||
str: 图片组件对应的文本表示。
|
||||
"""
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.chat.image_system.image_manager import image_manager
|
||||
|
||||
# 获取描述
|
||||
try:
|
||||
desc = await image_manager.get_image_description(image_bytes=component.binary_data)
|
||||
desc = await image_manager.get_image_description(
|
||||
image_bytes=component.binary_data,
|
||||
wait_for_build=enable_heavy_media_analysis,
|
||||
)
|
||||
except Exception:
|
||||
desc = None # 失败置空
|
||||
|
||||
content = f"[图片:{desc}]" if desc else "[一张图片,网卡了加载不出来]"
|
||||
content = f"[图片:{desc}]" if desc else "[图片]"
|
||||
component.content = content
|
||||
component.binary_data = b"" # 处理完就丢掉二进制数据,节省内存
|
||||
return content
|
||||
|
||||
async def process_emoji_component(self, component: EmojiComponent) -> str:
|
||||
async def process_emoji_component(
|
||||
self,
|
||||
component: EmojiComponent,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
) -> str:
|
||||
"""处理表情包组件。
|
||||
|
||||
Args:
|
||||
component: 表情包组件。
|
||||
enable_heavy_media_analysis: 是否同步执行表情包描述生成。
|
||||
|
||||
Returns:
|
||||
str: 表情包组件对应的文本表示。
|
||||
"""
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
# 获取表情包描述
|
||||
try:
|
||||
tuple_content = await emoji_manager.get_emoji_description(emoji_bytes=component.binary_data)
|
||||
tuple_content = await emoji_manager.get_emoji_description(
|
||||
emoji_bytes=component.binary_data,
|
||||
wait_for_build=enable_heavy_media_analysis,
|
||||
)
|
||||
except Exception:
|
||||
tuple_content = None # 失败置空
|
||||
|
||||
@@ -96,7 +185,7 @@ class SessionMessage(MaiMessage):
|
||||
desc, _ = tuple_content
|
||||
content = f"[表情包: {desc}]"
|
||||
else:
|
||||
content = "[一个表情,网卡了加载不出来]"
|
||||
content = "[表情包]"
|
||||
component.content = content
|
||||
component.binary_data = b"" # 处理完就丢掉二进制数据,节省内存
|
||||
return content
|
||||
@@ -124,9 +213,26 @@ class SessionMessage(MaiMessage):
|
||||
else: # 最后使用用户ID
|
||||
return f"@{component.target_user_id}"
|
||||
|
||||
async def process_voice_component(self, component: VoiceComponent) -> str:
|
||||
async def process_voice_component(
|
||||
self,
|
||||
component: VoiceComponent,
|
||||
*,
|
||||
enable_voice_transcription: bool = True,
|
||||
) -> str:
|
||||
"""处理语音组件。
|
||||
|
||||
Args:
|
||||
component: 语音组件。
|
||||
enable_voice_transcription: 是否同步执行语音转写。
|
||||
|
||||
Returns:
|
||||
str: 语音组件对应的文本表示。
|
||||
"""
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
if not enable_voice_transcription:
|
||||
component.content = "[语音消息]"
|
||||
return component.content
|
||||
from src.common.utils.utils_voice import get_voice_text
|
||||
|
||||
text = await get_voice_text(component.binary_data)
|
||||
@@ -169,13 +275,37 @@ class SessionMessage(MaiMessage):
|
||||
return "[回复了一条消息,但原消息已无法访问]"
|
||||
|
||||
async def process_forward_component(
|
||||
self, component: ForwardNodeComponent, id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
self,
|
||||
component: ForwardNodeComponent,
|
||||
id_content_map: MsgIDMapping,
|
||||
recursion_depth: int = 0,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
enable_voice_transcription: bool = True,
|
||||
) -> str:
|
||||
"""处理合并转发组件。
|
||||
|
||||
Args:
|
||||
component: 合并转发组件。
|
||||
id_content_map: 回复消息解析缓存。
|
||||
recursion_depth: 当前递归深度。
|
||||
enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。
|
||||
enable_voice_transcription: 是否同步执行语音转写。
|
||||
|
||||
Returns:
|
||||
str: 合并转发组件对应的文本表示。
|
||||
"""
|
||||
task_list: List[Task] = []
|
||||
node_user_info_list: List[UserInfo] = []
|
||||
for node in component.forward_components:
|
||||
task = asyncio.create_task(
|
||||
self._process_multiple_components(node.content, id_content_map, recursion_depth + 1)
|
||||
self._process_multiple_components(
|
||||
node.content,
|
||||
id_content_map,
|
||||
recursion_depth + 1,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
enable_voice_transcription=enable_voice_transcription,
|
||||
)
|
||||
)
|
||||
node_user_info = UserInfo(node.user_id or "未知用户", node.user_nickname, node.user_cardname)
|
||||
# 传入ID缓存映射,方便Reply组件获取并等待处理结果
|
||||
@@ -196,9 +326,36 @@ class SessionMessage(MaiMessage):
|
||||
return "【合并转发消息: \n" + "\n".join(forward_texts) + "\n】"
|
||||
|
||||
async def _process_multiple_components(
|
||||
self, components: Sequence[StandardMessageComponents], id_content_map: MsgIDMapping, recursion_depth: int = 0
|
||||
self,
|
||||
components: Sequence[StandardMessageComponents],
|
||||
id_content_map: MsgIDMapping,
|
||||
recursion_depth: int = 0,
|
||||
*,
|
||||
enable_heavy_media_analysis: bool = True,
|
||||
enable_voice_transcription: bool = True,
|
||||
) -> str:
|
||||
tasks = [self.process_single_component(component, id_content_map, recursion_depth) for component in components]
|
||||
"""并行处理多个消息组件。
|
||||
|
||||
Args:
|
||||
components: 待处理的组件序列。
|
||||
id_content_map: 回复消息解析缓存。
|
||||
recursion_depth: 当前递归深度。
|
||||
enable_heavy_media_analysis: 是否同步执行图片与表情包描述生成。
|
||||
enable_voice_transcription: 是否同步执行语音转写。
|
||||
|
||||
Returns:
|
||||
str: 多个组件拼接后的文本表示。
|
||||
"""
|
||||
tasks = [
|
||||
self.process_single_component(
|
||||
component,
|
||||
id_content_map,
|
||||
recursion_depth,
|
||||
enable_heavy_media_analysis=enable_heavy_media_analysis,
|
||||
enable_voice_transcription=enable_voice_transcription,
|
||||
)
|
||||
for component in components
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True) # 并行处理多个组件
|
||||
processed_texts: List[str] = []
|
||||
for result in results:
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
from src.plugin_runtime.component_query import ActionExecutor, component_query_service
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionHandle:
|
||||
"""Action 执行句柄
|
||||
|
||||
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
|
||||
brain_chat 调用 ``await handle.execute()`` 即可。
|
||||
"""
|
||||
|
||||
def __init__(self, executor: ActionExecutor, **kwargs):
|
||||
self._executor = executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
return await self._executor(**self._kwargs)
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
使用插件运行时统一查询服务的 executor-based 模式。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
def create_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
action_reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: BotChatSession,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[SessionMessage] = None,
|
||||
) -> Optional[ActionHandle]:
|
||||
"""
|
||||
创建动作执行句柄
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
action_reasoning: 执行理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
action_message: 动作消息记录
|
||||
|
||||
Returns:
|
||||
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||
"""
|
||||
try:
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
if not executor:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
info = component_query_service.get_action_info(action_name)
|
||||
if not info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
plugin_config = component_query_service.get_plugin_config(info.plugin_name) or {}
|
||||
|
||||
handle = ActionHandle(
|
||||
executor,
|
||||
action_data=action_data,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
plugin_config=plugin_config,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
logger.debug(f"创建Action执行句柄成功: {action_name}")
|
||||
return handle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
# === Modify相关方法 ===
|
||||
def remove_action_from_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
从当前使用的动作集中移除指定动作
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
bool: 移除是否成功
|
||||
"""
|
||||
if action_name not in self._using_actions:
|
||||
logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中")
|
||||
return False
|
||||
|
||||
del self._using_actions[action_name]
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_query_service.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
@@ -1,233 +0,0 @@
|
||||
import random
|
||||
import time
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
from src.core.types import ActionActivationType, ActionInfo
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionModifier:
|
||||
"""动作处理器
|
||||
|
||||
用于处理Observation对象和根据激活类型处理actions。
|
||||
集成了原有的modify_actions功能和新的激活类型处理功能。
|
||||
支持并行判定和智能缓存优化。
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||
"""初始化动作处理器"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
message_content: str = "",
|
||||
): # sourcery skip: use-named-expression
|
||||
"""
|
||||
动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
这个方法处理完整的动作管理流程:
|
||||
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
|
||||
2. 基于激活类型的智能动作判定,最终确定可用动作集
|
||||
|
||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
# removals_s3: List[Tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
message_list_before_now_half = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_stream.session_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
if message_content:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:去除用户自行禁用的 ===
|
||||
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||
if disabled_actions:
|
||||
for disabled_action_name in disabled_actions:
|
||||
if disabled_action_name in all_actions:
|
||||
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.context
|
||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||
|
||||
if type_mismatched_actions:
|
||||
removals_s2.extend(type_mismatched_actions)
|
||||
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
# if chat_content is not None:
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
reason = f"适配器不支持(需要: {associated_types_str})"
|
||||
type_mismatched_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
||||
return type_mismatched_actions
|
||||
|
||||
async def _get_deactivated_actions_by_type(
|
||||
self,
|
||||
actions_with_info: Dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> List[tuple[str, str]]:
|
||||
"""
|
||||
根据激活类型过滤,返回需要停用的动作列表及原因
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表
|
||||
"""
|
||||
deactivated_actions = []
|
||||
|
||||
actions_to_check = list(actions_with_info.items())
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
activation_type = action_info.activation_type or action_info.focus_activation_type
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.random_activation_probability
|
||||
if random.random() >= probability:
|
||||
reason = f"RANDOM类型未触发(概率{probability})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
if not self._check_keyword_activation(action_name, action_info, chat_content):
|
||||
keywords = action_info.activation_keywords
|
||||
reason = f"关键词未匹配(关键词: {keywords})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == ActionActivationType.NEVER:
|
||||
reason = "激活类型为never"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never")
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
return deactivated_actions
|
||||
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
activation_keywords = action_info.activation_keywords
|
||||
case_sensitive = action_info.keyword_case_sensitive
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
# 构建检索文本
|
||||
search_text = ""
|
||||
if chat_content:
|
||||
search_text += chat_content
|
||||
# if chat_context:
|
||||
# search_text += f" {chat_context}"
|
||||
# if extra_context:
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
@@ -1,933 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages_with_id,
|
||||
get_messages_before_time_in_chat,
|
||||
replace_user_references,
|
||||
translate_pid_to_description,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
) # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
# 黑话缓存:使用 OrderedDict 实现 LRU,最多缓存10个
|
||||
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
|
||||
self.unknown_words_cache_limit = 10
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional["SessionMessage"]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
Args:
|
||||
message_id: 要查找的消息ID
|
||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||
|
||||
Returns:
|
||||
找到的原始消息字典,如果未找到则返回None
|
||||
"""
|
||||
for item in message_id_list:
|
||||
if item[0] == message_id:
|
||||
return item[1]
|
||||
return None
|
||||
|
||||
def _replace_message_ids_with_text(
|
||||
self, text: Optional[str], message_id_list: List[Tuple[str, "SessionMessage"]]
|
||||
) -> Optional[str]:
|
||||
"""将文本中的 m+数字 消息ID替换为原消息内容,并添加双引号"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
id_to_message = dict(message_id_list)
|
||||
|
||||
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||
|
||||
matches = re.findall(pattern, text)
|
||||
if matches:
|
||||
available_ids = set(id_to_message.keys())
|
||||
found_ids = set(matches)
|
||||
missing_ids = found_ids - available_ids
|
||||
if missing_ids:
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
|
||||
)
|
||||
logger.info(
|
||||
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中"
|
||||
)
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
msg_id = match.group(0)
|
||||
message = id_to_message.get(msg_id)
|
||||
if not message:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
|
||||
return msg_id
|
||||
|
||||
msg_text = (message.processed_plain_text or "").strip()
|
||||
if not msg_text:
|
||||
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
|
||||
return msg_id
|
||||
|
||||
# 替换 [picid:xxx] 为 [图片:描述]
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(pic_match: re.Match) -> str:
|
||||
pic_id = pic_match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = message.platform or ""
|
||||
if not platform:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}planner: message {message.message_id} has no platform set, bot-self detection will be skipped"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
|
||||
# 这里匹配到的应该都是单独的格式
|
||||
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
|
||||
|
||||
def replace_user_ref(user_match: re.Match) -> str:
|
||||
user_name = user_match.group(1)
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if is_bot_self(platform, str(user_id)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
return user_name
|
||||
|
||||
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
|
||||
|
||||
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
|
||||
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview})")
|
||||
return f"消息({msg_text})"
|
||||
|
||||
return re.sub(pattern, _replace, text)
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
extracted_reasoning: str = "",
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_reply")
|
||||
# 使用 extracted_reasoning(整体推理文本)作为 reasoning
|
||||
if extracted_reasoning:
|
||||
reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list)
|
||||
if reasoning is None:
|
||||
reasoning = extracted_reasoning
|
||||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
target_message_id = action_json.get("target_message_id")
|
||||
if target_message_id:
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息")
|
||||
# 选择最新消息作为target_message
|
||||
target_message = message_id_list[-1][1]
|
||||
else:
|
||||
target_message = message_id_list[-1][1]
|
||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||
|
||||
if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
||||
)
|
||||
reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
||||
action = "no_reply"
|
||||
target_message = None
|
||||
|
||||
# 验证action是否可用
|
||||
available_action_names = [action_name for action_name, _ in current_available_actions]
|
||||
internal_action_names = ["no_reply", "reply", "wait_time"]
|
||||
|
||||
if action not in internal_action_names and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
|
||||
)
|
||||
reasoning = (
|
||||
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
)
|
||||
action = "no_reply"
|
||||
|
||||
# 创建ActionPlannerInfo对象
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
# 将列表转换为字典格式
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
action_planner_infos.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"解析单个action时出错: {e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
return action_planner_infos
|
||||
|
||||
def _is_message_from_self(self, message: "SessionMessage") -> bool:
|
||||
"""判断消息是否由机器人自身发送(支持多平台,包括 WebUI)"""
|
||||
try:
|
||||
return is_bot_self(message.platform or "", str(message.message_info.user_info.user_id))
|
||||
except AttributeError:
|
||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||
return False
|
||||
|
||||
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
|
||||
"""
|
||||
更新黑话缓存,将新的黑话加入缓存
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表
|
||||
"""
|
||||
for word in new_words:
|
||||
if not isinstance(word, str):
|
||||
continue
|
||||
word = word.strip()
|
||||
if not word:
|
||||
continue
|
||||
|
||||
# 如果已存在,移到末尾(LRU)
|
||||
if word in self.unknown_words_cache:
|
||||
self.unknown_words_cache.move_to_end(word)
|
||||
else:
|
||||
# 添加新词
|
||||
self.unknown_words_cache[word] = None
|
||||
# 如果超过限制,移除最老的
|
||||
if len(self.unknown_words_cache) > self.unknown_words_cache_limit:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}黑话缓存已满,移除最老的黑话")
|
||||
|
||||
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
|
||||
"""
|
||||
合并新提取的黑话和缓存中的黑话
|
||||
|
||||
Args:
|
||||
new_words: 新提取的黑话列表(可能为None)
|
||||
|
||||
Returns:
|
||||
合并后的黑话列表(去重)
|
||||
"""
|
||||
# 清理新提取的黑话
|
||||
cleaned_new_words: List[str] = []
|
||||
if new_words:
|
||||
for word in new_words:
|
||||
if isinstance(word, str):
|
||||
if word := word.strip():
|
||||
cleaned_new_words.append(word)
|
||||
|
||||
# 获取缓存中的黑话列表
|
||||
cached_words = list(self.unknown_words_cache.keys())
|
||||
|
||||
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
|
||||
merged_words: List[str] = []
|
||||
seen = set()
|
||||
|
||||
# 先添加新提取的
|
||||
for word in cleaned_new_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
# 再添加缓存的(如果不在新提取的列表中)
|
||||
for word in cached_words:
|
||||
if word not in seen:
|
||||
merged_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
return merged_words
|
||||
|
||||
def _process_unknown_words_cache(self, actions: List[ActionPlannerInfo]) -> None:
|
||||
"""
|
||||
处理黑话缓存逻辑:
|
||||
1. 检查是否有 reply action 提取了 unknown_words
|
||||
2. 如果没有提取,移除最老的1个
|
||||
3. 如果缓存数量大于5,移除最老的2个
|
||||
4. 对于每个 reply action,合并缓存和新提取的黑话
|
||||
5. 更新缓存
|
||||
|
||||
Args:
|
||||
actions: 解析后的动作列表
|
||||
"""
|
||||
# 先检查缓存数量,如果大于5,移除最老的2个
|
||||
if len(self.unknown_words_cache) > 5:
|
||||
# 移除最老的2个
|
||||
removed_count = 0
|
||||
for _ in range(2):
|
||||
if len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
removed_count += 1
|
||||
if removed_count > 0:
|
||||
logger.debug(f"{self.log_prefix}缓存数量大于5,移除最老的{removed_count}个缓存")
|
||||
|
||||
# 检查是否有 reply action 提取了 unknown_words
|
||||
has_extracted_unknown_words = False
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
unknown_words = action_data.get("unknown_words")
|
||||
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
|
||||
has_extracted_unknown_words = True
|
||||
break
|
||||
|
||||
# 如果当前 plan 的 reply 没有提取,移除最老的1个
|
||||
if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
|
||||
# 对于每个 reply action,合并缓存和新提取的黑话
|
||||
for action in actions:
|
||||
if action.action_type == "reply":
|
||||
action_data = action.action_data or {}
|
||||
new_words = action_data.get("unknown_words")
|
||||
|
||||
# 合并新提取的和缓存的黑话列表
|
||||
if merged_words := self._merge_unknown_words_with_cache(new_words):
|
||||
action_data["unknown_words"] = merged_words
|
||||
logger.debug(
|
||||
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
|
||||
f"缓存 {len(self.unknown_words_cache)} 个,合并后 {len(merged_words)} 个"
|
||||
)
|
||||
else:
|
||||
# 如果没有合并后的黑话,移除 unknown_words 字段
|
||||
action_data.pop("unknown_words", None)
|
||||
|
||||
# 更新缓存(将新提取的黑话加入缓存)
|
||||
if new_words:
|
||||
self._update_unknown_words_cache(new_words)
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
force_reply_message: Optional["SessionMessage"] = None,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_messages_before_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
filter_intercept_message_level=1,
|
||||
)
|
||||
message_id_list: list[Tuple[str, "SessionMessage"]] = []
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions=filtered_actions,
|
||||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
available_actions=available_actions,
|
||||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 如果有强制回复消息,确保回复该消息
|
||||
if force_reply_message:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = any(
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
for action in actions
|
||||
)
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
# 移除所有 no_reply action(如果有)
|
||||
actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
|
||||
# 创建强制回复 action
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
force_reply_action = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning="用户提及了我,必须回复该消息",
|
||||
action_data={"loop_start_time": loop_start_time},
|
||||
action_message=force_reply_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=None,
|
||||
)
|
||||
# 将强制回复 action 放在最前面
|
||||
actions.insert(0, force_reply_action)
|
||||
logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def add_plan_excute_log(self, result: str):
|
||||
self.plan_log.append(("", time.time(), result))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str:
|
||||
"""
|
||||
获取计划日志字符串
|
||||
|
||||
Args:
|
||||
max_action_records: 显示多少条最新的action记录,默认2
|
||||
max_execution_records: 显示多少条最新执行结果记录,默认8
|
||||
|
||||
Returns:
|
||||
格式化的日志字符串
|
||||
"""
|
||||
action_records = []
|
||||
execution_records = []
|
||||
|
||||
# 从后往前遍历,收集最新的记录
|
||||
for reasoning, timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
if len(action_records) < max_action_records:
|
||||
action_records.append((reasoning, timestamp, content, "action"))
|
||||
elif len(execution_records) < max_execution_records:
|
||||
execution_records.append((reasoning, timestamp, content, "execution"))
|
||||
|
||||
# 合并所有记录并按时间戳排序
|
||||
all_records = action_records + execution_records
|
||||
all_records.sort(key=lambda x: x[1]) # 按时间戳排序
|
||||
|
||||
plan_log_str = ""
|
||||
|
||||
# 按时间顺序添加所有记录
|
||||
for reasoning, timestamp, content, record_type in all_records:
|
||||
time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S")
|
||||
if record_type == "action":
|
||||
# plan_log_str += f"{time_str}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
|
||||
plan_log_str += f"{time_str}:{reasoning}\n"
|
||||
else:
|
||||
plan_log_str += f"{time_str}:你执行了action:{content}\n"
|
||||
|
||||
return plan_log_str
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional["TargetPersonInfo"],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
chat_content_block: str = "",
|
||||
interest: str = "",
|
||||
) -> tuple[str, List[Tuple[str, "SessionMessage"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
actions_before_now_block = self.get_plan_log_str()
|
||||
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
||||
# 其他信息
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 根据 think_mode 配置决定 reply action 的示例 JSON
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += (
|
||||
"5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", "unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
else:
|
||||
reply_action_example = (
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += (
|
||||
"6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
)
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
|
||||
planner_prompt_template = prompt_manager.get_prompt("planner")
|
||||
planner_prompt_template.add_context("time_block", time_block)
|
||||
planner_prompt_template.add_context("chat_context_description", chat_context_description)
|
||||
planner_prompt_template.add_context("chat_content_block", chat_content_block)
|
||||
planner_prompt_template.add_context("actions_before_now_block", actions_before_now_block)
|
||||
planner_prompt_template.add_context("action_options_text", action_options_block)
|
||||
planner_prompt_template.add_context("moderation_prompt", moderation_prompt_block)
|
||||
planner_prompt_template.add_context("name_block", name_block)
|
||||
planner_prompt_template.add_context("interest", interest)
|
||||
planner_prompt_template.add_context("plan_style", global_config.personality.plan_style)
|
||||
planner_prompt_template.add_context("reply_action_example", reply_action_example)
|
||||
prompt = await prompt_manager.render_prompt(planner_prompt_template)
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
def get_necessary_info(self) -> Tuple[bool, Optional["TargetPersonInfo"], Dict[str, ActionInfo]]:
|
||||
"""
|
||||
获取 Planner 需要的必要信息
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_query_service.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
return is_group_chat, chat_target_info, current_available_actions
|
||||
|
||||
def _filter_actions_by_activation_type(
|
||||
self, available_actions: Dict[str, ActionInfo], chat_content_block: str
|
||||
) -> Dict[str, ActionInfo]:
|
||||
"""根据激活类型过滤动作"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in available_actions.items():
|
||||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if action_info.activation_keywords:
|
||||
for keyword in action_info.activation_keywords:
|
||||
if keyword in chat_content_block:
|
||||
filtered_actions[action_name] = action_info
|
||||
break
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {action_info.activation_type},跳过处理")
|
||||
|
||||
return filtered_actions
|
||||
|
||||
async def _build_action_options_block(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
"""构建动作选项块"""
|
||||
if not current_available_actions:
|
||||
return ""
|
||||
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数文本
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n"
|
||||
for param_name, param_description in action_info.action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require)
|
||||
|
||||
parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)"
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt = prompt_manager.get_prompt("action")
|
||||
using_action_prompt.add_context("action_name", action_name)
|
||||
using_action_prompt.add_context("action_description", action_info.description)
|
||||
using_action_prompt.add_context("action_parameters", param_text)
|
||||
using_action_prompt.add_context("action_require", require_text)
|
||||
using_action_prompt.add_context("parallel_text", parallel_text)
|
||||
using_action_rendered_prompt = await prompt_manager.render_prompt(using_action_prompt)
|
||||
|
||||
action_options_block += using_action_rendered_prompt
|
||||
|
||||
return action_options_block
|
||||
|
||||
async def _execute_main_planner(
|
||||
self,
|
||||
prompt: str,
|
||||
message_id_list: List[Tuple[str, "SessionMessage"]],
|
||||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return (
|
||||
f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
],
|
||||
llm_content,
|
||||
llm_reasoning,
|
||||
llm_duration_ms,
|
||||
)
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
if llm_content:
|
||||
try:
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
extracted_reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list) or ""
|
||||
if json_objects:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(
|
||||
self._parse_single_action(
|
||||
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
extracted_reasoning = "LLM没有返回可用动作"
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
# 去重:如果同一个动作被选择了多次,随机选择其中一个
|
||||
if actions:
|
||||
shuffled = actions.copy()
|
||||
random.shuffle(shuffled)
|
||||
actions = list({a.action_type: a for a in shuffled}.values())
|
||||
|
||||
# 处理黑话缓存逻辑
|
||||
self._process_unknown_words_cache(actions)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=reasoning,
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
first_json_pos = len(content)
|
||||
if markdown_matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
# 处理```json包裹的JSON
|
||||
for match in markdown_matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```)
|
||||
if not json_objects:
|
||||
json_start_pos = content.find("```json")
|
||||
if json_start_pos != -1:
|
||||
# 找到```json之后的内容
|
||||
json_content_start = json_start_pos + 7 # ```json的长度
|
||||
# 提取从```json之后到内容结尾的所有内容
|
||||
incomplete_json_str = content[json_content_start:].strip()
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if json_start_pos > 0:
|
||||
reasoning_content = content[:json_start_pos].strip()
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
if incomplete_json_str:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
json_str = json_str.strip()
|
||||
|
||||
if json_str:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
try:
|
||||
json_obj = json.loads(repair_json(json_str))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
|
||||
|
||||
return json_objects, reasoning_content
|
||||
@@ -10,8 +10,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -56,14 +56,12 @@ class DefaultReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
@@ -397,6 +395,11 @@ class DefaultReplyer:
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
del chat_history
|
||||
del sender
|
||||
del target
|
||||
del enable_tool
|
||||
return ""
|
||||
"""构建工具信息块
|
||||
|
||||
Args:
|
||||
@@ -413,9 +416,7 @@ class DefaultReplyer:
|
||||
|
||||
try:
|
||||
# 使用工具执行器获取信息
|
||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||
sender=sender, target_message=target, chat_history=chat_history, return_details=False
|
||||
)
|
||||
tool_results = []
|
||||
|
||||
if tool_results:
|
||||
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
|
||||
@@ -576,29 +577,6 @@ class DefaultReplyer:
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
async def _build_disabled_jargon_explanation(self) -> str:
|
||||
"""当关闭黑话解释时使用的占位协程,避免额外的LLM调用"""
|
||||
return ""
|
||||
|
||||
async def _build_unknown_words_jargon(self, unknown_words: Optional[List[str]], chat_id: str) -> str:
|
||||
"""针对 Planner 提供的未知词语列表执行黑话检索"""
|
||||
if not unknown_words:
|
||||
return ""
|
||||
# 清洗未知词语列表,只保留非空字符串
|
||||
concepts: List[str] = []
|
||||
for item in unknown_words:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
concepts.append(s)
|
||||
if not concepts:
|
||||
return ""
|
||||
try:
|
||||
return await retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
except Exception as e:
|
||||
logger.error(f"未知词语黑话检索失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _build_jargon_explanation(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -608,19 +586,14 @@ class DefaultReplyer:
|
||||
) -> str:
|
||||
"""
|
||||
统一的黑话解释构建函数:
|
||||
- 根据 enable_jargon_explanation / jargon_mode 决定具体策略
|
||||
- 根据 enable_jargon_explanation 决定是否启用
|
||||
"""
|
||||
del unknown_words
|
||||
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
|
||||
if not enable_jargon_explanation:
|
||||
return ""
|
||||
|
||||
jargon_mode = getattr(global_config.expression, "jargon_mode", "context")
|
||||
|
||||
# planner 模式:仅使用 Planner 的 unknown_words
|
||||
if jargon_mode == "planner":
|
||||
return await self._build_unknown_words_jargon(unknown_words, chat_id)
|
||||
|
||||
# 默认 / context 模式:使用上下文自动匹配黑话
|
||||
# 使用上下文自动匹配黑话
|
||||
try:
|
||||
return await explain_jargon_in_context(chat_id, messages_short, chat_talking_prompt_short) or ""
|
||||
except Exception as e:
|
||||
@@ -1158,9 +1131,11 @@ class DefaultReplyer:
|
||||
# else:
|
||||
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
# 移除 content 前后的换行符和空格
|
||||
content = content.strip()
|
||||
@@ -1169,6 +1144,10 @@ class DefaultReplyer:
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
del message
|
||||
del sender
|
||||
del target
|
||||
return ""
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
try:
|
||||
@@ -1200,17 +1179,21 @@ class DefaultReplyer:
|
||||
template_prompt.add_context("sender", sender)
|
||||
template_prompt.add_context("target_message", target)
|
||||
prompt = await prompt_manager.render_prompt(template_prompt)
|
||||
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
generation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="utils",
|
||||
request_type="replyer.lpmm_knowledge",
|
||||
prompt=prompt,
|
||||
tool_options=[search_knowledge_tool.get_tool_definition()],
|
||||
)
|
||||
)
|
||||
tool_calls = generation_result.completion.tool_calls
|
||||
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0])
|
||||
result = None
|
||||
end_time = time.time()
|
||||
if not result or not result.get("content"):
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
|
||||
419
src/chat/replyer/maisaka_generator.py
Normal file
419
src/chat/replyer/maisaka_generator.py
Normal file
@@ -0,0 +1,419 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
ReplyGenerationResult,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExpressionRecord:
|
||||
"""表达方式的轻量记录。"""
|
||||
|
||||
expression_id: Optional[int]
|
||||
situation: str
|
||||
style: str
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设描述。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
if (
|
||||
hasattr(global_config.personality, "states")
|
||||
and global_config.personality.states
|
||||
and hasattr(global_config.personality, "state_probability")
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
if len(normalized) > limit:
|
||||
return normalized[:limit] + "..."
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _format_message_time(message: LLMContextMessage) -> str:
|
||||
return message.timestamp.strftime("%H:%M:%S")
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
|
||||
del message
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _format_chat_history(self, messages: List[LLMContextMessage]) -> str:
|
||||
"""格式化 replyer 使用的可见聊天记录。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
parts: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
timestamp = self._format_message_time(message)
|
||||
|
||||
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
|
||||
continue
|
||||
|
||||
if isinstance(message, SessionBackedMessage):
|
||||
guided_reply = self._extract_guided_bot_reply(message)
|
||||
if guided_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}")
|
||||
continue
|
||||
|
||||
raw_content = message.processed_plain_text
|
||||
for speaker_name, content_body in self._split_user_message_segments(raw_content):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
visible_speaker = speaker_name or global_config.maisaka.user_name.strip() or "User"
|
||||
parts.append(f"{timestamp} {visible_speaker}: {content}")
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
visible_reply = self._extract_visible_assistant_reply(message)
|
||||
if visible_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
formatted_history = self._format_chat_history(chat_history)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
"maidairy_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
|
||||
user_sections = [
|
||||
f"当前时间:{current_time}",
|
||||
f"【聊天记录】\n{formatted_history}",
|
||||
]
|
||||
if extra_sections:
|
||||
user_sections.append("\n\n".join(extra_sections))
|
||||
user_sections.append(f"【你的想法】\n{reply_reason}")
|
||||
user_sections.append("现在,你说:")
|
||||
|
||||
user_prompt = "\n\n".join(user_sections)
|
||||
return f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
) -> MaisakaReplyContext:
|
||||
"""在 replyer 内部构建表达习惯和黑话解释。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
expression_habits, selected_expression_ids = self._build_expression_habits(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=expression_habits,
|
||||
selected_expression_ids=selected_expression_ids,
|
||||
)
|
||||
|
||||
def _build_expression_habits(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> tuple[str, List[int]]:
|
||||
"""查询并格式化适合当前会话的表达习惯。"""
|
||||
del chat_history
|
||||
del reply_message
|
||||
del reply_reason
|
||||
|
||||
expression_records = self._load_expression_records(session_id)
|
||||
if not expression_records:
|
||||
return "", []
|
||||
|
||||
lines: List[str] = []
|
||||
selected_ids: List[int] = []
|
||||
for expression in expression_records:
|
||||
if expression.expression_id is not None:
|
||||
selected_ids.append(expression.expression_id)
|
||||
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
|
||||
|
||||
block = "【表达习惯参考】\n" + "\n".join(lines)
|
||||
logger.info(
|
||||
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
|
||||
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
)
|
||||
for expression in expressions
|
||||
]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[SessionMessage] = None,
|
||||
reply_time_point: Optional[float] = None,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del enable_tool
|
||||
del extra_info
|
||||
del from_plugin
|
||||
del log_reply
|
||||
del reply_time_point
|
||||
del think_level
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
if chat_history is None:
|
||||
result.error_message = "聊天历史为空"
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号="
|
||||
f"{reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
message
|
||||
for message in chat_history
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
# Validate that express_model is properly initialized
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return False, result
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
return False, result
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
return False, result
|
||||
|
||||
result.completion.request_prompt = prompt
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\nMaisaka 回复器提示词:\n{prompt}\n")
|
||||
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
return False, result
|
||||
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
result.completion = LLMCompletionResult(
|
||||
request_prompt=prompt,
|
||||
response_text=response_text,
|
||||
reasoning_text=generation_result.reasoning or "",
|
||||
model_name=generation_result.model_name or "",
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
|
||||
if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
if not result.success:
|
||||
result.error_message = "回复器返回了空内容"
|
||||
logger.warning("Maisaka 回复器返回了空内容")
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
result.text_fragments = [response_text]
|
||||
return True, result
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
@@ -52,15 +52,13 @@ class PrivateReplyer:
|
||||
chat_stream: 当前绑定的聊天会话。
|
||||
request_type: LLM 请求类型标识。
|
||||
"""
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer", request_type=request_type
|
||||
)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
@@ -290,6 +288,11 @@ class PrivateReplyer:
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
del chat_history
|
||||
del sender
|
||||
del target
|
||||
del enable_tool
|
||||
return ""
|
||||
"""构建工具信息块
|
||||
|
||||
Args:
|
||||
@@ -306,9 +309,7 @@ class PrivateReplyer:
|
||||
|
||||
try:
|
||||
# 使用工具执行器获取信息
|
||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||
sender=sender, target_message=target, chat_history=chat_history, return_details=False
|
||||
)
|
||||
tool_results = []
|
||||
|
||||
if tool_results:
|
||||
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
|
||||
@@ -997,9 +998,11 @@ class PrivateReplyer:
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
content = generation_result.response
|
||||
reasoning_content = generation_result.reasoning
|
||||
model_name = generation_result.model_name
|
||||
tool_calls = generation_result.tool_calls
|
||||
|
||||
content = content.strip()
|
||||
|
||||
|
||||
@@ -1,65 +1,82 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
class ReplyerManager:
|
||||
def __init__(self):
|
||||
self._repliers: Dict[str, DefaultReplyer | PrivateReplyer] = {}
|
||||
"""统一管理不同类型的回复生成器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._repliers: Dict[str, Any] = {}
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
"""
|
||||
获取或创建回复器实例。
|
||||
|
||||
model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。
|
||||
后续调用将返回已缓存的实例,忽略 model_configs 参数。
|
||||
"""
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取回复器。")
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer")
|
||||
return None
|
||||
|
||||
# 如果已有缓存实例,直接返回
|
||||
if stream_id in self._repliers:
|
||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
||||
return self._repliers[stream_id]
|
||||
cache_key = f"{replyer_type}:{stream_id}"
|
||||
if cache_key in self._repliers:
|
||||
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
|
||||
return self._repliers[cache_key]
|
||||
|
||||
# 如果没有缓存,则创建新实例(首次初始化)
|
||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
||||
|
||||
target_stream = chat_stream
|
||||
target_stream = chat_stream or _chat_manager.get_session_by_session_id(stream_id)
|
||||
if not target_stream:
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
|
||||
if not target_stream:
|
||||
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")
|
||||
logger.warning(f"[ReplyerManager] 未找到会话,stream_id={stream_id}")
|
||||
return None
|
||||
|
||||
# model_configs 只在此时(初始化时)生效
|
||||
if target_stream.is_group_session:
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
else:
|
||||
replyer = PrivateReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
logger.info(
|
||||
f"[ReplyerManager] 开始创建 replyer: cache_key={cache_key}, "
|
||||
f"replyer_type={replyer_type}, is_group_session={target_stream.is_group_session}"
|
||||
)
|
||||
|
||||
self._repliers[stream_id] = replyer
|
||||
try:
|
||||
if replyer_type == "maisaka":
|
||||
logger.info("[ReplyerManager] importing MaisakaReplyGenerator")
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
replyer = MaisakaReplyGenerator(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
elif target_stream.is_group_session:
|
||||
logger.info("[ReplyerManager] importing DefaultReplyer")
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
else:
|
||||
logger.info("[ReplyerManager] importing PrivateReplyer")
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
replyer = PrivateReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"[ReplyerManager] 创建 replyer 失败: cache_key={cache_key}")
|
||||
raise
|
||||
|
||||
self._repliers[cache_key] = replyer
|
||||
logger.info(f"[ReplyerManager] replyer 创建完成: cache_key={cache_key}")
|
||||
return replyer
|
||||
|
||||
|
||||
# 创建一个全局实例
|
||||
replyer_manager = ReplyerManager()
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
"""工具执行器。
|
||||
|
||||
独立的工具执行组件,可以直接输入聊天消息内容,
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache: Dict[str, dict] = {}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具"""
|
||||
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], ""
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
tools = self._get_tool_definitions()
|
||||
if not tools:
|
||||
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||
return [], [], ""
|
||||
|
||||
prompt_template = prompt_manager.get_prompt("tool_executor")
|
||||
prompt_template.add_context("target_message", target_message)
|
||||
prompt_template.add_context("chat_history", chat_history)
|
||||
prompt_template.add_context("sender", sender)
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_query_service.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用列表"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools: List[str] = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
tool_info["content"] = str(content)
|
||||
content_check = tool_info["content"]
|
||||
if (isinstance(content_check, str) and not content_check.strip()) or (
|
||||
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
||||
):
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
|
||||
continue
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
preview = str(content)[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用"""
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True
|
||||
|
||||
executor = component_query_service.get_tool_executor(function_name)
|
||||
if not executor:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
result = await executor(function_args)
|
||||
if result:
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具"""
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
# === 缓存方法 ===
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
del self.tool_cache[cache_key]
|
||||
return None
|
||||
cache_item["ttl"] -= 1
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
if not self.enable_cache:
|
||||
return
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
if not self.enable_cache:
|
||||
return
|
||||
expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0]
|
||||
for key in expired:
|
||||
del self.tool_cache[key]
|
||||
|
||||
def clear_cache(self):
|
||||
if self.enable_cache:
|
||||
self.tool_cache.clear()
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
self._cleanup_expired_cache()
|
||||
ttl_distribution: Dict[int, int] = {}
|
||||
for item in self.tool_cache.values():
|
||||
ttl = item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": len(self.tool_cache),
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
@@ -12,7 +12,7 @@ from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import OnlineTime, ModelUsage, Messages, ActionRecord
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
@@ -648,7 +648,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
def _collect_message_count_for_period(
|
||||
self,
|
||||
collect_period: list[tuple[str, datetime]],
|
||||
) -> StatPeriodMapping:
|
||||
) -> dict[str, dict[str, object]]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
@@ -659,8 +659,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats: StatPeriodMapping = {
|
||||
period_key: StatisticOutputTask._build_stat_period_data() for period_key, _ in collect_period
|
||||
stats: dict[str, dict[str, object]] = {
|
||||
period_key: {
|
||||
TOTAL_MSG_CNT: 0,
|
||||
MSG_CNT_BY_CHAT: defaultdict(int),
|
||||
TOTAL_REPLY_CNT: 0,
|
||||
}
|
||||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
@@ -710,24 +715,24 @@ class StatisticOutputTask(AsyncTask):
|
||||
StatisticOutputTask._add_defaultdict_int(stats[period_key], MSG_CNT_BY_CHAT, chat_id, 1)
|
||||
break
|
||||
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
# 使用 ToolRecord 中的 reply 工具次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1]
|
||||
tool_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||
actions = session.exec(statement).all()
|
||||
for action in actions:
|
||||
if action.action_name != "reply":
|
||||
statement = select(ToolRecord).where(col(ToolRecord.timestamp) >= tool_query_start_timestamp)
|
||||
tool_records = session.exec(statement).all()
|
||||
for tool_record in tool_records:
|
||||
if tool_record.tool_name != "reply":
|
||||
continue
|
||||
|
||||
action_time_ts = action.timestamp.timestamp()
|
||||
action_time_ts = tool_record.timestamp.timestamp()
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if action_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
StatisticOutputTask._add_int_stat(stats[period_key], TOTAL_REPLY_CNT, 1)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}")
|
||||
logger.warning(f"统计 reply 工具次数失败,将回复数视为 0,错误信息:{e}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ import jieba
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -235,10 +235,11 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
embedding_result = await llm.embed_text(text)
|
||||
embedding = embedding_result.embedding
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
|
||||
3
src/cli/__init__.py
Normal file
3
src/cli/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
CLI startup and interaction package.
|
||||
"""
|
||||
17
src/cli/console.py
Normal file
17
src/cli/console.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""MaiSaka terminal console helpers."""
|
||||
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
|
||||
custom_theme = Theme(
|
||||
{
|
||||
"info": "cyan",
|
||||
"success": "green",
|
||||
"warning": "yellow",
|
||||
"error": "bold red",
|
||||
"muted": "dim",
|
||||
"accent": "bold magenta",
|
||||
}
|
||||
)
|
||||
|
||||
console = Console(theme=custom_theme)
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
MaiSaka - 异步输入读取器
|
||||
将阻塞的标准输入读取放到后台线程中,供 asyncio 循环安全消费。
|
||||
MaiSaka asynchronous stdin reader for CLI interaction.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InputReader:
|
||||
463
src/cli/maisaka_cli.py
Normal file
463
src/cli/maisaka_cli.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
MaiSaka CLI and conversation loop.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from rich import box
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from src.know_u.knowledge import KnowledgeLearner, retrieve_relevant_knowledge
|
||||
from src.know_u.knowledge_store import get_knowledge_store
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.mcp_module import MCPManager
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.message_adapter import format_speaker_content
|
||||
from src.maisaka.tool_handlers import (
|
||||
ToolHandlerContext,
|
||||
handle_mcp_tool,
|
||||
handle_stop,
|
||||
handle_unknown_tool,
|
||||
handle_wait,
|
||||
)
|
||||
|
||||
from .console import console
|
||||
from .input_reader import InputReader
|
||||
|
||||
|
||||
class BufferCLI:
|
||||
"""Maisaka 命令行交互入口。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._chat_loop_service: Optional[MaisakaChatLoopService] = None
|
||||
self._reply_generator = MaisakaReplyGenerator()
|
||||
self._reader = InputReader()
|
||||
self._chat_history: Optional[list[LLMContextMessage]] = None
|
||||
self._knowledge_store = get_knowledge_store()
|
||||
self._knowledge_learner = KnowledgeLearner("maisaka_cli")
|
||||
self._knowledge_min_messages_for_extraction = 10
|
||||
self._knowledge_min_extraction_interval = 30
|
||||
self._last_knowledge_extraction_time = 0.0
|
||||
|
||||
knowledge_stats = self._knowledge_store.get_stats()
|
||||
if knowledge_stats["total_items"] > 0:
|
||||
console.print(f"[success]知识库中已有 {knowledge_stats['total_items']} 条数据[/success]")
|
||||
else:
|
||||
console.print("[muted]知识库已初始化,当前没有数据[/muted]")
|
||||
|
||||
self._chat_start_time: Optional[datetime] = None
|
||||
self._last_user_input_time: Optional[datetime] = None
|
||||
self._last_assistant_response_time: Optional[datetime] = None
|
||||
self._user_input_times: list[datetime] = []
|
||||
self._mcp_manager: Optional[MCPManager] = None
|
||||
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
|
||||
self._init_llm()
|
||||
|
||||
def _init_llm(self) -> None:
|
||||
"""初始化 Maisaka 使用的聊天服务。"""
|
||||
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
||||
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
|
||||
|
||||
_ = enable_thinking
|
||||
self._chat_loop_service = MaisakaChatLoopService()
|
||||
|
||||
model_name = self._get_current_model_name()
|
||||
console.print(f"[success]大模型服务已初始化[/success] [muted](模型: {model_name})[/muted]")
|
||||
|
||||
@staticmethod
|
||||
def _get_current_model_name() -> str:
|
||||
"""读取当前 planner 模型名。"""
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
if model_task_config.planner.model_list:
|
||||
return model_task_config.planner.model_list[0]
|
||||
except Exception:
|
||||
pass
|
||||
return "未配置"
|
||||
|
||||
def _build_tool_context(self) -> ToolHandlerContext:
|
||||
"""构建工具处理的共享上下文。"""
|
||||
tool_context = ToolHandlerContext(
|
||||
reader=self._reader,
|
||||
user_input_times=self._user_input_times,
|
||||
)
|
||||
tool_context.last_user_input_time = self._last_user_input_time
|
||||
return tool_context
|
||||
|
||||
def _show_banner(self) -> None:
|
||||
"""渲染启动横幅。"""
|
||||
banner = Text()
|
||||
banner.append("MaiSaka", style="bold cyan")
|
||||
banner.append(" v2.0\n", style="muted")
|
||||
banner.append("输入内容开始对话 | Ctrl+C 退出", style="muted")
|
||||
|
||||
console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2)))
|
||||
console.print()
|
||||
|
||||
async def _start_chat(self, user_text: str) -> None:
|
||||
"""追加用户输入并继续内部循环。"""
|
||||
if self._chat_loop_service is None:
|
||||
console.print("[warning]大模型服务尚未初始化,已跳过本次对话。[/warning]")
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
self._last_user_input_time = now
|
||||
self._user_input_times.append(now)
|
||||
|
||||
if self._chat_history is None:
|
||||
self._chat_start_time = now
|
||||
self._last_assistant_response_time = None
|
||||
self._chat_history = self._chat_loop_service.build_chat_context(user_text)
|
||||
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
|
||||
else:
|
||||
self._chat_history.append(
|
||||
self._build_cli_context_message(
|
||||
user_text=user_text,
|
||||
timestamp=now,
|
||||
source_kind="user",
|
||||
)
|
||||
)
|
||||
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
|
||||
|
||||
await self._run_llm_loop(self._chat_history)
|
||||
|
||||
@staticmethod
|
||||
def _build_cli_context_message(
|
||||
user_text: str,
|
||||
timestamp: datetime,
|
||||
source_kind: str = "user",
|
||||
speaker_name: Optional[str] = None,
|
||||
) -> SessionBackedMessage:
|
||||
"""为 CLI 构造新的上下文消息。"""
|
||||
resolved_speaker_name = speaker_name or global_config.maisaka.user_name.strip() or "用户"
|
||||
visible_text = format_speaker_content(
|
||||
resolved_speaker_name,
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
planner_prefix = (
|
||||
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
|
||||
f"[用户]{resolved_speaker_name}\n"
|
||||
"[用户群昵称]\n"
|
||||
"[msg_id]\n"
|
||||
"[发言内容]"
|
||||
)
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
|
||||
return SessionBackedMessage(
|
||||
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
|
||||
visible_text=visible_text,
|
||||
timestamp=timestamp,
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cli_session_message(user_text: str, timestamp: datetime) -> SessionMessage:
|
||||
"""为 CLI 的知识学习构造兼容 SessionMessage。"""
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
message = SessionMessage(message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}", timestamp=timestamp, platform="maisaka")
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id="maisaka_user",
|
||||
user_nickname=global_config.maisaka.user_name.strip() or "用户",
|
||||
user_cardname=None,
|
||||
),
|
||||
group_info=None,
|
||||
additional_config={},
|
||||
)
|
||||
message.session_id = "maisaka_cli"
|
||||
message.raw_message = MessageSequence([])
|
||||
visible_text = format_speaker_content(
|
||||
global_config.maisaka.user_name.strip() or "用户",
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
message.raw_message.text(visible_text)
|
||||
message.processed_plain_text = visible_text
|
||||
message.display_message = visible_text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""在 CLI 会话中按批次触发 knowledge 学习。"""
|
||||
if not global_config.maisaka.enable_knowledge_module:
|
||||
return
|
||||
|
||||
self._knowledge_learner.add_messages(messages)
|
||||
|
||||
elapsed = time.monotonic() - self._last_knowledge_extraction_time
|
||||
if elapsed < self._knowledge_min_extraction_interval:
|
||||
return
|
||||
|
||||
cache_size = self._knowledge_learner.get_cache_size()
|
||||
if cache_size < self._knowledge_min_messages_for_extraction:
|
||||
return
|
||||
|
||||
self._last_knowledge_extraction_time = time.monotonic()
|
||||
asyncio.create_task(self._run_knowledge_learning())
|
||||
|
||||
async def _run_knowledge_learning(self) -> None:
|
||||
"""后台执行 knowledge 学习,避免阻塞主对话。"""
|
||||
try:
|
||||
added_count = await self._knowledge_learner.learn()
|
||||
if added_count > 0 and global_config.maisaka.show_thinking:
|
||||
console.print(f"[muted]知识学习已完成,新增 {added_count} 条数据。[/muted]")
|
||||
except Exception as exc:
|
||||
console.print(f"[warning]知识学习失败:{exc}[/warning]")
|
||||
|
||||
async def _run_llm_loop(self, chat_history: list[LLMContextMessage]) -> None:
|
||||
"""
|
||||
Main inner loop for the Maisaka planner.
|
||||
|
||||
Each round may produce internal thoughts and optionally call tools:
|
||||
- reply(msg_id): generate a visible reply for the current round
|
||||
- no_reply(): skip visible output and continue the loop
|
||||
- wait(seconds): wait for new user input
|
||||
- stop(): stop the current inner loop and return to idle
|
||||
"""
|
||||
if self._chat_loop_service is None:
|
||||
return
|
||||
|
||||
consecutive_errors = 0
|
||||
last_had_tool_calls = True
|
||||
|
||||
while True:
|
||||
if last_had_tool_calls:
|
||||
tasks = []
|
||||
status_text_parts = []
|
||||
|
||||
if global_config.maisaka.enable_knowledge_module:
|
||||
tasks.append(("knowledge", retrieve_relevant_knowledge(self._chat_loop_service, chat_history)))
|
||||
status_text_parts.append("知识库")
|
||||
|
||||
with console.status(
|
||||
f"[info]{' + '.join(status_text_parts)} 分析中...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||||
|
||||
knowledge_analysis = ""
|
||||
if global_config.maisaka.enable_knowledge_module:
|
||||
knowledge_result = results[0] if results else None
|
||||
if isinstance(knowledge_result, Exception):
|
||||
console.print(f"[warning]知识分析失败:{knowledge_result}[/warning]")
|
||||
elif isinstance(knowledge_result, str) and knowledge_result.strip():
|
||||
knowledge_analysis = knowledge_result
|
||||
if global_config.maisaka.show_thinking:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(knowledge_analysis),
|
||||
title="知识",
|
||||
border_style="bright_magenta",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
if chat_history and isinstance(chat_history[-1], AssistantMessage) and chat_history[-1].source == "perception":
|
||||
chat_history.pop()
|
||||
|
||||
perception_parts = []
|
||||
if knowledge_analysis:
|
||||
perception_parts.append(f"知识库\n{knowledge_analysis}")
|
||||
|
||||
if perception_parts:
|
||||
chat_history.append(
|
||||
AssistantMessage(
|
||||
content="\n\n".join(perception_parts),
|
||||
timestamp=datetime.now(),
|
||||
source_kind="perception",
|
||||
)
|
||||
)
|
||||
elif global_config.maisaka.show_thinking:
|
||||
console.print("[muted]上一轮没有使用工具,本轮跳过模块分析。[/muted]")
|
||||
|
||||
with console.status("[info]正在思考...[/info]", spinner="dots"):
|
||||
try:
|
||||
response = await self._chat_loop_service.chat_loop_step(chat_history)
|
||||
consecutive_errors = 0
|
||||
except Exception as exc:
|
||||
consecutive_errors += 1
|
||||
console.print(f"[error]大模型调用失败:{exc}[/error]")
|
||||
if consecutive_errors >= 3:
|
||||
console.print("[error]连续失败次数过多,结束对话。[/error]\n")
|
||||
break
|
||||
continue
|
||||
|
||||
chat_history.append(response.raw_message)
|
||||
self._last_assistant_response_time = datetime.now()
|
||||
|
||||
if global_config.maisaka.show_thinking and response.content:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(response.content),
|
||||
title="思考",
|
||||
border_style="dim",
|
||||
padding=(1, 2),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
if response.content and not response.tool_calls:
|
||||
last_had_tool_calls = False
|
||||
continue
|
||||
|
||||
if not response.tool_calls:
|
||||
last_had_tool_calls = False
|
||||
continue
|
||||
|
||||
should_stop = False
|
||||
tool_context = self._build_tool_context()
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.func_name == "stop":
|
||||
await handle_stop(tool_call, chat_history)
|
||||
should_stop = True
|
||||
|
||||
elif tool_call.func_name == "reply":
|
||||
reply = await self._generate_visible_reply(chat_history, response.content or "")
|
||||
chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="已生成并记录可见回复。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
)
|
||||
)
|
||||
chat_history.append(
|
||||
self._build_cli_context_message(
|
||||
user_text=reply,
|
||||
timestamp=datetime.now(),
|
||||
source_kind="guided_reply",
|
||||
speaker_name=global_config.bot.nickname.strip() or "MaiSaka",
|
||||
)
|
||||
)
|
||||
|
||||
elif tool_call.func_name == "no_reply":
|
||||
if global_config.maisaka.show_thinking:
|
||||
console.print("[muted]本轮未发送可见回复。[/muted]")
|
||||
chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="本轮未发送可见回复。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
)
|
||||
)
|
||||
|
||||
elif tool_call.func_name == "wait":
|
||||
tool_result = await handle_wait(tool_call, chat_history, tool_context)
|
||||
if tool_context.last_user_input_time != self._last_user_input_time:
|
||||
self._last_user_input_time = tool_context.last_user_input_time
|
||||
if tool_result.startswith("[[QUIT]]"):
|
||||
should_stop = True
|
||||
|
||||
elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tool_call.func_name):
|
||||
await handle_mcp_tool(tool_call, chat_history, self._mcp_manager)
|
||||
|
||||
else:
|
||||
await handle_unknown_tool(tool_call, chat_history)
|
||||
|
||||
if should_stop:
|
||||
console.print("[muted]对话已暂停,等待新的输入...[/muted]\n")
|
||||
break
|
||||
|
||||
last_had_tool_calls = True
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""初始化 MCP 服务并注册暴露的工具。"""
|
||||
self._mcp_host_bridge = MCPHostLLMBridge(
|
||||
sampling_task_name=global_config.mcp.client.sampling.task_name,
|
||||
)
|
||||
self._mcp_manager = await MCPManager.from_app_config(
|
||||
global_config.mcp,
|
||||
host_callbacks=self._mcp_host_bridge.build_callbacks(),
|
||||
)
|
||||
|
||||
if self._mcp_manager and self._chat_loop_service:
|
||||
mcp_tools = self._mcp_manager.get_openai_tools()
|
||||
if mcp_tools:
|
||||
self._chat_loop_service.set_extra_tools(mcp_tools)
|
||||
summary = self._mcp_manager.get_feature_summary()
|
||||
console.print(
|
||||
Panel(
|
||||
f"已加载 {len(mcp_tools)} 个 MCP 工具。\n{summary}",
|
||||
title="MCP 能力",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
async def _generate_visible_reply(self, chat_history: list[LLMContextMessage], latest_thought: str) -> str:
|
||||
"""根据最新思考生成并输出可见回复。"""
|
||||
if not latest_thought:
|
||||
return ""
|
||||
|
||||
with console.status("[info]正在生成可见回复...[/info]", spinner="dots"):
|
||||
success, result = await self._reply_generator.generate_reply_with_context(
|
||||
reply_reason=latest_thought,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
if success and result.text_fragments:
|
||||
reply = result.text_fragments[0]
|
||||
else:
|
||||
reply = "..."
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(reply),
|
||||
title="MaiSaka",
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
async def run(self) -> None:
|
||||
"""主交互循环。"""
|
||||
if global_config.mcp.enable:
|
||||
await self._init_mcp()
|
||||
else:
|
||||
console.print("[muted]MCP 已禁用(mcp.enable=false)[/muted]")
|
||||
|
||||
self._reader.start(asyncio.get_event_loop())
|
||||
self._show_banner()
|
||||
|
||||
try:
|
||||
while True:
|
||||
console.print("[bold cyan]> [/bold cyan]", end="")
|
||||
raw_input = await self._reader.get_line()
|
||||
|
||||
if raw_input is None:
|
||||
console.print("\n[muted]再见![/muted]")
|
||||
break
|
||||
|
||||
raw_input = raw_input.strip()
|
||||
if not raw_input:
|
||||
continue
|
||||
|
||||
await self._start_chat(raw_input)
|
||||
finally:
|
||||
if self._mcp_manager:
|
||||
await self._mcp_manager.close()
|
||||
self._mcp_host_bridge = None
|
||||
187
src/common/data_models/llm_service_data_models.py
Normal file
187
src/common/data_models/llm_service_data_models.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""LLM 服务层与编排层共享数据模型。
|
||||
|
||||
该模块集中定义 LLM 服务层与底层编排器共同使用的请求、选项与结果对象,
|
||||
用于替代散落在各层之间的复杂元组返回值。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeAlias
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.data_models import BaseDataModel
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
|
||||
|
||||
PromptMessage: TypeAlias = Dict[str, Any]
|
||||
"""统一的原始提示消息结构。"""
|
||||
|
||||
PromptInput: TypeAlias = str | List[PromptMessage]
|
||||
"""统一的提示输入类型。"""
|
||||
|
||||
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
|
||||
"""统一的消息工厂类型。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMServiceRequest(BaseDataModel):
|
||||
"""LLM 服务层统一请求对象。"""
|
||||
|
||||
task_name: str
|
||||
request_type: str
|
||||
prompt: PromptInput | None = None
|
||||
message_factory: MessageFactory | None = None
|
||||
tool_options: List[ToolDefinitionInput] | None = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
response_format: RespFormat | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""校验请求对象的必要字段。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 `task_name` 为空,或 `prompt` 与 `message_factory`
|
||||
的组合非法时抛出。
|
||||
"""
|
||||
self.task_name = self.task_name.strip()
|
||||
if not self.task_name:
|
||||
raise ValueError("`task_name` 不能为空")
|
||||
has_prompt = self.prompt is not None
|
||||
has_message_factory = self.message_factory is not None
|
||||
if has_prompt == has_message_factory:
|
||||
raise ValueError("`prompt` 与 `message_factory` 必须且只能提供一个")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMResponseResult(BaseDataModel):
|
||||
"""单次 LLM 响应结果。"""
|
||||
|
||||
response: str = field(default_factory=str)
|
||||
reasoning: str = field(default_factory=str)
|
||||
model_name: str = field(default_factory=str)
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMServiceResult(BaseDataModel):
|
||||
"""LLM 服务层统一响应对象。"""
|
||||
|
||||
success: bool = False
|
||||
completion: LLMResponseResult = field(default_factory=LLMResponseResult)
|
||||
error: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_response_result(cls, completion: LLMResponseResult) -> "LLMServiceResult":
|
||||
"""从单次 LLM 响应结果构建服务响应。
|
||||
|
||||
Args:
|
||||
completion: 单次 LLM 响应结果。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 标记为成功的服务响应对象。
|
||||
"""
|
||||
return cls(
|
||||
success=True,
|
||||
completion=completion,
|
||||
error=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_error(cls, error_message: str, error_detail: str | None = None) -> "LLMServiceResult":
|
||||
"""构建失败的服务响应对象。
|
||||
|
||||
Args:
|
||||
error_message: 对上层展示的错误消息。
|
||||
error_detail: 底层错误详情。
|
||||
|
||||
Returns:
|
||||
LLMServiceResult: 标记为失败的服务响应对象。
|
||||
"""
|
||||
return cls(
|
||||
success=False,
|
||||
completion=LLMResponseResult(response=error_message),
|
||||
error=error_detail or error_message,
|
||||
)
|
||||
|
||||
def to_capability_payload(self) -> Dict[str, Any]:
|
||||
"""转换为插件能力层可直接返回的结构。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的能力返回值。
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"success": self.success,
|
||||
"response": self.completion.response,
|
||||
"reasoning": self.completion.reasoning,
|
||||
"model_name": self.completion.model_name,
|
||||
}
|
||||
if self.completion.tool_calls is not None:
|
||||
payload["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": tool_call.args or {},
|
||||
},
|
||||
}
|
||||
for tool_call in self.completion.tool_calls
|
||||
]
|
||||
if self.error:
|
||||
payload["error"] = self.error
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMGenerationOptions(BaseDataModel):
|
||||
"""LLM 文本生成选项。"""
|
||||
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
tool_options: List[ToolDefinitionInput] | None = None
|
||||
response_format: RespFormat | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
raise_when_empty: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMImageOptions(BaseDataModel):
|
||||
"""LLM 图像理解选项。"""
|
||||
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMAudioTranscriptionResult(BaseDataModel):
|
||||
"""LLM 音频转写结果。"""
|
||||
|
||||
text: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMEmbeddingResult(BaseDataModel):
|
||||
"""LLM 向量生成结果。"""
|
||||
|
||||
embedding: List[float] = field(default_factory=list)
|
||||
model_name: str = field(default_factory=str)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLMAudioTranscriptionResult",
|
||||
"LLMEmbeddingResult",
|
||||
"LLMGenerationOptions",
|
||||
"LLMImageOptions",
|
||||
"LLMResponseResult",
|
||||
"LLMServiceRequest",
|
||||
"LLMServiceResult",
|
||||
"MessageFactory",
|
||||
"PromptInput",
|
||||
"PromptMessage",
|
||||
]
|
||||
@@ -1,15 +1,17 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import (
|
||||
MessageBase,
|
||||
UserInfo as MaimUserInfo,
|
||||
GroupInfo as MaimGroupInfo,
|
||||
BaseMessageInfo as MaimBaseMessageInfo,
|
||||
Seg,
|
||||
)
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from maim_message import (
|
||||
BaseMessageInfo as MaimBaseMessageInfo,
|
||||
GroupInfo as MaimGroupInfo,
|
||||
MessageBase,
|
||||
ReceiverInfo as MaimReceiverInfo,
|
||||
Seg,
|
||||
SenderInfo as MaimSenderInfo,
|
||||
UserInfo as MaimUserInfo,
|
||||
)
|
||||
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
@@ -41,34 +43,24 @@ class MessageInfo:
|
||||
class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
def __init__(self, message_id: str, timestamp: datetime, platform: str):
|
||||
self.message_id: str = message_id
|
||||
self.timestamp: datetime = timestamp # 时间戳
|
||||
self.initialized = False # 用于标记是否已初始化其他属性
|
||||
self.timestamp: datetime = timestamp
|
||||
self.initialized = False
|
||||
self.platform: str = platform
|
||||
|
||||
# 定义其他属性
|
||||
self.message_info: MessageInfo # 初始化后赋值
|
||||
self.message_info: MessageInfo
|
||||
self.is_mentioned: bool = False
|
||||
"""机器人被提及标记,若被at,则提及也被标记"""
|
||||
self.is_at: bool = False
|
||||
"""机器人被at标记"""
|
||||
self.is_emoji: bool = False
|
||||
"""消息为纯表情包,在计算打字时长时候会被特殊处理"""
|
||||
self.is_picture: bool = False
|
||||
"""消息为纯图片,在计算打字时长时候会被特殊处理"""
|
||||
self.is_command: bool = False
|
||||
"""消息为命令消息,打字时长必定为0"""
|
||||
self.is_notify: bool = False
|
||||
"""消息为通知消息"""
|
||||
|
||||
self.session_id: str
|
||||
self.reply_to: Optional[str] = None
|
||||
|
||||
self.processed_plain_text: Optional[str] = None
|
||||
"""处理过后的纯文本内容"""
|
||||
self.display_message: Optional[str] = None
|
||||
"""最后显示给大模型的消息内容"""
|
||||
self.raw_message: MessageSequence
|
||||
"""原始消息数据"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: "Messages"):
|
||||
@@ -79,12 +71,12 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
group_info = GroupInfo(db_record.group_id, db_record.group_name)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
obj.message_info = MessageInfo(
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {},
|
||||
)
|
||||
|
||||
obj.is_mentioned = db_record.is_mentioned
|
||||
obj.is_at = db_record.is_at
|
||||
obj.is_emoji = db_record.is_emoji
|
||||
@@ -127,18 +119,22 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
|
||||
@classmethod
|
||||
def from_maim_message(cls, message: MessageBase):
|
||||
"""从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息"""
|
||||
"""从 maim_message.MessageBase 创建 MaiMessage。"""
|
||||
msg_info = message.message_info
|
||||
assert msg_info, "MessageBase 的 message_info 不能为空"
|
||||
|
||||
platform = msg_info.platform
|
||||
assert isinstance(platform, str)
|
||||
|
||||
msg_id = str(msg_info.message_id)
|
||||
timestamp = msg_info.time
|
||||
assert isinstance(msg_id, str)
|
||||
assert msg_id
|
||||
assert timestamp
|
||||
|
||||
obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp), platform=platform)
|
||||
obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message)
|
||||
|
||||
usr_info = msg_info.user_info
|
||||
assert usr_info
|
||||
assert isinstance(usr_info.user_id, str)
|
||||
@@ -148,40 +144,69 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
user_nickname=usr_info.user_nickname,
|
||||
user_cardname=usr_info.user_cardname,
|
||||
)
|
||||
if grp_info := msg_info.group_info:
|
||||
|
||||
if msg_info.group_info:
|
||||
grp_info = msg_info.group_info
|
||||
assert isinstance(grp_info.group_id, str)
|
||||
assert isinstance(grp_info.group_name, str)
|
||||
group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
add_cfg = msg_info.additional_config or {}
|
||||
obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg)
|
||||
return obj
|
||||
|
||||
async def to_maim_message(self) -> MessageBase:
|
||||
"""
|
||||
从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息
|
||||
"""
|
||||
maim_user_info = MaimUserInfo(
|
||||
"""将 MaiMessage 转换为 maim_message.MessageBase。"""
|
||||
sender_user_info = MaimUserInfo(
|
||||
user_id=self.message_info.user_info.user_id,
|
||||
user_nickname=self.message_info.user_info.user_nickname,
|
||||
user_cardname=self.message_info.user_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
maim_group_info = None
|
||||
|
||||
sender_group_info = None
|
||||
if self.message_info.group_info:
|
||||
maim_group_info = MaimGroupInfo(
|
||||
sender_group_info = MaimGroupInfo(
|
||||
group_id=self.message_info.group_info.group_id,
|
||||
group_name=self.message_info.group_info.group_name,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
sender_info = MaimSenderInfo(
|
||||
group_info=sender_group_info,
|
||||
user_info=sender_user_info,
|
||||
)
|
||||
|
||||
receiver_group_info = sender_group_info
|
||||
receiver_user_info = None
|
||||
additional_config = self.message_info.additional_config or {}
|
||||
target_user_id = str(additional_config.get("platform_io_target_user_id") or "").strip()
|
||||
if receiver_group_info is None and target_user_id:
|
||||
receiver_user_info = MaimUserInfo(
|
||||
user_id=target_user_id,
|
||||
user_nickname=None,
|
||||
user_cardname=None,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
receiver_info = None
|
||||
if receiver_group_info or receiver_user_info:
|
||||
receiver_info = MaimReceiverInfo(
|
||||
group_info=receiver_group_info,
|
||||
user_info=receiver_user_info,
|
||||
)
|
||||
|
||||
maim_msg_info = MaimBaseMessageInfo(
|
||||
platform=self.platform,
|
||||
message_id=self.message_id,
|
||||
time=self.timestamp.timestamp(),
|
||||
group_info=maim_group_info,
|
||||
user_info=maim_user_info,
|
||||
group_info=receiver_group_info,
|
||||
user_info=sender_user_info,
|
||||
additional_config=self.message_info.additional_config,
|
||||
sender_info=sender_info,
|
||||
receiver_info=receiver_info,
|
||||
)
|
||||
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
|
||||
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))
|
||||
|
||||
@@ -348,17 +348,11 @@ class MessageSequence:
|
||||
if isinstance(item, TextComponent):
|
||||
return {"type": "text", "data": item.text}
|
||||
elif isinstance(item, ImageComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("ImageComponent content 未初始化")
|
||||
return {"type": "image", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "image", "data": self._ensure_binary_component_content(item, "[图片]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, EmojiComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("EmojiComponent content 未初始化")
|
||||
return {"type": "emoji", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "emoji", "data": self._ensure_binary_component_content(item, "[表情包]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, VoiceComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("VoiceComponent content 未初始化")
|
||||
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "voice", "data": self._ensure_binary_component_content(item, "[语音消息]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
@@ -388,6 +382,14 @@ class MessageSequence:
|
||||
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
||||
return {"type": "dict", "data": item.data}
|
||||
|
||||
@staticmethod
|
||||
def _ensure_binary_component_content(item: ByteComponent, fallback_text: str) -> str:
|
||||
"""确保二进制组件在序列化时带有稳定的文本占位。"""
|
||||
if item.content:
|
||||
return item.content
|
||||
item.content = fallback_text
|
||||
return item.content
|
||||
|
||||
@classmethod
|
||||
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""内部方法:将单个消息组件的字典格式转换回组件对象"""
|
||||
|
||||
59
src/common/data_models/tool_record_data_model.py
Normal file
59
src/common/data_models/tool_record_data_model.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database_model import ToolRecord
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiToolRecord(BaseDatabaseDataModel[ToolRecord]):
|
||||
"""工具调用记录数据模型。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: str,
|
||||
timestamp: datetime,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_reasoning: Optional[str] = None,
|
||||
tool_data: Optional[Dict] = None,
|
||||
tool_builtin_prompt: Optional[str] = None,
|
||||
tool_display_prompt: Optional[str] = None,
|
||||
):
|
||||
self.tool_id = tool_id
|
||||
self.timestamp = timestamp
|
||||
self.session_id = session_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_reasoning = tool_reasoning
|
||||
self.tool_data = tool_data or {}
|
||||
self.tool_builtin_prompt = tool_builtin_prompt
|
||||
self.tool_display_prompt = tool_display_prompt
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: ToolRecord):
|
||||
"""从数据库实例创建数据模型对象。"""
|
||||
return cls(
|
||||
tool_id=db_record.tool_id,
|
||||
timestamp=db_record.timestamp,
|
||||
session_id=db_record.session_id,
|
||||
tool_name=db_record.tool_name,
|
||||
tool_reasoning=db_record.tool_reasoning,
|
||||
tool_data=json.loads(db_record.tool_data) if db_record.tool_data else None,
|
||||
tool_builtin_prompt=db_record.tool_builtin_prompt,
|
||||
tool_display_prompt=db_record.tool_display_prompt,
|
||||
)
|
||||
|
||||
def to_db_instance(self):
|
||||
"""将数据模型对象转换为数据库实例。"""
|
||||
return ToolRecord(
|
||||
tool_id=self.tool_id,
|
||||
timestamp=self.timestamp,
|
||||
session_id=self.session_id,
|
||||
tool_name=self.tool_name,
|
||||
tool_reasoning=self.tool_reasoning,
|
||||
tool_data=json.dumps(self.tool_data) if self.tool_data else None,
|
||||
tool_builtin_prompt=self.tool_builtin_prompt,
|
||||
tool_display_prompt=self.tool_display_prompt,
|
||||
)
|
||||
@@ -1,18 +1,23 @@
|
||||
from rich.traceback import install
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, TYPE_CHECKING
|
||||
from typing import ContextManager, Generator, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import event
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlite3 import Connection as SQLite3Connection
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
@@ -53,18 +58,70 @@ SessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
_db_initialized = False
|
||||
|
||||
|
||||
def _migrate_action_records_to_tool_records() -> None:
|
||||
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
|
||||
migration_sql = text(
|
||||
"""
|
||||
INSERT INTO tool_records (
|
||||
tool_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
tool_name,
|
||||
tool_reasoning,
|
||||
tool_data,
|
||||
tool_builtin_prompt,
|
||||
tool_display_prompt
|
||||
)
|
||||
SELECT
|
||||
action_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
action_name,
|
||||
action_reasoning,
|
||||
action_data,
|
||||
action_builtin_prompt,
|
||||
action_display_prompt
|
||||
FROM action_records
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tool_records
|
||||
WHERE tool_records.tool_id = action_records.action_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(migration_sql)
|
||||
|
||||
|
||||
def initialize_database() -> None:
|
||||
"""初始化数据库连接、结构与启动期迁移。
|
||||
|
||||
当前初始化流程遵循以下顺序:
|
||||
1. 确保数据库目录存在;
|
||||
2. 加载 SQLModel 模型定义;
|
||||
3. 执行已注册的启动期迁移;
|
||||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
||||
5. 执行项目现有的轻量数据补迁移逻辑。
|
||||
"""
|
||||
global _db_initialized
|
||||
if _db_initialized:
|
||||
return
|
||||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
migration_state = _migration_bootstrapper.prepare_database()
|
||||
logger.info(
|
||||
"数据库迁移准备完成,"
|
||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
_migrate_action_records_to_tool_records()
|
||||
_migration_bootstrapper.finalize_database(migration_state)
|
||||
_db_initialized = True
|
||||
|
||||
|
||||
@@ -114,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_db_session_manual():
|
||||
"""获取数据库会话的上下文管理器 (手动提交模式)。"""
|
||||
def get_db_session_manual() -> ContextManager[Session]:
|
||||
"""获取数据库会话的上下文管理器 (手动提交模式)。
|
||||
|
||||
Returns:
|
||||
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
|
||||
"""
|
||||
return get_db_session(auto_commit=False)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float
|
||||
from sqlalchemy import Column, DateTime, Enum as SQLEnum, Float, Text
|
||||
from sqlmodel import Field, LargeBinary, SQLModel
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ class ImageType(str, Enum):
|
||||
|
||||
|
||||
class ModifiedBy(str, Enum):
|
||||
AI = "ai"
|
||||
USER = "user"
|
||||
AI = "AI"
|
||||
USER = "USER"
|
||||
|
||||
|
||||
class Messages(SQLModel, table=True):
|
||||
@@ -134,6 +134,27 @@ class ActionRecord(SQLModel, table=True):
|
||||
action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容
|
||||
|
||||
|
||||
class ToolRecord(SQLModel, table=True):
|
||||
"""存储工具调用记录"""
|
||||
|
||||
__tablename__ = "tool_records" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
# 元信息
|
||||
tool_id: str = Field(index=True, max_length=255) # 工具调用ID
|
||||
timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳
|
||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
||||
|
||||
# 调用信息
|
||||
tool_name: str = Field(index=True, max_length=255) # 工具名称
|
||||
tool_reasoning: Optional[str] = Field(default=None) # 工具调用推理过程
|
||||
tool_data: Optional[str] = Field(default=None) # 工具数据,JSON格式存储
|
||||
|
||||
tool_builtin_prompt: Optional[str] = Field(default=None) # 内置工具提示
|
||||
tool_display_prompt: Optional[str] = Field(default=None) # 最终输入到 Prompt 的内容
|
||||
|
||||
|
||||
class CommandRecord(SQLModel, table=True):
|
||||
"""记录命令执行情况"""
|
||||
|
||||
@@ -202,18 +223,40 @@ class Jargon(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
content: str = Field(index=True, max_length=255) # 黑话内容
|
||||
raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str]
|
||||
raw_content: Optional[str] = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
) # 原始内容,未处理的黑话内容,为List[str]
|
||||
|
||||
meaning: str # 黑话含义
|
||||
session_id_dict: str = Field(default=r"{}") # 会话ID列表,格式为{"session_id": session_count, ...}
|
||||
meaning: str = Field(sa_column=Column(Text, nullable=False)) # 黑话含义
|
||||
session_id_dict: str = Field(
|
||||
default=r"{}", sa_column=Column(Text, nullable=False)
|
||||
) # 会话ID列表,格式为{"session_id": session_count, ...}
|
||||
|
||||
count: int = Field(default=0) # 使用次数
|
||||
is_jargon: Optional[bool] = Field(default=True) # 是否为黑话,False表示为白话
|
||||
is_complete: bool = Field(default=False) # 是否为已经完成全部推断(count > 100后不再推断)
|
||||
is_global: bool = Field(default=False) # 是否为全局黑话(独立于session_id_dict)
|
||||
last_inference_count: int = Field(default=0) # 上一次进行推断时的count值,用于判断是否需要重新推断
|
||||
inference_with_context: Optional[str] = Field(default=None, nullable=True) # 带上下文的推断结果,JSON格式
|
||||
inference_with_content_only: Optional[str] = Field(default=None, nullable=True) # 只基于词条的推断结果,JSON格式
|
||||
inference_with_context: Optional[str] = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
) # 带上下文的推断结果,JSON格式
|
||||
inference_with_content_only: Optional[str] = Field(
|
||||
default=None, sa_column=Column(Text, nullable=True)
|
||||
) # 只基于词条的推断结果,JSON格式
|
||||
|
||||
|
||||
class MaiKnowledge(SQLModel, table=True):
|
||||
"""存储 Maisaka 的用户画像知识。"""
|
||||
|
||||
__tablename__ = "mai_knowledge" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
knowledge_id: str = Field(index=True, max_length=255)
|
||||
category_id: str = Field(index=True, max_length=32)
|
||||
content: str
|
||||
normalized_content: str = Field(index=True)
|
||||
metadata_json: Optional[str] = Field(default=None, nullable=True)
|
||||
created_at: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True))
|
||||
|
||||
|
||||
class ChatHistory(SQLModel, table=True):
|
||||
|
||||
79
src/common/database/migrations/__init__.py
Normal file
79
src/common/database/migrations/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""数据库迁移基础设施导出模块。"""
|
||||
|
||||
from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper
|
||||
from .builtin import (
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
)
|
||||
from .exceptions import (
|
||||
DatabaseMigrationConfigurationError,
|
||||
DatabaseMigrationError,
|
||||
DatabaseMigrationExecutionError,
|
||||
DatabaseMigrationPlanningError,
|
||||
DatabaseMigrationVersionError,
|
||||
MissingMigrationStepError,
|
||||
UnrecognizedDatabaseSchemaError,
|
||||
UnsupportedMigrationDirectionError,
|
||||
)
|
||||
from .manager import DatabaseMigrationManager
|
||||
from .models import (
|
||||
ColumnSchema,
|
||||
DatabaseMigrationState,
|
||||
DatabaseSchemaSnapshot,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionSource,
|
||||
TableSchema,
|
||||
)
|
||||
from .planner import MigrationPlanner
|
||||
from .progress import (
|
||||
BaseMigrationProgressReporter,
|
||||
RichMigrationProgressReporter,
|
||||
create_rich_migration_progress_reporter,
|
||||
)
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
__all__ = [
|
||||
"BaseSchemaVersionDetector",
|
||||
"BaseMigrationProgressReporter",
|
||||
"build_default_migration_registry",
|
||||
"build_default_schema_version_resolver",
|
||||
"ColumnSchema",
|
||||
"create_database_migration_bootstrapper",
|
||||
"create_rich_migration_progress_reporter",
|
||||
"DatabaseMigrationConfigurationError",
|
||||
"DatabaseMigrationError",
|
||||
"DatabaseMigrationBootstrapper",
|
||||
"DatabaseMigrationExecutionError",
|
||||
"DatabaseMigrationManager",
|
||||
"DatabaseMigrationPlanningError",
|
||||
"DatabaseMigrationState",
|
||||
"DatabaseMigrationVersionError",
|
||||
"DatabaseSchemaSnapshot",
|
||||
"EMPTY_SCHEMA_VERSION",
|
||||
"LATEST_SCHEMA_VERSION",
|
||||
"LEGACY_V1_SCHEMA_VERSION",
|
||||
"MigrationExecutionContext",
|
||||
"MigrationPlan",
|
||||
"MigrationPlanner",
|
||||
"MigrationRegistry",
|
||||
"MigrationStep",
|
||||
"MissingMigrationStepError",
|
||||
"ResolvedSchemaVersion",
|
||||
"RichMigrationProgressReporter",
|
||||
"SchemaVersionResolver",
|
||||
"SchemaVersionSource",
|
||||
"SQLiteSchemaInspector",
|
||||
"SQLiteUserVersionStore",
|
||||
"TableSchema",
|
||||
"UnrecognizedDatabaseSchemaError",
|
||||
"UnsupportedMigrationDirectionError",
|
||||
]
|
||||
171
src/common/database/migrations/bootstrap.py
Normal file
171
src/common/database/migrations/bootstrap.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""数据库迁移启动桥接层。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .builtin import (
|
||||
LATEST_SCHEMA_VERSION,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
)
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .manager import DatabaseMigrationManager
|
||||
from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
|
||||
class DatabaseMigrationBootstrapper:
|
||||
"""数据库迁移启动桥接器。
|
||||
|
||||
该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束:
|
||||
1. 若数据库为空,则直接交给当前模型定义建出最新结构;
|
||||
2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动;
|
||||
3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移;
|
||||
4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: DatabaseMigrationManager,
|
||||
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||
) -> None:
|
||||
"""初始化数据库迁移启动桥接器。
|
||||
|
||||
Args:
|
||||
manager: 数据库迁移编排器。
|
||||
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||
"""
|
||||
self.manager = manager
|
||||
self.latest_schema_version = latest_schema_version
|
||||
|
||||
def prepare_database(self) -> DatabaseMigrationState:
|
||||
"""为数据库初始化阶段准备迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 迁移准备完成后的数据库状态。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。
|
||||
"""
|
||||
with self.manager.engine.connect() as connection:
|
||||
resolved_version = self.manager.resolver.resolve(connection)
|
||||
|
||||
if resolved_version.version > self.latest_schema_version:
|
||||
raise DatabaseMigrationExecutionError(
|
||||
"当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。"
|
||||
f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}"
|
||||
)
|
||||
|
||||
if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE:
|
||||
logger.info(
|
||||
"检测到空数据库,将直接根据当前模型创建最新结构。"
|
||||
f" 目标版本={self.latest_schema_version}"
|
||||
)
|
||||
return self._build_noop_state(
|
||||
current_version=resolved_version.version,
|
||||
target_version=self.latest_schema_version,
|
||||
resolved_state=resolved_version,
|
||||
)
|
||||
|
||||
migration_state = self.manager.describe_state(target_version=self.latest_schema_version)
|
||||
if not migration_state.requires_migration():
|
||||
logger.info(
|
||||
f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}"
|
||||
)
|
||||
return migration_state
|
||||
|
||||
logger.info(
|
||||
"检测到数据库需要迁移,"
|
||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||
)
|
||||
self.manager.migrate(target_version=self.latest_schema_version)
|
||||
return self.manager.describe_state(target_version=self.latest_schema_version)
|
||||
|
||||
def finalize_database(self, migration_state: DatabaseMigrationState) -> None:
|
||||
"""在数据库初始化末尾补写最终 schema 版本号。
|
||||
|
||||
该方法主要负责两类场景:
|
||||
1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本;
|
||||
2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。
|
||||
|
||||
Args:
|
||||
migration_state: 初始化前解析得到的迁移状态。
|
||||
"""
|
||||
if migration_state.requires_migration():
|
||||
return
|
||||
if migration_state.target_version <= 0:
|
||||
return
|
||||
if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA:
|
||||
return
|
||||
|
||||
with self.manager.engine.begin() as connection:
|
||||
self.manager.version_store.write_version(connection, migration_state.target_version)
|
||||
|
||||
logger.info(
|
||||
"数据库 schema 版本写入完成。"
|
||||
f" 来源={migration_state.resolved_version.source.value},"
|
||||
f" 写入版本={migration_state.target_version}"
|
||||
)
|
||||
|
||||
def _build_noop_state(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
resolved_state: ResolvedSchemaVersion,
|
||||
) -> DatabaseMigrationState:
|
||||
"""构建无迁移动作的数据库状态对象。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本号。
|
||||
target_version: 当前初始化流程期望达到的目标版本号。
|
||||
resolved_state: 已解析的数据库版本状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的状态对象。
|
||||
"""
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=resolved_state,
|
||||
target_version=target_version,
|
||||
plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]),
|
||||
)
|
||||
|
||||
|
||||
def create_database_migration_bootstrapper(
|
||||
engine: Engine,
|
||||
registry: Optional[MigrationRegistry] = None,
|
||||
resolver: Optional[SchemaVersionResolver] = None,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||
) -> DatabaseMigrationBootstrapper:
|
||||
"""创建数据库迁移启动桥接器。
|
||||
|
||||
Args:
|
||||
engine: 目标数据库引擎。
|
||||
registry: 迁移步骤注册表;未提供时使用默认注册表。
|
||||
resolver: 数据库版本解析器;未提供时使用默认解析器。
|
||||
version_store: 版本存储器;未提供时使用默认存储器。
|
||||
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。
|
||||
"""
|
||||
migration_registry = registry or build_default_migration_registry()
|
||||
migration_resolver = resolver or build_default_schema_version_resolver()
|
||||
migration_version_store = version_store or SQLiteUserVersionStore()
|
||||
migration_manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=migration_registry,
|
||||
resolver=migration_resolver,
|
||||
version_store=migration_version_store,
|
||||
)
|
||||
return DatabaseMigrationBootstrapper(
|
||||
manager=migration_manager,
|
||||
latest_schema_version=latest_schema_version,
|
||||
)
|
||||
159
src/common/database/migrations/builtin.py
Normal file
159
src/common/database/migrations/builtin.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""数据库迁移内置版本与默认注册表。"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
|
||||
from .models import DatabaseSchemaSnapshot, MigrationStep
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
from .schema import SQLiteSchemaInspector
|
||||
|
||||
EMPTY_SCHEMA_VERSION = 0
|
||||
LEGACY_V1_SCHEMA_VERSION = 1
|
||||
LATEST_SCHEMA_VERSION = 2
|
||||
|
||||
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||
"chat_streams",
|
||||
"emoji",
|
||||
"emoji_description_cache",
|
||||
"expression",
|
||||
"group_info",
|
||||
"image_descriptions",
|
||||
"jargon",
|
||||
"messages",
|
||||
"thinking_back",
|
||||
)
|
||||
|
||||
|
||||
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
"""当前最新 schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
return "latest_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否已经是当前最新结构。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
||||
"""
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return None
|
||||
|
||||
latest_marker_tables = (
|
||||
"mai_messages",
|
||||
"chat_sessions",
|
||||
"expressions",
|
||||
"jargons",
|
||||
"thinking_questions",
|
||||
"tool_records",
|
||||
)
|
||||
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_hash"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "full_path"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_type"):
|
||||
return None
|
||||
if not snapshot.has_column("action_records", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("chat_history", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
return LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||
"""旧版 ``0.x`` schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
return "legacy_v1_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否为旧版 ``0.x`` 结构。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
||||
"""
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return LEGACY_V1_SCHEMA_VERSION
|
||||
|
||||
legacy_shared_markers = (
|
||||
("action_records", ("chat_id", "time")),
|
||||
("chat_history", ("chat_id", "original_text")),
|
||||
("images", ("emoji_hash", "path", "type")),
|
||||
("llm_usage", ("model_api_provider", "status")),
|
||||
("online_time", ("duration",)),
|
||||
("person_info", ("nickname", "group_nick_name")),
|
||||
)
|
||||
for table_name, required_columns in legacy_shared_markers:
|
||||
if snapshot.has_table(table_name) and all(
|
||||
snapshot.has_column(table_name, column_name) for column_name in required_columns
|
||||
):
|
||||
return LEGACY_V1_SCHEMA_VERSION
|
||||
return None
|
||||
|
||||
|
||||
def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
||||
"""构建默认 schema 版本探测器链。
|
||||
|
||||
Returns:
|
||||
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
||||
"""
|
||||
return [
|
||||
LatestSchemaVersionDetector(),
|
||||
LegacyV1SchemaDetector(),
|
||||
]
|
||||
|
||||
|
||||
def build_default_schema_version_resolver() -> SchemaVersionResolver:
|
||||
"""构建默认 schema 版本解析器。
|
||||
|
||||
Returns:
|
||||
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
||||
"""
|
||||
return SchemaVersionResolver(
|
||||
version_store=SQLiteUserVersionStore(),
|
||||
schema_inspector=SQLiteSchemaInspector(),
|
||||
detectors=build_default_schema_version_detectors(),
|
||||
)
|
||||
|
||||
|
||||
def build_default_migration_registry() -> MigrationRegistry:
|
||||
"""构建默认迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
||||
"""
|
||||
return MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=LEGACY_V1_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
name="legacy_v1_to_latest_v2",
|
||||
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
|
||||
handler=migrate_legacy_v1_to_v2,
|
||||
)
|
||||
]
|
||||
)
|
||||
33
src/common/database/migrations/exceptions.py
Normal file
33
src/common/database/migrations/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""数据库迁移基础设施异常定义。"""
|
||||
|
||||
|
||||
class DatabaseMigrationError(Exception):
|
||||
"""数据库迁移基础异常。"""
|
||||
|
||||
|
||||
class DatabaseMigrationConfigurationError(DatabaseMigrationError):
|
||||
"""数据库迁移配置不合法。"""
|
||||
|
||||
|
||||
class DatabaseMigrationPlanningError(DatabaseMigrationError):
|
||||
"""数据库迁移计划生成失败。"""
|
||||
|
||||
|
||||
class DatabaseMigrationExecutionError(DatabaseMigrationError):
|
||||
"""数据库迁移执行失败。"""
|
||||
|
||||
|
||||
class DatabaseMigrationVersionError(DatabaseMigrationError):
|
||||
"""数据库版本读写或校验失败。"""
|
||||
|
||||
|
||||
class MissingMigrationStepError(DatabaseMigrationPlanningError):
|
||||
"""缺少某个版本区间所需的迁移步骤。"""
|
||||
|
||||
|
||||
class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError):
|
||||
"""当前迁移方向不被支持。"""
|
||||
|
||||
|
||||
class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError):
|
||||
"""无法识别未标记版本数据库的结构。"""
|
||||
1489
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
1489
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/common/database/migrations/manager.py
Normal file
205
src/common/database/migrations/manager.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""数据库迁移编排器。"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan
|
||||
from .planner import MigrationPlanner
|
||||
from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
|
||||
class DatabaseMigrationManager:
|
||||
"""数据库迁移编排器。
|
||||
|
||||
该类只负责基础设施层面的编排工作,包括:
|
||||
1. 解析当前数据库版本;
|
||||
2. 生成迁移计划;
|
||||
3. 顺序执行已注册迁移步骤;
|
||||
4. 在每一步成功后更新 ``user_version``。
|
||||
|
||||
当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
registry: Optional[MigrationRegistry] = None,
|
||||
planner: Optional[MigrationPlanner] = None,
|
||||
resolver: Optional[SchemaVersionResolver] = None,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None,
|
||||
) -> None:
|
||||
"""初始化数据库迁移编排器。
|
||||
|
||||
Args:
|
||||
engine: 目标数据库引擎。
|
||||
registry: 迁移步骤注册表。
|
||||
planner: 迁移计划生成器。
|
||||
resolver: 数据库版本解析器。
|
||||
version_store: 版本存储器。
|
||||
progress_reporter_factory: 迁移进度上报器工厂。
|
||||
"""
|
||||
self.engine = engine
|
||||
self.registry = registry or MigrationRegistry()
|
||||
self.planner = planner or MigrationPlanner()
|
||||
self.resolver = resolver or SchemaVersionResolver()
|
||||
self.version_store = version_store or SQLiteUserVersionStore()
|
||||
self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter
|
||||
|
||||
def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState:
|
||||
"""描述当前数据库的迁移状态。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 当前数据库迁移状态。
|
||||
"""
|
||||
with self.engine.connect() as connection:
|
||||
resolved_version = self.resolver.resolve(connection)
|
||||
|
||||
effective_target_version = self._resolve_target_version(target_version)
|
||||
migration_plan = self.planner.plan(
|
||||
current_version=resolved_version.version,
|
||||
target_version=effective_target_version,
|
||||
registry=self.registry,
|
||||
)
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=resolved_version,
|
||||
target_version=effective_target_version,
|
||||
plan=migration_plan,
|
||||
)
|
||||
|
||||
def plan(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||
"""生成当前数据库的迁移计划。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 当前数据库对应的迁移计划。
|
||||
"""
|
||||
return self.describe_state(target_version=target_version).plan
|
||||
|
||||
def migrate(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||
"""执行迁移计划。
|
||||
|
||||
注意:
|
||||
若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该
|
||||
版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式
|
||||
副作用。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 已执行的迁移计划。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。
|
||||
"""
|
||||
migration_state = self.describe_state(target_version=target_version)
|
||||
migration_plan = migration_state.plan
|
||||
if migration_plan.is_empty():
|
||||
logger.info("数据库迁移计划为空,跳过执行。")
|
||||
return migration_plan
|
||||
|
||||
current_version = migration_state.resolved_version.version
|
||||
total_steps = migration_plan.step_count()
|
||||
for step_index, step in enumerate(migration_plan.steps, start=1):
|
||||
logger.info(
|
||||
f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})"
|
||||
)
|
||||
try:
|
||||
with self.progress_reporter_factory() as progress_reporter:
|
||||
if step.transactional:
|
||||
with self.engine.begin() as connection:
|
||||
execution_context = self._build_execution_context(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
migration_plan=migration_plan,
|
||||
step_index=step_index,
|
||||
step_name=step.name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
step.run(execution_context)
|
||||
self.version_store.write_version(connection, step.version_to)
|
||||
else:
|
||||
with self.engine.connect() as connection:
|
||||
execution_context = self._build_execution_context(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
migration_plan=migration_plan,
|
||||
step_index=step_index,
|
||||
step_name=step.name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
step.run(execution_context)
|
||||
self.version_store.write_version(connection, step.version_to)
|
||||
connection.commit()
|
||||
except Exception as exc:
|
||||
raise DatabaseMigrationExecutionError(
|
||||
f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。"
|
||||
) from exc
|
||||
current_version = step.version_to
|
||||
logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}")
|
||||
|
||||
return migration_plan
|
||||
|
||||
def _resolve_target_version(self, target_version: Optional[int]) -> int:
|
||||
"""解析最终目标版本号。
|
||||
|
||||
Args:
|
||||
target_version: 调用方显式指定的目标版本。
|
||||
|
||||
Returns:
|
||||
int: 最终用于规划和执行的目标版本号。
|
||||
"""
|
||||
if target_version is not None:
|
||||
return target_version
|
||||
return self.registry.latest_version()
|
||||
|
||||
def _build_execution_context(
|
||||
self,
|
||||
connection: Connection,
|
||||
current_version: int,
|
||||
migration_plan: MigrationPlan,
|
||||
step_index: int,
|
||||
step_name: str,
|
||||
total_steps: int,
|
||||
progress_reporter: BaseMigrationProgressReporter,
|
||||
) -> MigrationExecutionContext:
|
||||
"""构建单个迁移步骤的执行上下文。
|
||||
|
||||
Args:
|
||||
connection: 当前迁移步骤使用的数据库连接。
|
||||
current_version: 当前数据库版本。
|
||||
migration_plan: 当前迁移计划。
|
||||
step_index: 当前步骤序号,从 ``1`` 开始。
|
||||
step_name: 当前步骤名称。
|
||||
total_steps: 计划总步骤数。
|
||||
progress_reporter: 当前步骤使用的进度上报器。
|
||||
|
||||
Returns:
|
||||
MigrationExecutionContext: 当前步骤的执行上下文对象。
|
||||
"""
|
||||
return MigrationExecutionContext(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
target_version=migration_plan.target_version,
|
||||
step_index=step_index,
|
||||
step_name=step_name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
305
src/common/database/migrations/models.py
Normal file
305
src/common/database/migrations/models.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""数据库迁移基础设施核心数据模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .progress import BaseMigrationProgressReporter
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""返回当前 UTC 时间。
|
||||
|
||||
Returns:
|
||||
datetime: 当前 UTC 时间。
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class SchemaVersionSource(str, Enum):
|
||||
"""数据库版本来源。"""
|
||||
|
||||
PRAGMA = "pragma"
|
||||
DETECTOR = "detector"
|
||||
EMPTY_DATABASE = "empty_database"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ColumnSchema:
|
||||
"""数据库列结构快照。"""
|
||||
|
||||
name: str
|
||||
declared_type: str
|
||||
default_value: Optional[str]
|
||||
is_not_null: bool
|
||||
primary_key_position: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TableSchema:
|
||||
"""数据库表结构快照。"""
|
||||
|
||||
name: str
|
||||
columns: Dict[str, ColumnSchema]
|
||||
|
||||
def has_column(self, column_name: str) -> bool:
|
||||
"""判断表中是否存在指定列。
|
||||
|
||||
Args:
|
||||
column_name: 待检查的列名。
|
||||
|
||||
Returns:
|
||||
bool: 若列存在则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
return column_name in self.columns
|
||||
|
||||
def get_column(self, column_name: str) -> Optional[ColumnSchema]:
|
||||
"""获取指定列的结构信息。
|
||||
|
||||
Args:
|
||||
column_name: 待获取的列名。
|
||||
|
||||
Returns:
|
||||
Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。
|
||||
"""
|
||||
return self.columns.get(column_name)
|
||||
|
||||
def column_names(self) -> List[str]:
|
||||
"""返回当前表中全部列名。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的列名列表。
|
||||
"""
|
||||
return sorted(self.columns)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseSchemaSnapshot:
|
||||
"""数据库结构快照。"""
|
||||
|
||||
tables: Dict[str, TableSchema]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""判断数据库是否没有任何用户表。
|
||||
|
||||
Returns:
|
||||
bool: 若数据库中没有用户表则返回 ``True``。
|
||||
"""
|
||||
return not self.tables
|
||||
|
||||
def has_table(self, table_name: str) -> bool:
|
||||
"""判断数据库是否存在指定表。
|
||||
|
||||
Args:
|
||||
table_name: 待检查的表名。
|
||||
|
||||
Returns:
|
||||
bool: 若表存在则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
return table_name in self.tables
|
||||
|
||||
def has_column(self, table_name: str, column_name: str) -> bool:
|
||||
"""判断数据库指定表中是否存在指定列。
|
||||
|
||||
Args:
|
||||
table_name: 待检查的表名。
|
||||
column_name: 待检查的列名。
|
||||
|
||||
Returns:
|
||||
bool: 若表和列均存在则返回 ``True``。
|
||||
"""
|
||||
table_schema = self.get_table(table_name)
|
||||
if table_schema is None:
|
||||
return False
|
||||
return table_schema.has_column(column_name)
|
||||
|
||||
def get_table(self, table_name: str) -> Optional[TableSchema]:
|
||||
"""获取指定表的结构信息。
|
||||
|
||||
Args:
|
||||
table_name: 待获取的表名。
|
||||
|
||||
Returns:
|
||||
Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。
|
||||
"""
|
||||
return self.tables.get(table_name)
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
"""返回当前数据库中的全部用户表名。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的表名列表。
|
||||
"""
|
||||
return sorted(self.tables)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedSchemaVersion:
|
||||
"""解析后的数据库版本信息。"""
|
||||
|
||||
version: int
|
||||
source: SchemaVersionSource
|
||||
detector_name: Optional[str] = None
|
||||
snapshot: Optional[DatabaseSchemaSnapshot] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationExecutionContext:
|
||||
"""单个迁移步骤的执行上下文。"""
|
||||
|
||||
connection: Connection
|
||||
current_version: int
|
||||
target_version: int
|
||||
step_index: int
|
||||
step_name: str
|
||||
total_steps: int
|
||||
started_at: datetime = field(default_factory=_utc_now)
|
||||
progress_reporter: Optional["BaseMigrationProgressReporter"] = None
|
||||
|
||||
def is_last_step(self) -> bool:
|
||||
"""判断当前步骤是否为最后一步。
|
||||
|
||||
Returns:
|
||||
bool: 若当前步骤已是计划中的最后一步则返回 ``True``。
|
||||
"""
|
||||
return self.step_index >= self.total_steps
|
||||
|
||||
def start_progress(
|
||||
self,
|
||||
total_tables: int,
|
||||
total_records: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""启动当前迁移步骤的进度展示。
|
||||
|
||||
Args:
|
||||
total_tables: 当前步骤需要处理的总表数。
|
||||
total_records: 当前步骤需要处理的总记录数。
|
||||
description: 进度描述文本。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
if self.progress_reporter is None:
|
||||
return
|
||||
self.progress_reporter.start(
|
||||
total_records=total_records,
|
||||
total_tables=total_tables,
|
||||
description=description,
|
||||
table_unit_name=table_unit_name,
|
||||
record_unit_name=record_unit_name,
|
||||
)
|
||||
|
||||
def advance_progress(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""推进当前迁移步骤的进度展示。
|
||||
|
||||
Args:
|
||||
records: 本次推进的记录数。
|
||||
completed_tables: 本次完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
if self.progress_reporter is None:
|
||||
return
|
||||
self.progress_reporter.advance(
|
||||
records=records,
|
||||
completed_tables=completed_tables,
|
||||
item_name=item_name,
|
||||
)
|
||||
|
||||
|
||||
MigrationHandler = Callable[[MigrationExecutionContext], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationStep:
|
||||
"""单个数据库迁移步骤定义。"""
|
||||
|
||||
version_from: int
|
||||
version_to: int
|
||||
name: str
|
||||
description: str
|
||||
handler: MigrationHandler
|
||||
transactional: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""校验迁移步骤定义是否合法。
|
||||
|
||||
Raises:
|
||||
ValueError: 当版本号不合法或迁移方向错误时抛出。
|
||||
"""
|
||||
if self.version_from < 0:
|
||||
raise ValueError("迁移起始版本不能小于 0。")
|
||||
if self.version_to <= self.version_from:
|
||||
raise ValueError("迁移目标版本必须大于起始版本。")
|
||||
|
||||
def run(self, context: MigrationExecutionContext) -> None:
|
||||
"""执行当前迁移步骤。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤的执行上下文。
|
||||
"""
|
||||
self.handler(context)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationPlan:
|
||||
"""数据库迁移执行计划。"""
|
||||
|
||||
current_version: int
|
||||
target_version: int
|
||||
steps: List[MigrationStep]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""判断迁移计划是否为空。
|
||||
|
||||
Returns:
|
||||
bool: 若无需执行任何迁移步骤则返回 ``True``。
|
||||
"""
|
||||
return not self.steps
|
||||
|
||||
def step_count(self) -> int:
|
||||
"""返回迁移计划中的步骤数量。
|
||||
|
||||
Returns:
|
||||
int: 当前计划中的迁移步骤数。
|
||||
"""
|
||||
return len(self.steps)
|
||||
|
||||
def latest_reachable_version(self) -> int:
|
||||
"""返回该计划执行后的最终版本。
|
||||
|
||||
Returns:
|
||||
int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。
|
||||
"""
|
||||
if self.is_empty():
|
||||
return self.current_version
|
||||
return self.steps[-1].version_to
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseMigrationState:
|
||||
"""数据库迁移状态描述。"""
|
||||
|
||||
resolved_version: ResolvedSchemaVersion
|
||||
target_version: int
|
||||
plan: MigrationPlan
|
||||
|
||||
def requires_migration(self) -> bool:
|
||||
"""判断当前状态是否需要执行迁移。
|
||||
|
||||
Returns:
|
||||
bool: 若计划中存在待执行迁移步骤则返回 ``True``。
|
||||
"""
|
||||
return not self.plan.is_empty()
|
||||
108
src/common/database/migrations/planner.py
Normal file
108
src/common/database/migrations/planner.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""数据库迁移计划生成器。"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from .exceptions import (
|
||||
DatabaseMigrationPlanningError,
|
||||
MissingMigrationStepError,
|
||||
UnsupportedMigrationDirectionError,
|
||||
)
|
||||
from .models import MigrationPlan, MigrationStep
|
||||
from .registry import MigrationRegistry
|
||||
|
||||
|
||||
class MigrationPlanner:
|
||||
"""数据库迁移计划生成器。"""
|
||||
|
||||
def plan(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
registry: MigrationRegistry,
|
||||
) -> MigrationPlan:
|
||||
"""根据当前版本与目标版本生成迁移计划。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
target_version: 目标数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 按顺序执行的迁移计划。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||
MissingMigrationStepError: 当所需迁移步骤缺失时抛出。
|
||||
UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。
|
||||
"""
|
||||
self._validate_version(current_version, "current_version")
|
||||
self._validate_version(target_version, "target_version")
|
||||
|
||||
if target_version < current_version:
|
||||
raise UnsupportedMigrationDirectionError(
|
||||
f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}。"
|
||||
)
|
||||
if target_version == current_version:
|
||||
return MigrationPlan(current_version=current_version, target_version=target_version, steps=[])
|
||||
|
||||
steps = self._build_steps(current_version, target_version, registry)
|
||||
return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps)
|
||||
|
||||
def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan:
|
||||
"""生成迁移到注册表最新版本的执行计划。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 指向最新版本的迁移计划。
|
||||
"""
|
||||
target_version = registry.latest_version()
|
||||
return self.plan(current_version=current_version, target_version=target_version, registry=registry)
|
||||
|
||||
def _build_steps(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
registry: MigrationRegistry,
|
||||
) -> List[MigrationStep]:
|
||||
"""按顺序拼装迁移步骤链。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
target_version: 目标数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
List[MigrationStep]: 按顺序执行的迁移步骤列表。
|
||||
|
||||
Raises:
|
||||
MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。
|
||||
"""
|
||||
planned_steps: List[MigrationStep] = []
|
||||
next_version = current_version
|
||||
|
||||
while next_version < target_version:
|
||||
step = registry.get_step(next_version)
|
||||
if step is None:
|
||||
raise MissingMigrationStepError(
|
||||
f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。"
|
||||
)
|
||||
planned_steps.append(step)
|
||||
next_version = step.version_to
|
||||
|
||||
return planned_steps
|
||||
|
||||
def _validate_version(self, version: int, field_name: str) -> None:
|
||||
"""校验版本号是否合法。
|
||||
|
||||
Args:
|
||||
version: 待校验的版本号。
|
||||
field_name: 当前版本号对应的字段名。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||
"""
|
||||
if version < 0:
|
||||
raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}")
|
||||
319
src/common/database/migrations/progress.py
Normal file
319
src/common/database/migrations/progress.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""数据库迁移进度展示工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
def _format_duration(total_seconds: Optional[float]) -> str:
|
||||
"""将秒数格式化为适合展示的耗时文本。
|
||||
|
||||
Args:
|
||||
total_seconds: 总秒数;为空时表示暂不可用。
|
||||
|
||||
Returns:
|
||||
str: 格式化后的耗时文本。
|
||||
"""
|
||||
if total_seconds is None:
|
||||
return "--:--:--"
|
||||
safe_seconds = max(total_seconds, 0.0)
|
||||
return str(timedelta(seconds=int(safe_seconds)))
|
||||
|
||||
|
||||
class MigrationSummaryColumn(ProgressColumn):
|
||||
"""渲染数据库迁移总进度摘要列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的总进度摘要。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的摘要文本。
|
||||
"""
|
||||
completed_tables = int(task.fields.get("completed_tables", 0))
|
||||
display_table_total = task.fields.get("display_table_total")
|
||||
total_text = "?" if display_table_total is None else str(int(display_table_total))
|
||||
completed_text = str(completed_tables)
|
||||
return Text(f"总迁移进度({completed_text}/{total_text})")
|
||||
|
||||
|
||||
class MigrationSpeedColumn(ProgressColumn):
|
||||
"""渲染数据库迁移速度列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的速度信息。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的速度文本。
|
||||
"""
|
||||
unit_name = str(task.fields.get("progress_unit_name", "项"))
|
||||
if task.speed is None or task.speed <= 0:
|
||||
return Text(f"-- {unit_name}/s")
|
||||
return Text(f"{task.speed:.2f} {unit_name}/s")
|
||||
|
||||
|
||||
class MigrationElapsedColumn(ProgressColumn):
|
||||
"""渲染数据库迁移已用时间列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的已用时间。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的已用时间文本。
|
||||
"""
|
||||
return Text(f"已用时间 {_format_duration(task.elapsed)}")
|
||||
|
||||
|
||||
class MigrationRemainingColumn(ProgressColumn):
|
||||
"""渲染数据库迁移预估剩余时间列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的预估剩余时间。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的预估剩余时间文本。
|
||||
"""
|
||||
return Text(f"预估时间 {_format_duration(task.time_remaining)}")
|
||||
|
||||
|
||||
class BaseMigrationProgressReporter(ABC):
|
||||
"""数据库迁移进度上报器基类。"""
|
||||
|
||||
def __enter__(self) -> "BaseMigrationProgressReporter":
|
||||
"""进入进度上报上下文。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 当前上报器实例。
|
||||
"""
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""退出进度上报上下文。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_value: 异常实例。
|
||||
traceback: 异常追踪对象。
|
||||
"""
|
||||
del exc_type, exc_value, traceback
|
||||
self.close()
|
||||
|
||||
@abstractmethod
|
||||
def open(self) -> None:
|
||||
"""打开进度上报资源。"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""关闭进度上报资源。"""
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""启动一个新的迁移进度任务。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""推进当前迁移进度任务。
|
||||
|
||||
Args:
|
||||
records: 本次推进的记录数。
|
||||
completed_tables: 本次完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
|
||||
|
||||
class NullMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""不输出任何内容的空进度上报器。"""
|
||||
|
||||
def open(self) -> None:
|
||||
"""打开空进度上报器。"""
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭空进度上报器。"""
|
||||
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""启动空进度任务。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
del total_records, total_tables, description, table_unit_name, record_unit_name
|
||||
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""推进空进度任务。
|
||||
|
||||
Args:
|
||||
records: 本次推进的记录数。
|
||||
completed_tables: 本次完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
del records, completed_tables, item_name
|
||||
|
||||
|
||||
class RichMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""基于 ``rich`` 的数据库迁移进度上报器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
console: Optional[Console] = None,
|
||||
disable: Optional[bool] = None,
|
||||
refresh_per_second: int = 10,
|
||||
) -> None:
|
||||
"""初始化 ``rich`` 迁移进度上报器。
|
||||
|
||||
Args:
|
||||
console: 输出使用的 ``rich`` 控制台。
|
||||
disable: 是否禁用进度条;为空时根据终端能力自动判断。
|
||||
refresh_per_second: 每秒刷新次数。
|
||||
"""
|
||||
self.console = console or Console()
|
||||
self.disable = disable
|
||||
self.refresh_per_second = refresh_per_second
|
||||
self._progress: Optional[Progress] = None
|
||||
self._task_id: Optional[TaskID] = None
|
||||
|
||||
def open(self) -> None:
|
||||
"""打开 ``rich`` 进度条资源。"""
|
||||
effective_disable = not self.console.is_terminal if self.disable is None else self.disable
|
||||
self._progress = Progress(
|
||||
MigrationSummaryColumn(),
|
||||
BarColumn(),
|
||||
MigrationSpeedColumn(),
|
||||
MigrationElapsedColumn(),
|
||||
MigrationRemainingColumn(),
|
||||
console=self.console,
|
||||
transient=False,
|
||||
disable=effective_disable,
|
||||
refresh_per_second=self.refresh_per_second,
|
||||
expand=True,
|
||||
)
|
||||
self._progress.start()
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭 ``rich`` 进度条资源。"""
|
||||
if self._progress is None:
|
||||
return
|
||||
self._progress.stop()
|
||||
self._progress = None
|
||||
self._task_id = None
|
||||
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""启动一个新的 ``rich`` 迁移进度任务。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
if self._progress is None:
|
||||
self.open()
|
||||
assert self._progress is not None
|
||||
use_record_progress = total_records > 0
|
||||
effective_total = total_records if use_record_progress else total_tables
|
||||
effective_total = max(effective_total, 1)
|
||||
progress_unit_name = record_unit_name if use_record_progress else table_unit_name
|
||||
self._task_id = self._progress.add_task(
|
||||
description,
|
||||
total=effective_total,
|
||||
completed_tables=0,
|
||||
display_table_total=total_tables,
|
||||
progress_unit_name=progress_unit_name,
|
||||
use_record_progress=use_record_progress,
|
||||
)
|
||||
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""推进当前 ``rich`` 迁移进度任务。
|
||||
|
||||
Args:
|
||||
records: 本次推进的记录数。
|
||||
completed_tables: 本次完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
del item_name
|
||||
if self._progress is None or self._task_id is None:
|
||||
return
|
||||
task = self._progress.tasks[self._task_id]
|
||||
use_record_progress = bool(task.fields.get("use_record_progress", False))
|
||||
progress_advance = records if use_record_progress else completed_tables
|
||||
updated_completed_tables = int(task.fields.get("completed_tables", 0)) + completed_tables
|
||||
self._progress.update(
|
||||
self._task_id,
|
||||
advance=progress_advance,
|
||||
completed_tables=updated_completed_tables,
|
||||
)
|
||||
|
||||
|
||||
def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter:
|
||||
"""创建默认的 ``rich`` 迁移进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 默认迁移进度上报器实例。
|
||||
"""
|
||||
return RichMigrationProgressReporter()
|
||||
98
src/common/database/migrations/registry.py
Normal file
98
src/common/database/migrations/registry.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""数据库迁移步骤注册表。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .exceptions import DatabaseMigrationConfigurationError
|
||||
from .models import MigrationStep
|
||||
|
||||
|
||||
class MigrationRegistry:
|
||||
"""数据库迁移步骤注册表。"""
|
||||
|
||||
def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None:
|
||||
"""初始化迁移步骤注册表。
|
||||
|
||||
Args:
|
||||
steps: 初始化时要注册的迁移步骤列表。
|
||||
"""
|
||||
self._steps_by_from_version: Dict[int, MigrationStep] = {}
|
||||
if steps:
|
||||
self.register_many(steps)
|
||||
|
||||
def register(self, step: MigrationStep) -> None:
|
||||
"""注册单个迁移步骤。
|
||||
|
||||
当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、
|
||||
易于回放,也便于后续生产问题排查。
|
||||
|
||||
Args:
|
||||
step: 待注册的迁移步骤定义。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。
|
||||
"""
|
||||
if step.version_to != step.version_from + 1:
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
"迁移步骤必须使用相邻版本号定义,例如 2 -> 3。"
|
||||
)
|
||||
if step.version_from in self._steps_by_from_version:
|
||||
existing_step = self._steps_by_from_version[step.version_from]
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}"
|
||||
)
|
||||
for registered_step in self._steps_by_from_version.values():
|
||||
if registered_step.version_to == step.version_to:
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。"
|
||||
)
|
||||
self._steps_by_from_version[step.version_from] = step
|
||||
|
||||
def register_many(self, steps: List[MigrationStep]) -> None:
|
||||
"""批量注册多个迁移步骤。
|
||||
|
||||
Args:
|
||||
steps: 待注册的迁移步骤列表。
|
||||
"""
|
||||
for step in steps:
|
||||
self.register(step)
|
||||
|
||||
def get_step(self, version_from: int) -> Optional[MigrationStep]:
|
||||
"""获取指定起始版本的迁移步骤。
|
||||
|
||||
Args:
|
||||
version_from: 迁移步骤的起始版本号。
|
||||
|
||||
Returns:
|
||||
Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。
|
||||
"""
|
||||
return self._steps_by_from_version.get(version_from)
|
||||
|
||||
def has_step(self, version_from: int) -> bool:
|
||||
"""判断指定起始版本是否已注册迁移步骤。
|
||||
|
||||
Args:
|
||||
version_from: 待检查的起始版本号。
|
||||
|
||||
Returns:
|
||||
bool: 若已注册对应步骤则返回 ``True``。
|
||||
"""
|
||||
return version_from in self._steps_by_from_version
|
||||
|
||||
def latest_version(self) -> int:
|
||||
"""返回当前注册表支持到的最新 schema 版本。
|
||||
|
||||
Returns:
|
||||
int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。
|
||||
"""
|
||||
if not self._steps_by_from_version:
|
||||
return 0
|
||||
return max(step.version_to for step in self._steps_by_from_version.values())
|
||||
|
||||
def list_steps(self) -> List[MigrationStep]:
|
||||
"""按起始版本顺序返回全部迁移步骤。
|
||||
|
||||
Returns:
|
||||
List[MigrationStep]: 已注册迁移步骤列表。
|
||||
"""
|
||||
ordered_versions = sorted(self._steps_by_from_version)
|
||||
return [self._steps_by_from_version[version] for version in ordered_versions]
|
||||
135
src/common/database/migrations/resolver.py
Normal file
135
src/common/database/migrations/resolver.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""数据库版本解析器。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError
|
||||
from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
|
||||
class BaseSchemaVersionDetector(ABC):
|
||||
"""未标记版本数据库的结构探测器基类。"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""返回当前探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据数据库结构快照推断版本号。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
|
||||
class SchemaVersionResolver:
|
||||
"""数据库版本解析器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
schema_inspector: Optional[SQLiteSchemaInspector] = None,
|
||||
detectors: Optional[List[BaseSchemaVersionDetector]] = None,
|
||||
) -> None:
|
||||
"""初始化数据库版本解析器。
|
||||
|
||||
Args:
|
||||
version_store: 版本存储器;未提供时将使用默认实现。
|
||||
schema_inspector: 结构探测器;未提供时将使用默认实现。
|
||||
detectors: 未标记版本数据库的探测器列表。
|
||||
"""
|
||||
self.version_store = version_store or SQLiteUserVersionStore()
|
||||
self.schema_inspector = schema_inspector or SQLiteSchemaInspector()
|
||||
self.detectors: List[BaseSchemaVersionDetector] = list(detectors or [])
|
||||
|
||||
def add_detector(self, detector: BaseSchemaVersionDetector) -> None:
|
||||
"""注册一个未标记版本数据库探测器。
|
||||
|
||||
Args:
|
||||
detector: 待注册的探测器实例。
|
||||
"""
|
||||
self.detectors.append(detector)
|
||||
|
||||
def list_detectors(self) -> List[BaseSchemaVersionDetector]:
|
||||
"""返回当前已注册的全部探测器。
|
||||
|
||||
Returns:
|
||||
List[BaseSchemaVersionDetector]: 已注册探测器列表副本。
|
||||
"""
|
||||
return list(self.detectors)
|
||||
|
||||
def resolve(self, connection: Connection) -> ResolvedSchemaVersion:
|
||||
"""解析当前数据库的 schema 版本信息。
|
||||
|
||||
解析顺序如下:
|
||||
1. 优先读取 ``PRAGMA user_version``。
|
||||
2. 若其值为 0,则对数据库结构做快照。
|
||||
3. 若数据库为空,则返回空库版本。
|
||||
4. 若数据库非空,则交给探测器链进行识别。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
ResolvedSchemaVersion: 解析后的数据库版本信息。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||
UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。
|
||||
"""
|
||||
recorded_version = self.version_store.read_version(connection)
|
||||
if recorded_version > 0:
|
||||
return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA)
|
||||
|
||||
snapshot = self.schema_inspector.inspect(connection)
|
||||
if snapshot.is_empty():
|
||||
return ResolvedSchemaVersion(
|
||||
version=0,
|
||||
source=SchemaVersionSource.EMPTY_DATABASE,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
return self._detect_unversioned_database(snapshot)
|
||||
|
||||
def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion:
|
||||
"""识别未标记版本的历史数据库。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
ResolvedSchemaVersion: 探测器识别出的版本信息。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||
UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。
|
||||
"""
|
||||
for detector in self.detectors:
|
||||
detected_version = detector.detect_version(snapshot)
|
||||
if detected_version is None:
|
||||
continue
|
||||
if detected_version < 0:
|
||||
raise DatabaseMigrationVersionError(
|
||||
f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}"
|
||||
)
|
||||
return ResolvedSchemaVersion(
|
||||
version=detected_version,
|
||||
source=SchemaVersionSource.DETECTOR,
|
||||
detector_name=detector.name,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。")
|
||||
98
src/common/database/migrations/schema.py
Normal file
98
src/common/database/migrations/schema.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""SQLite 数据库结构探测工具。"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema
|
||||
|
||||
|
||||
class SQLiteSchemaInspector:
|
||||
"""SQLite 数据库结构探测器。"""
|
||||
|
||||
def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot:
|
||||
"""提取数据库中的全部用户表结构快照。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
DatabaseSchemaSnapshot: 当前数据库结构快照。
|
||||
"""
|
||||
tables: Dict[str, TableSchema] = {}
|
||||
for table_name in self.list_user_tables(connection):
|
||||
table_schema = self.get_table_schema(connection, table_name)
|
||||
tables[table_name] = table_schema
|
||||
return DatabaseSchemaSnapshot(tables=tables)
|
||||
|
||||
def list_user_tables(self, connection: Connection) -> List[str]:
|
||||
"""列出数据库中的全部用户表。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的用户表名列表。
|
||||
"""
|
||||
statement = text(
|
||||
"""
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
rows = connection.execute(statement).fetchall()
|
||||
return [str(row[0]) for row in rows]
|
||||
|
||||
def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema:
|
||||
"""获取指定表的结构信息。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
table_name: 待读取结构的表名。
|
||||
|
||||
Returns:
|
||||
TableSchema: 指定表的结构快照。
|
||||
"""
|
||||
quoted_table_name = self._quote_identifier(table_name)
|
||||
rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all()
|
||||
|
||||
columns: Dict[str, ColumnSchema] = {}
|
||||
for row in rows:
|
||||
column_schema = ColumnSchema(
|
||||
name=str(row["name"]),
|
||||
declared_type=str(row["type"] or ""),
|
||||
default_value=None if row["dflt_value"] is None else str(row["dflt_value"]),
|
||||
is_not_null=bool(row["notnull"]),
|
||||
primary_key_position=int(row["pk"]),
|
||||
)
|
||||
columns[column_schema.name] = column_schema
|
||||
|
||||
return TableSchema(name=table_name, columns=columns)
|
||||
|
||||
def table_exists(self, connection: Connection, table_name: str) -> bool:
|
||||
"""判断数据库中是否存在指定表。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
table_name: 待检查的表名。
|
||||
|
||||
Returns:
|
||||
bool: 若表存在则返回 ``True``。
|
||||
"""
|
||||
return table_name in self.list_user_tables(connection)
|
||||
|
||||
def _quote_identifier(self, identifier: str) -> str:
|
||||
"""为 SQLite 标识符添加安全引号。
|
||||
|
||||
Args:
|
||||
identifier: 待引用的 SQLite 标识符。
|
||||
|
||||
Returns:
|
||||
str: 可直接拼接到 PRAGMA 语句中的安全标识符。
|
||||
"""
|
||||
escaped_identifier = identifier.replace('"', '""')
|
||||
return f'"{escaped_identifier}"'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user