fix:解决 upstream 合并冲突
在临时整合分支中合并 upstream/r-dev 保留人物事实写回与反馈纠错配置 移除已下线的聊天总结配置并同步测试
This commit is contained in:
@@ -1,148 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
|
||||
|
||||
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
|
||||
summarizer.session_id = "session-1"
|
||||
summarizer.log_prefix = "[session-1]"
|
||||
return summarizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
|
||||
calls = []
|
||||
summarizer = _build_summarizer()
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||
)
|
||||
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=1,
|
||||
theme="旅行计划",
|
||||
summary="我们讨论了春游安排",
|
||||
participants=["Alice", "Bob"],
|
||||
start_time=1.0,
|
||||
end_time=2.0,
|
||||
original_text="long text",
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
payload = calls[0]
|
||||
assert payload["external_id"] == "chat_history:1"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["participants"] == ["Alice", "Bob"]
|
||||
assert payload["respect_filter"] is True
|
||||
assert payload["user_id"] == "user-1"
|
||||
assert payload["group_id"] == ""
|
||||
assert "主题:旅行计划" in payload["text"]
|
||||
assert "概括:我们讨论了春游安排" in payload["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
|
||||
summarizer = _build_summarizer()
|
||||
fallback_calls = []
|
||||
|
||||
async def fake_fallback(**kwargs):
|
||||
fallback_calls.append(kwargs)
|
||||
|
||||
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||
)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=2,
|
||||
theme="",
|
||||
summary="",
|
||||
participants=[],
|
||||
start_time=10.0,
|
||||
end_time=20.0,
|
||||
original_text="raw chat",
|
||||
)
|
||||
|
||||
assert len(fallback_calls) == 1
|
||||
assert fallback_calls[0]["record_id"] == 2
|
||||
assert fallback_calls[0]["original_text"] == "raw chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
|
||||
summarizer = _build_summarizer()
|
||||
fallback_calls = []
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
|
||||
|
||||
async def fake_fallback(**kwargs):
|
||||
fallback_calls.append(kwargs)
|
||||
|
||||
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
|
||||
)
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=3,
|
||||
theme="电影",
|
||||
summary="聊了电影推荐",
|
||||
participants=["Alice"],
|
||||
start_time=3.0,
|
||||
end_time=4.0,
|
||||
original_text="raw",
|
||||
)
|
||||
|
||||
assert len(fallback_calls) == 1
|
||||
assert fallback_calls[0]["theme"] == "电影"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
|
||||
calls = []
|
||||
summarizer = _build_summarizer()
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
|
||||
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
|
||||
)
|
||||
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._fallback_import_to_long_term_memory(
|
||||
record_id=4,
|
||||
theme="工作",
|
||||
participants=["Alice"],
|
||||
start_time=5.0,
|
||||
end_time=6.0,
|
||||
original_text="a" * 128,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
metadata = calls[0]["metadata"]
|
||||
assert metadata["generate_from_chat"] is True
|
||||
assert metadata["context_length"] == 12
|
||||
assert calls[0]["respect_filter"] is True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,687 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import MemorySearchResult, memory_service
|
||||
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
|
||||
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
|
||||
|
||||
|
||||
def _load_benchmark_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _FakeEmbeddingAdapter:
|
||||
def __init__(self, dimension: int = 32) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.dimension
|
||||
|
||||
async def encode(self, texts, dimensions=None):
|
||||
dim = int(dimensions or self.dimension)
|
||||
if isinstance(texts, str):
|
||||
sequence = [texts]
|
||||
single = True
|
||||
else:
|
||||
sequence = list(texts)
|
||||
single = False
|
||||
|
||||
rows = []
|
||||
for text in sequence:
|
||||
vec = np.zeros(dim, dtype=np.float32)
|
||||
for ch in str(text or ""):
|
||||
code = ord(ch)
|
||||
vec[code % dim] += 1.0
|
||||
vec[(code * 7) % dim] += 0.5
|
||||
if not vec.any():
|
||||
vec[0] = 1.0
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
rows.append(vec)
|
||||
payload = np.vstack(rows)
|
||||
return payload[0] if single else payload
|
||||
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = person_id in reverse_registry
|
||||
self.person_name = reverse_registry.get(person_id, "")
|
||||
self._registry = registry
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
|
||||
for _ in range(max_rounds):
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(max(0.01, float(sleep_seconds)))
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result: MemorySearchResult) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
def _keyword_hits(text: str, keywords: List[str]) -> int:
|
||||
haystack = str(text or "")
|
||||
return sum(1 for keyword in keywords if keyword in haystack)
|
||||
|
||||
|
||||
def _keyword_recall(text: str, keywords: List[str]) -> float:
|
||||
if not keywords:
|
||||
return 1.0
|
||||
return _keyword_hits(text, keywords) / float(len(keywords))
|
||||
|
||||
|
||||
def _hit_blob(hit) -> str:
|
||||
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||
return "\n".join(
|
||||
[
|
||||
str(hit.content or ""),
|
||||
str(hit.title or ""),
|
||||
str(hit.source or ""),
|
||||
json.dumps(meta, ensure_ascii=False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
|
||||
for index, hit in enumerate(search_result.hits[:5], start=1):
|
||||
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
|
||||
return index
|
||||
return 0
|
||||
|
||||
|
||||
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
|
||||
return "\n".join(
|
||||
(
|
||||
f"{item.get('title', '')}\n"
|
||||
f"{item.get('summary', '')}\n"
|
||||
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
|
||||
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
|
||||
)
|
||||
for item in items
|
||||
)
|
||||
|
||||
|
||||
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
|
||||
chunks = []
|
||||
for hit in search_result.hits:
|
||||
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||
chunks.append(
|
||||
"\n".join(
|
||||
[
|
||||
str(hit.title or ""),
|
||||
str(hit.content or ""),
|
||||
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
|
||||
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
|
||||
]
|
||||
)
|
||||
)
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
episode_source = f"chat_summary:{session_id}"
|
||||
payload = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=episode_source,
|
||||
limit=20,
|
||||
)
|
||||
items = payload.get("items") or []
|
||||
blob = _episode_blob_from_items(items)
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in episode_cases:
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(items),
|
||||
"top_episode": items[0] if items else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"episode_count": len(items),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
episode_source = f"chat_summary:{session_id}"
|
||||
|
||||
for case in episode_cases:
|
||||
payload = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=episode_source,
|
||||
query=case["query"],
|
||||
limit=5,
|
||||
)
|
||||
items = payload.get("items") or []
|
||||
blob = "\n".join(
|
||||
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
|
||||
for item in items
|
||||
)
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(items),
|
||||
"top_episode": items[0] if items else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in episode_cases:
|
||||
result = await memory_service.search(
|
||||
case["query"],
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
blob = _episode_blob_from_hits(result)
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(result.hits),
|
||||
"top_episode": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
|
||||
search_case = dataset["search_cases"][0]
|
||||
episode_case = dataset["episode_cases"][0]
|
||||
aggregate_case = dataset["knowledge_fetcher_cases"][0]
|
||||
first_record = (dataset.get("chat_history_records") or [{}])[0]
|
||||
reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0
|
||||
if reference_ts:
|
||||
time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d")
|
||||
else:
|
||||
time_expression = "最近7天"
|
||||
tool_cases = [
|
||||
{
|
||||
"name": "search",
|
||||
"kwargs": {
|
||||
"query": "蓝漆铁盒 北塔木梯",
|
||||
"mode": "search",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"kwargs": {
|
||||
"query": "蓝漆铁盒 北塔",
|
||||
"mode": "time",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
"time_expression": time_expression,
|
||||
},
|
||||
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "episode",
|
||||
"kwargs": {
|
||||
"query": episode_case["query"],
|
||||
"mode": "episode",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": episode_case["expected_keywords"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "aggregate",
|
||||
"kwargs": {
|
||||
"query": aggregate_case["query"],
|
||||
"mode": "aggregate",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": aggregate_case["expected_keywords"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
]
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in tool_cases:
|
||||
text = await query_long_term_memory(**case["kwargs"])
|
||||
recall = _keyword_recall(text, case["expected_keywords"])
|
||||
success = (
|
||||
"失败" not in text
|
||||
and "无法解析" not in text
|
||||
and "未找到" not in text
|
||||
and recall >= float(case["minimum_keyword_recall"])
|
||||
)
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"name": case["name"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": text[:320],
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(tool_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def benchmark_env(monkeypatch, tmp_path):
|
||||
dataset = _load_benchmark_fixture()
|
||||
session_cfg = dataset["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||
reverse_registry = {value: key for key, value in registry.items()}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||
|
||||
async def fake_self_check(**kwargs):
|
||||
return {"ok": True, "message": "ok", "encoded_dimension": 32}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"Person",
|
||||
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||
)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"dataset": dataset,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
"person_registry": registry,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_novel_memory_benchmark(benchmark_env):
|
||||
dataset = benchmark_env["dataset"]
|
||||
session_id = benchmark_env["session"].session_id
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="long_novel_memory_benchmark.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
assert created["success"] is True
|
||||
|
||||
import_detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert import_detail["task"]["status"] == "completed"
|
||||
|
||||
for record in dataset["chat_history_records"]:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
for payload in dataset["person_writebacks"]:
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
payload["person_name"],
|
||||
payload["memory_content"],
|
||||
session_id,
|
||||
)
|
||||
|
||||
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||
|
||||
search_case_reports: List[Dict[str, Any]] = []
|
||||
search_accuracy_at_1 = 0.0
|
||||
search_recall_at_5 = 0.0
|
||||
search_precision_at_5 = 0.0
|
||||
search_mrr = 0.0
|
||||
search_keyword_recall = 0.0
|
||||
|
||||
for case in dataset["search_cases"]:
|
||||
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||
joined = _join_hit_content(result)
|
||||
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||
relevant_hits = sum(
|
||||
1
|
||||
for hit in result.hits[:5]
|
||||
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||
)
|
||||
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||
search_keyword_recall += keyword_recall
|
||||
search_case_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"rank_of_first_relevant": rank,
|
||||
"relevant_hits_top5": relevant_hits,
|
||||
"keyword_recall_top5": keyword_recall,
|
||||
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
|
||||
search_total = max(1, len(dataset["search_cases"]))
|
||||
|
||||
writeback_reports: List[Dict[str, Any]] = []
|
||||
writeback_success_rate = 0.0
|
||||
writeback_keyword_recall = 0.0
|
||||
for payload in dataset["person_writebacks"]:
|
||||
query = " ".join(payload["expected_keywords"])
|
||||
result = await memory_service.search(
|
||||
query,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
person_id=payload["person_id"],
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
joined = _join_hit_content(result)
|
||||
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= 0.67
|
||||
writeback_success_rate += 1.0 if success else 0.0
|
||||
writeback_keyword_recall += recall
|
||||
writeback_reports.append(
|
||||
{
|
||||
"person_id": payload["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"hit_count": len(result.hits),
|
||||
}
|
||||
)
|
||||
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||
|
||||
knowledge_reports: List[Dict[str, Any]] = []
|
||||
knowledge_success_rate = 0.0
|
||||
knowledge_keyword_recall = 0.0
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=dataset["session"]["display_name"],
|
||||
stream_id=session_id,
|
||||
)
|
||||
for case in dataset["knowledge_fetcher_cases"]:
|
||||
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
knowledge_success_rate += 1.0 if success else 0.0
|
||||
knowledge_keyword_recall += recall
|
||||
knowledge_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": knowledge_text[:300],
|
||||
}
|
||||
)
|
||||
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||
|
||||
profile_reports: List[Dict[str, Any]] = []
|
||||
profile_success_rate = 0.0
|
||||
profile_keyword_recall = 0.0
|
||||
profile_evidence_rate = 0.0
|
||||
for case in dataset["profile_cases"]:
|
||||
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||
has_evidence = bool(profile.evidence)
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||
profile_success_rate += 1.0 if success else 0.0
|
||||
profile_keyword_recall += recall
|
||||
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||
profile_reports.append(
|
||||
{
|
||||
"person_id": case["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"evidence_count": len(profile.evidence),
|
||||
"summary_preview": profile.summary[:240],
|
||||
}
|
||||
)
|
||||
profile_total = max(1, len(dataset["profile_cases"]))
|
||||
|
||||
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_rebuild = await memory_service.episode_admin(
|
||||
action="rebuild",
|
||||
source=f"chat_summary:{session_id}",
|
||||
)
|
||||
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||
|
||||
report = {
|
||||
"dataset": dataset["meta"],
|
||||
"import": {
|
||||
"task_id": created["task"]["task_id"],
|
||||
"status": import_detail["task"]["status"],
|
||||
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||
},
|
||||
"metrics": {
|
||||
"search": {
|
||||
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||
"mrr": round(search_mrr / search_total, 4),
|
||||
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||
},
|
||||
"writeback": {
|
||||
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||
},
|
||||
"knowledge_fetcher": {
|
||||
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||
},
|
||||
"profile": {
|
||||
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||
},
|
||||
"tool_modes": {
|
||||
"success_rate": tool_modes["success_rate"],
|
||||
"keyword_recall": tool_modes["keyword_recall"],
|
||||
},
|
||||
"episode_generation_auto": {
|
||||
"success_rate": episode_generation_auto["success_rate"],
|
||||
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||
"episode_count": episode_generation_auto["episode_count"],
|
||||
},
|
||||
"episode_generation_after_rebuild": {
|
||||
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_admin_query_auto": {
|
||||
"success_rate": episode_admin_query_auto["success_rate"],
|
||||
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||
},
|
||||
"episode_admin_query_after_rebuild": {
|
||||
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_search_mode_auto": {
|
||||
"success_rate": episode_search_mode_auto["success_rate"],
|
||||
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||
},
|
||||
"episode_search_mode_after_rebuild": {
|
||||
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
},
|
||||
"cases": {
|
||||
"search": search_case_reports,
|
||||
"writeback": writeback_reports,
|
||||
"knowledge_fetcher": knowledge_reports,
|
||||
"profile": profile_reports,
|
||||
"tool_modes": tool_modes["reports"],
|
||||
"episode_generation_auto": episode_generation_auto["reports"],
|
||||
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||
},
|
||||
}
|
||||
|
||||
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||
|
||||
assert report["import"]["status"] == "completed"
|
||||
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
|
||||
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
|
||||
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
|
||||
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
|
||||
assert report["metrics"]["profile"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
|
||||
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
|
||||
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
|
||||
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]
|
||||
@@ -1,342 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
from pytests.A_memorix_test.test_long_novel_memory_benchmark import (
|
||||
_evaluate_episode_admin_query,
|
||||
_evaluate_episode_generation,
|
||||
_evaluate_episode_search_mode,
|
||||
_evaluate_tool_modes,
|
||||
_KernelBackedRuntimeManager,
|
||||
_KnownPerson,
|
||||
_first_relevant_rank,
|
||||
_hit_blob,
|
||||
_join_hit_content,
|
||||
_keyword_hits,
|
||||
_keyword_recall,
|
||||
_load_benchmark_fixture,
|
||||
_wait_for_import_task,
|
||||
)
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||
reason="需要显式开启真实 external embedding benchmark",
|
||||
)
|
||||
|
||||
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def benchmark_live_env(monkeypatch, tmp_path):
|
||||
dataset = _load_benchmark_fixture()
|
||||
session_cfg = dataset["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||
reverse_registry = {value: key for key, value in registry.items()}
|
||||
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"Person",
|
||||
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||
)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"dataset": dataset,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
|
||||
dataset = benchmark_live_env["dataset"]
|
||||
session_id = benchmark_live_env["session"].session_id
|
||||
|
||||
self_check = await memory_service.runtime_admin(action="refresh_self_check")
|
||||
assert self_check["success"] is True
|
||||
assert self_check["report"]["ok"] is True
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="long_novel_memory_benchmark.live.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
assert created["success"] is True
|
||||
|
||||
import_detail = await _wait_for_import_task(
|
||||
created["task"]["task_id"],
|
||||
max_rounds=2400,
|
||||
sleep_seconds=0.25,
|
||||
)
|
||||
assert import_detail["task"]["status"] == "completed"
|
||||
|
||||
for record in dataset["chat_history_records"]:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
for payload in dataset["person_writebacks"]:
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
payload["person_name"],
|
||||
payload["memory_content"],
|
||||
session_id,
|
||||
)
|
||||
|
||||
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||
|
||||
search_case_reports: List[Dict[str, Any]] = []
|
||||
search_accuracy_at_1 = 0.0
|
||||
search_recall_at_5 = 0.0
|
||||
search_precision_at_5 = 0.0
|
||||
search_mrr = 0.0
|
||||
search_keyword_recall = 0.0
|
||||
for case in dataset["search_cases"]:
|
||||
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||
joined = _join_hit_content(result)
|
||||
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||
relevant_hits = sum(
|
||||
1
|
||||
for hit in result.hits[:5]
|
||||
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||
)
|
||||
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||
search_keyword_recall += keyword_recall
|
||||
search_case_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"rank_of_first_relevant": rank,
|
||||
"relevant_hits_top5": relevant_hits,
|
||||
"keyword_recall_top5": keyword_recall,
|
||||
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
search_total = max(1, len(dataset["search_cases"]))
|
||||
|
||||
writeback_reports: List[Dict[str, Any]] = []
|
||||
writeback_success_rate = 0.0
|
||||
writeback_keyword_recall = 0.0
|
||||
for payload in dataset["person_writebacks"]:
|
||||
query = " ".join(payload["expected_keywords"])
|
||||
result = await memory_service.search(
|
||||
query,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
person_id=payload["person_id"],
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
joined = _join_hit_content(result)
|
||||
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= 0.67
|
||||
writeback_success_rate += 1.0 if success else 0.0
|
||||
writeback_keyword_recall += recall
|
||||
writeback_reports.append(
|
||||
{
|
||||
"person_id": payload["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"hit_count": len(result.hits),
|
||||
}
|
||||
)
|
||||
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||
|
||||
knowledge_reports: List[Dict[str, Any]] = []
|
||||
knowledge_success_rate = 0.0
|
||||
knowledge_keyword_recall = 0.0
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=dataset["session"]["display_name"],
|
||||
stream_id=session_id,
|
||||
)
|
||||
for case in dataset["knowledge_fetcher_cases"]:
|
||||
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
knowledge_success_rate += 1.0 if success else 0.0
|
||||
knowledge_keyword_recall += recall
|
||||
knowledge_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": knowledge_text[:300],
|
||||
}
|
||||
)
|
||||
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||
|
||||
profile_reports: List[Dict[str, Any]] = []
|
||||
profile_success_rate = 0.0
|
||||
profile_keyword_recall = 0.0
|
||||
profile_evidence_rate = 0.0
|
||||
for case in dataset["profile_cases"]:
|
||||
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||
has_evidence = bool(profile.evidence)
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||
profile_success_rate += 1.0 if success else 0.0
|
||||
profile_keyword_recall += recall
|
||||
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||
profile_reports.append(
|
||||
{
|
||||
"person_id": case["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"evidence_count": len(profile.evidence),
|
||||
"summary_preview": profile.summary[:240],
|
||||
}
|
||||
)
|
||||
profile_total = max(1, len(dataset["profile_cases"]))
|
||||
|
||||
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_rebuild = await memory_service.episode_admin(
|
||||
action="rebuild",
|
||||
source=f"chat_summary:{session_id}",
|
||||
)
|
||||
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||
|
||||
report = {
|
||||
"dataset": dataset["meta"],
|
||||
"runtime_self_check": self_check["report"],
|
||||
"import": {
|
||||
"task_id": created["task"]["task_id"],
|
||||
"status": import_detail["task"]["status"],
|
||||
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||
},
|
||||
"metrics": {
|
||||
"search": {
|
||||
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||
"mrr": round(search_mrr / search_total, 4),
|
||||
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||
},
|
||||
"writeback": {
|
||||
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||
},
|
||||
"knowledge_fetcher": {
|
||||
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||
},
|
||||
"profile": {
|
||||
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||
},
|
||||
"tool_modes": {
|
||||
"success_rate": tool_modes["success_rate"],
|
||||
"keyword_recall": tool_modes["keyword_recall"],
|
||||
},
|
||||
"episode_generation_auto": {
|
||||
"success_rate": episode_generation_auto["success_rate"],
|
||||
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||
"episode_count": episode_generation_auto["episode_count"],
|
||||
},
|
||||
"episode_generation_after_rebuild": {
|
||||
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_admin_query_auto": {
|
||||
"success_rate": episode_admin_query_auto["success_rate"],
|
||||
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||
},
|
||||
"episode_admin_query_after_rebuild": {
|
||||
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_search_mode_auto": {
|
||||
"success_rate": episode_search_mode_auto["success_rate"],
|
||||
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||
},
|
||||
"episode_search_mode_after_rebuild": {
|
||||
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
},
|
||||
"cases": {
|
||||
"search": search_case_reports,
|
||||
"writeback": writeback_reports,
|
||||
"knowledge_fetcher": knowledge_reports,
|
||||
"profile": profile_reports,
|
||||
"tool_modes": tool_modes["reports"],
|
||||
"episode_generation_auto": episode_generation_auto["reports"],
|
||||
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||
},
|
||||
}
|
||||
|
||||
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||
|
||||
assert report["import"]["status"] == "completed"
|
||||
assert report["runtime_self_check"]["ok"] is True
|
||||
@@ -5,68 +5,6 @@ import pytest
|
||||
from src.services import memory_flow_service as memory_flow_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
|
||||
starts: list[str] = []
|
||||
summarizers: list[object] = []
|
||||
|
||||
class FakeSummarizer:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
summarizers.append(self)
|
||||
|
||||
async def start(self):
|
||||
starts.append(self.session_id)
|
||||
|
||||
async def stop(self):
|
||||
starts.append(f"stop:{self.session_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||
|
||||
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||
message = SimpleNamespace(session_id="session-1")
|
||||
|
||||
await manager.on_message(message)
|
||||
await manager.on_message(message)
|
||||
|
||||
assert len(summarizers) == 1
|
||||
assert starts == ["session-1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
|
||||
stopped: list[str] = []
|
||||
|
||||
class FakeSummarizer:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
|
||||
async def start(self):
|
||||
return None
|
||||
|
||||
async def stop(self):
|
||||
stopped.append(self.session_id)
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||
|
||||
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||
await manager.on_message(SimpleNamespace(session_id="session-a"))
|
||||
await manager.on_message(SimpleNamespace(session_id="session-b"))
|
||||
await manager.shutdown()
|
||||
|
||||
assert stopped == ["session-a", "session-b"]
|
||||
|
||||
|
||||
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
|
||||
raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]'
|
||||
|
||||
@@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
|
||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeSessionManager:
|
||||
async def on_message(self, message):
|
||||
events.append(("incoming", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "session"))
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
@@ -122,17 +53,13 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.session_manager = FakeSessionManager()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("incoming", "session-1"),
|
||||
("sent", "session-1"),
|
||||
("shutdown", "session"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||
|
||||
|
||||
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _FakeEmbeddingAdapter:
|
||||
def __init__(self, dimension: int = 16) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.dimension
|
||||
|
||||
async def encode(self, texts, dimensions=None):
|
||||
dim = int(dimensions or self.dimension)
|
||||
if isinstance(texts, str):
|
||||
sequence = [texts]
|
||||
single = True
|
||||
else:
|
||||
sequence = list(texts)
|
||||
single = False
|
||||
|
||||
rows = []
|
||||
for text in sequence:
|
||||
vec = np.zeros(dim, dtype=np.float32)
|
||||
for ch in str(text or ""):
|
||||
vec[ord(ch) % dim] += 1.0
|
||||
if not vec.any():
|
||||
vec[0] = 1.0
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
rows.append(vec)
|
||||
payload = np.vstack(rows)
|
||||
return payload[0] if single else payload
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
|
||||
for _ in range(max_rounds):
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(0.05)
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def real_dialogue_env(monkeypatch, tmp_path):
|
||||
scenario = _load_dialogue_fixture()
|
||||
session_cfg = scenario["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||
|
||||
async def fake_self_check(**kwargs):
|
||||
return {"ok": True, "message": "ok"}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"scenario": scenario,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="private_alice.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
|
||||
assert created["success"] is True
|
||||
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert detail["task"]["status"] == "completed"
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
person_id=scenario["person"]["person_id"],
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
await memory_service.ingest_text(
|
||||
external_id="fixture:knowledge_fetcher",
|
||||
source_type="dialogue_note",
|
||||
text=scenario["person_fact"]["memory_content"],
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
person_ids=[scenario["person"]["person_id"]],
|
||||
participants=[scenario["person"]["person_name"]],
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=scenario["session"]["display_name"],
|
||||
stream_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in knowledge_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
profile = await memory_service.get_person_profile(
|
||||
scenario["person"]["person_id"],
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert profile.evidence
|
||||
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
episodes = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=scenario["expectations"]["episode_source"],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert episodes["success"] is True
|
||||
assert int(episodes["count"]) >= 1
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||
reason="需要显式开启真实 embedding / self-check 集成测试",
|
||||
)
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||
|
||||
|
||||
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
|
||||
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(0.2)
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def live_dialogue_env(monkeypatch, tmp_path):
|
||||
scenario = _load_dialogue_fixture()
|
||||
session_cfg = scenario["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"scenario": scenario,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_runtime_self_check_passes(live_dialogue_env):
|
||||
report = await memory_service.runtime_admin(action="refresh_self_check")
|
||||
|
||||
assert report["success"] is True
|
||||
assert report["report"]["ok"] is True
|
||||
assert report["report"]["encoded_dimension"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="private_alice.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
|
||||
assert created["success"] is True
|
||||
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert detail["task"]["status"] == "completed"
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
person_id=scenario["person"]["person_id"],
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
await memory_service.ingest_text(
|
||||
external_id="fixture:knowledge_fetcher",
|
||||
source_type="dialogue_note",
|
||||
text=scenario["person_fact"]["memory_content"],
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
person_ids=[scenario["person"]["person_id"]],
|
||||
participants=[scenario["person"]["person_name"]],
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=scenario["session"]["display_name"],
|
||||
stream_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in knowledge_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
profile = await memory_service.get_person_profile(
|
||||
scenario["person"]["person_id"],
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert profile.evidence
|
||||
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
episodes = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=scenario["expectations"]["episode_source"],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert episodes["success"] is True
|
||||
assert int(episodes["count"]) >= 1
|
||||
@@ -1,23 +0,0 @@
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
|
||||
|
||||
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
|
||||
tool_definitions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mute_user",
|
||||
"description": "禁言指定用户",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "reply",
|
||||
"description": "发送回复",
|
||||
},
|
||||
]
|
||||
|
||||
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"
|
||||
@@ -1,339 +0,0 @@
|
||||
"""MutePlugin SDK 回归测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from maibot_sdk.context import PluginContext
|
||||
from maibot_sdk.plugin import MaiBotPlugin
|
||||
|
||||
from plugins.MutePlugin.plugin import create_plugin
|
||||
from src.core.tooling import ToolExecutionContext, ToolInvocation
|
||||
from src.plugin_runtime.component_query import ComponentQueryService
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
|
||||
def _build_plugin() -> MaiBotPlugin:
|
||||
"""构造已注入默认配置的插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
plugin.set_plugin_config(plugin.get_default_config())
|
||||
return plugin
|
||||
|
||||
|
||||
def test_mute_plugin_manifest_is_valid_v2() -> None:
|
||||
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
|
||||
|
||||
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
|
||||
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
|
||||
|
||||
assert manifest is not None
|
||||
assert manifest.id == "sengokucola.mute-plugin"
|
||||
assert manifest.manifest_version == 2
|
||||
|
||||
|
||||
def test_create_plugin_returns_sdk_plugin() -> None:
|
||||
"""插件入口应返回 SDK 插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
|
||||
assert isinstance(plugin, MaiBotPlugin)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_command_calls_napcat_group_ban_api() -> None:
|
||||
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
plugin.set_plugin_config(
|
||||
{
|
||||
**plugin.get_default_config(),
|
||||
"components": {
|
||||
"enable_smart_mute": True,
|
||||
"enable_mute_command": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok", "retcode": 0}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message, intercept = await plugin.handle_mute_command(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
user_id="42",
|
||||
matched_groups={
|
||||
"target": "张三",
|
||||
"duration": "120",
|
||||
"reason": "刷屏",
|
||||
},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
assert intercept is True
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["version"] == "1"
|
||||
assert api_call["args"]["args"] == {
|
||||
"group_id": "10001",
|
||||
"user_id": "123456",
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_requires_target_person_name() -> None:
|
||||
"""禁言工具在缺少目标时应直接失败并提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
return {"success": True}
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="",
|
||||
duration="60",
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "禁言目标不能为空"
|
||||
assert capability_calls[-1]["capability"] == "send.text"
|
||||
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
|
||||
"""禁言工具应能兼容解包多层 capability 返回结果。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "result": {"success": True, "value": "123456"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["args"]["user_id"] == "123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
|
||||
"""禁言工具应在检测到群主时提前返回明确提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "owner"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
assert not any(
|
||||
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
for call in capability_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
|
||||
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
|
||||
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {
|
||||
"success": True,
|
||||
"result": {
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": None,
|
||||
"message": "",
|
||||
"wording": "",
|
||||
},
|
||||
}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
|
||||
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
|
||||
"""插件工具执行时应自动补齐群聊上下文字段。"""
|
||||
|
||||
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
|
||||
anchor_message = SimpleNamespace(
|
||||
message_info=SimpleNamespace(
|
||||
group_info=SimpleNamespace(group_id="10001"),
|
||||
user_info=SimpleNamespace(user_id="20002"),
|
||||
)
|
||||
)
|
||||
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
|
||||
context = ToolExecutionContext(
|
||||
session_id="session-1",
|
||||
stream_id="session-1",
|
||||
reasoning="test",
|
||||
metadata={"anchor_message": anchor_message},
|
||||
)
|
||||
|
||||
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
|
||||
|
||||
assert payload["target"] == "张三"
|
||||
assert payload["stream_id"] == "session-1"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["group_id"] == "10001"
|
||||
assert payload["user_id"] == "20002"
|
||||
@@ -1,227 +0,0 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.warning_messages: list[str] = []
|
||||
|
||||
def debug(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def info(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def warning(self, msg: str) -> None:
|
||||
self.warning_messages.append(msg)
|
||||
|
||||
def error(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def load_utils_module(monkeypatch, qq_account=123456, platforms=None):
|
||||
logger = DummyLogger()
|
||||
configured_platforms = platforms or []
|
||||
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
for package_name in [
|
||||
"src",
|
||||
"src.chat",
|
||||
"src.chat.message_receive",
|
||||
"src.chat.utils",
|
||||
"src.common",
|
||||
"src.config",
|
||||
"src.llm_models",
|
||||
"src.person_info",
|
||||
]:
|
||||
if package_name not in sys.modules:
|
||||
package_module = ModuleType(package_name)
|
||||
package_module.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, package_name, package_module)
|
||||
|
||||
jieba_module = ModuleType("jieba")
|
||||
jieba_module.cut = lambda text: list(text)
|
||||
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
|
||||
|
||||
logger_module = _stub_module("src.common.logger")
|
||||
logger_module.get_logger = lambda _name: logger
|
||||
|
||||
config_module = _stub_module("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(
|
||||
bot=SimpleNamespace(
|
||||
qq_account=qq_account,
|
||||
platforms=configured_platforms,
|
||||
nickname="MaiBot",
|
||||
alias_names=[],
|
||||
),
|
||||
chat=SimpleNamespace(
|
||||
at_bot_inevitable_reply=1,
|
||||
mentioned_bot_reply=1,
|
||||
),
|
||||
)
|
||||
config_module.model_config = SimpleNamespace()
|
||||
|
||||
message_module = _stub_module("src.chat.message_receive.message")
|
||||
|
||||
class SessionMessage:
|
||||
pass
|
||||
|
||||
message_module.SessionMessage = SessionMessage
|
||||
|
||||
chat_manager_module = _stub_module("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None)
|
||||
|
||||
llm_module = _stub_module("src.llm_models.utils_model")
|
||||
|
||||
class LLMRequest:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
llm_module.LLMRequest = LLMRequest
|
||||
|
||||
person_module = _stub_module("src.person_info.person_info")
|
||||
|
||||
class Person:
|
||||
pass
|
||||
|
||||
person_module.Person = Person
|
||||
|
||||
typo_generator_module = _stub_module("src.chat.utils.typo_generator")
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
def create_typo_sentence(self, sentence: str):
|
||||
return sentence, ""
|
||||
|
||||
typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator
|
||||
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py"
|
||||
spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path)
|
||||
utils_module = importlib.util.module_from_spec(spec)
|
||||
utils_module.__package__ = "src.chat.utils"
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(utils_module)
|
||||
return utils_module, logger
|
||||
|
||||
|
||||
def test_platform_specific_bot_accounts(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=[" TG : tg_bot ", "discord: disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_bot_account("qq") == "123456"
|
||||
assert utils_module.get_bot_account("webui") == "123456"
|
||||
assert utils_module.get_bot_account("telegram") == "tg_bot"
|
||||
assert utils_module.get_bot_account("tg") == "tg_bot"
|
||||
assert utils_module.get_bot_account("discord") == "disc_bot"
|
||||
|
||||
assert utils_module.is_bot_self("qq", "123456")
|
||||
assert utils_module.is_bot_self("webui", "123456")
|
||||
assert utils_module.is_bot_self("telegram", "tg_bot")
|
||||
assert utils_module.is_bot_self(" TG ", "tg_bot")
|
||||
|
||||
|
||||
def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=["TG:tg_bot", "discord:disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_all_bot_accounts() == {
|
||||
"qq": "123456",
|
||||
"webui": "123456",
|
||||
"telegram": "tg_bot",
|
||||
"tg": "tg_bot",
|
||||
"discord": "disc_bot",
|
||||
}
|
||||
|
||||
|
||||
def test_get_all_bot_accounts_keeps_canonical_qq_identity(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=["qq:999999", "webui:888888", "TG:tg_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_all_bot_accounts()["qq"] == "123456"
|
||||
assert utils_module.get_all_bot_accounts()["webui"] == "123456"
|
||||
|
||||
|
||||
def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch):
|
||||
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
assert utils_module.is_bot_self("unknown_platform", "123456") is False
|
||||
assert logger.warning_messages
|
||||
assert "unknown_platform" in logger.warning_messages[-1]
|
||||
|
||||
|
||||
def test_unknown_platform_warns_only_once(monkeypatch):
|
||||
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
assert utils_module.is_bot_self("unknown_platform", "first") is False
|
||||
assert utils_module.is_bot_self(" unknown_platform ", "second") is False
|
||||
assert len(logger.warning_messages) == 1
|
||||
|
||||
|
||||
def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"])
|
||||
|
||||
assert utils_module.get_bot_account("qq") == ""
|
||||
assert utils_module.get_bot_account("webui") == ""
|
||||
assert utils_module.is_bot_self("qq", "0") is False
|
||||
assert utils_module.is_bot_self("webui", "0") is False
|
||||
|
||||
|
||||
def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="@tg_bot 你好",
|
||||
platform="telegram",
|
||||
is_mentioned=False,
|
||||
message_segment=None,
|
||||
message_info=SimpleNamespace(
|
||||
additional_config={},
|
||||
user_info=SimpleNamespace(user_id="user_1"),
|
||||
),
|
||||
)
|
||||
|
||||
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
|
||||
|
||||
assert is_mentioned is True
|
||||
assert is_at is True
|
||||
assert reply_probability == 1.0
|
||||
|
||||
|
||||
def test_is_mentioned_bot_in_message_normalizes_qq_platform(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="@<MaiBot:123456> 你好",
|
||||
platform=" QQ ",
|
||||
is_mentioned=False,
|
||||
message_segment=None,
|
||||
message_info=SimpleNamespace(
|
||||
additional_config={},
|
||||
user_info=SimpleNamespace(user_id="user_1"),
|
||||
),
|
||||
)
|
||||
|
||||
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
|
||||
|
||||
assert is_mentioned is True
|
||||
assert is_at is True
|
||||
assert reply_probability == 1.0
|
||||
@@ -612,13 +612,6 @@ class ChatBot:
|
||||
scope=scope,
|
||||
) # 确保会话存在
|
||||
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_incoming_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}")
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令。
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.panel import Panel
|
||||
@@ -13,7 +13,6 @@ from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
@@ -32,9 +31,10 @@ from src.maisaka.context_messages import (
|
||||
ReferenceMessage,
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
build_llm_message_from_context,
|
||||
)
|
||||
from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
|
||||
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
@@ -110,11 +110,15 @@ class BaseMaisakaReplyGenerator:
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
# 只能根据结构化来源字段判断是否为 bot 自身写回的历史消息,
|
||||
# 不能依赖昵称/群名片等可控文本,避免误判和提示注入。
|
||||
if message.source_kind != "guided_reply":
|
||||
return ""
|
||||
|
||||
plain_text = message.processed_plain_text.strip()
|
||||
_, body = parse_speaker_content(plain_text)
|
||||
normalized_body = body.strip() or plain_text
|
||||
return self._normalize_content(normalized_body) if normalized_body else ""
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
if reply_message is None:
|
||||
@@ -210,6 +214,7 @@ class BaseMaisakaReplyGenerator:
|
||||
self,
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
reference_info: str = "",
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
) -> str:
|
||||
@@ -234,8 +239,13 @@ class BaseMaisakaReplyGenerator:
|
||||
sections.append(expression_habits.strip())
|
||||
if target_message_block:
|
||||
sections.append(target_message_block)
|
||||
reply_reference_lines: List[str] = []
|
||||
if reply_reason.strip():
|
||||
sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
reply_reference_lines.append(f"【最新推理】\n{reply_reason.strip()}")
|
||||
if reference_info.strip():
|
||||
reply_reference_lines.append(f"【参考信息】\n{reference_info.strip()}")
|
||||
if reply_reference_lines:
|
||||
sections.append("【回复信息参考】\n" + "\n\n".join(reply_reference_lines))
|
||||
if not sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
@@ -243,28 +253,6 @@ class BaseMaisakaReplyGenerator:
|
||||
def _build_reply_instruction(self) -> str:
|
||||
return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。"
|
||||
|
||||
def _build_visual_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
enable_visual_message: bool,
|
||||
) -> Optional[Message]:
|
||||
if not enable_visual_message:
|
||||
return None
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
if not raw_message.components:
|
||||
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
|
||||
|
||||
visual_message = SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=message.processed_plain_text,
|
||||
timestamp=message.timestamp,
|
||||
message_id=message.message_id,
|
||||
original_message=message.original_message,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
return visual_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
@@ -284,12 +272,10 @@ class BaseMaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
visual_message = self._build_visual_user_message(message, enable_visual_message)
|
||||
if visual_message is not None:
|
||||
messages.append(visual_message)
|
||||
continue
|
||||
|
||||
llm_message = message.to_llm_message()
|
||||
llm_message = build_llm_message_from_context(
|
||||
message,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
continue
|
||||
@@ -308,6 +294,7 @@ class BaseMaisakaReplyGenerator:
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
reference_info: str = "",
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
enable_visual_message: bool = False,
|
||||
@@ -316,6 +303,7 @@ class BaseMaisakaReplyGenerator:
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
reference_info=reference_info,
|
||||
expression_habits=expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
@@ -377,6 +365,7 @@ class BaseMaisakaReplyGenerator:
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
reference_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
from_plugin: bool = True,
|
||||
@@ -461,6 +450,7 @@ class BaseMaisakaReplyGenerator:
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
reference_info=reference_info or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
@@ -486,6 +476,7 @@ class BaseMaisakaReplyGenerator:
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
reference_info=reference_info or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
enable_visual_message=self._resolve_enable_visual_message(model_info),
|
||||
@@ -504,7 +495,6 @@ class BaseMaisakaReplyGenerator:
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
selection_reason=f"ID: {preview_chat_id}",
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title="Reply Prompt",
|
||||
border_style="bright_yellow",
|
||||
|
||||
@@ -159,7 +159,6 @@ MODULE_ALIASES = {
|
||||
"planner": "规划器",
|
||||
"config": "配置",
|
||||
"main": "主程序",
|
||||
"chat_history_summarizer": "聊天概括器",
|
||||
"plugin_runtime.integration": "IPC插件系统",
|
||||
"plugin_runtime.host.supervisor": "插件监督器",
|
||||
"plugin_runtime.host.runner_manager": "插件监督器",
|
||||
|
||||
@@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.7.1"
|
||||
CONFIG_VERSION: str = "8.8.0"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.0"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -414,15 +414,6 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
|
||||
|
||||
long_term_auto_summary_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "book-open",
|
||||
},
|
||||
)
|
||||
"""是否自动启动聊天总结并导入长期记忆"""
|
||||
|
||||
person_fact_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -578,77 +569,9 @@ class MemoryConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
"""反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小"""
|
||||
chat_history_topic_check_message_threshold: int = Field(
|
||||
default=80,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查"""
|
||||
|
||||
chat_history_topic_check_time_hours: float = Field(
|
||||
default=8.0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock",
|
||||
},
|
||||
)
|
||||
"""聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查"""
|
||||
|
||||
chat_history_topic_check_min_messages: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""聊天历史话题检查的时间触发模式下的最小消息数阈值"""
|
||||
|
||||
chat_history_finalize_no_update_checks: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "check-circle",
|
||||
},
|
||||
)
|
||||
"""聊天历史话题打包存储的连续无更新检查次数阈值,当话题连续N次检查无新增内容时触发打包存储"""
|
||||
|
||||
chat_history_finalize_message_count: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "package",
|
||||
},
|
||||
)
|
||||
"""聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储"""
|
||||
|
||||
def model_post_init(self, context: Optional[dict] = None) -> None:
|
||||
"""验证配置值"""
|
||||
if self.chat_history_topic_check_message_threshold < 1:
|
||||
raise ValueError(
|
||||
f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}"
|
||||
)
|
||||
if self.chat_history_topic_check_time_hours <= 0:
|
||||
raise ValueError(
|
||||
f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}"
|
||||
)
|
||||
if self.chat_history_topic_check_min_messages < 1:
|
||||
raise ValueError(
|
||||
f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}"
|
||||
)
|
||||
if self.chat_history_finalize_no_update_checks < 1:
|
||||
raise ValueError(
|
||||
f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}"
|
||||
)
|
||||
if self.chat_history_finalize_message_count < 1:
|
||||
raise ValueError(
|
||||
f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}"
|
||||
)
|
||||
if self.feedback_correction_window_hours <= 0:
|
||||
raise ValueError(
|
||||
f"feedback_correction_window_hours 必须大于0,当前值: {self.feedback_correction_window_hours}"
|
||||
|
||||
@@ -335,6 +335,8 @@ async def send_emoji_for_maisaka(
|
||||
storage_message=True,
|
||||
set_reply=False,
|
||||
reply_message=None,
|
||||
sync_to_maisaka_history=True,
|
||||
maisaka_source_kind="guided_reply",
|
||||
)
|
||||
sent = sent_message is not None
|
||||
except Exception as exc:
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from json_repair import repair_json
|
||||
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
||||
@@ -119,6 +120,13 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
|
||||
"""OpenAI 非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
def _build_fallback_tool_call_id(prefix: str) -> str:
|
||||
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
|
||||
|
||||
normalized_prefix = str(prefix).strip() or "tool_call"
|
||||
return f"{normalized_prefix}_{uuid4().hex}"
|
||||
|
||||
|
||||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||||
"""将配置中的推理解析模式收敛为枚举值。
|
||||
|
||||
@@ -609,7 +617,7 @@ def _extract_xml_tool_calls(
|
||||
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
|
||||
call_id=_build_fallback_tool_call_id("xml_tool_call"),
|
||||
func_name=function_name,
|
||||
args=arguments,
|
||||
)
|
||||
@@ -855,7 +863,7 @@ class _OpenAIStreamAccumulator:
|
||||
if raw_arguments
|
||||
else None
|
||||
)
|
||||
call_id = state.call_id or f"tool_call_{index}"
|
||||
call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}")
|
||||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||||
|
||||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from base64 import b64decode
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent
|
||||
@@ -12,13 +12,10 @@ from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionResult
|
||||
|
||||
from ..context_messages import SessionBackedMessage
|
||||
from ..history_utils import build_prefixed_message_sequence, build_session_message_visible_text
|
||||
from ..message_adapter import format_speaker_content
|
||||
from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
from ..reasoning_engine import MaisakaReasoningEngine
|
||||
from ..runtime import MaisakaHeartFlowChatting
|
||||
|
||||
@@ -139,37 +136,6 @@ class BuiltinToolRuntimeContext:
|
||||
|
||||
return self.engine._get_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _build_visible_text_from_sent_message(message: "SessionMessage") -> str:
|
||||
"""将已发送消息转换为 Maisaka 可见文本。"""
|
||||
|
||||
return build_session_message_visible_text(message)
|
||||
|
||||
def append_sent_message_to_chat_history(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
*,
|
||||
source_kind: str = "guided_reply",
|
||||
) -> None:
|
||||
"""将真实已发送消息同步到 Maisaka 历史。"""
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
planner_prefix = build_planner_prefix(
|
||||
timestamp=message.timestamp,
|
||||
user_name=speaker_name,
|
||||
group_card=user_info.user_cardname or "",
|
||||
message_id=message.message_id,
|
||||
include_message_id=not message.is_notify and bool(message.message_id),
|
||||
)
|
||||
history_message = SessionBackedMessage.from_session_message(
|
||||
message,
|
||||
raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix),
|
||||
visible_text=self._build_visible_text_from_sent_message(message),
|
||||
source_kind=source_kind,
|
||||
)
|
||||
self.runtime._chat_history.append(history_message)
|
||||
|
||||
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
|
||||
"""将引导回复写回 Maisaka 历史。"""
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ def get_tool_spec() -> ToolSpec:
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- msg_id:string,必填。要回复的目标用户消息编号。\n"
|
||||
"- set_quote:boolean,可选。以引用回复的方式发送,默认 true。"
|
||||
"- set_quote:boolean,可选。以引用回复的方式发送,默认 true。\n"
|
||||
"- reference_info:string,可选。上文中有助于回复的所有参考信息,使用平文本格式。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
@@ -50,6 +51,11 @@ def get_tool_spec() -> ToolSpec:
|
||||
"description": "以引用回复的方式发送这条回复,不用每句都引用。",
|
||||
"default": True,
|
||||
},
|
||||
"reference_info": {
|
||||
"type": "string",
|
||||
"description": "有助于回复的信息,之前搜集得到的事实性信息,记忆等,使用平文本格式。",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["msg_id"],
|
||||
},
|
||||
@@ -75,6 +81,7 @@ async def handle_tool(
|
||||
"""执行 reply 内置工具。"""
|
||||
|
||||
latest_thought = context.reasoning if context is not None else invocation.reasoning
|
||||
reference_info = str(invocation.arguments.get("reference_info") or "").strip()
|
||||
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||
set_quote = bool(invocation.arguments.get("set_quote", True))
|
||||
|
||||
@@ -117,6 +124,7 @@ async def handle_tool(
|
||||
try:
|
||||
success, reply_result = await replyer.generate_reply_with_context(
|
||||
reply_reason=latest_thought,
|
||||
reference_info=reference_info,
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
reply_message=target_message,
|
||||
chat_history=tool_ctx.runtime._chat_history,
|
||||
@@ -152,7 +160,6 @@ async def handle_tool(
|
||||
combined_reply_text = "".join(reply_segments)
|
||||
try:
|
||||
sent = False
|
||||
sent_messages = []
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
for segment in reply_segments:
|
||||
render_cli_message(segment)
|
||||
@@ -166,11 +173,12 @@ async def handle_tool(
|
||||
reply_message=target_message if set_quote and index == 0 else None,
|
||||
selected_expressions=reply_result.selected_expression_ids or None,
|
||||
typing=index > 0,
|
||||
sync_to_maisaka_history=True,
|
||||
maisaka_source_kind="guided_reply",
|
||||
)
|
||||
sent = sent_message is not None
|
||||
if not sent:
|
||||
break
|
||||
sent_messages.append(sent_message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
|
||||
@@ -198,9 +206,6 @@ async def handle_tool(
|
||||
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
|
||||
else:
|
||||
for sent_message in sent_messages:
|
||||
tool_ctx.append_sent_message_to_chat_history(sent_message)
|
||||
tool_ctx.runtime._record_reply_sent()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
|
||||
@@ -53,40 +53,16 @@ def get_tool_spec() -> ToolSpec:
|
||||
return ToolSpec(
|
||||
name="send_emoji",
|
||||
brief_description="发送一个合适的表情包来辅助表达情绪。",
|
||||
detailed_description="参数说明:\n- emotion:string,可选。希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
detailed_description="无需参数,直接发送一个合适的表情包。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"emotion": {
|
||||
"type": "string",
|
||||
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
},
|
||||
},
|
||||
"properties": {},
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_candidate_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""清洗候选表情上的情绪标签。"""
|
||||
|
||||
raw_emotions = getattr(emoji, "emotion", None)
|
||||
if isinstance(raw_emotions, list) and raw_emotions:
|
||||
return [str(item).strip() for item in raw_emotions if str(item).strip()]
|
||||
|
||||
description = str(getattr(emoji, "description", "") or "").strip()
|
||||
if not description:
|
||||
return []
|
||||
|
||||
normalized_description = (
|
||||
description.replace(",", ",")
|
||||
.replace("、", ",")
|
||||
.replace(";", ",")
|
||||
)
|
||||
return [item.strip() for item in normalized_description.split(",") if item.strip()]
|
||||
|
||||
|
||||
async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes:
|
||||
"""读取单个表情包图片字节。"""
|
||||
|
||||
@@ -232,18 +208,6 @@ async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBacke
|
||||
)
|
||||
|
||||
|
||||
def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str:
|
||||
"""构建供监控展示使用的候选表情摘要。"""
|
||||
|
||||
summary_lines: list[str] = []
|
||||
for index, emoji in enumerate(emojis, start=1):
|
||||
description = emoji.description.strip() or "(无描述)"
|
||||
emotions = "、".join(_normalize_candidate_emotions(emoji)) or "无"
|
||||
summary_lines.append(f"{index}. 描述:{description}")
|
||||
summary_lines.append(f" 情绪:{emotions}")
|
||||
return "\n".join(summary_lines).strip()
|
||||
|
||||
|
||||
def _build_send_emoji_monitor_detail(
|
||||
*,
|
||||
request_messages: Optional[list[dict[str, Any]]] = None,
|
||||
@@ -252,7 +216,7 @@ def _build_send_emoji_monitor_detail(
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
extra_sections: Optional[list[dict[str, str]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建 emotion tool 统一监控详情。"""
|
||||
"""构建 send_emoji 工具统一监控详情。"""
|
||||
|
||||
detail: Dict[str, Any] = {}
|
||||
if isinstance(request_messages, list) and request_messages:
|
||||
@@ -281,7 +245,6 @@ def _build_send_emoji_monitor_detail(
|
||||
def _build_send_emoji_monitor_metadata(
|
||||
selection_metadata: Dict[str, Any],
|
||||
*,
|
||||
requested_emotion: str,
|
||||
send_result: Optional[Any] = None,
|
||||
error_message: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
@@ -293,7 +256,6 @@ def _build_send_emoji_monitor_metadata(
|
||||
|
||||
if send_result is not None:
|
||||
result_lines = [
|
||||
f"请求情绪:{requested_emotion or '未指定'}",
|
||||
f"命中情绪:{send_result.matched_emotion or '未命中'}",
|
||||
f"表情描述:{send_result.description or '无描述'}",
|
||||
f"情绪标签:{'、'.join(send_result.emotions) if send_result.emotions else '无'}",
|
||||
@@ -306,10 +268,7 @@ def _build_send_emoji_monitor_metadata(
|
||||
elif error_message.strip():
|
||||
extra_sections.append({
|
||||
"title": "表情发送结果",
|
||||
"content": (
|
||||
f"请求情绪:{requested_emotion or '未指定'}\n"
|
||||
f"发送结果:{error_message.strip()}"
|
||||
),
|
||||
"content": f"发送结果:{error_message.strip()}",
|
||||
})
|
||||
|
||||
if extra_sections:
|
||||
@@ -322,7 +281,6 @@ def _build_send_emoji_monitor_metadata(
|
||||
|
||||
async def _select_emoji_with_sub_agent(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
requested_emotion: str,
|
||||
reasoning: str,
|
||||
context_texts: list[str],
|
||||
sample_size: int,
|
||||
@@ -347,14 +305,12 @@ async def _select_emoji_with_sub_agent(
|
||||
f"一共 {len(sampled_emojis)} 个位置。\n"
|
||||
f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n"
|
||||
f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n"
|
||||
"如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n"
|
||||
"你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n"
|
||||
'返回格式固定为:{"emoji_index":1,"reason":"简短理由"}'
|
||||
)
|
||||
prompt_message = ReferenceMessage(
|
||||
content=(
|
||||
f"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"候选总数: {len(sampled_emojis)}\n"
|
||||
f"拼图布局: {grid_rows}x{grid_columns}\n"
|
||||
"请只输出 JSON。"
|
||||
@@ -439,7 +395,6 @@ async def handle_tool(
|
||||
"""执行 send_emoji 内置工具。"""
|
||||
|
||||
del context
|
||||
emotion = str(invocation.arguments.get("emotion") or "").strip()
|
||||
context_texts = [
|
||||
message.processed_plain_text.strip()
|
||||
for message in tool_ctx.runtime._chat_history[-5:]
|
||||
@@ -450,23 +405,20 @@ async def handle_tool(
|
||||
"message": "",
|
||||
"description": "",
|
||||
"emotion": [],
|
||||
"requested_emotion": emotion,
|
||||
"matched_emotion": "",
|
||||
"reason": "",
|
||||
}
|
||||
selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}}
|
||||
|
||||
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}")
|
||||
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具")
|
||||
|
||||
try:
|
||||
send_result = await send_emoji_for_maisaka(
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
requested_emotion=emotion,
|
||||
reasoning=tool_ctx.engine.last_reasoning_content,
|
||||
context_texts=context_texts,
|
||||
emoji_selector=lambda requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
|
||||
emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
|
||||
tool_ctx,
|
||||
requested_emotion,
|
||||
reasoning,
|
||||
list(context_texts or []),
|
||||
sample_size,
|
||||
@@ -482,7 +434,6 @@ async def handle_tool(
|
||||
structured_content=structured_result,
|
||||
metadata=_build_send_emoji_monitor_metadata(
|
||||
selection_metadata,
|
||||
requested_emotion=emotion,
|
||||
error_message=structured_result["message"],
|
||||
),
|
||||
)
|
||||
@@ -493,11 +444,9 @@ async def handle_tool(
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
|
||||
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
|
||||
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
|
||||
f"命中情绪={send_result.matched_emotion!r}"
|
||||
)
|
||||
if send_result.sent_message is not None:
|
||||
tool_ctx.append_sent_message_to_chat_history(send_result.sent_message)
|
||||
else:
|
||||
if send_result.sent_message is None:
|
||||
tool_ctx.append_sent_emoji_to_chat_history(
|
||||
emoji_base64=send_result.emoji_base64,
|
||||
success_message=_EMOJI_SUCCESS_MESSAGE,
|
||||
@@ -509,7 +458,6 @@ async def handle_tool(
|
||||
structured_content=structured_result,
|
||||
metadata=_build_send_emoji_monitor_metadata(
|
||||
selection_metadata,
|
||||
requested_emotion=emotion,
|
||||
send_result=send_result,
|
||||
),
|
||||
)
|
||||
@@ -521,7 +469,7 @@ async def handle_tool(
|
||||
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送失败 "
|
||||
f"请求情绪={emotion!r} 错误信息={send_result.message}"
|
||||
f"错误信息={send_result.message}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
@@ -529,7 +477,6 @@ async def handle_tool(
|
||||
structured_content=structured_result,
|
||||
metadata=_build_send_emoji_monitor_metadata(
|
||||
selection_metadata,
|
||||
requested_emotion=emotion,
|
||||
send_result=send_result,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -30,9 +30,15 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tool import get_builtin_tools
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
ToolResultMessage,
|
||||
build_llm_message_from_context,
|
||||
)
|
||||
from .history_utils import drop_orphan_tool_results
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .visual_mode_utils import resolve_enable_visual_planner
|
||||
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
|
||||
@@ -395,6 +401,7 @@ class MaisakaChatLoopService:
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
enable_visual_message: bool,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
@@ -413,7 +420,10 @@ class MaisakaChatLoopService:
|
||||
messages.append(system_msg.build())
|
||||
|
||||
for msg in selected_history:
|
||||
llm_message = msg.to_llm_message()
|
||||
llm_message = build_llm_message_from_context(
|
||||
msg,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
|
||||
@@ -475,12 +485,15 @@ class MaisakaChatLoopService:
|
||||
|
||||
if not self._prompts_loaded:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
enable_visual_message = self._resolve_enable_visual_message(request_kind)
|
||||
selected_history, selection_reason = self.select_llm_context_messages(
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
built_messages = self._build_request_messages(
|
||||
selected_history,
|
||||
enable_visual_message=enable_visual_message,
|
||||
injected_user_messages=injected_user_messages,
|
||||
)
|
||||
|
||||
@@ -528,14 +541,12 @@ class MaisakaChatLoopService:
|
||||
|
||||
prompt_section: RenderableType | None = None
|
||||
if global_config.debug.show_maisaka_thinking:
|
||||
image_display_mode: str = "path_link" if global_config.maisaka.show_image_path else "legacy"
|
||||
prompt_section = PromptCLIVisualizer.build_prompt_section(
|
||||
built_messages,
|
||||
category="planner" if request_kind != "timing_gate" else "timing_gate",
|
||||
chat_id=self._session_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
folded=global_config.debug.fold_maisaka_thinking,
|
||||
tool_definitions=list(all_tools),
|
||||
)
|
||||
@@ -604,6 +615,7 @@ class MaisakaChatLoopService:
|
||||
def select_llm_context_messages(
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
enable_visual_message: Optional[bool] = None,
|
||||
request_kind: str = "planner",
|
||||
max_context_size: Optional[int] = None,
|
||||
) -> tuple[List[LLMContextMessage], str]:
|
||||
@@ -617,9 +629,21 @@ class MaisakaChatLoopService:
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
active_enable_visual_message = (
|
||||
enable_visual_message
|
||||
if enable_visual_message is not None
|
||||
else MaisakaChatLoopService._resolve_enable_visual_message(request_kind)
|
||||
)
|
||||
|
||||
for index in range(len(filtered_history) - 1, -1, -1):
|
||||
message = filtered_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
if (
|
||||
build_llm_message_from_context(
|
||||
message,
|
||||
enable_visual_message=active_enable_visual_message,
|
||||
)
|
||||
is None
|
||||
):
|
||||
continue
|
||||
|
||||
selected_indices.append(index)
|
||||
@@ -629,18 +653,18 @@ class MaisakaChatLoopService:
|
||||
break
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"没有选择到上下文消息,实际发送 {effective_context_size} 条 user/assistant 消息"
|
||||
return [], "实际发送 0 条消息(tool 0 条,普通消息 0 条)"
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [filtered_history[index] for index in selected_indices]
|
||||
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = drop_orphan_tool_results(selected_history)
|
||||
tool_message_count = sum(1 for message in selected_history if isinstance(message, ToolResultMessage))
|
||||
normal_message_count = len(selected_history) - tool_message_count
|
||||
selection_reason = (
|
||||
f"上下文裁剪:最近 {effective_context_size} 条 user/assistant 消息,"
|
||||
f"实际发送 {len(selected_history)} 条"
|
||||
f"实际发送 {len(selected_history)} 条消息"
|
||||
f"|消息 {normal_message_count} 条|tool {tool_message_count} 条"
|
||||
)
|
||||
if hidden_assistant_count > 0:
|
||||
selection_reason += f",已隐藏最早 {hidden_assistant_count} 条 assistant 消息"
|
||||
return (
|
||||
selected_history,
|
||||
selection_reason,
|
||||
@@ -685,6 +709,12 @@ class MaisakaChatLoopService:
|
||||
|
||||
return filtered_history
|
||||
|
||||
@staticmethod
|
||||
def _resolve_enable_visual_message(request_kind: str) -> bool:
|
||||
if request_kind in {"planner", "timing_gate"}:
|
||||
return resolve_enable_visual_planner()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _hide_early_assistant_messages(
|
||||
selected_history: List[LLMContextMessage],
|
||||
|
||||
@@ -40,10 +40,15 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) -> bool:
|
||||
def _append_emoji_component(
|
||||
builder: MessageBuilder,
|
||||
component: EmojiComponent,
|
||||
*,
|
||||
enable_visual_message: bool,
|
||||
) -> bool:
|
||||
"""将表情组件追加到 LLM 消息构建器。"""
|
||||
image_format = _guess_image_format(component.binary_data)
|
||||
if image_format and component.binary_data:
|
||||
if enable_visual_message and image_format and component.binary_data:
|
||||
builder.add_text_content("[消息类型]表情包")
|
||||
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
|
||||
return True
|
||||
@@ -56,10 +61,15 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent)
|
||||
return True
|
||||
|
||||
|
||||
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
|
||||
def _append_image_component(
|
||||
builder: MessageBuilder,
|
||||
component: ImageComponent,
|
||||
*,
|
||||
enable_visual_message: bool,
|
||||
) -> bool:
|
||||
"""将图片组件追加到 LLM 消息构建器。"""
|
||||
image_format = _guess_image_format(component.binary_data)
|
||||
if image_format and component.binary_data:
|
||||
if enable_visual_message and image_format and component.binary_data:
|
||||
builder.add_text_content("[消息类型]图片")
|
||||
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
|
||||
return True
|
||||
@@ -216,6 +226,7 @@ def _build_message_from_sequence(
|
||||
message_sequence: MessageSequence,
|
||||
fallback_text: str,
|
||||
*,
|
||||
enable_visual_message: bool = True,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_calls: Optional[list[ToolCall]] = None,
|
||||
@@ -238,11 +249,25 @@ def _build_message_from_sequence(
|
||||
continue
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
has_content = _append_emoji_component(builder, component) or has_content
|
||||
has_content = (
|
||||
_append_emoji_component(
|
||||
builder,
|
||||
component,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
or has_content
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
has_content = _append_image_component(builder, component) or has_content
|
||||
has_content = (
|
||||
_append_image_component(
|
||||
builder,
|
||||
component,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
or has_content
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
@@ -297,7 +322,7 @@ class LLMContextMessage(ABC):
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
"""转换为统一 LLM 消息。"""
|
||||
|
||||
def consume_once(self) -> bool:
|
||||
@@ -328,11 +353,12 @@ class SessionBackedMessage(LLMContextMessage):
|
||||
def source(self) -> str:
|
||||
return self.source_kind
|
||||
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
return _build_message_from_sequence(
|
||||
RoleType.User,
|
||||
self.raw_message,
|
||||
self.processed_plain_text,
|
||||
enable_visual_message=enable_visual_message,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -366,7 +392,8 @@ class ComplexSessionMessage(SessionBackedMessage):
|
||||
def source(self) -> str:
|
||||
return f"{self.source_kind}:{self.complex_message_type}"
|
||||
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
del enable_visual_message
|
||||
message_sequence = MessageSequence([TextComponent(self.prompt_text)])
|
||||
return _build_message_from_sequence(
|
||||
RoleType.User,
|
||||
@@ -426,7 +453,8 @@ class ReferenceMessage(LLMContextMessage):
|
||||
def source(self) -> str:
|
||||
return self.reference_type.value
|
||||
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
del enable_visual_message
|
||||
message_sequence = MessageSequence([TextComponent(self.processed_plain_text)])
|
||||
return _build_message_from_sequence(RoleType.User, message_sequence, self.processed_plain_text)
|
||||
|
||||
@@ -463,7 +491,8 @@ class AssistantMessage(LLMContextMessage):
|
||||
def source(self) -> str:
|
||||
return self.source_kind
|
||||
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
del enable_visual_message
|
||||
message_sequence = MessageSequence([])
|
||||
if self.content:
|
||||
message_sequence.text(self.content)
|
||||
@@ -501,7 +530,8 @@ class ToolResultMessage(LLMContextMessage):
|
||||
def source(self) -> str:
|
||||
return self.tool_name or "tool"
|
||||
|
||||
def to_llm_message(self) -> Optional[Message]:
|
||||
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
|
||||
del enable_visual_message
|
||||
message_sequence = MessageSequence([TextComponent(self.content)])
|
||||
return _build_message_from_sequence(
|
||||
RoleType.Tool,
|
||||
@@ -510,3 +540,13 @@ class ToolResultMessage(LLMContextMessage):
|
||||
tool_call_id=self.tool_call_id,
|
||||
tool_name=self.tool_name,
|
||||
)
|
||||
|
||||
|
||||
def build_llm_message_from_context(
|
||||
context_message: LLMContextMessage,
|
||||
*,
|
||||
enable_visual_message: bool = True,
|
||||
) -> Optional[Message]:
|
||||
"""将 Maisaka 内部上下文消息转换为发给 LLM 的统一消息。"""
|
||||
|
||||
return context_message.to_llm_message(enable_visual_message=enable_visual_message)
|
||||
|
||||
@@ -799,7 +799,7 @@ class PromptCLIVisualizer:
|
||||
chat_id: str,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
image_display_mode: Literal["legacy", "path_link"] = "path_link",
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> RenderableType:
|
||||
"""构建用于查看完整 prompt 的折叠入口内容。"""
|
||||
@@ -864,7 +864,7 @@ class PromptCLIVisualizer:
|
||||
chat_id: str,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
image_display_mode: Literal["legacy", "path_link"] = "path_link",
|
||||
folded: bool,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> Panel:
|
||||
@@ -878,14 +878,10 @@ class PromptCLIVisualizer:
|
||||
chat_id=chat_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
else:
|
||||
ordered_panels = cls.build_prompt_panels(
|
||||
messages,
|
||||
image_display_mode=image_display_mode,
|
||||
)
|
||||
ordered_panels = cls.build_prompt_panels(messages)
|
||||
prompt_renderable = Group(*ordered_panels)
|
||||
|
||||
return Panel(
|
||||
@@ -1102,11 +1098,9 @@ class PromptCLIVisualizer:
|
||||
cls,
|
||||
messages: list[Any],
|
||||
*,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
image_display_mode: Literal["legacy", "path_link"] = "path_link",
|
||||
) -> List[Panel]:
|
||||
"""构建完整 prompt 可视化面板。"""
|
||||
if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}:
|
||||
image_display_mode = PromptImageDisplayMode.LEGACY
|
||||
settings = PromptImageDisplaySettings(
|
||||
display_mode=PromptImageDisplayMode(image_display_mode),
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
@@ -44,6 +44,7 @@ from .monitor_events import (
|
||||
emit_timing_gate_result,
|
||||
)
|
||||
from .planner_message_utils import build_planner_user_prefix_from_session_message
|
||||
from .visual_mode_utils import resolve_enable_visual_planner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .runtime import MaisakaHeartFlowChatting
|
||||
@@ -739,47 +740,10 @@ class MaisakaReasoningEngine:
|
||||
planner_prefix: str,
|
||||
) -> MessageSequence:
|
||||
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
|
||||
if self._resolve_enable_visual_planner():
|
||||
if resolve_enable_visual_planner():
|
||||
await self._hydrate_visual_components(message_sequence.components)
|
||||
return message_sequence
|
||||
|
||||
@staticmethod
|
||||
def _resolve_enable_visual_planner() -> bool:
|
||||
planner_mode = global_config.visual.planner_mode
|
||||
planner_task_config = config_manager.get_model_config().model_task_config.planner
|
||||
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
if planner_mode == "text":
|
||||
return False
|
||||
|
||||
planner_models: list[str] = list(planner_task_config.model_list)
|
||||
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
|
||||
non_visual_models = [
|
||||
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
|
||||
]
|
||||
|
||||
if planner_mode == "multimodal":
|
||||
if missing_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未定义的模型:"
|
||||
f"{', '.join(missing_models)}"
|
||||
)
|
||||
if non_visual_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:"
|
||||
f"{', '.join(non_visual_models)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if missing_models:
|
||||
logger.warning(
|
||||
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
|
||||
f"{', '.join(missing_models)},将退化为纯文本 planner"
|
||||
)
|
||||
return False
|
||||
|
||||
return bool(planner_models) and not non_visual_models
|
||||
|
||||
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
|
||||
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
|
||||
load_tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
@@ -183,6 +183,43 @@ class MaisakaHeartFlowChatting:
|
||||
self._talk_frequency_adjust = max(0.01, float(frequency))
|
||||
self._schedule_message_turn()
|
||||
|
||||
def append_sent_message_to_chat_history(
|
||||
self,
|
||||
message: SessionMessage,
|
||||
*,
|
||||
source_kind: str = "guided_reply",
|
||||
) -> bool:
|
||||
"""将一条已发送成功的消息同步到 Maisaka 内部历史。"""
|
||||
|
||||
try:
|
||||
from .context_messages import SessionBackedMessage
|
||||
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text
|
||||
from .planner_message_utils import build_planner_prefix
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
planner_prefix = build_planner_prefix(
|
||||
timestamp=message.timestamp,
|
||||
user_name=speaker_name,
|
||||
group_card=user_info.user_cardname or "",
|
||||
message_id=message.message_id,
|
||||
include_message_id=not message.is_notify and bool(message.message_id),
|
||||
)
|
||||
history_message = SessionBackedMessage.from_session_message(
|
||||
message,
|
||||
raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix),
|
||||
visible_text=build_session_message_visible_text(message),
|
||||
source_kind=source_kind,
|
||||
)
|
||||
self._chat_history.append(history_message)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 同步已发送消息到 Maisaka 历史失败: "
|
||||
f"message_id={message.message_id} error={exc}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def register_message(self, message: SessionMessage) -> None:
|
||||
"""缓存一条新消息并唤醒主循环。"""
|
||||
if self._running:
|
||||
@@ -1151,7 +1188,6 @@ class MaisakaHeartFlowChatting:
|
||||
chat_id=self.session_id,
|
||||
request_kind=labels["request_kind"],
|
||||
selection_reason=subtitle,
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title=labels["prompt_title"],
|
||||
border_style=border_style,
|
||||
|
||||
43
src/maisaka/visual_mode_utils.py
Normal file
43
src/maisaka/visual_mode_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager, global_config
|
||||
|
||||
logger = get_logger("maisaka_visual_mode")
|
||||
|
||||
|
||||
def resolve_enable_visual_planner() -> bool:
|
||||
"""根据 planner 配置解析当前是否应启用视觉消息。"""
|
||||
|
||||
planner_mode = global_config.visual.planner_mode
|
||||
planner_task_config = config_manager.get_model_config().model_task_config.planner
|
||||
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
if planner_mode == "text":
|
||||
return False
|
||||
|
||||
planner_models: list[str] = list(planner_task_config.model_list)
|
||||
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
|
||||
non_visual_models = [
|
||||
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
|
||||
]
|
||||
|
||||
if planner_mode == "multimodal":
|
||||
if missing_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未定义的模型:"
|
||||
f"{', '.join(missing_models)}"
|
||||
)
|
||||
if non_visual_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:"
|
||||
f"{', '.join(non_visual_models)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if missing_models:
|
||||
logger.warning(
|
||||
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
|
||||
f"{', '.join(missing_models)},将退化为纯文本 planner"
|
||||
)
|
||||
return False
|
||||
|
||||
return bool(planner_models) and not non_visual_models
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,6 +75,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
|
||||
text = str(args.get("text", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
|
||||
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
|
||||
if not text or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
|
||||
|
||||
@@ -85,6 +87,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
typing=bool(args.get("typing", False)),
|
||||
set_reply=bool(args.get("set_reply", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as exc:
|
||||
@@ -107,6 +111,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
|
||||
emoji_base64 = str(args.get("emoji_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
|
||||
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
|
||||
if not emoji_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
|
||||
|
||||
@@ -115,6 +121,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
emoji_base64=emoji_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as exc:
|
||||
@@ -137,6 +145,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
|
||||
image_base64 = str(args.get("image_base64", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
|
||||
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
|
||||
if not image_base64 or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
|
||||
|
||||
@@ -145,6 +155,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
image_base64=image_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as exc:
|
||||
@@ -167,6 +179,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
|
||||
command = str(args.get("command", ""))
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
|
||||
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
|
||||
if not command or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
|
||||
|
||||
@@ -177,6 +191,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
stream_id=stream_id,
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as exc:
|
||||
@@ -202,6 +218,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
if content is None:
|
||||
content = args.get("data", "")
|
||||
stream_id = str(args.get("stream_id", ""))
|
||||
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
|
||||
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
|
||||
if not message_type or not stream_id:
|
||||
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
|
||||
|
||||
@@ -213,6 +231,8 @@ class RuntimeCoreCapabilityMixin:
|
||||
display_message=str(args.get("display_message", "")),
|
||||
typing=bool(args.get("typing", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
return {"success": result}
|
||||
except Exception as exc:
|
||||
|
||||
@@ -4,16 +4,18 @@ import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy import delete, func
|
||||
from sqlmodel import SQLModel, select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ToolRecord
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
|
||||
logger = get_logger("database_service")
|
||||
|
||||
|
||||
@@ -158,7 +160,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
|
||||
|
||||
async def store_tool_info(
|
||||
chat_stream: BotChatSession,
|
||||
chat_stream: "BotChatSession",
|
||||
builtin_prompt: Optional[str] = None,
|
||||
display_prompt: str = "",
|
||||
tool_id: str = "",
|
||||
@@ -191,7 +193,7 @@ async def store_tool_info(
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream: BotChatSession,
|
||||
chat_stream: "BotChatSession",
|
||||
builtin_prompt: Optional[str] = None,
|
||||
display_prompt: str = "",
|
||||
thinking_id: str = "",
|
||||
|
||||
@@ -2,54 +2,20 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.config.config import global_config
|
||||
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
|
||||
|
||||
class LongTermMemorySessionManager:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._summarizers: Dict[str, ChatHistorySummarizer] = {}
|
||||
|
||||
async def on_message(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)):
|
||||
return
|
||||
session_id = str(getattr(message, "session_id", "") or "").strip()
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
created = False
|
||||
async with self._lock:
|
||||
summarizer = self._summarizers.get(session_id)
|
||||
if summarizer is None:
|
||||
summarizer = ChatHistorySummarizer(session_id=session_id)
|
||||
self._summarizers[session_id] = summarizer
|
||||
created = True
|
||||
if created:
|
||||
await summarizer.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
async with self._lock:
|
||||
items = list(self._summarizers.items())
|
||||
self._summarizers.clear()
|
||||
for session_id, summarizer in items:
|
||||
try:
|
||||
await summarizer.stop()
|
||||
except Exception as exc:
|
||||
logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc)
|
||||
|
||||
|
||||
class PersonFactWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
@@ -123,7 +89,11 @@ class PersonFactWritebackService:
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip()
|
||||
person_name = str(
|
||||
getattr(target_person, "person_name", "")
|
||||
or getattr(target_person, "nickname", "")
|
||||
or ""
|
||||
).strip()
|
||||
if not person_name:
|
||||
return
|
||||
|
||||
@@ -242,7 +212,6 @@ class PersonFactWritebackService:
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.session_manager = LongTermMemorySessionManager()
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self._started = False
|
||||
|
||||
@@ -255,15 +224,9 @@ class MemoryAutomationService:
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.session_manager.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.session_manager.on_message(message)
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
@@ -707,6 +707,28 @@ async def _notify_memory_automation_on_message_sent(message: SessionMessage) ->
|
||||
logger.warning(f"[{session_id}] 长期记忆人物事实写回注册失败: {exc}")
|
||||
|
||||
|
||||
def _sync_sent_message_to_maisaka_history(
|
||||
message: SessionMessage,
|
||||
*,
|
||||
source_kind: str,
|
||||
) -> None:
|
||||
"""将已发送成功的消息同步到当前会话对应的 Maisaka 历史。"""
|
||||
|
||||
session_id = str(message.session_id or "").strip()
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
||||
|
||||
runtime = heartflow_manager.heartflow_chat_list.get(session_id)
|
||||
if runtime is None:
|
||||
return
|
||||
runtime.append_sent_message_to_chat_history(message, source_kind=source_kind)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[SendService] 同步消息到 Maisaka 历史失败: session_id={session_id} error={exc}")
|
||||
|
||||
|
||||
def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
|
||||
"""输出 Platform IO 批量发送失败详情。
|
||||
|
||||
@@ -837,13 +859,15 @@ async def send_session_message_with_message(
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> Optional[SessionMessage]:
|
||||
"""统一发送一条内部消息,并返回最终发送成功的消息对象。"""
|
||||
if not message.message_id:
|
||||
logger.error("[SendService] 消息缺少 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_id,无法发送")
|
||||
|
||||
return await _send_via_platform_io(
|
||||
sent_message = await _send_via_platform_io(
|
||||
message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
@@ -851,6 +875,12 @@ async def send_session_message_with_message(
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if sent_message is not None and sync_to_maisaka_history:
|
||||
_sync_sent_message_to_maisaka_history(
|
||||
sent_message,
|
||||
source_kind=str(maisaka_source_kind or "outbound_send"),
|
||||
)
|
||||
return sent_message
|
||||
|
||||
|
||||
async def send_session_message(
|
||||
@@ -861,6 +891,8 @@ async def send_session_message(
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""统一发送一条内部消息。
|
||||
|
||||
@@ -893,6 +925,8 @@ async def send_session_message(
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
@@ -908,6 +942,8 @@ async def _send_to_target(
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定目标构建并发送消息,并返回是否发送成功。"""
|
||||
return (
|
||||
@@ -921,6 +957,8 @@ async def _send_to_target(
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
selected_expressions=selected_expressions,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
@@ -936,6 +974,8 @@ async def _send_to_target_with_message(
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> Optional[SessionMessage]:
|
||||
"""向指定目标构建并发送消息。
|
||||
|
||||
@@ -998,6 +1038,8 @@ async def _send_to_target_with_message(
|
||||
reply_message_id=reply_message.message_id if reply_message is not None else None,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
if sent_message is not None:
|
||||
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
|
||||
@@ -1019,6 +1061,8 @@ async def text_to_stream_with_message(
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
storage_message: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> Optional[SessionMessage]:
|
||||
"""向指定流发送文本消息,并返回发送成功后的消息对象。"""
|
||||
return await _send_to_target_with_message(
|
||||
@@ -1030,6 +1074,8 @@ async def text_to_stream_with_message(
|
||||
reply_message=reply_message,
|
||||
storage_message=storage_message,
|
||||
selected_expressions=selected_expressions,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
|
||||
|
||||
@@ -1041,6 +1087,8 @@ async def text_to_stream(
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
storage_message: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息。
|
||||
|
||||
@@ -1065,6 +1113,8 @@ async def text_to_stream(
|
||||
reply_message=reply_message,
|
||||
storage_message=storage_message,
|
||||
selected_expressions=selected_expressions,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
@@ -1076,6 +1126,8 @@ async def emoji_to_stream_with_message(
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> Optional[SessionMessage]:
|
||||
"""向指定流发送表情消息,并返回发送成功后的消息对象。"""
|
||||
return await _send_to_target_with_message(
|
||||
@@ -1086,6 +1138,8 @@ async def emoji_to_stream_with_message(
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
|
||||
|
||||
@@ -1095,6 +1149,8 @@ async def emoji_to_stream(
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定流发送表情消息。
|
||||
|
||||
@@ -1115,6 +1171,8 @@ async def emoji_to_stream(
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
@@ -1126,6 +1184,8 @@ async def image_to_stream(
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定流发送图片消息。
|
||||
|
||||
@@ -1147,6 +1207,8 @@ async def image_to_stream(
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
|
||||
|
||||
@@ -1160,6 +1222,8 @@ async def custom_to_stream(
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定流发送自定义类型消息。
|
||||
|
||||
@@ -1186,6 +1250,8 @@ async def custom_to_stream(
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
|
||||
|
||||
@@ -1198,6 +1264,8 @@ async def custom_reply_set_to_stream(
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
sync_to_maisaka_history: bool = False,
|
||||
maisaka_source_kind: str = "outbound_send",
|
||||
) -> bool:
|
||||
"""向指定流发送消息组件序列。
|
||||
|
||||
@@ -1223,4 +1291,6 @@ async def custom_reply_set_to_stream(
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user