更好的Prompt管理系统,增加用户自定义Prompt与覆盖功能

This commit is contained in:
UnCLAS-Prommer
2026-02-02 20:53:42 +08:00
parent 0d0f5a9cdb
commit b793a3d62b
3 changed files with 402 additions and 46 deletions

View File

@@ -1,4 +1,4 @@
# File: tests/test_prompt_manager.py
# File: pytests/prompt_test/test_prompt_manager.py
import asyncio
import inspect
@@ -12,7 +12,15 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prompt_manager # noqa
from src.prompt.prompt_manager import ( # noqa
SUFFIX_PROMPT,
Prompt,
PromptManager,
prompt_manager,
)
# ========= Prompt 基础行为 =========
@pytest.mark.parametrize(
@@ -20,7 +28,11 @@ from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prom
[
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
pytest.param("brace-escaping", "Use {{ and }} around {field}", id="template-with-escaped-braces"),
pytest.param(
"brace-escaping",
"Use {{ and }} around {field}",
id="template-with-escaped-braces",
),
],
)
def test_prompt_init_happy_paths(prompt_name: str, template: str):
@@ -53,7 +65,12 @@ def test_prompt_init_happy_paths(prompt_name: str, template: str):
),
],
)
def test_prompt_init_error_cases(prompt_name, template, expected_exception, expected_msg_substring):
def test_prompt_init_error_cases(
prompt_name,
template,
expected_exception,
expected_msg_substring,
):
# Act / Assert
with pytest.raises(expected_exception) as exc_info:
Prompt(prompt_name=prompt_name, template=template)
@@ -123,6 +140,25 @@ def test_prompt_add_context(
assert result == expected_value
def test_prompt_clone_independent_instance():
# Arrange
prompt = Prompt(prompt_name="p", template="T {x}")
prompt.add_context("x", "X")
# Act
cloned = prompt.clone()
# Assert
assert cloned is not prompt
assert cloned.prompt_name == prompt.prompt_name
assert cloned.template == prompt.template
# 当前实现 clone 不复制 context
assert cloned.prompt_render_context == {}
# ========= PromptManager添加/获取/删除/替换 =========
def test_prompt_manager_add_prompt_happy_and_error():
# Arrange
manager = PromptManager()
@@ -147,6 +183,59 @@ def test_prompt_manager_add_prompt_happy_and_error():
# Assert
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
def test_prompt_manager_remove_prompt_happy_and_error():
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p1", template="T")
manager.add_prompt(p1, need_save=True)
# Act
manager.remove_prompt("p1")
# Assert
assert "p1" not in manager.prompts
assert "p1" not in manager._prompt_to_save
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.remove_prompt("no_such")
assert "Prompt name 'no_such' 不存在" in str(exc_info.value)
def test_prompt_manager_replace_prompt_happy_and_error():
# sourcery skip: extract-duplicate-method
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p", template="Old")
manager.add_prompt(p1, need_save=True)
p_new = Prompt(prompt_name="p", template="New")
# Act: 替换且保持 need_save
manager.replace_prompt(p_new, need_save=True)
# Assert
assert manager.prompts["p"].template == "New"
assert "p" in manager._prompt_to_save
# Act: 再次替换,且不需要保存
p_new2 = Prompt(prompt_name="p", template="New2")
manager.replace_prompt(p_new2, need_save=False)
# Assert
assert manager.prompts["p"].template == "New2"
assert "p" not in manager._prompt_to_save
# Error: 不存在的 prompt
p_unknown = Prompt(prompt_name="unknown", template="T")
with pytest.raises(KeyError) as exc_info:
manager.replace_prompt(p_unknown)
assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value)
def test_prompt_manager_get_prompt_is_copy():
# Arrange
manager = PromptManager()
@@ -162,6 +251,7 @@ def test_prompt_manager_get_prompt_is_copy():
assert retrieved_prompt.template == prompt.template
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
def test_prompt_manager_add_prompt_conflict_with_context_name():
# Arrange
manager = PromptManager()
@@ -230,6 +320,9 @@ def test_prompt_manager_get_prompt_not_exist():
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
# ========= 渲染逻辑 =========
@pytest.mark.parametrize(
"template, inner_context, global_context, expected, case_id",
[
@@ -264,7 +357,13 @@ def test_prompt_manager_get_prompt_not_exist():
],
)
@pytest.mark.asyncio
async def test_prompt_manager_render_contexts(template, inner_context, global_context, expected, case_id):
async def test_prompt_manager_render_contexts(
template,
inner_context,
global_context,
expected,
case_id,
):
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template=template)
@@ -274,7 +373,6 @@ async def test_prompt_manager_render_contexts(template, inner_context, global_co
prompt.add_context(name, fn)
for name, fn in global_context.items():
manager.add_context_construct_function(name, fn)
# Act
rendered = await manager.render_prompt(prompt)
@@ -396,6 +494,20 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
assert rendered == "g-main"
@pytest.mark.asyncio
async def test_prompt_manager_render_only_cloned_instance():
# Arrange
manager = PromptManager()
p = Prompt(prompt_name="p", template="T")
manager.add_prompt(p)
# Act / Assert: 直接用原始 p 渲染会报错
with pytest.raises(ValueError) as exc_info:
await manager.render_prompt(p)
assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value)
@pytest.mark.parametrize(
"is_prompt_context, use_coroutine, case_id",
[
@@ -406,7 +518,12 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
],
)
@pytest.mark.asyncio
async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_prompt_context, use_coroutine, case_id):
async def test_prompt_manager_get_function_result_error_logging(
monkeypatch,
is_prompt_context,
use_coroutine,
case_id,
):
# Arrange
manager = PromptManager()
@@ -449,6 +566,9 @@ async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
# ========= add_context_construct_function 边界 =========
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
# Arrange
manager = PromptManager()
@@ -496,50 +616,68 @@ def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monk
monkeypatch.setattr("inspect.currentframe", real_currentframe)
def test_prompt_manager_save_and_load_prompts(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
# ========= save/load & 目录逻辑 =========
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch):
"""
save_prompts 现在的逻辑:
1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件;
2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。
这里模拟删除已有自定义 prompt 文件时发生 IO 错误。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 先在自定义目录写入一个 prompt 文件,触发 unlink 路径
old_file = custom_dir / f"old{SUFFIX_PROMPT}"
old_file.write_text("old", encoding="utf-8")
manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}")
p1.add_context("x", "X")
p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True)
# Act
manager.save_prompts()
# 打桩 Path.unlink使删除文件时报错
def fake_unlink(self):
raise OSError("disk unlink error")
monkeypatch.setattr("pathlib.Path.unlink", fake_unlink)
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.save_prompts()
# Assert
saved_file = test_dir / f"save_me{SUFFIX_PROMPT}"
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
assert "disk unlink error" in str(exc_info.value)
def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch):
"""
模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。
"""
# Arrange
new_manager = PromptManager()
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
# Act
new_manager.load_prompts()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# Assert
loaded = new_manager.get_prompt("save_me")
assert loaded.template == "Template {x}"
assert "save_me" in new_manager._prompt_to_save
def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True)
class FakeFile:
def __enter__(self):
raise OSError("disk error")
raise OSError("disk write error")
def __exit__(self, exc_type, exc, tb):
return False
@@ -554,15 +692,23 @@ def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
manager.save_prompts()
# Assert
assert "disk error" in str(exc_info.value)
assert "disk write error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch):
"""
模拟从 PROMPTS_DIR 读取 prompt 时发生 IO 错误。
"""
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
prompt_file = test_dir / f"bad{SUFFIX_PROMPT}"
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
prompt_file = prompts_dir / f"bad{SUFFIX_PROMPT}"
prompt_file.write_text("content", encoding="utf-8")
class FakeFile:
@@ -572,8 +718,12 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
def fake_open(*_args, **_kwargs):
return FakeFile()
def fake_open(*args, **kwargs):
# 只对 default 目录下的文件触发错误,其余正常(如果有)
file_path = Path(args[0])
if file_path == prompt_file:
return FakeFile()
return open(*args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open)
manager = PromptManager()
@@ -586,6 +736,151 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
assert "read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch):
"""
模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。
包含两种路径:
1. default 与 custom 同名load_prompts 会优先读取 custom
2. 仅 custom 有文件,且 default 无同名文件。
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# default 与 custom 同名的文件
same_name = f"same{SUFFIX_PROMPT}"
base_file = prompts_dir / same_name
base_file.write_text("base", encoding="utf-8")
custom_file_same = custom_dir / same_name
custom_file_same.write_text("custom", encoding="utf-8")
# 仅 custom 下存在的文件
only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}"
only_custom_file.write_text("only", encoding="utf-8")
class FakeFile:
def __enter__(self):
raise OSError("custom read error")
def __exit__(self, exc_type, exc, tb):
return False
def fake_open(*args, **kwargs):
file_path = Path(args[0])
# 对 custom 目录下的 prompt 文件统一触发错误
if file_path.parent == custom_dir:
return FakeFile()
return open(*args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open)
manager = PromptManager()
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.load_prompts()
# Assert
assert "custom read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch):
"""
load_prompts 逻辑:
- 遍历 PROMPTS_DIR/*.prompt
- 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 默认目录 prompt
base_file = prompts_dir / f"testp{SUFFIX_PROMPT}"
base_file.write_text("BaseTemplate {x}", encoding="utf-8")
# 自定义目录同名 prompt应当覆盖默认
custom_file = custom_dir / base_file.name
custom_file.write_text("CustomTemplate {x}", encoding="utf-8")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("testp")
assert p.template == "CustomTemplate {x}"
# 从自定义目录加载的 prompt 应标记为 need_save加入 _prompt_to_save
assert "testp" in manager._prompt_to_save
def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch):
"""
从 PROMPTS_DIR 加载、且没有同名自定义 prompt 时need_save 应为 False不进入 _prompt_to_save
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 仅默认目录有 prompt自定义目录中无同名文件
base_file = prompts_dir / f"only_default{SUFFIX_PROMPT}"
base_file.write_text("DefaultTemplate {x}", encoding="utf-8")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("only_default")
assert p.template == "DefaultTemplate {x}"
# 从默认目录加载的 prompt 不应标记为 need_save
assert "only_default" not in manager._prompt_to_save
def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch):
"""
save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。
"""
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}")
p1.add_context("x", "X")
manager.add_prompt(p1, need_save=True)
# Act
manager.save_prompts()
# Assert: 文件应保存在 custom_dir 中
saved_file = custom_dir / f"save_me{SUFFIX_PROMPT}"
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
# ========= 其它 =========
def test_prompt_manager_global_instance_access():
# Act
pm = prompt_manager