Prompt管理器与测试;调整部分方法名称

This commit is contained in:
UnCLAS-Prommer
2026-01-20 14:00:15 +08:00
parent 3a66bfeac1
commit c15b77907e
3 changed files with 768 additions and 4 deletions

View File

@@ -0,0 +1,607 @@
import asyncio
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Optional
import pytest
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.prompt.prompt_manager import PromptManager
# --- Minimal stubs / constants matching the production module ---
# These imports/definitions are here only to make the tests selfcontained
# In the real project, they already exist in `prompt_manager.py`'s module.
# We mirror them here to control behavior via monkeypatch.
class Prompt:
def __init__(self, prompt_name: str, template: str, prompt_render_context: Optional[dict[str, Callable]] = None):
self.prompt_name = prompt_name
self.template = template
self.prompt_render_context = prompt_render_context or {}
class DummyLogger:
def __init__(self):
self.errors: list[str] = []
self.warnings: list[str] = []
def error(self, msg: str) -> None:
self.errors.append(msg)
def warning(self, msg: str) -> None:
self.warnings.append(msg)
# --- Fixtures to patch module-level objects in prompt_manager ---
@pytest.fixture
def dummy_logger(monkeypatch):
from src.prompt import prompt_manager as pm
logger = DummyLogger()
monkeypatch.setattr(pm, "logger", logger, raising=False)
return logger
@pytest.fixture
def temp_prompts_dir(tmp_path, monkeypatch):
from src.prompt import prompt_manager as pm
prompts_dir = tmp_path / "prompts"
monkeypatch.setattr(pm, "PROMPTS_DIR", prompts_dir, raising=False)
return prompts_dir
@pytest.fixture
def brace_constants(monkeypatch):
from src.prompt import prompt_manager as pm
# emulate the placeholders used in the manager
monkeypatch.setattr(pm, "_LEFT_BRACE", "__LEFT__", raising=False)
monkeypatch.setattr(pm, "_RIGHT_BRACE", "__RIGHT__", raising=False)
@pytest.fixture
def suffix_prompt(monkeypatch):
from src.prompt import prompt_manager as pm
monkeypatch.setattr(pm, "SUFFIX_PROMPT", ".prompt", raising=False)
@pytest.fixture
def manager(temp_prompts_dir, brace_constants, suffix_prompt):
# PromptManager.__init__ uses patched PROMPTS_DIR
return PromptManager()
# --- Helper to run async methods in tests (for non-async pytest) ---
def run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
# --- add_prompt tests --------------------------------------------------------
@pytest.mark.parametrize(
"existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save",
[
pytest.param(
{},
{},
"greeting",
False,
False,
id="add_prompt_simple_not_saved",
),
pytest.param(
{},
{},
"system",
True,
True,
id="add_prompt_marked_for_save",
),
pytest.param(
{"existing": Prompt("existing", "tmpl")},
{},
"new",
True,
True,
id="add_prompt_with_existing_other_prompt",
),
],
)
def test_add_prompt_happy_path(manager, existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save):
# Arrange
manager.prompts.update(existing_prompts)
manager._context_construct_functions.update(existing_funcs)
prompt = Prompt(name_to_add, "template")
# Act
manager.add_prompt(prompt, need_save=need_save)
# Assert
assert manager.prompts[name_to_add] is prompt
assert (name_to_add in manager._prompt_to_save) is expect_in_save
@pytest.mark.parametrize(
"existing_prompts, existing_funcs, new_name, conflict_type",
[
pytest.param(
{"dup": Prompt("dup", "tmpl")},
{},
"dup",
"prompt_conflict",
id="add_prompt_conflict_with_existing_prompt",
),
pytest.param(
{},
{"dup": (lambda x: x, "mod")},
"dup",
"func_conflict",
id="add_prompt_conflict_with_existing_context_function",
),
],
)
def test_add_prompt_conflict_raises_key_error(manager, existing_prompts, existing_funcs, new_name, conflict_type):
# Arrange
manager.prompts.update(existing_prompts)
manager._context_construct_functions.update(existing_funcs)
prompt = Prompt(new_name, "template")
# Act / Assert
with pytest.raises(KeyError) as exc:
manager.add_prompt(prompt)
assert new_name in str(exc.value)
# --- add_context_construct_function tests -----------------------------------
def test_add_context_construct_function_happy_path(manager):
# Arrange
def builder(prompt_name: str) -> str:
return f"ctx_for_{prompt_name}"
# Act
manager.add_context_construct_function("ctx", builder)
# Assert
assert "ctx" in manager._context_construct_functions
stored_func, module = manager._context_construct_functions["ctx"]
assert stored_func is builder
# module is caller's module name
assert isinstance(module, str)
assert module != ""
def test_add_context_construct_function_logs_unknown_module(manager, dummy_logger, monkeypatch):
# Arrange
def builder(prompt_name: str) -> str:
return f"v_{prompt_name}"
def fake_currentframe():
class FakeCallerFrame:
f_globals = {"__name__": "unknown"}
class FakeFrame:
f_back = FakeCallerFrame()
return FakeFrame()
from src.prompt import prompt_manager as pm
monkeypatch.setattr(pm.inspect, "currentframe", fake_currentframe)
# Act
manager.add_context_construct_function("unknown_ctx", builder)
# Assert
assert any("无法获取调用函数的模块名" in msg for msg in dummy_logger.warnings)
assert "unknown_ctx" in manager._context_construct_functions
@pytest.mark.parametrize(
"existing_prompts, existing_funcs, name_to_add",
[
pytest.param(
{"p": Prompt("p", "tmpl")},
{},
"p",
id="add_context_construct_function_conflict_with_prompt",
),
pytest.param(
{},
{"f": (lambda x: x, "mod")},
"f",
id="add_context_construct_function_conflict_with_existing_func",
),
],
)
def test_add_context_construct_function_conflict_raises_key_error(
manager, existing_prompts, existing_funcs, name_to_add
):
# Arrange
manager.prompts.update(existing_prompts)
manager._context_construct_functions.update(existing_funcs)
def func(prompt_name: str) -> str:
return "x"
# Act / Assert
with pytest.raises(KeyError) as exc:
manager.add_context_construct_function(name_to_add, func)
assert name_to_add in str(exc.value)
def test_add_context_construct_function_no_frame_raises_runtime_error(manager, monkeypatch):
# Arrange
from src.prompt import prompt_manager as pm
monkeypatch.setattr(pm.inspect, "currentframe", lambda: None)
def func(prompt_name: str) -> str:
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc:
manager.add_context_construct_function("ctx", func)
assert "无法获取调用栈" in str(exc.value)
def test_add_context_construct_function_no_caller_frame_raises_runtime_error(manager, monkeypatch):
# Arrange
from src.prompt import prompt_manager as pm
class FakeFrame:
f_back = None
monkeypatch.setattr(pm.inspect, "currentframe", lambda: FakeFrame())
def func(prompt_name: str) -> str:
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc:
manager.add_context_construct_function("ctx", func)
assert "无法获取调用栈的上一级" in str(exc.value)
# --- get_prompt tests --------------------------------------------------------
@pytest.mark.parametrize(
"existing_name, requested_name, should_raise",
[
pytest.param("p1", "p1", False, id="get_existing_prompt"),
pytest.param("p1", "missing", True, id="get_missing_prompt_raises"),
],
)
def test_get_prompt(manager, existing_name, requested_name, should_raise):
# Arrange
manager.prompts[existing_name] = Prompt(existing_name, "tmpl")
# Act / Assert
if should_raise:
with pytest.raises(KeyError) as exc:
manager.get_prompt(requested_name)
assert requested_name in str(exc.value)
else:
prompt = manager.get_prompt(requested_name)
assert prompt.prompt_name == existing_name
# --- render_prompt and _render tests ----------------------------------------
@pytest.mark.parametrize(
"template, prompts_setup, ctx_funcs_setup, prompt_ctx, expected",
[
pytest.param(
"Hello {name}",
{},
{},
{"name": lambda p: "World"},
"Hello World",
id="render_with_prompt_context_sync",
),
pytest.param(
"Hello {name}",
{},
{},
{
"name": lambda p: asyncio.sleep(0, result=f"Async-{p}"),
},
"Hello Async-main",
id="render_with_prompt_context_async",
),
pytest.param(
"Outer {inner}",
{
"inner": Prompt("inner", "Inner {value}", {"value": lambda p: "42"}),
},
{},
{},
"Outer Inner 42",
id="render_with_nested_prompt_reference",
),
pytest.param(
"Module says {ext}",
{},
{
"ext": (lambda p: f"external-{p}", "test_module"),
},
{},
"Module says external-main",
id="render_with_external_context_function_sync",
),
pytest.param(
"Module async {ext}",
{},
{
"ext": (lambda p: asyncio.sleep(0, result=f"ext_async-{p}"), "test_module"),
},
{},
"Module async ext_async-main",
id="render_with_external_context_function_async",
),
pytest.param(
"Escaped {{ and }} literal plus {value}",
{},
{},
{"value": lambda p: "X"},
"Escaped { and } literal plus X",
id="render_with_escaped_braces",
),
],
)
def test_render_prompt_happy_path(
manager,
template,
prompts_setup,
ctx_funcs_setup,
prompt_ctx,
expected,
):
# Arrange
main_prompt = Prompt("main", template, prompt_ctx)
manager.add_prompt(main_prompt)
for name, prompt in prompts_setup.items():
manager.add_prompt(prompt)
manager._context_construct_functions.update(ctx_funcs_setup)
# Act
rendered = run(manager.render_prompt(main_prompt))
# Assert
assert rendered == expected
def test_render_prompt_missing_field_raises_key_error(manager):
# Arrange
prompt = Prompt("main", "Hello {missing}")
manager.add_prompt(prompt)
# Act / Assert
with pytest.raises(KeyError) as exc:
run(manager.render_prompt(prompt))
assert "缺少必要的内容块或构建函数" in str(exc.value)
assert "missing" in str(exc.value)
def test_render_prompt_recursion_limit_exceeded(manager):
# Arrange
# Create mutual recursion between two prompts
p1 = Prompt("p1", "P1 uses {p2}")
p2 = Prompt("p2", "P2 uses {p1}")
manager.add_prompt(p1)
manager.add_prompt(p2)
# Act / Assert
with pytest.raises(RecursionError):
run(manager.render_prompt(p1))
# --- _get_function_result tests ---------------------------------------------
@pytest.mark.parametrize(
"func, is_prompt_context, expect_async",
[
pytest.param(
lambda p: f"sync_{p}",
True,
False,
id="get_function_result_sync_prompt_context",
),
pytest.param(
lambda p: asyncio.sleep(0, result=f"async_{p}"),
False,
True,
id="get_function_result_async_external_context",
),
],
)
def test_get_function_result_happy_path(manager, dummy_logger, func, is_prompt_context, expect_async):
# Act
res = run(
manager._get_function_result(
func=func,
prompt_name="prompt",
field_name="f",
is_prompt_context=is_prompt_context,
module="mod",
)
)
# Assert
assert res in {"sync_prompt", "async_prompt"}
@pytest.mark.parametrize(
"is_prompt_context, expected_message_part",
[
pytest.param(True, "内部上下文构造函数", id="get_function_result_error_prompt_context_logs_internal_msg"),
pytest.param(False, "上下文构造函数", id="get_function_result_error_external_logs_external_msg"),
],
)
def test_get_function_result_error_logging(manager, dummy_logger, is_prompt_context, expected_message_part):
# Arrange
def bad_func(prompt_name: str) -> str:
raise ValueError("bad")
# Act / Assert
with pytest.raises(ValueError):
run(
manager._get_function_result(
func=bad_func,
prompt_name="promptX",
field_name="fieldX",
is_prompt_context=is_prompt_context,
module="modX",
)
)
# Assert
assert any(expected_message_part in msg for msg in dummy_logger.errors)
assert any("promptX" in msg for msg in dummy_logger.errors) ^ (not is_prompt_context)
assert any("modX" in msg for msg in dummy_logger.errors) ^ is_prompt_context
assert any("fieldX" in msg for msg in dummy_logger.errors)
# --- save_prompts tests ------------------------------------------------------
def test_save_prompts_happy_path(manager, temp_prompts_dir):
# Arrange
p1 = Prompt("p1", "Hello {{name}}")
p2 = Prompt("p2", "Bye {{value}}")
manager.add_prompt(p1, need_save=True)
manager.add_prompt(p2, need_save=True)
# Act
manager.save_prompts()
# Assert
files = sorted(temp_prompts_dir.glob("*.prompt"))
assert len(files) == 2
contents = {f.stem: f.read_text(encoding="utf-8") for f in files}
assert contents["p1"] == "Hello {{name}}"
assert contents["p2"] == "Bye {{value}}"
def test_save_prompts_io_error(manager, temp_prompts_dir, dummy_logger, monkeypatch):
# Arrange
prompt = Prompt("p1", "Hi")
manager.add_prompt(prompt, need_save=True)
def bad_open(*args, **kwargs):
raise OSError("disk full")
monkeypatch.setattr("builtins.open", bad_open)
# Act / Assert
with pytest.raises(OSError):
manager.save_prompts()
# Assert
assert any("保存 Prompt 'p1' 时出错" in msg for msg in dummy_logger.errors)
# --- load_prompts tests ------------------------------------------------------
def test_load_prompts_happy_path(manager, temp_prompts_dir):
# Arrange
file1 = temp_prompts_dir / "greet.prompt"
file2 = temp_prompts_dir / "farewell.prompt"
temp_prompts_dir.mkdir(parents=True, exist_ok=True)
file1.write_text("Hello {{name}}", encoding="utf-8")
file2.write_text("Bye {{name}}", encoding="utf-8")
# Act
manager.load_prompts()
# Assert
assert "greet" in manager.prompts
assert "farewell" in manager.prompts
assert "greet" in manager._prompt_to_save
assert "farewell" in manager._prompt_to_save
assert manager.prompts["greet"].template == "Hello {{name}}"
assert manager.prompts["farewell"].template == "Bye {{name}}"
def test_load_prompts_error(manager, temp_prompts_dir, dummy_logger, monkeypatch):
# Arrange
file1 = temp_prompts_dir / "broken.prompt"
temp_prompts_dir.mkdir(parents=True, exist_ok=True)
file1.write_text("whatever", encoding="utf-8")
def bad_open(*args, **kwargs):
raise OSError("cannot read")
monkeypatch.setattr("builtins.open", bad_open)
# Act / Assert
with pytest.raises(OSError):
manager.load_prompts()
# Assert
assert any("加载 Prompt 文件" in msg for msg in dummy_logger.errors)