feat:实际应用自定义prompt,修复docker同级目录问题

This commit is contained in:
SengokuCola
2026-05-08 13:05:39 +08:00
parent 2c14fd8d49
commit fb3f4c28ef
8 changed files with 216 additions and 18 deletions

View File

@@ -18,10 +18,12 @@ logger = logging.getLogger("maibot.prompt_i18n")
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PROMPTS_ROOT = (PROJECT_ROOT / "prompts").resolve()
CUSTOM_PROMPTS_ROOT = (PROJECT_ROOT / "data" / "custom_prompts").resolve()
PROMPT_EXTENSIONS = (".prompt",)
SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")
STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT")
STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
_PROMPT_CACHE_REVISION = 0
extract_prompt_placeholders = extract_placeholders
@@ -43,6 +45,17 @@ def get_prompts_root(prompts_root: Path | None = None) -> Path:
return (prompts_root or PROMPTS_ROOT).resolve()
def get_custom_prompts_root(
custom_prompts_root: Path | None = None,
prompts_root: Path | None = None,
) -> Path:
if custom_prompts_root is not None:
return custom_prompts_root.resolve()
if prompts_root is not None:
return (prompts_root.resolve().parent / "data" / "custom_prompts").resolve()
return CUSTOM_PROMPTS_ROOT
def normalize_prompt_name(name: str) -> str:
candidate_name = name.strip()
for suffix in PROMPT_EXTENSIONS:
@@ -194,6 +207,28 @@ def _iter_locale_candidates(requested_locale: str) -> list[str]:
return locale_candidates
def _iter_prompt_path_candidates(base_dir: Path, name: str, category: str | None = None) -> list[Path]:
candidates: list[Path] = []
for suffix in PROMPT_EXTENSIONS:
if category is not None:
candidates.append((base_dir / category / f"{name}{suffix}").resolve())
candidates.append((base_dir / f"{name}{suffix}").resolve())
return candidates
def _resolve_custom_prompt_path(
name: str,
locale: str,
category: str | None,
custom_prompts_root: Path,
) -> Path | None:
custom_locale_dir = custom_prompts_root / locale
for candidate_path in _iter_prompt_path_candidates(custom_locale_dir, name, category):
if candidate_path.is_file():
return candidate_path
return None
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, PromptTemplateInfo]:
resolved_prompts_root = get_prompts_root(prompts_root)
requested_locale = normalize_locale(locale or get_locale())
@@ -206,15 +241,29 @@ def list_prompt_templates(locale: str | None = None, prompts_root: Path | None =
def resolve_prompt_path(
name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None
name: str,
locale: str | None = None,
category: str | None = None,
prompts_root: Path | None = None,
custom_prompts_root: Path | None = None,
) -> Path:
resolved_prompts_root = get_prompts_root(prompts_root)
resolved_custom_prompts_root = get_custom_prompts_root(custom_prompts_root, prompts_root)
normalized_name = normalize_prompt_name(name)
normalized_category = normalize_prompt_category(category)
requested_locale = normalize_locale(locale or get_locale())
if normalized_category is not None:
for locale_candidate in _iter_locale_candidates(requested_locale):
custom_path = _resolve_custom_prompt_path(
normalized_name,
locale_candidate,
normalized_category,
resolved_custom_prompts_root,
)
if custom_path is not None:
return custom_path
base_dir = resolved_prompts_root / locale_candidate
for suffix in PROMPT_EXTENSIONS:
candidate_path = (base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()
@@ -226,9 +275,20 @@ def resolve_prompt_path(
if fallback_path.is_file():
return fallback_path
else:
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
if normalized_name in prompt_paths:
return prompt_paths[normalized_name].path
for locale_candidate in _iter_locale_candidates(requested_locale):
custom_path = _resolve_custom_prompt_path(
normalized_name,
locale_candidate,
None,
resolved_custom_prompts_root,
)
if custom_path is not None:
return custom_path
base_dir = resolved_prompts_root / locale_candidate
for candidate_path in _iter_prompt_path_candidates(base_dir, normalized_name):
if candidate_path.is_file():
return candidate_path
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
@@ -263,13 +323,26 @@ def load_prompt(
locale: str | None = None,
category: str | None = None,
prompts_root: Path | None = None,
custom_prompts_root: Path | None = None,
**kwargs: object,
) -> str:
normalized_name = normalize_prompt_name(name)
prompt_path = resolve_prompt_path(name=normalized_name, locale=locale, category=category, prompts_root=prompts_root)
prompt_path = resolve_prompt_path(
name=normalized_name,
locale=locale,
category=category,
prompts_root=prompts_root,
custom_prompts_root=custom_prompts_root,
)
template = _read_prompt_template(prompt_path)
return _format_prompt_template(normalized_name, template, **kwargs)
def clear_prompt_cache() -> None:
global _PROMPT_CACHE_REVISION
_PROMPT_CACHE_REVISION += 1
_read_prompt_template.cache_clear()
def get_prompt_cache_revision() -> int:
return _PROMPT_CACHE_REVISION