feat:支持用户自定义prompt

This commit is contained in:
SengokuCola
2026-05-07 00:45:58 +08:00
parent 8edf13df14
commit 57100797a5
5 changed files with 306 additions and 36 deletions

View File

@@ -5,8 +5,10 @@ from typing import Any, Optional
import inspect
from src.common.i18n import get_locale
from src.common.i18n.loaders import DEFAULT_LOCALE, normalize_locale
from src.common.logger import get_logger
from src.common.prompt_i18n import list_prompt_templates, load_prompt
from src.common.prompt_i18n import list_prompt_templates
logger = get_logger("Prompt")
@@ -23,6 +25,44 @@ CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
SUFFIX_PROMPT = ".prompt"
def _normalize_prompt_locale(locale: str | None = None) -> str:
return normalize_locale(locale or get_locale())
def _get_prompt_locale_from_path(prompt_path: Path) -> str | None:
try:
relative_path = prompt_path.resolve().relative_to(PROMPTS_DIR.resolve())
except ValueError:
return None
return relative_path.parts[0] if len(relative_path.parts) > 1 else None
def _custom_prompt_path(prompt_name: str, locale: str | None = None) -> Path:
return CUSTOM_PROMPTS_DIR / _normalize_prompt_locale(locale) / f"{prompt_name}{SUFFIX_PROMPT}"
def _legacy_custom_prompt_path(prompt_name: str) -> Path:
return CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
def _iter_custom_prompt_candidates(prompt_name: str, locale: str | None = None) -> list[Path]:
candidates: list[Path] = []
if locale:
candidates.append(_custom_prompt_path(prompt_name, locale))
candidates.append(_legacy_custom_prompt_path(prompt_name))
return candidates
def _iter_active_custom_prompt_dirs() -> list[Path]:
prompt_dirs = [
CUSTOM_PROMPTS_DIR / DEFAULT_LOCALE,
CUSTOM_PROMPTS_DIR / _normalize_prompt_locale(),
CUSTOM_PROMPTS_DIR,
]
return list(dict.fromkeys(prompt_dirs))
class Prompt:
def __init__(self, prompt_name: str, template: str) -> None:
self.prompt_name = prompt_name
@@ -74,8 +114,10 @@ class PromptManager:
"""模板解析器"""
self._prompt_to_save: set[str] = set()
"""需要保存的 Prompt 名称集合"""
self._prompt_save_locales: dict[str, str] = {}
"""Prompt 保存时使用的语言目录"""
def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
def add_prompt(self, prompt: Prompt, need_save: bool = False, prompt_locale: str | None = None) -> None:
"""
添加一个新的 Prompt 实例
@@ -91,6 +133,7 @@ class PromptManager:
self.prompts[prompt.prompt_name] = prompt
if need_save:
self._prompt_to_save.add(prompt.prompt_name)
self._prompt_save_locales[prompt.prompt_name] = _normalize_prompt_locale(prompt_locale)
def remove_prompt(self, prompt_name: str) -> None:
"""
@@ -105,8 +148,9 @@ class PromptManager:
del self.prompts[prompt_name]
if prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt_name)
self._prompt_save_locales.pop(prompt_name, None)
def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
def replace_prompt(self, prompt: Prompt, need_save: bool = False, prompt_locale: str | None = None) -> None:
"""
替换一个已存在的 Prompt 实例
Args:
@@ -120,8 +164,10 @@ class PromptManager:
self.prompts[prompt.prompt_name] = prompt
if need_save:
self._prompt_to_save.add(prompt.prompt_name)
self._prompt_save_locales[prompt.prompt_name] = _normalize_prompt_locale(prompt_locale)
elif prompt.prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt.prompt_name)
self._prompt_save_locales.pop(prompt.prompt_name, None)
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
"""
@@ -245,27 +291,33 @@ class PromptManager:
Raises:
Exception: 如果在保存过程中出现任何文件操作错误则引发该异常
"""
# 先清空自定义目录下的所有 Prompt 文件
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
try:
prompt_file.unlink()
except Exception as exc:
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise
# 只清理当前加载语言层的 Prompt 文件,避免误删其它语言的用户自定义模板。
for prompt_dir in _iter_active_custom_prompt_dirs():
if not prompt_dir.exists():
continue
for prompt_file in prompt_dir.glob(f"*{SUFFIX_PROMPT}"):
try:
prompt_file.unlink()
except Exception as exc:
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise
for prompt_name in self._prompt_to_save:
prompt = self.prompts[prompt_name]
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
prompt_locale = self._prompt_save_locales.get(prompt_name, _normalize_prompt_locale())
file_path = _custom_prompt_path(prompt_name, prompt_locale)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(prompt.template, encoding="utf-8")
except Exception as exc:
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {exc}")
raise
def _load_prompt_template(self, prompt_name: str) -> tuple[str, bool]:
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
if custom_prompt_path.exists():
return custom_prompt_path.read_text(encoding="utf-8"), True
return load_prompt(prompt_name, prompts_root=PROMPTS_DIR), False
def _load_prompt_template(self, prompt_name: str, source_path: Path) -> tuple[str, bool, str | None]:
prompt_locale = _get_prompt_locale_from_path(source_path)
for custom_prompt_path in _iter_custom_prompt_candidates(prompt_name, prompt_locale):
if custom_prompt_path.exists():
return custom_prompt_path.read_text(encoding="utf-8"), True, prompt_locale
return source_path.read_text(encoding="utf-8"), False, prompt_locale
def load_prompts(self) -> None:
"""
@@ -276,20 +328,34 @@ class PromptManager:
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
for prompt_name, prompt_template in prompt_templates.items():
try:
template, need_save = self._load_prompt_template(prompt_name)
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
template, need_save, prompt_locale = self._load_prompt_template(prompt_name, prompt_template.path)
self.add_prompt(
Prompt(prompt_name=prompt_name, template=template),
need_save=need_save,
prompt_locale=prompt_locale,
)
except Exception as exc:
logger.error(f"加载 Prompt 文件 '{prompt_template.path}' 时出错,错误信息: {exc}")
raise
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
if prompt_file.stem in prompt_templates:
continue # 已经加载过了,跳过
try:
template = prompt_file.read_text(encoding="utf-8")
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
except Exception as exc:
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise
loaded_custom_prompts = set(prompt_templates)
for prompt_dir in _iter_active_custom_prompt_dirs():
if not prompt_dir.exists():
continue
prompt_locale = prompt_dir.name if prompt_dir.parent == CUSTOM_PROMPTS_DIR else None
for prompt_file in prompt_dir.glob(f"*{SUFFIX_PROMPT}"):
if prompt_file.stem in loaded_custom_prompts:
continue # 已经加载过了,跳过
try:
template = prompt_file.read_text(encoding="utf-8")
self.add_prompt(
Prompt(prompt_name=prompt_file.stem, template=template),
need_save=True,
prompt_locale=prompt_locale,
)
loaded_custom_prompts.add(prompt_file.stem)
except Exception as exc:
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise
async def _get_function_result(
self,

View File

@@ -55,6 +55,7 @@ PromptContentBody = Annotated[str, Body(embed=True)]
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
PROMPTS_DIR = PROJECT_ROOT / "prompts"
CUSTOM_PROMPTS_DIR = PROJECT_ROOT / "data" / "custom_prompts"
MAISAKA_PROMPT_PREVIEW_DIR = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
@@ -67,6 +68,7 @@ class PromptFileInfo(BaseModel):
display_name: str = Field(default="", description="Prompt 展示名称")
advanced: bool = Field(default=False, description="是否为高级 Prompt")
description: str = Field(default="", description="Prompt 描述")
customized: bool = Field(default=False, description="是否存在用户自定义覆盖")
class PromptCatalogResponse(BaseModel):
@@ -84,6 +86,7 @@ class PromptFileResponse(BaseModel):
language: str
filename: str
content: str
customized: bool = False
def _safe_prompt_path(language: str, filename: str) -> Path:
@@ -106,6 +109,26 @@ def _safe_prompt_path(language: str, filename: str) -> Path:
return prompt_path
def _safe_custom_prompt_path(language: str, filename: str) -> Path:
"""校验并解析 data/custom_prompts 下的用户覆盖文件路径。"""
normalized_language = language.strip()
normalized_filename = filename.strip()
if not normalized_language or any(part in normalized_language for part in ("..", "/", "\\")):
raise HTTPException(status_code=400, detail="无效的 Prompt 语言目录")
if not normalized_filename.endswith(".prompt") or any(part in normalized_filename for part in ("..", "/", "\\")):
raise HTTPException(status_code=400, detail="无效的 Prompt 文件名")
prompt_path = (CUSTOM_PROMPTS_DIR / normalized_language / normalized_filename).resolve()
custom_prompts_root = CUSTOM_PROMPTS_DIR.resolve()
try:
prompt_path.relative_to(custom_prompts_root)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Prompt 路径越界") from exc
return prompt_path
def _safe_maisaka_prompt_preview_path(relative_path: str) -> Path:
"""校验并解析 MaiSaka Prompt HTML 预览路径。"""
@@ -219,7 +242,9 @@ async def list_prompt_files():
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
prompt_files: List[PromptFileInfo] = []
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
stat = prompt_file.stat()
custom_prompt_file = _safe_custom_prompt_path(language, prompt_file.name)
effective_prompt_file = custom_prompt_file if custom_prompt_file.exists() else prompt_file
stat = effective_prompt_file.stat()
template_info = prompt_template_infos.get(prompt_file.stem)
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
prompt_files.append(
@@ -230,6 +255,7 @@ async def list_prompt_files():
display_name=metadata.display_name if metadata else "",
advanced=metadata.advanced if metadata else False,
description=metadata.description if metadata else "",
customized=custom_prompt_file.exists(),
)
)
@@ -248,16 +274,39 @@ async def list_prompt_files():
async def get_prompt_file(language: str, filename: str):
"""读取指定语言下的 Prompt 文件内容。"""
prompt_path = _safe_prompt_path(language, filename)
custom_prompt_path = _safe_custom_prompt_path(language, filename)
if not prompt_path.exists() or not prompt_path.is_file():
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
try:
effective_prompt_path = custom_prompt_path if custom_prompt_path.exists() else prompt_path
content = effective_prompt_path.read_text(encoding="utf-8")
return PromptFileResponse(
language=language,
filename=filename,
content=content,
customized=custom_prompt_path.exists(),
)
except Exception as e:
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
@router.get("/prompts/{language}/{filename}/default", response_model=PromptFileResponse)
async def get_default_prompt_file(language: str, filename: str):
"""只读获取内置 Prompt 模板内容,不读取或修改用户自定义覆盖。"""
prompt_path = _safe_prompt_path(language, filename)
if not prompt_path.exists() or not prompt_path.is_file():
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
try:
content = prompt_path.read_text(encoding="utf-8")
return PromptFileResponse(language=language, filename=filename, content=content)
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
except Exception as e:
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
logger.error(f"读取默认 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"读取默认 Prompt 文件失败: {str(e)}") from e
@router.put("/prompts/{language}/{filename}", response_model=PromptFileResponse)
@@ -265,19 +314,40 @@ async def update_prompt_file(language: str, filename: str, content: PromptConten
"""更新指定语言下的 Prompt 文件内容。"""
prompt_path = _safe_prompt_path(language, filename)
custom_prompt_path = _safe_custom_prompt_path(language, filename)
if not prompt_path.parent.exists() or not prompt_path.parent.is_dir():
raise HTTPException(status_code=404, detail="Prompt 语言目录不存在")
if not prompt_path.exists() or not prompt_path.is_file():
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
try:
prompt_path.write_text(content, encoding="utf-8", newline="\n")
return PromptFileResponse(language=language, filename=filename, content=content)
custom_prompt_path.parent.mkdir(parents=True, exist_ok=True)
custom_prompt_path.write_text(content, encoding="utf-8", newline="\n")
return PromptFileResponse(language=language, filename=filename, content=content, customized=True)
except Exception as e:
logger.error(f"保存 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"保存 Prompt 文件失败: {str(e)}") from e
@router.delete("/prompts/{language}/{filename}", response_model=PromptFileResponse)
async def reset_prompt_file(language: str, filename: str):
"""删除用户自定义覆盖,恢复使用内置 Prompt 模板。"""
prompt_path = _safe_prompt_path(language, filename)
custom_prompt_path = _safe_custom_prompt_path(language, filename)
if not prompt_path.exists() or not prompt_path.is_file():
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
try:
if custom_prompt_path.exists():
custom_prompt_path.unlink()
content = prompt_path.read_text(encoding="utf-8")
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
except Exception as e:
logger.error(f"恢复 Prompt 默认模板失败: {prompt_path} {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"恢复 Prompt 默认模板失败: {str(e)}") from e
@router.get("/maisaka-prompt-preview", response_class=FileResponse)
async def get_maisaka_prompt_preview(path: str = Query(..., description="logs/maisaka_prompt 下的相对 HTML 路径")):
"""打开 MaiSaka 监控中生成的 Prompt HTML 预览。"""