feat:支持用户自定义prompt
This commit is contained in:
@@ -55,6 +55,7 @@ PromptContentBody = Annotated[str, Body(embed=True)]
|
||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
CUSTOM_PROMPTS_DIR = PROJECT_ROOT / "data" / "custom_prompts"
|
||||
MAISAKA_PROMPT_PREVIEW_DIR = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
|
||||
|
||||
|
||||
@@ -67,6 +68,7 @@ class PromptFileInfo(BaseModel):
|
||||
display_name: str = Field(default="", description="Prompt 展示名称")
|
||||
advanced: bool = Field(default=False, description="是否为高级 Prompt")
|
||||
description: str = Field(default="", description="Prompt 描述")
|
||||
customized: bool = Field(default=False, description="是否存在用户自定义覆盖")
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
@@ -84,6 +86,7 @@ class PromptFileResponse(BaseModel):
|
||||
language: str
|
||||
filename: str
|
||||
content: str
|
||||
customized: bool = False
|
||||
|
||||
|
||||
def _safe_prompt_path(language: str, filename: str) -> Path:
|
||||
@@ -106,6 +109,26 @@ def _safe_prompt_path(language: str, filename: str) -> Path:
|
||||
return prompt_path
|
||||
|
||||
|
||||
def _safe_custom_prompt_path(language: str, filename: str) -> Path:
|
||||
"""校验并解析 data/custom_prompts 下的用户覆盖文件路径。"""
|
||||
|
||||
normalized_language = language.strip()
|
||||
normalized_filename = filename.strip()
|
||||
|
||||
if not normalized_language or any(part in normalized_language for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 语言目录")
|
||||
if not normalized_filename.endswith(".prompt") or any(part in normalized_filename for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 文件名")
|
||||
|
||||
prompt_path = (CUSTOM_PROMPTS_DIR / normalized_language / normalized_filename).resolve()
|
||||
custom_prompts_root = CUSTOM_PROMPTS_DIR.resolve()
|
||||
try:
|
||||
prompt_path.relative_to(custom_prompts_root)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Prompt 路径越界") from exc
|
||||
return prompt_path
|
||||
|
||||
|
||||
def _safe_maisaka_prompt_preview_path(relative_path: str) -> Path:
|
||||
"""校验并解析 MaiSaka Prompt HTML 预览路径。"""
|
||||
|
||||
@@ -219,7 +242,9 @@ async def list_prompt_files():
|
||||
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
|
||||
prompt_files: List[PromptFileInfo] = []
|
||||
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
|
||||
stat = prompt_file.stat()
|
||||
custom_prompt_file = _safe_custom_prompt_path(language, prompt_file.name)
|
||||
effective_prompt_file = custom_prompt_file if custom_prompt_file.exists() else prompt_file
|
||||
stat = effective_prompt_file.stat()
|
||||
template_info = prompt_template_infos.get(prompt_file.stem)
|
||||
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
|
||||
prompt_files.append(
|
||||
@@ -230,6 +255,7 @@ async def list_prompt_files():
|
||||
display_name=metadata.display_name if metadata else "",
|
||||
advanced=metadata.advanced if metadata else False,
|
||||
description=metadata.description if metadata else "",
|
||||
customized=custom_prompt_file.exists(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -248,16 +274,39 @@ async def list_prompt_files():
|
||||
async def get_prompt_file(language: str, filename: str):
|
||||
"""读取指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
effective_prompt_path = custom_prompt_path if custom_prompt_path.exists() else prompt_path
|
||||
content = effective_prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(
|
||||
language=language,
|
||||
filename=filename,
|
||||
content=content,
|
||||
customized=custom_prompt_path.exists(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/prompts/{language}/{filename}/default", response_model=PromptFileResponse)
|
||||
async def get_default_prompt_file(language: str, filename: str):
|
||||
"""只读获取内置 Prompt 模板内容,不读取或修改用户自定义覆盖。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
content = prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
|
||||
logger.error(f"读取默认 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取默认 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
@@ -265,19 +314,40 @@ async def update_prompt_file(language: str, filename: str, content: PromptConten
|
||||
"""更新指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.parent.exists() or not prompt_path.parent.is_dir():
|
||||
raise HTTPException(status_code=404, detail="Prompt 语言目录不存在")
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
prompt_path.write_text(content, encoding="utf-8", newline="\n")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
custom_prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
custom_prompt_path.write_text(content, encoding="utf-8", newline="\n")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=True)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"保存 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
async def reset_prompt_file(language: str, filename: str):
|
||||
"""删除用户自定义覆盖,恢复使用内置 Prompt 模板。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
custom_prompt_path = _safe_custom_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
if custom_prompt_path.exists():
|
||||
custom_prompt_path.unlink()
|
||||
content = prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content, customized=False)
|
||||
except Exception as e:
|
||||
logger.error(f"恢复 Prompt 默认模板失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"恢复 Prompt 默认模板失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/maisaka-prompt-preview", response_class=FileResponse)
|
||||
async def get_maisaka_prompt_preview(path: str = Query(..., description="logs/maisaka_prompt 下的相对 HTML 路径")):
|
||||
"""打开 MaiSaka 监控中生成的 Prompt HTML 预览。"""
|
||||
|
||||
Reference in New Issue
Block a user