feat;私聊回滚wait,添加记忆总结模型配置
This commit is contained in:
@@ -12,12 +12,24 @@ from src.services import llm_service as llm_api
|
||||
|
||||
def _fake_available_models() -> dict[str, TaskConfig]:
|
||||
return {
|
||||
"memory": TaskConfig(
|
||||
model_list=["memory-model"],
|
||||
max_tokens=512,
|
||||
temperature=0.4,
|
||||
selection_strategy="random",
|
||||
),
|
||||
"utils": TaskConfig(
|
||||
model_list=["utils-model"],
|
||||
max_tokens=256,
|
||||
temperature=0.5,
|
||||
selection_strategy="random",
|
||||
),
|
||||
"replyer": TaskConfig(
|
||||
model_list=["test-model"],
|
||||
model_list=["replyer-model"],
|
||||
max_tokens=128,
|
||||
temperature=0.7,
|
||||
selection_strategy="priority",
|
||||
)
|
||||
selection_strategy="random",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +47,63 @@ def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["test-model"]
|
||||
assert resolved.model_list == ["memory-model"]
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"utils": TaskConfig(model_list=["utils-model"]),
|
||||
"planner": TaskConfig(model_list=["planner-model"]),
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
},
|
||||
)
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["utils-model"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"planner": TaskConfig(model_list=["planner-model"]),
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
},
|
||||
)
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["planner-model"]
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
"embedding": TaskConfig(model_list=["embedding-model"]),
|
||||
},
|
||||
)
|
||||
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={},
|
||||
)
|
||||
|
||||
assert importer._resolve_summary_model_config() is None
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):
|
||||
|
||||
@@ -4,7 +4,7 @@ from types import SimpleNamespace
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from src.core.tooling import ToolExecutionResult, ToolInvocation
|
||||
from src.core.tooling import ToolAvailabilityContext, ToolExecutionResult, ToolInvocation
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka.builtin_tool import get_timing_tools
|
||||
from src.maisaka.chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||
@@ -33,20 +33,42 @@ def _build_chat_response(tool_calls: list[ToolCall]) -> ChatResponse:
|
||||
)
|
||||
|
||||
|
||||
def test_timing_gate_tools_only_expose_continue_and_no_reply() -> None:
|
||||
tool_names = {tool_definition["name"] for tool_definition in get_timing_tools()}
|
||||
def _build_runtime_stub(*, is_group_chat: bool) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
_force_next_timing_continue=False,
|
||||
_chat_history=[],
|
||||
session_id="test-session",
|
||||
chat_stream=SimpleNamespace(
|
||||
session_id="test-session",
|
||||
stream_id="test-stream",
|
||||
is_group_session=is_group_chat,
|
||||
group_id="group-1" if is_group_chat else "",
|
||||
user_id="user-1",
|
||||
platform="qq",
|
||||
),
|
||||
_chat_loop_service=SimpleNamespace(build_prompt_template_context=lambda: {}),
|
||||
log_prefix="[test]",
|
||||
stopped=False,
|
||||
)
|
||||
|
||||
assert tool_names == {"continue", "no_reply"}
|
||||
|
||||
def test_timing_gate_tools_expose_wait_only_in_private_chat() -> None:
|
||||
private_tool_names = {
|
||||
tool_definition["name"]
|
||||
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=False))
|
||||
}
|
||||
group_tool_names = {
|
||||
tool_definition["name"]
|
||||
for tool_definition in get_timing_tools(ToolAvailabilityContext(is_group_chat=True))
|
||||
}
|
||||
|
||||
assert private_tool_names == {"continue", "no_reply", "wait"}
|
||||
assert group_tool_names == {"continue", "no_reply"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = SimpleNamespace(
|
||||
_force_next_timing_continue=False,
|
||||
_chat_history=[],
|
||||
log_prefix="[test]",
|
||||
stopped=False,
|
||||
)
|
||||
runtime = _build_runtime_stub(is_group_chat=True)
|
||||
|
||||
def _enter_stop_state() -> None:
|
||||
runtime.stopped = True
|
||||
@@ -90,12 +112,7 @@ async def test_timing_gate_invalid_tool_defaults_to_no_reply(monkeypatch: pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = SimpleNamespace(
|
||||
_force_next_timing_continue=False,
|
||||
_chat_history=[],
|
||||
log_prefix="[test]",
|
||||
stopped=False,
|
||||
)
|
||||
runtime = _build_runtime_stub(is_group_chat=True)
|
||||
|
||||
def _enter_stop_state() -> None:
|
||||
runtime.stopped = True
|
||||
@@ -148,6 +165,37 @@ async def test_timing_gate_invalid_tool_retries_until_valid(monkeypatch: pytest.
|
||||
assert tool_monitor_results[0]["tool_name"] == "continue"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timing_gate_group_chat_treats_wait_as_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = _build_runtime_stub(is_group_chat=True)
|
||||
|
||||
def _enter_stop_state() -> None:
|
||||
runtime.stopped = True
|
||||
|
||||
runtime._enter_stop_state = _enter_stop_state
|
||||
engine = MaisakaReasoningEngine(runtime) # type: ignore[arg-type]
|
||||
|
||||
async def _fake_timing_gate_sub_agent(**kwargs: object) -> ChatResponse:
|
||||
tool_definitions = kwargs["tool_definitions"]
|
||||
assert {tool_definition["name"] for tool_definition in tool_definitions} == {"continue", "no_reply"}
|
||||
return _build_chat_response([
|
||||
ToolCall(call_id="disabled-wait", func_name="wait", args={"seconds": 3}),
|
||||
])
|
||||
|
||||
async def _fail_invoke_tool_call(*args: object, **kwargs: object) -> None:
|
||||
del args, kwargs
|
||||
raise AssertionError("群聊中禁用的 wait 不应被执行")
|
||||
|
||||
monkeypatch.setattr(engine, "_run_timing_gate_sub_agent", _fake_timing_gate_sub_agent)
|
||||
monkeypatch.setattr(engine, "_invoke_tool_call", _fail_invoke_tool_call)
|
||||
|
||||
action, _, tool_results, _ = await engine._run_timing_gate(object()) # type: ignore[arg-type]
|
||||
|
||||
assert action == "no_reply"
|
||||
assert runtime.stopped is True
|
||||
assert tool_results[-1] == "- no_reply [非法 Timing 工具]: 返回了 wait,已停止本轮并等待新消息"
|
||||
|
||||
|
||||
def test_timing_gate_invalid_tool_hint_keeps_only_latest() -> None:
|
||||
old_hint = SimpleNamespace(source=TIMING_GATE_INVALID_TOOL_HINT_SOURCE)
|
||||
runtime = SimpleNamespace(_chat_history=[old_hint])
|
||||
@@ -184,6 +232,7 @@ def test_timing_gate_invalid_tool_hint_only_visible_to_timing_gate() -> None:
|
||||
|
||||
def test_forced_timing_trigger_bypasses_message_frequency_threshold() -> None:
|
||||
runtime = SimpleNamespace(
|
||||
_STATE_WAIT="wait",
|
||||
_agent_state="stop",
|
||||
_message_turn_scheduled=False,
|
||||
_internal_turn_queue=asyncio.Queue(),
|
||||
|
||||
Reference in New Issue
Block a user