feat;私聊回滚wait,添加记忆总结模型配置

This commit is contained in:
SengokuCola
2026-05-08 02:21:27 +08:00
parent 197351d469
commit 7bdbdec157
13 changed files with 371 additions and 92 deletions

View File

@@ -12,12 +12,24 @@ from src.services import llm_service as llm_api
def _fake_available_models() -> dict[str, TaskConfig]:
return {
"memory": TaskConfig(
model_list=["memory-model"],
max_tokens=512,
temperature=0.4,
selection_strategy="random",
),
"utils": TaskConfig(
model_list=["utils-model"],
max_tokens=256,
temperature=0.5,
selection_strategy="random",
),
"replyer": TaskConfig(
model_list=["test-model"],
model_list=["replyer-model"],
max_tokens=128,
temperature=0.7,
selection_strategy="priority",
)
selection_strategy="random",
),
}
@@ -35,7 +47,63 @@ def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["test-model"]
assert resolved.model_list == ["memory-model"]
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"utils": TaskConfig(model_list=["utils-model"]),
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["utils-model"]
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["planner-model"]
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"replyer": TaskConfig(model_list=["replyer-model"]),
"embedding": TaskConfig(model_list=["embedding-model"]),
},
)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
assert importer._resolve_summary_model_config() is None
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):

View File

@@ -4,7 +4,7 @@ from types import SimpleNamespace
import asyncio
import pytest
from src.core.tooling import ToolExecutionResult, ToolInvocation
from src.core.tooling import ToolAvailabilityContext, ToolExecutionResult, ToolInvocation
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka.builtin_tool import get_timing_tools
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
@@ -33,20 +33,42 @@ def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
)
def test_timing_gate_tools_only_expose_continue_and_no_reply() -> None:
tool_names = {tool_definition["name"] for tool_definition in get_timing_tools()}
def _build_runtime_stub(*, is_group_chat: bool) -> SimpleNamespace:
return SimpleNamespace(
_force_next_timing_continue=False,
_chat_history=[],
session_id="test-session",
chat_stream=SimpleNamespace(
session_id="test-session",
stream_id="test-stream",
is_group_session=is_group_chat,
group_id="group-1" if is_group_chat else "",
user_id="user-1",
platform="qq",
),
_chat_loop_service=SimpleNamespace(build_prompt_template_context=lambda: {}),
log_prefix="[test]",
stopped=False,
)
assert tool_names == {"continue", "no_reply"}
def test_timing_gate_tools_expose_wait_only_in_private_chat() -> None:
private_tool_names = {
tool_definition["name"]
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=False))
}
group_tool_names = {
tool_definition["name"]
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=True))
}
assert private_tool_names == {"continue", "no_reply", "wait"}
assert group_tool_names == {"continue", "no_reply"}
@pytest.mark.asyncio
async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = SimpleNamespace(
_force_next_timing_continue=False,
_chat_history=[],
log_prefix="[test]",
stopped=False,
)
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
@@ -90,12 +112,7 @@ async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest
@pytest.mark.asyncio
async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = SimpleNamespace(
_force_next_timing_continue=False,
_chat_history=[],
log_prefix="[test]",
stopped=False,
)
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
@@ -148,6 +165,37 @@ async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.
assert tool_monitor_results[0]["tool_name"] == "continue"
@pytest.mark.asyncio
async def test_timing_gate_group_chat_treats_wait_as_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = _build_runtime_stub(is_group_chat=True)
def _enter_stop_state() -> None:
runtime.stopped = True
runtime._enter_stop_state = _enter_stop_state
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
tool_definitions = kwargs["tool_definitions"]
assert {tool_definition["name"] for tool_definition in tool_definitions} == {"continue", "no_reply"}
return _build_chat_response([
ToolCall(call_id="disabled-wait", func_name="wait", args={"seconds": 3}),
])
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
del args, kwargs
raise AssertionError("群聊中禁用的 wait 不应被执行")
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
action, _, tool_results, _ = await engine._run_timing_gate(object()) # type: ignore[arg-type]
assert action == "no_reply"
assert runtime.stopped is True
assert tool_results[-1] == "- no_reply [非法 Timing 工具]: 返回了 wait已停止本轮并等待新消息"
def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None:
old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE)
runtime = SimpleNamespace(_chat_history=[old_hint])
@@ -184,6 +232,7 @@ def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None:
runtime = SimpleNamespace(
_STATE_WAIT="wait",
_agent_state="stop",
_message_turn_scheduled=False,
_internal_turn_queue=asyncio.Queue(),