PromptManager再修改,测试更新;将主程序的prompt独立到文件(部分)
This commit is contained in:
@@ -19,19 +19,28 @@ SUFFIX_PROMPT = ".prompt"
|
||||
|
||||
|
||||
class Prompt:
|
||||
prompt_name: str
|
||||
template: str
|
||||
prompt_render_context: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}
|
||||
|
||||
def __init__(self, prompt_name: str, template: str) -> None:
|
||||
self.prompt_name = prompt_name
|
||||
self.template = template
|
||||
self.prompt_render_context: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}
|
||||
self._is_cloned = False
|
||||
self.__post_init__()
|
||||
|
||||
def add_context(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
def add_context(self, name: str, func_or_str: Callable[[str], str | Coroutine[Any, Any, str]] | str) -> None:
|
||||
if name in self.prompt_render_context:
|
||||
raise KeyError(f"Context function name '{name}' 已存在于 Prompt '{self.prompt_name}' 中")
|
||||
self.prompt_render_context[name] = func
|
||||
if isinstance(func_or_str, str):
|
||||
|
||||
def tmp_func(_: str) -> str:
|
||||
return func_or_str
|
||||
|
||||
render_function = tmp_func
|
||||
else:
|
||||
render_function = func_or_str
|
||||
self.prompt_render_context[name] = render_function
|
||||
|
||||
def clone(self) -> "Prompt":
|
||||
return Prompt(self.prompt_name, self.template)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.prompt_name:
|
||||
@@ -47,7 +56,7 @@ class PromptManager:
|
||||
def __init__(self) -> None:
|
||||
PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在
|
||||
self.prompts: dict[str, Prompt] = {}
|
||||
"""存储 Prompt 实例"""
|
||||
"""存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果"""
|
||||
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
|
||||
"""存储上下文构造函数及其所属模块"""
|
||||
self._formatter = Formatter() # 仅用来解析模板
|
||||
@@ -57,6 +66,7 @@ class PromptManager:
|
||||
|
||||
def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
|
||||
if prompt.prompt_name in self.prompts or prompt.prompt_name in self._context_construct_functions:
|
||||
# 确保名称无冲突
|
||||
raise KeyError(f"Prompt name '{prompt.prompt_name}' 已存在")
|
||||
self.prompts[prompt.prompt_name] = prompt
|
||||
if need_save:
|
||||
@@ -81,14 +91,26 @@ class PromptManager:
|
||||
self._context_construct_functions[name] = func, caller_module
|
||||
|
||||
def get_prompt(self, prompt_name: str) -> Prompt:
|
||||
"""获取指定名称的 Prompt 实例的克隆"""
|
||||
if prompt_name not in self.prompts:
|
||||
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
|
||||
return self.prompts[prompt_name]
|
||||
prompt = self.prompts[prompt_name].clone()
|
||||
prompt._is_cloned = True
|
||||
return prompt
|
||||
|
||||
async def render_prompt(self, prompt: Prompt) -> str:
|
||||
if not prompt._is_cloned:
|
||||
raise ValueError(
|
||||
"只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例,你可能对原始实例进行了修改和渲染操作"
|
||||
)
|
||||
return await self._render(prompt)
|
||||
|
||||
async def _render(self, prompt: Prompt, recursive_level: int = 0) -> str:
|
||||
async def _render(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
recursive_level: int = 0,
|
||||
additional_construction_function_dict: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}, # noqa: B006
|
||||
) -> str:
|
||||
prompt.template = prompt.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE)
|
||||
if recursive_level > 10:
|
||||
raise RecursionError("递归层级过深,可能存在循环引用")
|
||||
@@ -96,16 +118,40 @@ class PromptManager:
|
||||
rendered_fields: dict[str, str] = {}
|
||||
for field_name in field_block:
|
||||
if field_name in self.prompts:
|
||||
rendered_fields[field_name] = await self._render(self.prompts[field_name], recursive_level + 1)
|
||||
nested_prompt = self.get_prompt(field_name)
|
||||
additional_construction_function_dict |= prompt.prompt_render_context
|
||||
rendered_fields[field_name] = await self._render(
|
||||
nested_prompt,
|
||||
recursive_level + 1,
|
||||
additional_construction_function_dict,
|
||||
)
|
||||
elif field_name in prompt.prompt_render_context:
|
||||
# 优先使用内部构造函数
|
||||
func = prompt.prompt_render_context[field_name]
|
||||
rendered_fields[field_name] = await self._get_function_result(
|
||||
func, prompt.prompt_name, field_name, is_prompt_context=True
|
||||
func,
|
||||
prompt.prompt_name,
|
||||
field_name,
|
||||
is_prompt_context=True,
|
||||
)
|
||||
elif field_name in self._context_construct_functions:
|
||||
# 随后查找全局构造函数
|
||||
func, module = self._context_construct_functions[field_name]
|
||||
rendered_fields[field_name] = await self._get_function_result(
|
||||
func, prompt.prompt_name, field_name, is_prompt_context=False, module=module
|
||||
func,
|
||||
prompt.prompt_name,
|
||||
field_name,
|
||||
is_prompt_context=False,
|
||||
module=module,
|
||||
)
|
||||
elif field_name in additional_construction_function_dict:
|
||||
# 最后查找额外传入的构造函数
|
||||
func = additional_construction_function_dict[field_name]
|
||||
rendered_fields[field_name] = await self._get_function_result(
|
||||
func,
|
||||
prompt.prompt_name,
|
||||
field_name,
|
||||
is_prompt_context=True,
|
||||
)
|
||||
else:
|
||||
raise KeyError(f"Prompt '{prompt.prompt_name}' 中缺少必要的内容块或构建函数: '{field_name}'")
|
||||
|
||||
Reference in New Issue
Block a user