feat: 增强国际化验证功能,添加对共享翻译字符串的支持,优化提示模板加载逻辑

This commit is contained in:
春河晴
2026-03-13 01:42:22 +09:00
parent 55eb911dd3
commit d418a8451b
6 changed files with 217 additions and 178 deletions

View File

@@ -21,8 +21,8 @@ def write_locale_file(locales_root: Path, locale: str, file_name: str, payload:
def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None: def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
locales_root = tmp_path / "locales" locales_root = tmp_path / "locales"
write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": "输入\"同意\"继续"}) write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": '输入"同意"继续'})
write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": "Type \"confirmed\" or \"同意\" to continue"}) write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": 'Type "confirmed" or "同意" to continue'})
errors = I18N_VALIDATE.validate_json_locales(locales_root) errors = I18N_VALIDATE.validate_json_locales(locales_root)
@@ -37,3 +37,34 @@ def test_validate_json_locales_rejects_untranslated_han_source_in_other_target_l
errors = I18N_VALIDATE.validate_json_locales(locales_root) errors = I18N_VALIDATE.validate_json_locales(locales_root)
assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors) assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
def test_validate_json_locales_avoids_false_positive_when_plural_categories_do_not_align(tmp_path: Path) -> None:
locales_root = tmp_path / "locales"
write_locale_file(
locales_root,
"zh-CN",
"core.json",
{
"tasks.cancelled": {
"one": "中文单数",
"other": "中文复数",
}
},
)
write_locale_file(
locales_root,
"ja",
"core.json",
{
"tasks.cancelled": {
"many": "中文单数",
"other": "已翻译",
}
},
)
errors = I18N_VALIDATE.validate_json_locales(locales_root)
assert any("tasks.cancelled" in error and "plural category 不一致" in error for error in errors)
assert not any("tasks.cancelled" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)

View File

@@ -18,9 +18,10 @@ from src.common.i18n.loaders import ( # noqa: E402
) )
from src.common.i18n.loaders import extract_placeholders # noqa: E402 from src.common.i18n.loaders import extract_placeholders # noqa: E402
from src.common.prompt_i18n import ( # noqa: E402 from src.common.prompt_i18n import ( # noqa: E402
PROMPT_EXTENSIONS, discover_prompt_locales,
extract_prompt_placeholders, extract_prompt_placeholders,
get_prompts_root, get_prompts_root,
iter_prompt_files,
) )
HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]") HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
@@ -36,6 +37,18 @@ def iter_translation_strings(value: TranslationValue) -> list[str]:
return [value[category] for category in sorted(value.keys())] return [value[category] for category in sorted(value.keys())]
def iter_shared_translation_strings(
source_value: TranslationValue, target_value: TranslationValue
) -> list[tuple[str, str]]:
if isinstance(source_value, str) or isinstance(target_value, str):
if isinstance(source_value, str) and isinstance(target_value, str):
return [(source_value, target_value)]
return []
shared_categories = sorted(set(source_value.keys()) & set(target_value.keys()))
return [(source_value[category], target_value[category]) for category in shared_categories]
def locale_requires_latin_only_validation(locale: str) -> bool: def locale_requires_latin_only_validation(locale: str) -> bool:
normalized_locale = locale.lower() normalized_locale = locale.lower()
return normalized_locale == "en" or normalized_locale.startswith("en-") return normalized_locale == "en" or normalized_locale.startswith("en-")
@@ -48,14 +61,15 @@ def validate_locale_content(
locale: str, locale: str,
errors: list[str], errors: list[str],
) -> None: ) -> None:
source_texts = iter_translation_strings(source_value)
target_texts = iter_translation_strings(target_value) target_texts = iter_translation_strings(target_value)
if any( if any(
source_text == target_text and contains_han_characters(source_text) source_text == target_text and contains_han_characters(source_text)
for source_text, target_text in zip(source_texts, target_texts, strict=False) for source_text, target_text in iter_shared_translation_strings(source_value, target_value)
): ):
errors.append(f"[{locale}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译") errors.append(
f"[{locale}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译"
)
if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts): if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts):
errors.append(f"[{locale}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交") errors.append(f"[{locale}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交")
@@ -121,12 +135,9 @@ def validate_json_locales(locales_root: Path | None = None) -> list[str]:
continue continue
locale_keys = set(catalog.keys()) locale_keys = set(catalog.keys())
missing_keys = sorted(source_keys - locale_keys) for key in sorted(source_keys - locale_keys):
extra_keys = sorted(locale_keys - source_keys)
for key in missing_keys:
errors.append(f"[{locale}] 缺少 key: {key}") errors.append(f"[{locale}] 缺少 key: {key}")
for key in extra_keys: for key in sorted(locale_keys - source_keys):
errors.append(f"[{locale}] 存在多余 key: {key}") errors.append(f"[{locale}] 存在多余 key: {key}")
for key in sorted(source_keys & locale_keys): for key in sorted(source_keys & locale_keys):
@@ -139,25 +150,13 @@ def validate_json_locales(locales_root: Path | None = None) -> list[str]:
return errors return errors
def discover_prompt_locales(prompts_root: Path | None = None) -> list[str]: def build_prompt_catalog(locale_dir: Path) -> dict[Path, Path]:
resolved_prompts_root = get_prompts_root(prompts_root) return {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)}
if not resolved_prompts_root.exists():
return []
locale_names = [path.name for path in resolved_prompts_root.iterdir() if path.is_dir()]
return sorted(locale_names)
def iter_prompt_files(locale_dir: Path) -> list[Path]:
prompt_files: list[Path] = []
for extension in PROMPT_EXTENSIONS:
prompt_files.extend(path for path in locale_dir.rglob(f"*{extension}") if path.is_file())
return sorted(set(prompt_files))
def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]: def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]:
resolved_prompts_root = get_prompts_root(prompts_root) resolved_prompts_root = get_prompts_root(prompts_root)
prompt_locales = discover_prompt_locales(resolved_prompts_root) prompt_locales = set(discover_prompt_locales(resolved_prompts_root))
known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE] known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE]
errors: list[str] = [] errors: list[str] = []
warnings: list[str] = [] warnings: list[str] = []
@@ -167,7 +166,8 @@ def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[st
return errors, warnings return errors, warnings
source_dir = resolved_prompts_root / DEFAULT_LOCALE source_dir = resolved_prompts_root / DEFAULT_LOCALE
source_files = {path.relative_to(source_dir): path for path in iter_prompt_files(source_dir)} source_files = build_prompt_catalog(source_dir)
source_relative_paths = set(source_files.keys())
for locale in known_locales: for locale in known_locales:
locale_dir = resolved_prompts_root / locale locale_dir = resolved_prompts_root / locale
@@ -175,8 +175,7 @@ def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[st
warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}") warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}")
continue continue
locale_files = {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)} locale_files = build_prompt_catalog(locale_dir)
source_relative_paths = set(source_files.keys())
locale_relative_paths = set(locale_files.keys()) locale_relative_paths = set(locale_files.keys())
for relative_path in sorted(source_relative_paths - locale_relative_paths): for relative_path in sorted(source_relative_paths - locale_relative_paths):

View File

@@ -3,18 +3,16 @@ from __future__ import annotations
from collections.abc import Iterator from collections.abc import Iterator
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from functools import lru_cache
from .loaders import DEFAULT_LOCALE from .loaders import DEFAULT_LOCALE
@lru_cache(maxsize=1)
def _get_manager(): def _get_manager():
from .manager import I18nManager from .manager import I18nManager
manager = getattr(_get_manager, "_manager", None) return I18nManager()
if manager is None:
manager = I18nManager()
_get_manager._manager = manager
return manager
def set_locale(locale: str) -> str: def set_locale(locale: str) -> str:

View File

@@ -67,22 +67,12 @@ class I18nManager:
self._catalog_cache.pop(normalize_locale(locale), None) self._catalog_cache.pop(normalize_locale(locale), None)
def t(self, key: str, locale: str | None = None, **kwargs: object) -> str: def t(self, key: str, locale: str | None = None, **kwargs: object) -> str:
translation_value, _ = self._get_translation_value(key, locale) translation_value, translation_locale = self._get_translation_value(key, locale)
if translation_value is None: template = self._get_standard_template(key, translation_value, translation_locale)
if template is None:
return key return key
if isinstance(translation_value, dict): return self._format_translation(key, template, kwargs)
template = translation_value.get("other")
if template is None:
self._log_once(
("plural_missing_other", self.get_locale(), key),
logging.WARNING,
"翻译 key '%s' 缺少 other plural category已回退到 key 本身",
key,
)
return key
return self._format_translation(key, template, kwargs)
return self._format_translation(key, translation_value, kwargs)
def tn(self, key: str, count: int | float, locale: str | None = None, **kwargs: object) -> str: def tn(self, key: str, count: int | float, locale: str | None = None, **kwargs: object) -> str:
translation_value, translation_locale = self._get_translation_value(key, locale) translation_value, translation_locale = self._get_translation_value(key, locale)
@@ -118,6 +108,27 @@ class I18nManager:
formatting_kwargs["count"] = count formatting_kwargs["count"] = count
return self._format_translation(key, template, formatting_kwargs) return self._format_translation(key, template, formatting_kwargs)
def _get_standard_template(
self,
key: str,
translation_value: TranslationValue | None,
translation_locale: str,
) -> str | None:
if translation_value is None:
return None
if not isinstance(translation_value, dict):
return translation_value
template = translation_value.get("other")
if template is None:
self._log_once(
("plural_missing_other", translation_locale, key),
logging.WARNING,
"翻译 key '%s' 缺少 other plural category已回退到 key 本身",
key,
)
return template
def _format_translation(self, key: str, template: str, kwargs: dict[str, object]) -> str: def _format_translation(self, key: str, template: str, kwargs: dict[str, object]) -> str:
try: try:
return format_template(template, **kwargs) return format_template(template, **kwargs)
@@ -161,14 +172,15 @@ class I18nManager:
try: try:
return normalize_locale(locale) return normalize_locale(locale)
except InvalidLocaleError: except InvalidLocaleError:
current_locale = self.get_locale()
self._log_once( self._log_once(
("invalid_locale", "explicit", locale), ("invalid_locale", "explicit", locale),
logging.WARNING, logging.WARNING,
"检测到非法 locale='%s',已回退到当前默认 locale %s", "检测到非法 locale='%s',已回退到当前默认 locale %s",
locale, locale,
self.get_locale(), current_locale,
) )
return self.get_locale() return current_locale
def _get_catalog(self, locale: str) -> dict[str, TranslationValue]: def _get_catalog(self, locale: str) -> dict[str, TranslationValue]:
normalized_locale = normalize_locale(locale) normalized_locale = normalize_locale(locale)

View File

@@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from functools import lru_cache
from pathlib import Path from pathlib import Path
import logging import logging
import os import os
import re import re
import threading
from .i18n import get_locale, t from .i18n import get_locale, t
from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders as extract_prompt_placeholders, normalize_locale from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders, normalize_locale
logger = logging.getLogger("maibot.prompt_i18n") logger = logging.getLogger("maibot.prompt_i18n")
@@ -17,9 +17,9 @@ PROMPTS_ROOT = (PROJECT_ROOT / "prompts").resolve()
PROMPT_EXTENSIONS = (".prompt",) PROMPT_EXTENSIONS = (".prompt",)
SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$") SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")
STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT") STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT")
STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
_prompt_cache: dict[Path, str] = {} extract_prompt_placeholders = extract_placeholders
_cache_lock = threading.RLock()
def get_prompts_root(prompts_root: Path | None = None) -> Path: def get_prompts_root(prompts_root: Path | None = None) -> Path:
@@ -56,91 +56,100 @@ def is_strict_prompt_i18n_mode() -> bool:
if os.getenv("PYTEST_CURRENT_TEST"): if os.getenv("PYTEST_CURRENT_TEST"):
return True return True
return any(os.getenv(env_key, "").strip().lower() in {"1", "true", "yes", "on"} for env_key in STRICT_ENV_KEYS) return any(os.getenv(env_key, "").strip().lower() in STRICT_ENV_VALUES for env_key in STRICT_ENV_KEYS)
def _supported_prompt_files(directory: Path, recursive: bool = True) -> list[Path]: def discover_prompt_locales(prompts_root: Path | None = None) -> list[str]:
resolved_prompts_root = get_prompts_root(prompts_root)
if not resolved_prompts_root.exists():
return []
locale_names = [path.name for path in resolved_prompts_root.iterdir() if path.is_dir()]
return sorted(locale_names)
def iter_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
if not directory.exists():
return []
search = directory.rglob if recursive else directory.glob search = directory.rglob if recursive else directory.glob
matched_files: list[Path] = [] prompt_files: list[Path] = []
for suffix in PROMPT_EXTENSIONS: for suffix in PROMPT_EXTENSIONS:
matched_files.extend(path for path in search(f"*{suffix}") if path.is_file()) prompt_files.extend(path for path in search(f"*{suffix}") if path.is_file())
return sorted(set(matched_files)) return sorted(set(prompt_files))
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]: def _raise_duplicate_prompt_name(name: str, first_path: Path, second_path: Path, prompts_root: Path) -> None:
raise ValueError(
t(
"prompt.duplicate_template_name",
name=name,
path_a=first_path.relative_to(prompts_root),
path_b=second_path.relative_to(prompts_root),
)
)
def _scan_prompt_directory(directory: Path, prompts_root: Path, recursive: bool = True) -> dict[str, Path]:
prompt_paths: dict[str, Path] = {} prompt_paths: dict[str, Path] = {}
if not directory.exists(): for prompt_path in iter_prompt_files(directory, recursive=recursive):
return prompt_paths
for prompt_path in _supported_prompt_files(directory):
prompt_name = prompt_path.stem prompt_name = prompt_path.stem
if prompt_name in prompt_paths: existing_path = prompt_paths.get(prompt_name)
raise ValueError( if existing_path is not None:
t( _raise_duplicate_prompt_name(prompt_name, existing_path, prompt_path, prompts_root)
"prompt.duplicate_template_name",
name=prompt_name,
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
path_b=prompt_path.relative_to(prompts_root),
)
)
prompt_paths[prompt_name] = prompt_path prompt_paths[prompt_name] = prompt_path
return prompt_paths return prompt_paths
def _scan_legacy_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]: def _iter_prompt_template_layers(prompts_root: Path, requested_locale: str) -> list[tuple[Path, bool]]:
prompt_paths: dict[str, Path] = {} prompt_layers: list[tuple[Path, bool]] = [
if not directory.exists(): (prompts_root, False),
return prompt_paths (prompts_root / DEFAULT_LOCALE, True),
]
if requested_locale != DEFAULT_LOCALE:
prompt_layers.append((prompts_root / requested_locale, True))
return prompt_layers
for prompt_path in _supported_prompt_files(directory, recursive=False):
prompt_name = prompt_path.stem def _iter_locale_candidates(requested_locale: str) -> list[str | None]:
if prompt_name in prompt_paths: locale_candidates: list[str | None] = [requested_locale]
raise ValueError( if requested_locale != DEFAULT_LOCALE:
t( locale_candidates.append(DEFAULT_LOCALE)
"prompt.duplicate_template_name", locale_candidates.append(None)
name=prompt_name, return locale_candidates
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
path_b=prompt_path.relative_to(prompts_root),
)
)
prompt_paths[prompt_name] = prompt_path
return prompt_paths
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]: def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
resolved_prompts_root = get_prompts_root(prompts_root) resolved_prompts_root = get_prompts_root(prompts_root)
requested_locale = normalize_locale(locale or get_locale()) requested_locale = normalize_locale(locale or get_locale())
prompt_paths = _scan_legacy_prompt_directory(resolved_prompts_root, resolved_prompts_root) prompt_paths: dict[str, Path] = {}
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / DEFAULT_LOCALE, resolved_prompts_root)) for directory, recursive in _iter_prompt_template_layers(resolved_prompts_root, requested_locale):
prompt_paths.update(_scan_prompt_directory(directory, resolved_prompts_root, recursive=recursive))
if requested_locale != DEFAULT_LOCALE:
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / requested_locale, resolved_prompts_root))
return prompt_paths return prompt_paths
def resolve_prompt_path(name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None) -> Path: def resolve_prompt_path(
name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None
) -> Path:
resolved_prompts_root = get_prompts_root(prompts_root) resolved_prompts_root = get_prompts_root(prompts_root)
normalized_name = normalize_prompt_name(name) normalized_name = normalize_prompt_name(name)
normalized_category = normalize_prompt_category(category) normalized_category = normalize_prompt_category(category)
requested_locale = normalize_locale(locale or get_locale()) requested_locale = normalize_locale(locale or get_locale())
locale_candidates: list[str | None] = [requested_locale]
if requested_locale != DEFAULT_LOCALE:
locale_candidates.append(DEFAULT_LOCALE)
locale_candidates.append(None)
if normalized_category is not None: if normalized_category is not None:
for locale_candidate in locale_candidates: for locale_candidate in _iter_locale_candidates(requested_locale):
base_dir = resolved_prompts_root if locale_candidate is None else resolved_prompts_root / locale_candidate base_dir = resolved_prompts_root if locale_candidate is None else resolved_prompts_root / locale_candidate
for suffix in PROMPT_EXTENSIONS: for suffix in PROMPT_EXTENSIONS:
candidate_paths = [(base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()] candidate_path = (base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()
if candidate_path.is_file():
return candidate_path
# 允许带 category 的调用在旧版平铺目录或未迁移完的 locale 目录中继续工作。 # 允许带 category 的调用在旧版平铺目录或未迁移完的 locale 目录中继续工作。
candidate_paths.append((base_dir / f"{normalized_name}{suffix}").resolve()) fallback_path = (base_dir / f"{normalized_name}{suffix}").resolve()
for candidate_path in candidate_paths: if fallback_path.is_file():
if candidate_path.is_file(): return fallback_path
return candidate_path
else: else:
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root) prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
if normalized_name in prompt_paths: if normalized_name in prompt_paths:
@@ -149,6 +158,31 @@ def resolve_prompt_path(name: str, locale: str | None = None, category: str | No
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name)) raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
@lru_cache(maxsize=None)
def _read_prompt_template(prompt_path: Path) -> str:
return prompt_path.read_text(encoding="utf-8")
def _format_prompt_template(name: str, template: str, **kwargs: object) -> str:
if not kwargs:
return template
try:
return template.format(**kwargs)
except KeyError as exc:
missing_placeholder = exc.args[0]
error = KeyError(t("prompt.missing_placeholder", name=name, placeholder=missing_placeholder))
if is_strict_prompt_i18n_mode():
raise error from exc
logger.error("%s", error)
return template
except Exception as exc:
logger.error(t("prompt.format_failed", name=name, error=exc))
if is_strict_prompt_i18n_mode():
raise
return template
def load_prompt( def load_prompt(
name: str, name: str,
locale: str | None = None, locale: str | None = None,
@@ -156,40 +190,11 @@ def load_prompt(
prompts_root: Path | None = None, prompts_root: Path | None = None,
**kwargs: object, **kwargs: object,
) -> str: ) -> str:
prompt_path = resolve_prompt_path(name=name, locale=locale, category=category, prompts_root=prompts_root) normalized_name = normalize_prompt_name(name)
with _cache_lock: prompt_path = resolve_prompt_path(name=normalized_name, locale=locale, category=category, prompts_root=prompts_root)
template = _prompt_cache.get(prompt_path) template = _read_prompt_template(prompt_path)
if template is None: return _format_prompt_template(normalized_name, template, **kwargs)
template = prompt_path.read_text(encoding="utf-8")
with _cache_lock:
_prompt_cache.setdefault(prompt_path, template)
template = _prompt_cache[prompt_path]
if not kwargs:
return template
try:
return template.format(**kwargs)
except KeyError as exc:
missing_placeholder = exc.args[0]
error = KeyError(
t(
"prompt.missing_placeholder",
name=normalize_prompt_name(name),
placeholder=missing_placeholder,
)
)
if is_strict_prompt_i18n_mode():
raise error from exc
logger.error("%s", error)
return template
except Exception as exc:
logger.error(t("prompt.format_failed", name=normalize_prompt_name(name), error=exc))
if is_strict_prompt_i18n_mode():
raise
return template
def clear_prompt_cache() -> None: def clear_prompt_cache() -> None:
with _cache_lock: _read_prompt_template.cache_clear()
_prompt_cache.clear()

View File

@@ -1,12 +1,12 @@
from collections.abc import Callable, Coroutine
from pathlib import Path from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Optional from typing import Any, Optional
from collections.abc import Callable, Coroutine
import inspect import inspect
from src.common.prompt_i18n import list_prompt_templates, load_prompt
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.prompt_i18n import list_prompt_templates, load_prompt
logger = get_logger("Prompt") logger = get_logger("Prompt")
@@ -14,7 +14,7 @@ logger = get_logger("Prompt")
_LEFT_BRACE = "\ufde9" _LEFT_BRACE = "\ufde9"
_RIGHT_BRACE = "\ufdea" _RIGHT_BRACE = "\ufdea"
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve() PROJECT_ROOT = Path(__file__).resolve().parents[2]
PROMPTS_DIR = PROJECT_ROOT / "prompts" PROMPTS_DIR = PROJECT_ROOT / "prompts"
DATA_DIR = PROJECT_ROOT / "data" DATA_DIR = PROJECT_ROOT / "data"
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts" CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
@@ -240,18 +240,23 @@ class PromptManager:
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
try: try:
prompt_file.unlink() prompt_file.unlink()
except Exception as e: except Exception as exc:
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise e raise
for prompt_name in self._prompt_to_save: for prompt_name in self._prompt_to_save:
prompt = self.prompts[prompt_name] prompt = self.prompts[prompt_name]
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}" file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
try: try:
with open(file_path, "w", encoding="utf-8") as f: file_path.write_text(prompt.template, encoding="utf-8")
f.write(prompt.template) except Exception as exc:
except Exception as e: logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {exc}")
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}") raise
raise e
def _load_prompt_template(self, prompt_name: str) -> tuple[str, bool]:
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
if custom_prompt_path.exists():
return custom_prompt_path.read_text(encoding="utf-8"), True
return load_prompt(prompt_name, prompts_root=PROMPTS_DIR), False
def load_prompts(self) -> None: def load_prompts(self) -> None:
""" """
@@ -259,34 +264,23 @@ class PromptManager:
Raises: Raises:
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常 Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
""" """
prompt_files = list_prompt_templates(prompts_root=PROMPTS_DIR) prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
for prompt_name, prompt_file in prompt_files.items(): for prompt_name, prompt_file in prompt_templates.items():
try: try:
prompt_to_load = prompt_file template, need_save = self._load_prompt_template(prompt_name)
need_save = False
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
if custom_prompt_path.exists():
# 优先加载自定义目录下的 Prompt 文件
prompt_to_load = custom_prompt_path
need_save = True
with open(prompt_to_load, "r", encoding="utf-8") as f:
template = f.read()
else:
template = load_prompt(prompt_name, prompts_root=PROMPTS_DIR)
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save) self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
except Exception as e: except Exception as exc:
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise e raise
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
if prompt_file.stem in prompt_files: if prompt_file.stem in prompt_templates:
continue # 已经加载过了,跳过 continue # 已经加载过了,跳过
try: try:
with open(prompt_file, "r", encoding="utf-8") as f: template = prompt_file.read_text(encoding="utf-8")
template = f.read()
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True) self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
except Exception as e: except Exception as exc:
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
raise e raise
async def _get_function_result( async def _get_function_result(
self, self,
@@ -301,12 +295,12 @@ class PromptManager:
if isinstance(res, Coroutine): if isinstance(res, Coroutine):
res = await res res = await res
return res return res
except Exception as e: except Exception as exc:
if is_prompt_context: if is_prompt_context:
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}") logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {exc}")
else: else:
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}") logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {exc}")
raise e raise
prompt_manager = PromptManager() prompt_manager = PromptManager()