refactor: 优化 Prompt 类的克隆逻辑,添加克隆标记属性
This commit is contained in:
@@ -11,8 +11,8 @@ from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
||||
|
||||
logger = get_logger("Prompt")
|
||||
|
||||
_LEFT_BRACE = "\ufde9"
|
||||
_RIGHT_BRACE = "\ufdea"
|
||||
_LEFT_BRACE = chr(0xFDE9)
|
||||
_RIGHT_BRACE = chr(0xFDEA)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
@@ -47,7 +47,14 @@ class Prompt:
|
||||
def clone(self) -> "Prompt":
|
||||
return Prompt(self.prompt_name, self.template)
|
||||
|
||||
def __post_init__(self):
|
||||
@property
|
||||
def is_cloned(self) -> bool:
|
||||
return self._is_cloned
|
||||
|
||||
def mark_as_cloned(self) -> None:
|
||||
self._is_cloned = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.prompt_name:
|
||||
raise ValueError("prompt_name 不能为空")
|
||||
if not self.template:
|
||||
@@ -157,7 +164,7 @@ class PromptManager:
|
||||
if prompt_name not in self.prompts:
|
||||
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
|
||||
prompt = self.prompts[prompt_name].clone()
|
||||
prompt._is_cloned = True
|
||||
prompt.mark_as_cloned()
|
||||
return prompt
|
||||
|
||||
async def render_prompt(self, prompt: Prompt) -> str:
|
||||
@@ -171,7 +178,7 @@ class PromptManager:
|
||||
Raises:
|
||||
ValueError: 如果传入的 Prompt 实例不是通过 get_prompt 方法获取的克隆实例则引发该异常
|
||||
"""
|
||||
if not prompt._is_cloned:
|
||||
if not prompt.is_cloned:
|
||||
raise ValueError(
|
||||
"只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例,你可能对原始实例进行了修改和渲染操作"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user