fix:解决 upstream 合并冲突

在临时整合分支中合并 upstream/r-dev
保留人物事实写回与反馈纠错配置
移除已下线的聊天总结配置并同步测试
This commit is contained in:
A-Dawn
2026-04-16 10:27:25 +08:00
31 changed files with 349 additions and 5128 deletions

View File

@@ -1,148 +0,0 @@
from types import SimpleNamespace
import pytest
from src.memory_system import chat_history_summarizer as summarizer_module
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
summarizer.session_id = "session-1"
summarizer.log_prefix = "[session-1]"
return summarizer
@pytest.mark.asyncio
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=1,
theme="旅行计划",
summary="我们讨论了春游安排",
participants=["Alice", "Bob"],
start_time=1.0,
end_time=2.0,
original_text="long text",
)
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"] == "chat_history:1"
assert payload["chat_id"] == "session-1"
assert payload["participants"] == ["Alice", "Bob"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "user-1"
assert payload["group_id"] == ""
assert "主题:旅行计划" in payload["text"]
assert "概括:我们讨论了春游安排" in payload["text"]
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
await summarizer._import_to_long_term_memory(
record_id=2,
theme="",
summary="",
participants=[],
start_time=10.0,
end_time=20.0,
original_text="raw chat",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["record_id"] == 2
assert fallback_calls[0]["original_text"] == "raw chat"
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_ingest_summary(**kwargs):
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
)
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=3,
theme="电影",
summary="聊了电影推荐",
participants=["Alice"],
start_time=3.0,
end_time=4.0,
original_text="raw",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["theme"] == "电影"
@pytest.mark.asyncio
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._fallback_import_to_long_term_memory(
record_id=4,
theme="工作",
participants=["Alice"],
start_time=5.0,
end_time=6.0,
original_text="a" * 128,
)
assert len(calls) == 1
metadata = calls[0]["metadata"]
assert metadata["generate_from_chat"] is True
assert metadata["context_length"] == 12
assert calls[0]["respect_filter"] is True

View File

@@ -1,687 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import MemorySearchResult, memory_service
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
def _load_benchmark_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 32) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
code = ord(ch)
vec[code % dim] += 1.0
vec[(code * 7) % dim] += 0.5
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KnownPerson:
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
self.person_id = person_id
self.is_known = person_id in reverse_registry
self.person_name = reverse_registry.get(person_id, "")
self._registry = registry
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(max(0.01, float(sleep_seconds)))
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result: MemorySearchResult) -> str:
return "\n".join(hit.content for hit in search_result.hits)
def _keyword_hits(text: str, keywords: List[str]) -> int:
haystack = str(text or "")
return sum(1 for keyword in keywords if keyword in haystack)
def _keyword_recall(text: str, keywords: List[str]) -> float:
if not keywords:
return 1.0
return _keyword_hits(text, keywords) / float(len(keywords))
def _hit_blob(hit) -> str:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
return "\n".join(
[
str(hit.content or ""),
str(hit.title or ""),
str(hit.source or ""),
json.dumps(meta, ensure_ascii=False),
]
)
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
for index, hit in enumerate(search_result.hits[:5], start=1):
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
return index
return 0
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
return "\n".join(
(
f"{item.get('title', '')}\n"
f"{item.get('summary', '')}\n"
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
)
for item in items
)
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
chunks = []
for hit in search_result.hits:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
chunks.append(
"\n".join(
[
str(hit.title or ""),
str(hit.content or ""),
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
]
)
)
return "\n".join(chunks)
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
episode_source = f"chat_summary:{session_id}"
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
limit=20,
)
items = payload.get("items") or []
blob = _episode_blob_from_items(items)
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"episode_count": len(items),
"reports": reports,
}
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
episode_source = f"chat_summary:{session_id}"
for case in episode_cases:
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
query=case["query"],
limit=5,
)
items = payload.get("items") or []
blob = "\n".join(
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
for item in items
)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
result = await memory_service.search(
case["query"],
mode="episode",
chat_id=session_id,
respect_filter=False,
limit=5,
)
blob = _episode_blob_from_hits(result)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(result.hits),
"top_episode": result.hits[0].to_dict() if result.hits else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
search_case = dataset["search_cases"][0]
episode_case = dataset["episode_cases"][0]
aggregate_case = dataset["knowledge_fetcher_cases"][0]
first_record = (dataset.get("chat_history_records") or [{}])[0]
reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0
if reference_ts:
time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d")
else:
time_expression = "最近7天"
tool_cases = [
{
"name": "search",
"kwargs": {
"query": "蓝漆铁盒 北塔木梯",
"mode": "search",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
"minimum_keyword_recall": 0.67,
},
{
"name": "time",
"kwargs": {
"query": "蓝漆铁盒 北塔",
"mode": "time",
"chat_id": session_id,
"limit": 5,
"time_expression": time_expression,
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
"minimum_keyword_recall": 0.67,
},
{
"name": "episode",
"kwargs": {
"query": episode_case["query"],
"mode": "episode",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": episode_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
{
"name": "aggregate",
"kwargs": {
"query": aggregate_case["query"],
"mode": "aggregate",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": aggregate_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
]
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in tool_cases:
text = await query_long_term_memory(**case["kwargs"])
recall = _keyword_recall(text, case["expected_keywords"])
success = (
"失败" not in text
and "无法解析" not in text
and "未找到" not in text
and recall >= float(case["minimum_keyword_recall"])
)
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"name": case["name"],
"success": success,
"keyword_recall": recall,
"preview": text[:320],
}
)
total = max(1, len(tool_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
@pytest_asyncio.fixture
async def benchmark_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok", "encoded_dimension": 32}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
"person_registry": registry,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark(benchmark_env):
dataset = benchmark_env["dataset"]
session_id = benchmark_env["session"].session_id
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(created["task"]["task_id"])
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
assert report["metrics"]["profile"]["success_rate"] >= 0.66
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]

View File

@@ -1,342 +0,0 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from pytests.A_memorix_test.test_long_novel_memory_benchmark import (
_evaluate_episode_admin_query,
_evaluate_episode_generation,
_evaluate_episode_search_mode,
_evaluate_tool_modes,
_KernelBackedRuntimeManager,
_KnownPerson,
_first_relevant_rank,
_hit_blob,
_join_hit_content,
_keyword_hits,
_keyword_recall,
_load_benchmark_fixture,
_wait_for_import_task,
)
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 external embedding benchmark",
)
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
@pytest_asyncio.fixture
async def benchmark_live_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
dataset = benchmark_live_env["dataset"]
session_id = benchmark_live_env["session"].session_id
self_check = await memory_service.runtime_admin(action="refresh_self_check")
assert self_check["success"] is True
assert self_check["report"]["ok"] is True
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.live.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(
created["task"]["task_id"],
max_rounds=2400,
sleep_seconds=0.25,
)
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"runtime_self_check": self_check["report"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["runtime_self_check"]["ok"] is True

View File

@@ -5,68 +5,6 @@ import pytest
from src.services import memory_flow_service as memory_flow_module
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
starts: list[str] = []
summarizers: list[object] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
summarizers.append(self)
async def start(self):
starts.append(self.session_id)
async def stop(self):
starts.append(f"stop:{self.session_id}")
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
message = SimpleNamespace(session_id="session-1")
await manager.on_message(message)
await manager.on_message(message)
assert len(summarizers) == 1
assert starts == ["session-1"]
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
stopped: list[str] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
async def start(self):
return None
async def stop(self):
stopped.append(self.session_id)
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
await manager.on_message(SimpleNamespace(session_id="session-a"))
await manager.on_message(SimpleNamespace(session_id="session-b"))
await manager.shutdown()
assert stopped == ["session-a", "session-b"]
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
raw = '["他喜欢猫", "他喜欢猫", "", "", "他会弹吉他"]'
@@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
class FakeSessionManager:
async def on_message(self, message):
events.append(("incoming", message.session_id))
async def shutdown(self):
events.append(("shutdown", "session"))
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
@@ -122,17 +53,13 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
events.append(("shutdown", "fact"))
service = memory_flow_module.MemoryAutomationService()
service.session_manager = FakeSessionManager()
service.fact_writeback = FakeFactWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("incoming", "session-1"),
("sent", "session-1"),
("shutdown", "session"),
("shutdown", "fact"),
]

View File

@@ -1,324 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 16) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
vec[ord(ch) % dim] += 1.0
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.05)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def real_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok"}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=real_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=real_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=real_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -1,301 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 embedding / self-check 集成测试",
)
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.2)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def live_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_live_runtime_self_check_passes(live_dialogue_env):
report = await memory_service.runtime_admin(action="refresh_self_check")
assert report["success"] is True
assert report["report"]["ok"] is True
assert report["report"]["encoded_dimension"] > 0
@pytest.mark.asyncio
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=live_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=live_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=live_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -1,23 +0,0 @@
from src.maisaka.chat_loop_service import MaisakaChatLoopService
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
tool_definitions = [
{
"type": "function",
"function": {
"name": "mute_user",
"description": "禁言指定用户",
"parameters": {
"type": "object",
"properties": {},
},
},
},
{
"name": "reply",
"description": "发送回复",
},
]
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"

View File

@@ -1,339 +0,0 @@
"""MutePlugin SDK 回归测试。"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from maibot_sdk.context import PluginContext
from maibot_sdk.plugin import MaiBotPlugin
from plugins.MutePlugin.plugin import create_plugin
from src.core.tooling import ToolExecutionContext, ToolInvocation
from src.plugin_runtime.component_query import ComponentQueryService
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
def _build_plugin() -> MaiBotPlugin:
"""构造已注入默认配置的插件实例。"""
plugin = create_plugin()
plugin.set_plugin_config(plugin.get_default_config())
return plugin
def test_mute_plugin_manifest_is_valid_v2() -> None:
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
assert manifest is not None
assert manifest.id == "sengokucola.mute-plugin"
assert manifest.manifest_version == 2
def test_create_plugin_returns_sdk_plugin() -> None:
"""插件入口应返回 SDK 插件实例。"""
plugin = create_plugin()
assert isinstance(plugin, MaiBotPlugin)
@pytest.mark.asyncio
async def test_mute_command_calls_napcat_group_ban_api() -> None:
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
plugin = _build_plugin()
plugin.set_plugin_config(
{
**plugin.get_default_config(),
"components": {
"enable_smart_mute": True,
"enable_mute_command": True,
},
}
)
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok", "retcode": 0}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message, intercept = await plugin.handle_mute_command(
stream_id="group-10001",
group_id="10001",
user_id="42",
matched_groups={
"target": "张三",
"duration": "120",
"reason": "刷屏",
},
)
assert success is True
assert message == "成功禁言 张三"
assert intercept is True
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["version"] == "1"
assert api_call["args"]["args"] == {
"group_id": "10001",
"user_id": "123456",
"duration": 120,
}
@pytest.mark.asyncio
async def test_mute_tool_requires_target_person_name() -> None:
"""禁言工具在缺少目标时应直接失败并提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
return {"success": True}
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="",
duration="60",
reason="测试",
)
assert success is False
assert message == "禁言目标不能为空"
assert capability_calls[-1]["capability"] == "send.text"
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
@pytest.mark.asyncio
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
"""禁言工具应能兼容解包多层 capability 返回结果。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
if capability == "person.get_value":
return {"success": True, "result": {"success": True, "value": "123456"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["args"]["user_id"] == "123456"
@pytest.mark.asyncio
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
"""禁言工具应在检测到群主时提前返回明确提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "owner"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
assert not any(
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
for call in capability_calls
)
@pytest.mark.asyncio
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
@pytest.mark.asyncio
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
plugin = _build_plugin()
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {
"success": True,
"result": {
"status": "ok",
"retcode": 0,
"data": None,
"message": "",
"wording": "",
},
}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
"""插件工具执行时应自动补齐群聊上下文字段。"""
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
anchor_message = SimpleNamespace(
message_info=SimpleNamespace(
group_info=SimpleNamespace(group_id="10001"),
user_info=SimpleNamespace(user_id="20002"),
)
)
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
context = ToolExecutionContext(
session_id="session-1",
stream_id="session-1",
reasoning="test",
metadata={"anchor_message": anchor_message},
)
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
assert payload["target"] == "张三"
assert payload["stream_id"] == "session-1"
assert payload["chat_id"] == "session-1"
assert payload["group_id"] == "10001"
assert payload["user_id"] == "20002"

View File

@@ -1,227 +0,0 @@
from pathlib import Path
from types import ModuleType, SimpleNamespace
import importlib.util
import sys
class DummyLogger:
def __init__(self) -> None:
self.warning_messages: list[str] = []
def debug(self, _msg: str) -> None:
return
def info(self, _msg: str) -> None:
return
def warning(self, msg: str) -> None:
self.warning_messages.append(msg)
def error(self, _msg: str) -> None:
return
def load_utils_module(monkeypatch, qq_account=123456, platforms=None):
logger = DummyLogger()
configured_platforms = platforms or []
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
for package_name in [
"src",
"src.chat",
"src.chat.message_receive",
"src.chat.utils",
"src.common",
"src.config",
"src.llm_models",
"src.person_info",
]:
if package_name not in sys.modules:
package_module = ModuleType(package_name)
package_module.__path__ = []
monkeypatch.setitem(sys.modules, package_name, package_module)
jieba_module = ModuleType("jieba")
jieba_module.cut = lambda text: list(text)
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
logger_module = _stub_module("src.common.logger")
logger_module.get_logger = lambda _name: logger
config_module = _stub_module("src.config.config")
config_module.global_config = SimpleNamespace(
bot=SimpleNamespace(
qq_account=qq_account,
platforms=configured_platforms,
nickname="MaiBot",
alias_names=[],
),
chat=SimpleNamespace(
at_bot_inevitable_reply=1,
mentioned_bot_reply=1,
),
)
config_module.model_config = SimpleNamespace()
message_module = _stub_module("src.chat.message_receive.message")
class SessionMessage:
pass
message_module.SessionMessage = SessionMessage
chat_manager_module = _stub_module("src.chat.message_receive.chat_manager")
chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None)
llm_module = _stub_module("src.llm_models.utils_model")
class LLMRequest:
def __init__(self, *args, **kwargs) -> None:
del args, kwargs
llm_module.LLMRequest = LLMRequest
person_module = _stub_module("src.person_info.person_info")
class Person:
pass
person_module.Person = Person
typo_generator_module = _stub_module("src.chat.utils.typo_generator")
class ChineseTypoGenerator:
def __init__(self, *args, **kwargs) -> None:
del args, kwargs
def create_typo_sentence(self, sentence: str):
return sentence, ""
typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py"
spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path)
utils_module = importlib.util.module_from_spec(spec)
utils_module.__package__ = "src.chat.utils"
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
assert spec.loader is not None
spec.loader.exec_module(utils_module)
return utils_module, logger
def test_platform_specific_bot_accounts(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=[" TG : tg_bot ", "discord: disc_bot"],
)
assert utils_module.get_bot_account("qq") == "123456"
assert utils_module.get_bot_account("webui") == "123456"
assert utils_module.get_bot_account("telegram") == "tg_bot"
assert utils_module.get_bot_account("tg") == "tg_bot"
assert utils_module.get_bot_account("discord") == "disc_bot"
assert utils_module.is_bot_self("qq", "123456")
assert utils_module.is_bot_self("webui", "123456")
assert utils_module.is_bot_self("telegram", "tg_bot")
assert utils_module.is_bot_self(" TG ", "tg_bot")
def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=["TG:tg_bot", "discord:disc_bot"],
)
assert utils_module.get_all_bot_accounts() == {
"qq": "123456",
"webui": "123456",
"telegram": "tg_bot",
"tg": "tg_bot",
"discord": "disc_bot",
}
def test_get_all_bot_accounts_keeps_canonical_qq_identity(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=["qq:999999", "webui:888888", "TG:tg_bot"],
)
assert utils_module.get_all_bot_accounts()["qq"] == "123456"
assert utils_module.get_all_bot_accounts()["webui"] == "123456"
def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch):
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
assert utils_module.is_bot_self("unknown_platform", "123456") is False
assert logger.warning_messages
assert "unknown_platform" in logger.warning_messages[-1]
def test_unknown_platform_warns_only_once(monkeypatch):
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
assert utils_module.is_bot_self("unknown_platform", "first") is False
assert utils_module.is_bot_self(" unknown_platform ", "second") is False
assert len(logger.warning_messages) == 1
def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"])
assert utils_module.get_bot_account("qq") == ""
assert utils_module.get_bot_account("webui") == ""
assert utils_module.is_bot_self("qq", "0") is False
assert utils_module.is_bot_self("webui", "0") is False
def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"])
message = SimpleNamespace(
processed_plain_text="@tg_bot 你好",
platform="telegram",
is_mentioned=False,
message_segment=None,
message_info=SimpleNamespace(
additional_config={},
user_info=SimpleNamespace(user_id="user_1"),
),
)
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
assert is_mentioned is True
assert is_at is True
assert reply_probability == 1.0
def test_is_mentioned_bot_in_message_normalizes_qq_platform(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
message = SimpleNamespace(
processed_plain_text="@<MaiBot:123456> 你好",
platform=" QQ ",
is_mentioned=False,
message_segment=None,
message_info=SimpleNamespace(
additional_config={},
user_info=SimpleNamespace(user_id="user_1"),
),
)
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
assert is_mentioned is True
assert is_at is True
assert reply_probability == 1.0

View File

@@ -612,13 +612,6 @@ class ChatBot:
scope=scope,
) # 确保会话存在
try:
from src.services.memory_flow_service import memory_automation_service
await memory_automation_service.on_incoming_message(message)
except Exception as exc:
logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}")
# message.update_chat_stream(chat)
# 命令处理 - 使用新插件系统检查并处理命令。

View File

@@ -1,9 +1,9 @@
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import random
import time
from rich.console import Group, RenderableType
from rich.panel import Panel
@@ -13,7 +13,6 @@ from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.chat.utils.utils import get_chat_type_and_target_info
from src.cli.console import console
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
@@ -32,9 +31,10 @@ from src.maisaka.context_messages import (
ReferenceMessage,
SessionBackedMessage,
ToolResultMessage,
build_llm_message_from_context,
)
from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
from src.maisaka.message_adapter import parse_speaker_content
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
from .maisaka_expression_selector import maisaka_expression_selector
@@ -110,11 +110,15 @@ class BaseMaisakaReplyGenerator:
return ""
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
bot_nickname = global_config.bot.nickname.strip() or "Bot"
if speaker_name == bot_nickname:
return self._normalize_content(body.strip())
return ""
# 只能根据结构化来源字段判断是否为 bot 自身写回的历史消息,
# 不能依赖昵称/群名片等可控文本,避免误判和提示注入。
if message.source_kind != "guided_reply":
return ""
plain_text = message.processed_plain_text.strip()
_, body = parse_speaker_content(plain_text)
normalized_body = body.strip() or plain_text
return self._normalize_content(normalized_body) if normalized_body else ""
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
if reply_message is None:
@@ -210,6 +214,7 @@ class BaseMaisakaReplyGenerator:
self,
reply_message: Optional[SessionMessage],
reply_reason: str,
reference_info: str = "",
expression_habits: str = "",
stream_id: Optional[str] = None,
) -> str:
@@ -234,8 +239,13 @@ class BaseMaisakaReplyGenerator:
sections.append(expression_habits.strip())
if target_message_block:
sections.append(target_message_block)
reply_reference_lines: List[str] = []
if reply_reason.strip():
sections.append(f"回复信息参考\n{reply_reason}")
reply_reference_lines.append(f"最新推理\n{reply_reason.strip()}")
if reference_info.strip():
reply_reference_lines.append(f"【参考信息】\n{reference_info.strip()}")
if reply_reference_lines:
sections.append("【回复信息参考】\n" + "\n\n".join(reply_reference_lines))
if not sections:
return system_prompt
return f"{system_prompt}\n\n" + "\n\n".join(sections)
@@ -243,28 +253,6 @@ class BaseMaisakaReplyGenerator:
def _build_reply_instruction(self) -> str:
return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。"
def _build_visual_user_message(
self,
message: SessionBackedMessage,
enable_visual_message: bool,
) -> Optional[Message]:
if not enable_visual_message:
return None
raw_message = clone_message_sequence(message.raw_message)
if not raw_message.components:
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
visual_message = SessionBackedMessage(
raw_message=raw_message,
visible_text=message.processed_plain_text,
timestamp=message.timestamp,
message_id=message.message_id,
original_message=message.original_message,
source_kind=message.source_kind,
)
return visual_message.to_llm_message()
def _build_history_messages(
self,
chat_history: List[LLMContextMessage],
@@ -284,12 +272,10 @@ class BaseMaisakaReplyGenerator:
)
continue
visual_message = self._build_visual_user_message(message, enable_visual_message)
if visual_message is not None:
messages.append(visual_message)
continue
llm_message = message.to_llm_message()
llm_message = build_llm_message_from_context(
message,
enable_visual_message=enable_visual_message,
)
if llm_message is not None:
messages.append(llm_message)
continue
@@ -308,6 +294,7 @@ class BaseMaisakaReplyGenerator:
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
reference_info: str = "",
expression_habits: str = "",
stream_id: Optional[str] = None,
enable_visual_message: bool = False,
@@ -316,6 +303,7 @@ class BaseMaisakaReplyGenerator:
system_prompt = self._build_system_prompt(
reply_message=reply_message,
reply_reason=reply_reason,
reference_info=reference_info,
expression_habits=expression_habits,
stream_id=stream_id,
)
@@ -377,6 +365,7 @@ class BaseMaisakaReplyGenerator:
self,
extra_info: str = "",
reply_reason: str = "",
reference_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[object]] = None,
from_plugin: bool = True,
@@ -461,6 +450,7 @@ class BaseMaisakaReplyGenerator:
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
reference_info=reference_info or "",
expression_habits=merged_expression_habits,
stream_id=stream_id,
)
@@ -486,6 +476,7 @@ class BaseMaisakaReplyGenerator:
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
reference_info=reference_info or "",
expression_habits=merged_expression_habits,
stream_id=stream_id,
enable_visual_message=self._resolve_enable_visual_message(model_info),
@@ -504,7 +495,6 @@ class BaseMaisakaReplyGenerator:
chat_id=preview_chat_id,
request_kind="replyer",
selection_reason=f"ID: {preview_chat_id}",
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title="Reply Prompt",
border_style="bright_yellow",

View File

@@ -159,7 +159,6 @@ MODULE_ALIASES = {
"planner": "规划器",
"config": "配置",
"main": "主程序",
"chat_history_summarizer": "聊天概括器",
"plugin_runtime.integration": "IPC插件系统",
"plugin_runtime.host.supervisor": "插件监督器",
"plugin_runtime.host.runner_manager": "插件监督器",

View File

@@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.7.1"
CONFIG_VERSION: str = "8.8.0"
MODEL_CONFIG_VERSION: str = "1.14.0"
logger = get_logger("config")

View File

@@ -414,15 +414,6 @@ class MemoryConfig(ConfigBase):
)
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
long_term_auto_summary_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "book-open",
},
)
"""是否自动启动聊天总结并导入长期记忆"""
person_fact_writeback_enabled: bool = Field(
default=True,
json_schema_extra={
@@ -578,77 +569,9 @@ class MemoryConfig(ConfigBase):
},
)
"""反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小"""
chat_history_topic_check_message_threshold: int = Field(
default=80,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "hash",
},
)
"""聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查"""
chat_history_topic_check_time_hours: float = Field(
default=8.0,
json_schema_extra={
"x-widget": "input",
"x-icon": "clock",
},
)
"""聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查"""
chat_history_topic_check_min_messages: int = Field(
default=20,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "hash",
},
)
"""聊天历史话题检查的时间触发模式下的最小消息数阈值"""
chat_history_finalize_no_update_checks: int = Field(
default=3,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "check-circle",
},
)
"""聊天历史话题打包存储的连续无更新检查次数阈值当话题连续N次检查无新增内容时触发打包存储"""
chat_history_finalize_message_count: int = Field(
default=5,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "package",
},
)
"""聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储"""
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置值"""
if self.chat_history_topic_check_message_threshold < 1:
raise ValueError(
f"chat_history_topic_check_message_threshold 必须至少为1当前值: {self.chat_history_topic_check_message_threshold}"
)
if self.chat_history_topic_check_time_hours <= 0:
raise ValueError(
f"chat_history_topic_check_time_hours 必须大于0当前值: {self.chat_history_topic_check_time_hours}"
)
if self.chat_history_topic_check_min_messages < 1:
raise ValueError(
f"chat_history_topic_check_min_messages 必须至少为1当前值: {self.chat_history_topic_check_min_messages}"
)
if self.chat_history_finalize_no_update_checks < 1:
raise ValueError(
f"chat_history_finalize_no_update_checks 必须至少为1当前值: {self.chat_history_finalize_no_update_checks}"
)
if self.chat_history_finalize_message_count < 1:
raise ValueError(
f"chat_history_finalize_message_count 必须至少为1当前值: {self.chat_history_finalize_message_count}"
)
if self.feedback_correction_window_hours <= 0:
raise ValueError(
f"feedback_correction_window_hours 必须大于0当前值: {self.feedback_correction_window_hours}"

View File

@@ -335,6 +335,8 @@ async def send_emoji_for_maisaka(
storage_message=True,
set_reply=False,
reply_message=None,
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
sent = sent_message is not None
except Exception as exc:

View File

@@ -6,6 +6,7 @@ import json
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from uuid import uuid4
from json_repair import repair_json
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
@@ -119,6 +120,13 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
"""OpenAI 非流式响应解析函数类型。"""
def _build_fallback_tool_call_id(prefix: str) -> str:
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
normalized_prefix = str(prefix).strip() or "tool_call"
return f"{normalized_prefix}_{uuid4().hex}"
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
"""将配置中的推理解析模式收敛为枚举值。
@@ -609,7 +617,7 @@ def _extract_xml_tool_calls(
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
tool_calls.append(
ToolCall(
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
call_id=_build_fallback_tool_call_id("xml_tool_call"),
func_name=function_name,
args=arguments,
)
@@ -855,7 +863,7 @@ class _OpenAIStreamAccumulator:
if raw_arguments
else None
)
call_id = state.call_id or f"tool_call_{index}"
call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}")
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
response.raw_data = {"model": self.model_name} if self.model_name else None

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from base64 import b64decode
from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent
@@ -12,13 +12,10 @@ from src.config.config import global_config
from src.core.tooling import ToolExecutionResult
from ..context_messages import SessionBackedMessage
from ..history_utils import build_prefixed_message_sequence, build_session_message_visible_text
from ..message_adapter import format_speaker_content
from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from ..reasoning_engine import MaisakaReasoningEngine
from ..runtime import MaisakaHeartFlowChatting
@@ -139,37 +136,6 @@ class BuiltinToolRuntimeContext:
return self.engine._get_runtime_manager()
@staticmethod
def _build_visible_text_from_sent_message(message: "SessionMessage") -> str:
"""将已发送消息转换为 Maisaka 可见文本。"""
return build_session_message_visible_text(message)
def append_sent_message_to_chat_history(
self,
message: "SessionMessage",
*,
source_kind: str = "guided_reply",
) -> None:
"""将真实已发送消息同步到 Maisaka 历史。"""
user_info = message.message_info.user_info
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
planner_prefix = build_planner_prefix(
timestamp=message.timestamp,
user_name=speaker_name,
group_card=user_info.user_cardname or "",
message_id=message.message_id,
include_message_id=not message.is_notify and bool(message.message_id),
)
history_message = SessionBackedMessage.from_session_message(
message,
raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix),
visible_text=self._build_visible_text_from_sent_message(message),
source_kind=source_kind,
)
self.runtime._chat_history.append(history_message)
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
"""将引导回复写回 Maisaka 历史。"""

View File

@@ -36,7 +36,8 @@ def get_tool_spec() -> ToolSpec:
detailed_description=(
"参数说明:\n"
"- msg_idstring必填。要回复的目标用户消息编号。\n"
"- set_quoteboolean可选。以引用回复的方式发送默认 true。"
"- set_quoteboolean可选。以引用回复的方式发送默认 true。\n"
"- reference_infostring可选。上文中有助于回复的所有参考信息使用平文本格式。"
),
parameters_schema={
"type": "object",
@@ -50,6 +51,11 @@ def get_tool_spec() -> ToolSpec:
"description": "以引用回复的方式发送这条回复,不用每句都引用。",
"default": True,
},
"reference_info": {
"type": "string",
"description": "有助于回复的信息,之前搜集得到的事实性信息,记忆等,使用平文本格式。",
"default": True,
},
},
"required": ["msg_id"],
},
@@ -75,6 +81,7 @@ async def handle_tool(
"""执行 reply 内置工具。"""
latest_thought = context.reasoning if context is not None else invocation.reasoning
reference_info = str(invocation.arguments.get("reference_info") or "").strip()
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
set_quote = bool(invocation.arguments.get("set_quote", True))
@@ -117,6 +124,7 @@ async def handle_tool(
try:
success, reply_result = await replyer.generate_reply_with_context(
reply_reason=latest_thought,
reference_info=reference_info,
stream_id=tool_ctx.runtime.session_id,
reply_message=target_message,
chat_history=tool_ctx.runtime._chat_history,
@@ -152,7 +160,6 @@ async def handle_tool(
combined_reply_text = "".join(reply_segments)
try:
sent = False
sent_messages = []
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
for segment in reply_segments:
render_cli_message(segment)
@@ -166,11 +173,12 @@ async def handle_tool(
reply_message=target_message if set_quote and index == 0 else None,
selected_expressions=reply_result.selected_expression_ids or None,
typing=index > 0,
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
sent = sent_message is not None
if not sent:
break
sent_messages.append(sent_message)
except Exception:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
@@ -198,9 +206,6 @@ async def handle_tool(
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
else:
for sent_message in sent_messages:
tool_ctx.append_sent_message_to_chat_history(sent_message)
tool_ctx.runtime._record_reply_sent()
return tool_ctx.build_success_result(
invocation.tool_name,

View File

@@ -53,40 +53,16 @@ def get_tool_spec() -> ToolSpec:
return ToolSpec(
name="send_emoji",
brief_description="发送一个合适的表情包来辅助表达情绪。",
detailed_description="参数说明:\n- emotionstring可选。希望表达的情绪例如 happy、sad、angry 等",
detailed_description="无需参数,直接发送一个合适的表情包",
parameters_schema={
"type": "object",
"properties": {
"emotion": {
"type": "string",
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
},
},
"properties": {},
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
def _normalize_candidate_emotions(emoji: MaiEmoji) -> list[str]:
"""清洗候选表情上的情绪标签。"""
raw_emotions = getattr(emoji, "emotion", None)
if isinstance(raw_emotions, list) and raw_emotions:
return [str(item).strip() for item in raw_emotions if str(item).strip()]
description = str(getattr(emoji, "description", "") or "").strip()
if not description:
return []
normalized_description = (
description.replace("", ",")
.replace("", ",")
.replace("", ",")
)
return [item.strip() for item in normalized_description.split(",") if item.strip()]
async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes:
"""读取单个表情包图片字节。"""
@@ -232,18 +208,6 @@ async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBacke
)
def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str:
"""构建供监控展示使用的候选表情摘要。"""
summary_lines: list[str] = []
for index, emoji in enumerate(emojis, start=1):
description = emoji.description.strip() or "(无描述)"
emotions = "".join(_normalize_candidate_emotions(emoji)) or ""
summary_lines.append(f"{index}. 描述:{description}")
summary_lines.append(f" 情绪:{emotions}")
return "\n".join(summary_lines).strip()
def _build_send_emoji_monitor_detail(
*,
request_messages: Optional[list[dict[str, Any]]] = None,
@@ -252,7 +216,7 @@ def _build_send_emoji_monitor_detail(
metrics: Optional[Dict[str, Any]] = None,
extra_sections: Optional[list[dict[str, str]]] = None,
) -> Dict[str, Any]:
"""构建 emotion tool 统一监控详情。"""
"""构建 send_emoji 工具统一监控详情。"""
detail: Dict[str, Any] = {}
if isinstance(request_messages, list) and request_messages:
@@ -281,7 +245,6 @@ def _build_send_emoji_monitor_detail(
def _build_send_emoji_monitor_metadata(
selection_metadata: Dict[str, Any],
*,
requested_emotion: str,
send_result: Optional[Any] = None,
error_message: str = "",
) -> Dict[str, Any]:
@@ -293,7 +256,6 @@ def _build_send_emoji_monitor_metadata(
if send_result is not None:
result_lines = [
f"请求情绪:{requested_emotion or '未指定'}",
f"命中情绪:{send_result.matched_emotion or '未命中'}",
f"表情描述:{send_result.description or '无描述'}",
f"情绪标签:{''.join(send_result.emotions) if send_result.emotions else ''}",
@@ -306,10 +268,7 @@ def _build_send_emoji_monitor_metadata(
elif error_message.strip():
extra_sections.append({
"title": "表情发送结果",
"content": (
f"请求情绪:{requested_emotion or '未指定'}\n"
f"发送结果:{error_message.strip()}"
),
"content": f"发送结果:{error_message.strip()}",
})
if extra_sections:
@@ -322,7 +281,6 @@ def _build_send_emoji_monitor_metadata(
async def _select_emoji_with_sub_agent(
tool_ctx: BuiltinToolRuntimeContext,
requested_emotion: str,
reasoning: str,
context_texts: list[str],
sample_size: int,
@@ -347,14 +305,12 @@ async def _select_emoji_with_sub_agent(
f"一共 {len(sampled_emojis)} 个位置。\n"
f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}\n"
f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n"
"如果提供了 requested_emotion请优先考虑与其接近的候选如果没有完全匹配则选择最符合上下文语气的候选。\n"
"你必须返回一个 JSON 对象json object不要输出任何 JSON 之外的内容。\n"
'返回格式固定为:{"emoji_index":1,"reason":"简短理由"}'
)
prompt_message = ReferenceMessage(
content=(
f"[选择任务]\n"
f"requested_emotion: {requested_emotion or '未指定'}\n"
f"候选总数: {len(sampled_emojis)}\n"
f"拼图布局: {grid_rows}x{grid_columns}\n"
"请只输出 JSON。"
@@ -439,7 +395,6 @@ async def handle_tool(
"""执行 send_emoji 内置工具。"""
del context
emotion = str(invocation.arguments.get("emotion") or "").strip()
context_texts = [
message.processed_plain_text.strip()
for message in tool_ctx.runtime._chat_history[-5:]
@@ -450,23 +405,20 @@ async def handle_tool(
"message": "",
"description": "",
"emotion": [],
"requested_emotion": emotion,
"matched_emotion": "",
"reason": "",
}
selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}}
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}")
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具")
try:
send_result = await send_emoji_for_maisaka(
stream_id=tool_ctx.runtime.session_id,
requested_emotion=emotion,
reasoning=tool_ctx.engine.last_reasoning_content,
context_texts=context_texts,
emoji_selector=lambda requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent(
tool_ctx,
requested_emotion,
reasoning,
list(context_texts or []),
sample_size,
@@ -482,7 +434,6 @@ async def handle_tool(
structured_content=structured_result,
metadata=_build_send_emoji_monitor_metadata(
selection_metadata,
requested_emotion=emotion,
error_message=structured_result["message"],
),
)
@@ -493,11 +444,9 @@ async def handle_tool(
logger.info(
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
f"命中情绪={send_result.matched_emotion!r}"
)
if send_result.sent_message is not None:
tool_ctx.append_sent_message_to_chat_history(send_result.sent_message)
else:
if send_result.sent_message is None:
tool_ctx.append_sent_emoji_to_chat_history(
emoji_base64=send_result.emoji_base64,
success_message=_EMOJI_SUCCESS_MESSAGE,
@@ -509,7 +458,6 @@ async def handle_tool(
structured_content=structured_result,
metadata=_build_send_emoji_monitor_metadata(
selection_metadata,
requested_emotion=emotion,
send_result=send_result,
),
)
@@ -521,7 +469,7 @@ async def handle_tool(
logger.warning(
f"{tool_ctx.runtime.log_prefix} 表情包发送失败 "
f"请求情绪={emotion!r} 错误信息={send_result.message}"
f"错误信息={send_result.message}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
@@ -529,7 +477,6 @@ async def handle_tool(
structured_content=structured_result,
metadata=_build_send_emoji_monitor_metadata(
selection_metadata,
requested_emotion=emotion,
send_result=send_result,
),
)

View File

@@ -30,9 +30,15 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr
from src.services.llm_service import LLMServiceClient
from .builtin_tool import get_builtin_tools
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
from .context_messages import (
AssistantMessage,
LLMContextMessage,
ToolResultMessage,
build_llm_message_from_context,
)
from .history_utils import drop_orphan_tool_results
from .display.prompt_cli_renderer import PromptCLIVisualizer
from .visual_mode_utils import resolve_enable_visual_planner
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
@@ -395,6 +401,7 @@ class MaisakaChatLoopService:
self,
selected_history: List[LLMContextMessage],
*,
enable_visual_message: bool,
injected_user_messages: Sequence[str] | None = None,
system_prompt: Optional[str] = None,
) -> List[Message]:
@@ -413,7 +420,10 @@ class MaisakaChatLoopService:
messages.append(system_msg.build())
for msg in selected_history:
llm_message = msg.to_llm_message()
llm_message = build_llm_message_from_context(
msg,
enable_visual_message=enable_visual_message,
)
if llm_message is not None:
messages.append(llm_message)
@@ -475,12 +485,15 @@ class MaisakaChatLoopService:
if not self._prompts_loaded:
await self.ensure_chat_prompt_loaded()
enable_visual_message = self._resolve_enable_visual_message(request_kind)
selected_history, selection_reason = self.select_llm_context_messages(
chat_history,
request_kind=request_kind,
enable_visual_message=enable_visual_message,
)
built_messages = self._build_request_messages(
selected_history,
enable_visual_message=enable_visual_message,
injected_user_messages=injected_user_messages,
)
@@ -528,14 +541,12 @@ class MaisakaChatLoopService:
prompt_section: RenderableType | None = None
if global_config.debug.show_maisaka_thinking:
image_display_mode: str = "path_link" if global_config.maisaka.show_image_path else "legacy"
prompt_section = PromptCLIVisualizer.build_prompt_section(
built_messages,
category="planner" if request_kind != "timing_gate" else "timing_gate",
chat_id=self._session_id,
request_kind=request_kind,
selection_reason=selection_reason,
image_display_mode=image_display_mode,
folded=global_config.debug.fold_maisaka_thinking,
tool_definitions=list(all_tools),
)
@@ -604,6 +615,7 @@ class MaisakaChatLoopService:
def select_llm_context_messages(
chat_history: List[LLMContextMessage],
*,
enable_visual_message: Optional[bool] = None,
request_kind: str = "planner",
max_context_size: Optional[int] = None,
) -> tuple[List[LLMContextMessage], str]:
@@ -617,9 +629,21 @@ class MaisakaChatLoopService:
selected_indices: List[int] = []
counted_message_count = 0
active_enable_visual_message = (
enable_visual_message
if enable_visual_message is not None
else MaisakaChatLoopService._resolve_enable_visual_message(request_kind)
)
for index in range(len(filtered_history) - 1, -1, -1):
message = filtered_history[index]
if message.to_llm_message() is None:
if (
build_llm_message_from_context(
message,
enable_visual_message=active_enable_visual_message,
)
is None
):
continue
selected_indices.append(index)
@@ -629,18 +653,18 @@ class MaisakaChatLoopService:
break
if not selected_indices:
return [], f"没有选择到上下文消息,实际发送 {effective_context_size} 条 user/assistant 消息"
return [], "实际发送 0 条消息tool 0 条,普通消息 0 条)"
selected_indices.reverse()
selected_history = [filtered_history[index] for index in selected_indices]
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
selected_history, _ = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
selected_history, _ = drop_orphan_tool_results(selected_history)
tool_message_count = sum(1 for message in selected_history if isinstance(message, ToolResultMessage))
normal_message_count = len(selected_history) - tool_message_count
selection_reason = (
f"上下文裁剪:最近 {effective_context_size} 条 user/assistant 消息"
f"实际发送 {len(selected_history)}"
f"实际发送 {len(selected_history)}消息"
f"|消息 {normal_message_count} 条|tool {tool_message_count}"
)
if hidden_assistant_count > 0:
selection_reason += f",已隐藏最早 {hidden_assistant_count} 条 assistant 消息"
return (
selected_history,
selection_reason,
@@ -685,6 +709,12 @@ class MaisakaChatLoopService:
return filtered_history
@staticmethod
def _resolve_enable_visual_message(request_kind: str) -> bool:
if request_kind in {"planner", "timing_gate"}:
return resolve_enable_visual_planner()
return True
@staticmethod
def _hide_early_assistant_messages(
selected_history: List[LLMContextMessage],

View File

@@ -40,10 +40,15 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]:
return None
def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) -> bool:
def _append_emoji_component(
builder: MessageBuilder,
component: EmojiComponent,
*,
enable_visual_message: bool,
) -> bool:
"""将表情组件追加到 LLM 消息构建器。"""
image_format = _guess_image_format(component.binary_data)
if image_format and component.binary_data:
if enable_visual_message and image_format and component.binary_data:
builder.add_text_content("[消息类型]表情包")
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
return True
@@ -56,10 +61,15 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent)
return True
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
def _append_image_component(
builder: MessageBuilder,
component: ImageComponent,
*,
enable_visual_message: bool,
) -> bool:
"""将图片组件追加到 LLM 消息构建器。"""
image_format = _guess_image_format(component.binary_data)
if image_format and component.binary_data:
if enable_visual_message and image_format and component.binary_data:
builder.add_text_content("[消息类型]图片")
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
return True
@@ -216,6 +226,7 @@ def _build_message_from_sequence(
message_sequence: MessageSequence,
fallback_text: str,
*,
enable_visual_message: bool = True,
tool_call_id: Optional[str] = None,
tool_name: Optional[str] = None,
tool_calls: Optional[list[ToolCall]] = None,
@@ -238,11 +249,25 @@ def _build_message_from_sequence(
continue
if isinstance(component, EmojiComponent):
has_content = _append_emoji_component(builder, component) or has_content
has_content = (
_append_emoji_component(
builder,
component,
enable_visual_message=enable_visual_message,
)
or has_content
)
continue
if isinstance(component, ImageComponent):
has_content = _append_image_component(builder, component) or has_content
has_content = (
_append_image_component(
builder,
component,
enable_visual_message=enable_visual_message,
)
or has_content
)
continue
if isinstance(component, AtComponent):
@@ -297,7 +322,7 @@ class LLMContextMessage(ABC):
return self.__class__.__name__
@abstractmethod
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
"""转换为统一 LLM 消息。"""
def consume_once(self) -> bool:
@@ -328,11 +353,12 @@ class SessionBackedMessage(LLMContextMessage):
def source(self) -> str:
return self.source_kind
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
return _build_message_from_sequence(
RoleType.User,
self.raw_message,
self.processed_plain_text,
enable_visual_message=enable_visual_message,
)
@classmethod
@@ -366,7 +392,8 @@ class ComplexSessionMessage(SessionBackedMessage):
def source(self) -> str:
return f"{self.source_kind}:{self.complex_message_type}"
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
del enable_visual_message
message_sequence = MessageSequence([TextComponent(self.prompt_text)])
return _build_message_from_sequence(
RoleType.User,
@@ -426,7 +453,8 @@ class ReferenceMessage(LLMContextMessage):
def source(self) -> str:
return self.reference_type.value
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
del enable_visual_message
message_sequence = MessageSequence([TextComponent(self.processed_plain_text)])
return _build_message_from_sequence(RoleType.User, message_sequence, self.processed_plain_text)
@@ -463,7 +491,8 @@ class AssistantMessage(LLMContextMessage):
def source(self) -> str:
return self.source_kind
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
del enable_visual_message
message_sequence = MessageSequence([])
if self.content:
message_sequence.text(self.content)
@@ -501,7 +530,8 @@ class ToolResultMessage(LLMContextMessage):
def source(self) -> str:
return self.tool_name or "tool"
def to_llm_message(self) -> Optional[Message]:
def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]:
del enable_visual_message
message_sequence = MessageSequence([TextComponent(self.content)])
return _build_message_from_sequence(
RoleType.Tool,
@@ -510,3 +540,13 @@ class ToolResultMessage(LLMContextMessage):
tool_call_id=self.tool_call_id,
tool_name=self.tool_name,
)
def build_llm_message_from_context(
context_message: LLMContextMessage,
*,
enable_visual_message: bool = True,
) -> Optional[Message]:
"""将 Maisaka 内部上下文消息转换为发给 LLM 的统一消息。"""
return context_message.to_llm_message(enable_visual_message=enable_visual_message)

View File

@@ -799,7 +799,7 @@ class PromptCLIVisualizer:
chat_id: str,
request_kind: str,
selection_reason: str,
image_display_mode: Literal["legacy", "path_link"],
image_display_mode: Literal["legacy", "path_link"] = "path_link",
tool_definitions: list[dict[str, Any]] | None = None,
) -> RenderableType:
"""构建用于查看完整 prompt 的折叠入口内容。"""
@@ -864,7 +864,7 @@ class PromptCLIVisualizer:
chat_id: str,
request_kind: str,
selection_reason: str,
image_display_mode: Literal["legacy", "path_link"],
image_display_mode: Literal["legacy", "path_link"] = "path_link",
folded: bool,
tool_definitions: list[dict[str, Any]] | None = None,
) -> Panel:
@@ -878,14 +878,10 @@ class PromptCLIVisualizer:
chat_id=chat_id,
request_kind=request_kind,
selection_reason=selection_reason,
image_display_mode=image_display_mode,
tool_definitions=tool_definitions,
)
else:
ordered_panels = cls.build_prompt_panels(
messages,
image_display_mode=image_display_mode,
)
ordered_panels = cls.build_prompt_panels(messages)
prompt_renderable = Group(*ordered_panels)
return Panel(
@@ -1102,11 +1098,9 @@ class PromptCLIVisualizer:
cls,
messages: list[Any],
*,
image_display_mode: Literal["legacy", "path_link"],
image_display_mode: Literal["legacy", "path_link"] = "path_link",
) -> List[Panel]:
"""构建完整 prompt 可视化面板。"""
if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}:
image_display_mode = PromptImageDisplayMode.LEGACY
settings = PromptImageDisplaySettings(
display_mode=PromptImageDisplayMode(image_display_mode),
)

View File

@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import config_manager, global_config
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.exceptions import ReqAbortException
from src.llm_models.payload_content.tool_option import ToolCall
@@ -44,6 +44,7 @@ from .monitor_events import (
emit_timing_gate_result,
)
from .planner_message_utils import build_planner_user_prefix_from_session_message
from .visual_mode_utils import resolve_enable_visual_planner
if TYPE_CHECKING:
from .runtime import MaisakaHeartFlowChatting
@@ -739,47 +740,10 @@ class MaisakaReasoningEngine:
planner_prefix: str,
) -> MessageSequence:
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
if self._resolve_enable_visual_planner():
if resolve_enable_visual_planner():
await self._hydrate_visual_components(message_sequence.components)
return message_sequence
@staticmethod
def _resolve_enable_visual_planner() -> bool:
planner_mode = global_config.visual.planner_mode
planner_task_config = config_manager.get_model_config().model_task_config.planner
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
if planner_mode == "text":
return False
planner_models: list[str] = list(planner_task_config.model_list)
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
non_visual_models = [
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
]
if planner_mode == "multimodal":
if missing_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未定义的模型:"
f"{', '.join(missing_models)}"
)
if non_visual_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未开启 visual 的模型:"
f"{', '.join(non_visual_models)}"
)
return True
if missing_models:
logger.warning(
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
f"{', '.join(missing_models)},将退化为纯文本 planner"
)
return False
return bool(planner_models) and not non_visual_models
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
load_tasks: list[asyncio.Task[None]] = []

View File

@@ -183,6 +183,43 @@ class MaisakaHeartFlowChatting:
self._talk_frequency_adjust = max(0.01, float(frequency))
self._schedule_message_turn()
def append_sent_message_to_chat_history(
self,
message: SessionMessage,
*,
source_kind: str = "guided_reply",
) -> bool:
"""将一条已发送成功的消息同步到 Maisaka 内部历史。"""
try:
from .context_messages import SessionBackedMessage
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text
from .planner_message_utils import build_planner_prefix
user_info = message.message_info.user_info
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
planner_prefix = build_planner_prefix(
timestamp=message.timestamp,
user_name=speaker_name,
group_card=user_info.user_cardname or "",
message_id=message.message_id,
include_message_id=not message.is_notify and bool(message.message_id),
)
history_message = SessionBackedMessage.from_session_message(
message,
raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix),
visible_text=build_session_message_visible_text(message),
source_kind=source_kind,
)
self._chat_history.append(history_message)
return True
except Exception as exc:
logger.warning(
f"{self.log_prefix} 同步已发送消息到 Maisaka 历史失败: "
f"message_id={message.message_id} error={exc}"
)
return False
async def register_message(self, message: SessionMessage) -> None:
"""缓存一条新消息并唤醒主循环。"""
if self._running:
@@ -1151,7 +1188,6 @@ class MaisakaHeartFlowChatting:
chat_id=self.session_id,
request_kind=labels["request_kind"],
selection_reason=subtitle,
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title=labels["prompt_title"],
border_style=border_style,

View File

@@ -0,0 +1,43 @@
from src.common.logger import get_logger
from src.config.config import config_manager, global_config
logger = get_logger("maisaka_visual_mode")
def resolve_enable_visual_planner() -> bool:
"""根据 planner 配置解析当前是否应启用视觉消息。"""
planner_mode = global_config.visual.planner_mode
planner_task_config = config_manager.get_model_config().model_task_config.planner
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
if planner_mode == "text":
return False
planner_models: list[str] = list(planner_task_config.model_list)
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
non_visual_models = [
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
]
if planner_mode == "multimodal":
if missing_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未定义的模型:"
f"{', '.join(missing_models)}"
)
if non_visual_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未开启 visual 的模型:"
f"{', '.join(non_visual_models)}"
)
return True
if missing_models:
logger.warning(
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
f"{', '.join(missing_models)},将退化为纯文本 planner"
)
return False
return bool(planner_models) and not non_visual_models

File diff suppressed because it is too large Load Diff

View File

@@ -75,6 +75,8 @@ class RuntimeCoreCapabilityMixin:
text = str(args.get("text", ""))
stream_id = str(args.get("stream_id", ""))
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
if not text or not stream_id:
return {"success": False, "error": "缺少必要参数 text 或 stream_id"}
@@ -85,6 +87,8 @@ class RuntimeCoreCapabilityMixin:
typing=bool(args.get("typing", False)),
set_reply=bool(args.get("set_reply", False)),
storage_message=bool(args.get("storage_message", True)),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
return {"success": result}
except Exception as exc:
@@ -107,6 +111,8 @@ class RuntimeCoreCapabilityMixin:
emoji_base64 = str(args.get("emoji_base64", ""))
stream_id = str(args.get("stream_id", ""))
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
if not emoji_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"}
@@ -115,6 +121,8 @@ class RuntimeCoreCapabilityMixin:
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=bool(args.get("storage_message", True)),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
return {"success": result}
except Exception as exc:
@@ -137,6 +145,8 @@ class RuntimeCoreCapabilityMixin:
image_base64 = str(args.get("image_base64", ""))
stream_id = str(args.get("stream_id", ""))
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
if not image_base64 or not stream_id:
return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"}
@@ -145,6 +155,8 @@ class RuntimeCoreCapabilityMixin:
image_base64=image_base64,
stream_id=stream_id,
storage_message=bool(args.get("storage_message", True)),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
return {"success": result}
except Exception as exc:
@@ -167,6 +179,8 @@ class RuntimeCoreCapabilityMixin:
command = str(args.get("command", ""))
stream_id = str(args.get("stream_id", ""))
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
if not command or not stream_id:
return {"success": False, "error": "缺少必要参数 command 或 stream_id"}
@@ -177,6 +191,8 @@ class RuntimeCoreCapabilityMixin:
stream_id=stream_id,
storage_message=bool(args.get("storage_message", True)),
display_message=str(args.get("display_message", "")),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
return {"success": result}
except Exception as exc:
@@ -202,6 +218,8 @@ class RuntimeCoreCapabilityMixin:
if content is None:
content = args.get("data", "")
stream_id = str(args.get("stream_id", ""))
sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False))
maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send")
if not message_type or not stream_id:
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}
@@ -213,6 +231,8 @@ class RuntimeCoreCapabilityMixin:
display_message=str(args.get("display_message", "")),
typing=bool(args.get("typing", False)),
storage_message=bool(args.get("storage_message", True)),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
return {"success": result}
except Exception as exc:

View File

@@ -4,16 +4,18 @@ import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
from sqlalchemy import delete, func
from sqlmodel import SQLModel, select
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import ToolRecord
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("database_service")
@@ -158,7 +160,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
async def store_tool_info(
chat_stream: BotChatSession,
chat_stream: "BotChatSession",
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
tool_id: str = "",
@@ -191,7 +193,7 @@ async def store_tool_info(
async def store_action_info(
chat_stream: BotChatSession,
chat_stream: "BotChatSession",
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
thinking_id: str = "",

View File

@@ -2,54 +2,20 @@ from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
from json_repair import repair_json
from src.chat.utils.utils import is_bot_self
from src.common.message_repository import find_messages
from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.config.config import global_config
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
from src.services.llm_service import LLMServiceClient
logger = get_logger("memory_flow_service")
class LongTermMemorySessionManager:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._summarizers: Dict[str, ChatHistorySummarizer] = {}
async def on_message(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)):
return
session_id = str(getattr(message, "session_id", "") or "").strip()
if not session_id:
return
created = False
async with self._lock:
summarizer = self._summarizers.get(session_id)
if summarizer is None:
summarizer = ChatHistorySummarizer(session_id=session_id)
self._summarizers[session_id] = summarizer
created = True
if created:
await summarizer.start()
async def shutdown(self) -> None:
async with self._lock:
items = list(self._summarizers.items())
self._summarizers.clear()
for session_id, summarizer in items:
try:
await summarizer.stop()
except Exception as exc:
logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc)
class PersonFactWritebackService:
def __init__(self) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
@@ -123,7 +89,11 @@ class PersonFactWritebackService:
if not session_id:
return
person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip()
person_name = str(
getattr(target_person, "person_name", "")
or getattr(target_person, "nickname", "")
or ""
).strip()
if not person_name:
return
@@ -242,7 +212,6 @@ class PersonFactWritebackService:
class MemoryAutomationService:
def __init__(self) -> None:
self.session_manager = LongTermMemorySessionManager()
self.fact_writeback = PersonFactWritebackService()
self._started = False
@@ -255,15 +224,9 @@ class MemoryAutomationService:
async def shutdown(self) -> None:
if not self._started:
return
await self.session_manager.shutdown()
await self.fact_writeback.shutdown()
self._started = False
async def on_incoming_message(self, message: Any) -> None:
if not self._started:
await self.start()
await self.session_manager.on_message(message)
async def on_message_sent(self, message: Any) -> None:
if not self._started:
await self.start()

View File

@@ -707,6 +707,28 @@ async def _notify_memory_automation_on_message_sent(message: SessionMessage) ->
logger.warning(f"[{session_id}] 长期记忆人物事实写回注册失败: {exc}")
def _sync_sent_message_to_maisaka_history(
message: SessionMessage,
*,
source_kind: str,
) -> None:
"""将已发送成功的消息同步到当前会话对应的 Maisaka 历史。"""
session_id = str(message.session_id or "").strip()
if not session_id:
return
try:
from src.chat.heart_flow.heartflow_manager import heartflow_manager
runtime = heartflow_manager.heartflow_chat_list.get(session_id)
if runtime is None:
return
runtime.append_sent_message_to_chat_history(message, source_kind=source_kind)
except Exception as exc:
logger.warning(f"[SendService] 同步消息到 Maisaka 历史失败: session_id={session_id} error={exc}")
def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
"""输出 Platform IO 批量发送失败详情。
@@ -837,13 +859,15 @@ async def send_session_message_with_message(
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> Optional[SessionMessage]:
"""统一发送一条内部消息,并返回最终发送成功的消息对象。"""
if not message.message_id:
logger.error("[SendService] 消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_id无法发送")
return await _send_via_platform_io(
sent_message = await _send_via_platform_io(
message,
typing=typing,
set_reply=set_reply,
@@ -851,6 +875,12 @@ async def send_session_message_with_message(
storage_message=storage_message,
show_log=show_log,
)
if sent_message is not None and sync_to_maisaka_history:
_sync_sent_message_to_maisaka_history(
sent_message,
source_kind=str(maisaka_source_kind or "outbound_send"),
)
return sent_message
async def send_session_message(
@@ -861,6 +891,8 @@ async def send_session_message(
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""统一发送一条内部消息。
@@ -893,6 +925,8 @@ async def send_session_message(
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
is not None
)
@@ -908,6 +942,8 @@ async def _send_to_target(
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定目标构建并发送消息,并返回是否发送成功。"""
return (
@@ -921,6 +957,8 @@ async def _send_to_target(
storage_message=storage_message,
show_log=show_log,
selected_expressions=selected_expressions,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
is not None
)
@@ -936,6 +974,8 @@ async def _send_to_target_with_message(
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> Optional[SessionMessage]:
"""向指定目标构建并发送消息。
@@ -998,6 +1038,8 @@ async def _send_to_target_with_message(
reply_message_id=reply_message.message_id if reply_message is not None else None,
storage_message=storage_message,
show_log=show_log,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
if sent_message is not None:
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
@@ -1019,6 +1061,8 @@ async def text_to_stream_with_message(
reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> Optional[SessionMessage]:
"""向指定流发送文本消息,并返回发送成功后的消息对象。"""
return await _send_to_target_with_message(
@@ -1030,6 +1074,8 @@ async def text_to_stream_with_message(
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
@@ -1041,6 +1087,8 @@ async def text_to_stream(
reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定流发送文本消息。
@@ -1065,6 +1113,8 @@ async def text_to_stream(
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
is not None
)
@@ -1076,6 +1126,8 @@ async def emoji_to_stream_with_message(
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> Optional[SessionMessage]:
"""向指定流发送表情消息,并返回发送成功后的消息对象。"""
return await _send_to_target_with_message(
@@ -1086,6 +1138,8 @@ async def emoji_to_stream_with_message(
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
@@ -1095,6 +1149,8 @@ async def emoji_to_stream(
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定流发送表情消息。
@@ -1115,6 +1171,8 @@ async def emoji_to_stream(
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
is not None
)
@@ -1126,6 +1184,8 @@ async def image_to_stream(
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定流发送图片消息。
@@ -1147,6 +1207,8 @@ async def image_to_stream(
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
@@ -1160,6 +1222,8 @@ async def custom_to_stream(
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定流发送自定义类型消息。
@@ -1186,6 +1250,8 @@ async def custom_to_stream(
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
@@ -1198,6 +1264,8 @@ async def custom_reply_set_to_stream(
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
sync_to_maisaka_history: bool = False,
maisaka_source_kind: str = "outbound_send",
) -> bool:
"""向指定流发送消息组件序列。
@@ -1223,4 +1291,6 @@ async def custom_reply_set_to_stream(
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)