Files
mai-bot/src/prompt/prompt_manager.py
2026-03-13 00:09:08 +09:00

313 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from string import Formatter
from typing import Any, Optional
from collections.abc import Callable, Coroutine
import inspect
from src.common.prompt_i18n import list_prompt_templates, load_prompt
from src.common.logger import get_logger
logger = get_logger("Prompt")
_LEFT_BRACE = "\ufde9"
_RIGHT_BRACE = "\ufdea"
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
PROMPTS_DIR = PROJECT_ROOT / "prompts"
DATA_DIR = PROJECT_ROOT / "data"
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
SUFFIX_PROMPT = ".prompt"
class Prompt:
def __init__(self, prompt_name: str, template: str) -> None:
self.prompt_name = prompt_name
self.template = template
self.prompt_render_context: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}
self._is_cloned = False
self.__post_init__()
def add_context(self, name: str, func_or_str: Callable[[str], str | Coroutine[Any, Any, str]] | str) -> None:
if name in self.prompt_render_context:
raise KeyError(f"Context function name '{name}' 已存在于 Prompt '{self.prompt_name}'")
if isinstance(func_or_str, str):
def tmp_func(_: str) -> str:
return func_or_str
render_function = tmp_func
else:
render_function = func_or_str
self.prompt_render_context[name] = render_function
def clone(self) -> "Prompt":
return Prompt(self.prompt_name, self.template)
def __post_init__(self):
if not self.prompt_name:
raise ValueError("prompt_name 不能为空")
if not self.template:
raise ValueError("template 不能为空")
tmp = self.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE)
if "{}" in tmp:
raise ValueError(r"模板中不允许使用未命名的占位符 '{}'")
class PromptManager:
def __init__(self) -> None:
self.prompts: dict[str, Prompt] = {}
"""存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果"""
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
"""存储上下文构造函数及其所属模块"""
self._formatter = Formatter() # 仅用来解析模板
"""模板解析器"""
self._prompt_to_save: set[str] = set()
"""需要保存的 Prompt 名称集合"""
def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
"""
添加一个新的 Prompt 实例
Args:
prompt (Prompt): 要添加的 Prompt 实例
need_save (bool): 是否需要保存该 Prompt默认为 False
Raises:
KeyError: 如果 Prompt 名称已存在则引发该异常
"""
if prompt.prompt_name in self.prompts or prompt.prompt_name in self._context_construct_functions:
# 确保名称无冲突
raise KeyError(f"Prompt name '{prompt.prompt_name}' 已存在")
self.prompts[prompt.prompt_name] = prompt
if need_save:
self._prompt_to_save.add(prompt.prompt_name)
def remove_prompt(self, prompt_name: str) -> None:
"""
移除一个已存在的 Prompt 实例
Args:
prompt_name (str): 要移除的 Prompt 名称
Raises:
KeyError: 如果 Prompt 名称不存在则引发该异常
"""
if prompt_name not in self.prompts:
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
del self.prompts[prompt_name]
if prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt_name)
def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
"""
替换一个已存在的 Prompt 实例
Args:
prompt (Prompt): 要替换的 Prompt 实例
need_save (bool): 是否需要保存该 Prompt默认为 False
Raises:
KeyError: 如果 Prompt 名称不存在则引发该异常
"""
if prompt.prompt_name not in self.prompts:
raise KeyError(f"Prompt name '{prompt.prompt_name}' 不存在,无法替换")
self.prompts[prompt.prompt_name] = prompt
if need_save:
self._prompt_to_save.add(prompt.prompt_name)
elif prompt.prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt.prompt_name)
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
"""
添加一个上下文构造函数
Args:
name (str): 上下文名称
func (Callable[[str], str | Coroutine[Any, Any, str]]): 构造函数,接受 Prompt 名称作为参数,返回字符串或返回字符串的协程
Raises:
KeyError: 如果上下文名称已存在则引发该异常
"""
if name in self._context_construct_functions or name in self.prompts:
raise KeyError(f"Construct function name '{name}' 已存在")
# 获取调用栈
frame = inspect.currentframe()
if not frame:
# 不应该出现的情况
raise RuntimeError("无法获取调用栈")
caller_frame = frame.f_back
if not caller_frame:
# 不应该出现的情况
raise RuntimeError("无法获取调用栈的上一级")
caller_module = caller_frame.f_globals.get("__name__", "unknown")
if caller_module == "unknown":
logger.warning("无法获取调用函数的模块名,使用 'unknown' 作为默认值")
self._context_construct_functions[name] = func, caller_module
def get_prompt(self, prompt_name: str) -> Prompt:
"""
获取指定名称的 Prompt 实例的克隆
Args:
prompt_name (str): 要获取的 Prompt 名称
Returns:
return (Prompt): 指定名称的 Prompt 实例的克隆
Raises:
KeyError: 如果 Prompt 名称不存在则引发该异常
"""
if prompt_name not in self.prompts:
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
prompt = self.prompts[prompt_name].clone()
prompt._is_cloned = True
return prompt
async def render_prompt(self, prompt: Prompt) -> str:
"""
渲染一个 Prompt 实例
Args:
prompt (Prompt): 要渲染的 Prompt 实例
Returns:
return (str): 渲染后的字符串
Raises:
ValueError: 如果传入的 Prompt 实例不是通过 get_prompt 方法获取的克隆实例则引发该异常
"""
if not prompt._is_cloned:
raise ValueError(
"只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例,你可能对原始实例进行了修改和渲染操作"
)
return await self._render(prompt)
async def _render(
self,
prompt: Prompt,
recursive_level: int = 0,
additional_construction_function_dict: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}, # noqa: B006
) -> str:
prompt.template = prompt.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE)
if recursive_level > 10:
raise RecursionError("递归层级过深,可能存在循环引用")
field_block = {field_name for _, field_name, _, _ in self._formatter.parse(prompt.template) if field_name}
rendered_fields: dict[str, str] = {}
for field_name in field_block:
if field_name in self.prompts:
nested_prompt = self.get_prompt(field_name)
additional_construction_function_dict |= prompt.prompt_render_context
rendered_fields[field_name] = await self._render(
nested_prompt,
recursive_level + 1,
additional_construction_function_dict,
)
elif field_name in prompt.prompt_render_context:
# 优先使用内部构造函数
func = prompt.prompt_render_context[field_name]
rendered_fields[field_name] = await self._get_function_result(
func,
prompt.prompt_name,
field_name,
is_prompt_context=True,
)
elif field_name in self._context_construct_functions:
# 随后查找全局构造函数
func, module = self._context_construct_functions[field_name]
rendered_fields[field_name] = await self._get_function_result(
func,
prompt.prompt_name,
field_name,
is_prompt_context=False,
module=module,
)
elif field_name in additional_construction_function_dict:
# 最后查找额外传入的构造函数
func = additional_construction_function_dict[field_name]
rendered_fields[field_name] = await self._get_function_result(
func,
prompt.prompt_name,
field_name,
is_prompt_context=True,
)
else:
raise KeyError(f"Prompt '{prompt.prompt_name}' 中缺少必要的内容块或构建函数: '{field_name}'")
rendered_template = prompt.template.format(**rendered_fields)
return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}")
def save_prompts(self) -> None:
"""
保存需要保存的 Prompt 实例到自定义目录,将清空未注册的自定义 Prompt 文件
Raises:
Exception: 如果在保存过程中出现任何文件操作错误则引发该异常
"""
# 先清空自定义目录下的所有 Prompt 文件
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
try:
prompt_file.unlink()
except Exception as e:
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e
for prompt_name in self._prompt_to_save:
prompt = self.prompts[prompt_name]
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
try:
with open(file_path, "w", encoding="utf-8") as f:
f.write(prompt.template)
except Exception as e:
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}")
raise e
def load_prompts(self) -> None:
"""
加载全部 Prompt 实例,优先加载自定义目录下的文件,支持覆盖加载
Raises:
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
"""
prompt_files = list_prompt_templates(prompts_root=PROMPTS_DIR)
for prompt_name, prompt_file in prompt_files.items():
try:
prompt_to_load = prompt_file
need_save = False
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
if custom_prompt_path.exists():
# 优先加载自定义目录下的 Prompt 文件
prompt_to_load = custom_prompt_path
need_save = True
with open(prompt_to_load, "r", encoding="utf-8") as f:
template = f.read()
else:
template = load_prompt(prompt_name, prompts_root=PROMPTS_DIR)
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
except Exception as e:
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
if prompt_file.stem in prompt_files:
continue # 已经加载过了,跳过
try:
with open(prompt_file, "r", encoding="utf-8") as f:
template = f.read()
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
except Exception as e:
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e
async def _get_function_result(
self,
func: Callable[[str], str | Coroutine[Any, Any, str]],
prompt_name: str,
field_name: str,
is_prompt_context: bool,
module: Optional[str] = None,
) -> str:
try:
res = func(prompt_name)
if isinstance(res, Coroutine):
res = await res
return res
except Exception as e:
if is_prompt_context:
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}")
else:
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}")
raise e
prompt_manager = PromptManager()