feat:支持用户自定义prompt
This commit is contained in:
@@ -5,8 +5,10 @@ from typing import Any, Optional
|
||||
|
||||
import inspect
|
||||
|
||||
from src.common.i18n import get_locale
|
||||
from src.common.i18n.loaders import DEFAULT_LOCALE, normalize_locale
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
||||
from src.common.prompt_i18n import list_prompt_templates
|
||||
|
||||
|
||||
logger = get_logger("Prompt")
|
||||
@@ -23,6 +25,44 @@ CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SUFFIX_PROMPT = ".prompt"
|
||||
|
||||
|
||||
def _normalize_prompt_locale(locale: str | None = None) -> str:
|
||||
return normalize_locale(locale or get_locale())
|
||||
|
||||
|
||||
def _get_prompt_locale_from_path(prompt_path: Path) -> str | None:
|
||||
try:
|
||||
relative_path = prompt_path.resolve().relative_to(PROMPTS_DIR.resolve())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return relative_path.parts[0] if len(relative_path.parts) > 1 else None
|
||||
|
||||
|
||||
def _custom_prompt_path(prompt_name: str, locale: str | None = None) -> Path:
|
||||
return CUSTOM_PROMPTS_DIR / _normalize_prompt_locale(locale) / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
|
||||
|
||||
def _legacy_custom_prompt_path(prompt_name: str) -> Path:
|
||||
return CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
|
||||
|
||||
def _iter_custom_prompt_candidates(prompt_name: str, locale: str | None = None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if locale:
|
||||
candidates.append(_custom_prompt_path(prompt_name, locale))
|
||||
candidates.append(_legacy_custom_prompt_path(prompt_name))
|
||||
return candidates
|
||||
|
||||
|
||||
def _iter_active_custom_prompt_dirs() -> list[Path]:
|
||||
prompt_dirs = [
|
||||
CUSTOM_PROMPTS_DIR / DEFAULT_LOCALE,
|
||||
CUSTOM_PROMPTS_DIR / _normalize_prompt_locale(),
|
||||
CUSTOM_PROMPTS_DIR,
|
||||
]
|
||||
return list(dict.fromkeys(prompt_dirs))
|
||||
|
||||
|
||||
class Prompt:
|
||||
def __init__(self, prompt_name: str, template: str) -> None:
|
||||
self.prompt_name = prompt_name
|
||||
@@ -74,8 +114,10 @@ class PromptManager:
|
||||
"""模板解析器"""
|
||||
self._prompt_to_save: set[str] = set()
|
||||
"""需要保存的 Prompt 名称集合"""
|
||||
self._prompt_save_locales: dict[str, str] = {}
|
||||
"""Prompt 保存时使用的语言目录"""
|
||||
|
||||
def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
|
||||
def add_prompt(self, prompt: Prompt, need_save: bool = False, prompt_locale: str | None = None) -> None:
|
||||
"""
|
||||
添加一个新的 Prompt 实例
|
||||
|
||||
@@ -91,6 +133,7 @@ class PromptManager:
|
||||
self.prompts[prompt.prompt_name] = prompt
|
||||
if need_save:
|
||||
self._prompt_to_save.add(prompt.prompt_name)
|
||||
self._prompt_save_locales[prompt.prompt_name] = _normalize_prompt_locale(prompt_locale)
|
||||
|
||||
def remove_prompt(self, prompt_name: str) -> None:
|
||||
"""
|
||||
@@ -105,8 +148,9 @@ class PromptManager:
|
||||
del self.prompts[prompt_name]
|
||||
if prompt_name in self._prompt_to_save:
|
||||
self._prompt_to_save.remove(prompt_name)
|
||||
self._prompt_save_locales.pop(prompt_name, None)
|
||||
|
||||
def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
|
||||
def replace_prompt(self, prompt: Prompt, need_save: bool = False, prompt_locale: str | None = None) -> None:
|
||||
"""
|
||||
替换一个已存在的 Prompt 实例
|
||||
Args:
|
||||
@@ -120,8 +164,10 @@ class PromptManager:
|
||||
self.prompts[prompt.prompt_name] = prompt
|
||||
if need_save:
|
||||
self._prompt_to_save.add(prompt.prompt_name)
|
||||
self._prompt_save_locales[prompt.prompt_name] = _normalize_prompt_locale(prompt_locale)
|
||||
elif prompt.prompt_name in self._prompt_to_save:
|
||||
self._prompt_to_save.remove(prompt.prompt_name)
|
||||
self._prompt_save_locales.pop(prompt.prompt_name, None)
|
||||
|
||||
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
"""
|
||||
@@ -245,27 +291,33 @@ class PromptManager:
|
||||
Raises:
|
||||
Exception: 如果在保存过程中出现任何文件操作错误则引发该异常
|
||||
"""
|
||||
# 先清空自定义目录下的所有 Prompt 文件
|
||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
try:
|
||||
prompt_file.unlink()
|
||||
except Exception as exc:
|
||||
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
# 只清理当前加载语言层的 Prompt 文件,避免误删其它语言的用户自定义模板。
|
||||
for prompt_dir in _iter_active_custom_prompt_dirs():
|
||||
if not prompt_dir.exists():
|
||||
continue
|
||||
for prompt_file in prompt_dir.glob(f"*{SUFFIX_PROMPT}"):
|
||||
try:
|
||||
prompt_file.unlink()
|
||||
except Exception as exc:
|
||||
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
for prompt_name in self._prompt_to_save:
|
||||
prompt = self.prompts[prompt_name]
|
||||
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
prompt_locale = self._prompt_save_locales.get(prompt_name, _normalize_prompt_locale())
|
||||
file_path = _custom_prompt_path(prompt_name, prompt_locale)
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(prompt.template, encoding="utf-8")
|
||||
except Exception as exc:
|
||||
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {exc}")
|
||||
raise
|
||||
|
||||
def _load_prompt_template(self, prompt_name: str) -> tuple[str, bool]:
|
||||
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
if custom_prompt_path.exists():
|
||||
return custom_prompt_path.read_text(encoding="utf-8"), True
|
||||
return load_prompt(prompt_name, prompts_root=PROMPTS_DIR), False
|
||||
def _load_prompt_template(self, prompt_name: str, source_path: Path) -> tuple[str, bool, str | None]:
|
||||
prompt_locale = _get_prompt_locale_from_path(source_path)
|
||||
for custom_prompt_path in _iter_custom_prompt_candidates(prompt_name, prompt_locale):
|
||||
if custom_prompt_path.exists():
|
||||
return custom_prompt_path.read_text(encoding="utf-8"), True, prompt_locale
|
||||
return source_path.read_text(encoding="utf-8"), False, prompt_locale
|
||||
|
||||
def load_prompts(self) -> None:
|
||||
"""
|
||||
@@ -276,20 +328,34 @@ class PromptManager:
|
||||
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
||||
for prompt_name, prompt_template in prompt_templates.items():
|
||||
try:
|
||||
template, need_save = self._load_prompt_template(prompt_name)
|
||||
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
|
||||
template, need_save, prompt_locale = self._load_prompt_template(prompt_name, prompt_template.path)
|
||||
self.add_prompt(
|
||||
Prompt(prompt_name=prompt_name, template=template),
|
||||
need_save=need_save,
|
||||
prompt_locale=prompt_locale,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_template.path}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
if prompt_file.stem in prompt_templates:
|
||||
continue # 已经加载过了,跳过
|
||||
try:
|
||||
template = prompt_file.read_text(encoding="utf-8")
|
||||
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
||||
except Exception as exc:
|
||||
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
loaded_custom_prompts = set(prompt_templates)
|
||||
for prompt_dir in _iter_active_custom_prompt_dirs():
|
||||
if not prompt_dir.exists():
|
||||
continue
|
||||
prompt_locale = prompt_dir.name if prompt_dir.parent == CUSTOM_PROMPTS_DIR else None
|
||||
for prompt_file in prompt_dir.glob(f"*{SUFFIX_PROMPT}"):
|
||||
if prompt_file.stem in loaded_custom_prompts:
|
||||
continue # 已经加载过了,跳过
|
||||
try:
|
||||
template = prompt_file.read_text(encoding="utf-8")
|
||||
self.add_prompt(
|
||||
Prompt(prompt_name=prompt_file.stem, template=template),
|
||||
need_save=True,
|
||||
prompt_locale=prompt_locale,
|
||||
)
|
||||
loaded_custom_prompts.add(prompt_file.stem)
|
||||
except Exception as exc:
|
||||
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
|
||||
async def _get_function_result(
|
||||
self,
|
||||
|
||||
@@ -55,6 +55,7 @@ PromptContentBody = Annotated[str, Body(embed=True)]
|
||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
CUSTOM_PROMPTS_DIR = PROJECT_ROOT / "data" / "custom_prompts"
|
||||
MAISAKA_PROMPT_PREVIEW_DIR = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
|
||||
|
||||
|
||||
@@ -67,6 +68,7 @@ class PromptFileInfo(BaseModel):
|
||||
display_name: str = Field(default="", description="Prompt 展示名称")
|
||||
advanced: bool = Field(default=False, description="是否为高级 Prompt")
|
||||
description: str = Field(default="", description="Prompt 描述")
|
||||
customized: bool = Field(default=False, description="是否存在用户自定义覆盖")
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
@@ -84,6 +86,7 @@ class PromptFileResponse(BaseModel):
|
||||
language: str
|
||||
filename: str
|
||||
content: str
|
||||
customized: bool = False
|
||||
|
||||
|
||||
def _safe_prompt_path(language: str, filename: str) -> Path:
|
||||
@@ -106,6 +109,26 @@ def _safe_prompt_path(language: str, filename: str) -> Path:
|
||||
return prompt_path
|
||||
|
||||
|
||||
def _safe_custom_prompt_path(language: str, filename: str) -> Path:
|
||||
"""校验并解析 data/custom_prompts 下的用户覆盖文件路径。"""
|
||||
|
||||
normalized_language = language.strip()
|
||||
normalized_filename = filename.strip()
|
||||
|
||||
if not normalized_language or any(part in normalized_language for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 语言目录")
|
||||
if not normalized_filename.endswith(".prompt") or any(part in normalized_filename for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 文件名")
|
||||
|
||||
prompt_path = (CUSTOM_PROMPTS_DIR / normalized_language / normalized_filename).resolve()
|
||||
custom_prompts_root = CUSTOM_PROMPTS_DIR.resolve()
|
||||
try:
|
||||
prompt_path.relative_to(custom_prompts_root)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Prompt 路径越界") from exc
|
||||
return prompt_path
|
||||
|
||||
|
||||
def _safe_maisaka_prompt_preview_path(relative_path: str) -> Path:
|
||||
"""校验并解析 MaiSaka Prompt HTML 预览路径。"""
|
||||
|
||||
@@ -219,7 +242,9 @@ async def list_prompt_files():
|
||||
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
|
||||
prompt_files: List[PromptFileInfo] = []
|
||||
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
|
||||
stat = prompt_file.stat()
|
||||
custom_prompt_file = _safe_custom_prompt_path(language, prompt_file.name)
|
||||
effective_prompt_file = custom_prompt_file if custom_prompt_file.exists() else prompt_file
|
||||
stat = effective_prompt_file.stat()
|
||||
template_info = prompt_template_infos.get(prompt_file.stem)
|
||||
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
|
||||
prompt_files.append(
|
||||
@@ -230,6 +255,7 @@ async def list_prompt_files():
|
||||
display_name=metadata.display_name if metadata else "",
|
||||
advanced=metadata.advanced if metadata else False,
|
||||
description=metadata.description if metadata else "",
|
||||
customized=custom_prompt_file.exists(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -248,16 +274,39 @@ async def list_prompt_files():
|
||||
async def get_prompt_file(language: str, filename: str):
|
||||
"""读取指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
effective_prompt_path = custom_prompt_path if custom_prompt_path.exists() else prompt_path
|
||||
content = effective_prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(
|
||||
language=language,
|
||||
filename=filename,
|
||||
content=content,
|
||||
customized=custom_prompt_path.exists(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/prompts/{language}/{filename}/default", response_model=PromptFileResponse)
|
||||
async def get_default_prompt_file(language: str, filename: str):
|
||||
"""只读获取内置 Prompt 模板内容,不读取或修改用户自定义覆盖。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
content = prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
|
||||
logger.error(f"读取默认 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取默认 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
@@ -265,19 +314,40 @@ async def update_prompt_file(language: str, filename: str, content: PromptConten
|
||||
"""更新指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.parent.exists() or not prompt_path.parent.is_dir():
|
||||
raise HTTPException(status_code=404, detail="Prompt 语言目录不存在")
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
prompt_path.write_text(content, encoding="utf-8", newline="\n")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
custom_prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
custom_prompt_path.write_text(content, encoding="utf-8", newline="\n")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=True)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"保存 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
async def reset_prompt_file(language: str, filename: str):
|
||||
"""删除用户自定义覆盖,恢复使用内置 Prompt 模板。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
if custom_prompt_path.exists():
|
||||
custom_prompt_path.unlink()
|
||||
content = prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
|
||||
except Exception as e:
|
||||
logger.error(f"恢复 Prompt 默认模板失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"恢复 Prompt 默认模板失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/maisaka-prompt-preview", response_class=FileResponse)
|
||||
async def get_maisaka_prompt_preview(path: str = Query(..., description="logs/maisaka_prompt 下的相对 HTML 路径")):
|
||||
"""打开 MaiSaka 监控中生成的 Prompt HTML 预览。"""
|
||||
|
||||
Reference in New Issue
Block a user