feat: 增强国际化验证功能,添加对共享翻译字符串的支持,优化提示模板加载逻辑
This commit is contained in:
@@ -21,8 +21,8 @@ def write_locale_file(locales_root: Path, locale: str, file_name: str, payload:
|
|||||||
|
|
||||||
def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
|
def test_validate_json_locales_rejects_han_characters_in_english_locale(tmp_path: Path) -> None:
|
||||||
locales_root = tmp_path / "locales"
|
locales_root = tmp_path / "locales"
|
||||||
write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": "输入\"同意\"继续"})
|
write_locale_file(locales_root, "zh-CN", "core.json", {"consent.prompt": '输入"同意"继续'})
|
||||||
write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": "Type \"confirmed\" or \"同意\" to continue"})
|
write_locale_file(locales_root, "en-US", "core.json", {"consent.prompt": 'Type "confirmed" or "同意" to continue'})
|
||||||
|
|
||||||
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
||||||
|
|
||||||
@@ -37,3 +37,34 @@ def test_validate_json_locales_rejects_untranslated_han_source_in_other_target_l
|
|||||||
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
||||||
|
|
||||||
assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
|
assert any("greeting" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_json_locales_avoids_false_positive_when_plural_categories_do_not_align(tmp_path: Path) -> None:
|
||||||
|
locales_root = tmp_path / "locales"
|
||||||
|
write_locale_file(
|
||||||
|
locales_root,
|
||||||
|
"zh-CN",
|
||||||
|
"core.json",
|
||||||
|
{
|
||||||
|
"tasks.cancelled": {
|
||||||
|
"one": "中文单数",
|
||||||
|
"other": "中文复数",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
write_locale_file(
|
||||||
|
locales_root,
|
||||||
|
"ja",
|
||||||
|
"core.json",
|
||||||
|
{
|
||||||
|
"tasks.cancelled": {
|
||||||
|
"many": "中文单数",
|
||||||
|
"other": "已翻译",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = I18N_VALIDATE.validate_json_locales(locales_root)
|
||||||
|
|
||||||
|
assert any("tasks.cancelled" in error and "plural category 不一致" in error for error in errors)
|
||||||
|
assert not any("tasks.cancelled" in error and "直接保留了包含中文字符的 source 文案" in error for error in errors)
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ from src.common.i18n.loaders import ( # noqa: E402
|
|||||||
)
|
)
|
||||||
from src.common.i18n.loaders import extract_placeholders # noqa: E402
|
from src.common.i18n.loaders import extract_placeholders # noqa: E402
|
||||||
from src.common.prompt_i18n import ( # noqa: E402
|
from src.common.prompt_i18n import ( # noqa: E402
|
||||||
PROMPT_EXTENSIONS,
|
discover_prompt_locales,
|
||||||
extract_prompt_placeholders,
|
extract_prompt_placeholders,
|
||||||
get_prompts_root,
|
get_prompts_root,
|
||||||
|
iter_prompt_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
|
HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
|
||||||
@@ -36,6 +37,18 @@ def iter_translation_strings(value: TranslationValue) -> list[str]:
|
|||||||
return [value[category] for category in sorted(value.keys())]
|
return [value[category] for category in sorted(value.keys())]
|
||||||
|
|
||||||
|
|
||||||
|
def iter_shared_translation_strings(
|
||||||
|
source_value: TranslationValue, target_value: TranslationValue
|
||||||
|
) -> list[tuple[str, str]]:
|
||||||
|
if isinstance(source_value, str) or isinstance(target_value, str):
|
||||||
|
if isinstance(source_value, str) and isinstance(target_value, str):
|
||||||
|
return [(source_value, target_value)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
shared_categories = sorted(set(source_value.keys()) & set(target_value.keys()))
|
||||||
|
return [(source_value[category], target_value[category]) for category in shared_categories]
|
||||||
|
|
||||||
|
|
||||||
def locale_requires_latin_only_validation(locale: str) -> bool:
|
def locale_requires_latin_only_validation(locale: str) -> bool:
|
||||||
normalized_locale = locale.lower()
|
normalized_locale = locale.lower()
|
||||||
return normalized_locale == "en" or normalized_locale.startswith("en-")
|
return normalized_locale == "en" or normalized_locale.startswith("en-")
|
||||||
@@ -48,14 +61,15 @@ def validate_locale_content(
|
|||||||
locale: str,
|
locale: str,
|
||||||
errors: list[str],
|
errors: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
source_texts = iter_translation_strings(source_value)
|
|
||||||
target_texts = iter_translation_strings(target_value)
|
target_texts = iter_translation_strings(target_value)
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
source_text == target_text and contains_han_characters(source_text)
|
source_text == target_text and contains_han_characters(source_text)
|
||||||
for source_text, target_text in zip(source_texts, target_texts, strict=False)
|
for source_text, target_text in iter_shared_translation_strings(source_value, target_value)
|
||||||
):
|
):
|
||||||
errors.append(f"[{locale}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译")
|
errors.append(
|
||||||
|
f"[{locale}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译"
|
||||||
|
)
|
||||||
|
|
||||||
if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts):
|
if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts):
|
||||||
errors.append(f"[{locale}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交")
|
errors.append(f"[{locale}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交")
|
||||||
@@ -121,12 +135,9 @@ def validate_json_locales(locales_root: Path | None = None) -> list[str]:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
locale_keys = set(catalog.keys())
|
locale_keys = set(catalog.keys())
|
||||||
missing_keys = sorted(source_keys - locale_keys)
|
for key in sorted(source_keys - locale_keys):
|
||||||
extra_keys = sorted(locale_keys - source_keys)
|
|
||||||
|
|
||||||
for key in missing_keys:
|
|
||||||
errors.append(f"[{locale}] 缺少 key: {key}")
|
errors.append(f"[{locale}] 缺少 key: {key}")
|
||||||
for key in extra_keys:
|
for key in sorted(locale_keys - source_keys):
|
||||||
errors.append(f"[{locale}] 存在多余 key: {key}")
|
errors.append(f"[{locale}] 存在多余 key: {key}")
|
||||||
|
|
||||||
for key in sorted(source_keys & locale_keys):
|
for key in sorted(source_keys & locale_keys):
|
||||||
@@ -139,25 +150,13 @@ def validate_json_locales(locales_root: Path | None = None) -> list[str]:
|
|||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
def discover_prompt_locales(prompts_root: Path | None = None) -> list[str]:
|
def build_prompt_catalog(locale_dir: Path) -> dict[Path, Path]:
|
||||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
return {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)}
|
||||||
if not resolved_prompts_root.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
locale_names = [path.name for path in resolved_prompts_root.iterdir() if path.is_dir()]
|
|
||||||
return sorted(locale_names)
|
|
||||||
|
|
||||||
|
|
||||||
def iter_prompt_files(locale_dir: Path) -> list[Path]:
|
|
||||||
prompt_files: list[Path] = []
|
|
||||||
for extension in PROMPT_EXTENSIONS:
|
|
||||||
prompt_files.extend(path for path in locale_dir.rglob(f"*{extension}") if path.is_file())
|
|
||||||
return sorted(set(prompt_files))
|
|
||||||
|
|
||||||
|
|
||||||
def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]:
|
def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]:
|
||||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||||
prompt_locales = discover_prompt_locales(resolved_prompts_root)
|
prompt_locales = set(discover_prompt_locales(resolved_prompts_root))
|
||||||
known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE]
|
known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE]
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
warnings: list[str] = []
|
warnings: list[str] = []
|
||||||
@@ -167,7 +166,8 @@ def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[st
|
|||||||
return errors, warnings
|
return errors, warnings
|
||||||
|
|
||||||
source_dir = resolved_prompts_root / DEFAULT_LOCALE
|
source_dir = resolved_prompts_root / DEFAULT_LOCALE
|
||||||
source_files = {path.relative_to(source_dir): path for path in iter_prompt_files(source_dir)}
|
source_files = build_prompt_catalog(source_dir)
|
||||||
|
source_relative_paths = set(source_files.keys())
|
||||||
|
|
||||||
for locale in known_locales:
|
for locale in known_locales:
|
||||||
locale_dir = resolved_prompts_root / locale
|
locale_dir = resolved_prompts_root / locale
|
||||||
@@ -175,8 +175,7 @@ def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[st
|
|||||||
warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}")
|
warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
locale_files = {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)}
|
locale_files = build_prompt_catalog(locale_dir)
|
||||||
source_relative_paths = set(source_files.keys())
|
|
||||||
locale_relative_paths = set(locale_files.keys())
|
locale_relative_paths = set(locale_files.keys())
|
||||||
|
|
||||||
for relative_path in sorted(source_relative_paths - locale_relative_paths):
|
for relative_path in sorted(source_relative_paths - locale_relative_paths):
|
||||||
|
|||||||
@@ -3,18 +3,16 @@ from __future__ import annotations
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from .loaders import DEFAULT_LOCALE
|
from .loaders import DEFAULT_LOCALE
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def _get_manager():
|
def _get_manager():
|
||||||
from .manager import I18nManager
|
from .manager import I18nManager
|
||||||
|
|
||||||
manager = getattr(_get_manager, "_manager", None)
|
return I18nManager()
|
||||||
if manager is None:
|
|
||||||
manager = I18nManager()
|
|
||||||
_get_manager._manager = manager
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
def set_locale(locale: str) -> str:
|
def set_locale(locale: str) -> str:
|
||||||
|
|||||||
@@ -67,22 +67,12 @@ class I18nManager:
|
|||||||
self._catalog_cache.pop(normalize_locale(locale), None)
|
self._catalog_cache.pop(normalize_locale(locale), None)
|
||||||
|
|
||||||
def t(self, key: str, locale: str | None = None, **kwargs: object) -> str:
|
def t(self, key: str, locale: str | None = None, **kwargs: object) -> str:
|
||||||
translation_value, _ = self._get_translation_value(key, locale)
|
translation_value, translation_locale = self._get_translation_value(key, locale)
|
||||||
if translation_value is None:
|
template = self._get_standard_template(key, translation_value, translation_locale)
|
||||||
|
if template is None:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
if isinstance(translation_value, dict):
|
return self._format_translation(key, template, kwargs)
|
||||||
template = translation_value.get("other")
|
|
||||||
if template is None:
|
|
||||||
self._log_once(
|
|
||||||
("plural_missing_other", self.get_locale(), key),
|
|
||||||
logging.WARNING,
|
|
||||||
"翻译 key '%s' 缺少 other plural category,已回退到 key 本身",
|
|
||||||
key,
|
|
||||||
)
|
|
||||||
return key
|
|
||||||
return self._format_translation(key, template, kwargs)
|
|
||||||
return self._format_translation(key, translation_value, kwargs)
|
|
||||||
|
|
||||||
def tn(self, key: str, count: int | float, locale: str | None = None, **kwargs: object) -> str:
|
def tn(self, key: str, count: int | float, locale: str | None = None, **kwargs: object) -> str:
|
||||||
translation_value, translation_locale = self._get_translation_value(key, locale)
|
translation_value, translation_locale = self._get_translation_value(key, locale)
|
||||||
@@ -118,6 +108,27 @@ class I18nManager:
|
|||||||
formatting_kwargs["count"] = count
|
formatting_kwargs["count"] = count
|
||||||
return self._format_translation(key, template, formatting_kwargs)
|
return self._format_translation(key, template, formatting_kwargs)
|
||||||
|
|
||||||
|
def _get_standard_template(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
translation_value: TranslationValue | None,
|
||||||
|
translation_locale: str,
|
||||||
|
) -> str | None:
|
||||||
|
if translation_value is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(translation_value, dict):
|
||||||
|
return translation_value
|
||||||
|
|
||||||
|
template = translation_value.get("other")
|
||||||
|
if template is None:
|
||||||
|
self._log_once(
|
||||||
|
("plural_missing_other", translation_locale, key),
|
||||||
|
logging.WARNING,
|
||||||
|
"翻译 key '%s' 缺少 other plural category,已回退到 key 本身",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return template
|
||||||
|
|
||||||
def _format_translation(self, key: str, template: str, kwargs: dict[str, object]) -> str:
|
def _format_translation(self, key: str, template: str, kwargs: dict[str, object]) -> str:
|
||||||
try:
|
try:
|
||||||
return format_template(template, **kwargs)
|
return format_template(template, **kwargs)
|
||||||
@@ -161,14 +172,15 @@ class I18nManager:
|
|||||||
try:
|
try:
|
||||||
return normalize_locale(locale)
|
return normalize_locale(locale)
|
||||||
except InvalidLocaleError:
|
except InvalidLocaleError:
|
||||||
|
current_locale = self.get_locale()
|
||||||
self._log_once(
|
self._log_once(
|
||||||
("invalid_locale", "explicit", locale),
|
("invalid_locale", "explicit", locale),
|
||||||
logging.WARNING,
|
logging.WARNING,
|
||||||
"检测到非法 locale='%s',已回退到当前默认 locale %s",
|
"检测到非法 locale='%s',已回退到当前默认 locale %s",
|
||||||
locale,
|
locale,
|
||||||
self.get_locale(),
|
current_locale,
|
||||||
)
|
)
|
||||||
return self.get_locale()
|
return current_locale
|
||||||
|
|
||||||
def _get_catalog(self, locale: str) -> dict[str, TranslationValue]:
|
def _get_catalog(self, locale: str) -> dict[str, TranslationValue]:
|
||||||
normalized_locale = normalize_locale(locale)
|
normalized_locale = normalize_locale(locale)
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
|
|
||||||
from .i18n import get_locale, t
|
from .i18n import get_locale, t
|
||||||
from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders as extract_prompt_placeholders, normalize_locale
|
from .i18n.loaders import DEFAULT_LOCALE, extract_placeholders, normalize_locale
|
||||||
|
|
||||||
logger = logging.getLogger("maibot.prompt_i18n")
|
logger = logging.getLogger("maibot.prompt_i18n")
|
||||||
|
|
||||||
@@ -17,9 +17,9 @@ PROMPTS_ROOT = (PROJECT_ROOT / "prompts").resolve()
|
|||||||
PROMPT_EXTENSIONS = (".prompt",)
|
PROMPT_EXTENSIONS = (".prompt",)
|
||||||
SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")
|
SAFE_SEGMENT_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")
|
||||||
STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT")
|
STRICT_ENV_KEYS = ("MAIBOT_PROMPT_I18N_STRICT", "MAIBOT_I18N_STRICT")
|
||||||
|
STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
_prompt_cache: dict[Path, str] = {}
|
extract_prompt_placeholders = extract_placeholders
|
||||||
_cache_lock = threading.RLock()
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompts_root(prompts_root: Path | None = None) -> Path:
|
def get_prompts_root(prompts_root: Path | None = None) -> Path:
|
||||||
@@ -56,91 +56,100 @@ def is_strict_prompt_i18n_mode() -> bool:
|
|||||||
if os.getenv("PYTEST_CURRENT_TEST"):
|
if os.getenv("PYTEST_CURRENT_TEST"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return any(os.getenv(env_key, "").strip().lower() in {"1", "true", "yes", "on"} for env_key in STRICT_ENV_KEYS)
|
return any(os.getenv(env_key, "").strip().lower() in STRICT_ENV_VALUES for env_key in STRICT_ENV_KEYS)
|
||||||
|
|
||||||
|
|
||||||
def _supported_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
|
def discover_prompt_locales(prompts_root: Path | None = None) -> list[str]:
|
||||||
|
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||||
|
if not resolved_prompts_root.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
locale_names = [path.name for path in resolved_prompts_root.iterdir() if path.is_dir()]
|
||||||
|
return sorted(locale_names)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
|
||||||
|
if not directory.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
search = directory.rglob if recursive else directory.glob
|
search = directory.rglob if recursive else directory.glob
|
||||||
matched_files: list[Path] = []
|
prompt_files: list[Path] = []
|
||||||
for suffix in PROMPT_EXTENSIONS:
|
for suffix in PROMPT_EXTENSIONS:
|
||||||
matched_files.extend(path for path in search(f"*{suffix}") if path.is_file())
|
prompt_files.extend(path for path in search(f"*{suffix}") if path.is_file())
|
||||||
return sorted(set(matched_files))
|
return sorted(set(prompt_files))
|
||||||
|
|
||||||
|
|
||||||
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
|
def _raise_duplicate_prompt_name(name: str, first_path: Path, second_path: Path, prompts_root: Path) -> None:
|
||||||
|
raise ValueError(
|
||||||
|
t(
|
||||||
|
"prompt.duplicate_template_name",
|
||||||
|
name=name,
|
||||||
|
path_a=first_path.relative_to(prompts_root),
|
||||||
|
path_b=second_path.relative_to(prompts_root),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_prompt_directory(directory: Path, prompts_root: Path, recursive: bool = True) -> dict[str, Path]:
|
||||||
prompt_paths: dict[str, Path] = {}
|
prompt_paths: dict[str, Path] = {}
|
||||||
if not directory.exists():
|
for prompt_path in iter_prompt_files(directory, recursive=recursive):
|
||||||
return prompt_paths
|
|
||||||
|
|
||||||
for prompt_path in _supported_prompt_files(directory):
|
|
||||||
prompt_name = prompt_path.stem
|
prompt_name = prompt_path.stem
|
||||||
if prompt_name in prompt_paths:
|
existing_path = prompt_paths.get(prompt_name)
|
||||||
raise ValueError(
|
if existing_path is not None:
|
||||||
t(
|
_raise_duplicate_prompt_name(prompt_name, existing_path, prompt_path, prompts_root)
|
||||||
"prompt.duplicate_template_name",
|
|
||||||
name=prompt_name,
|
|
||||||
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
|
|
||||||
path_b=prompt_path.relative_to(prompts_root),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prompt_paths[prompt_name] = prompt_path
|
prompt_paths[prompt_name] = prompt_path
|
||||||
return prompt_paths
|
return prompt_paths
|
||||||
|
|
||||||
|
|
||||||
def _scan_legacy_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
|
def _iter_prompt_template_layers(prompts_root: Path, requested_locale: str) -> list[tuple[Path, bool]]:
|
||||||
prompt_paths: dict[str, Path] = {}
|
prompt_layers: list[tuple[Path, bool]] = [
|
||||||
if not directory.exists():
|
(prompts_root, False),
|
||||||
return prompt_paths
|
(prompts_root / DEFAULT_LOCALE, True),
|
||||||
|
]
|
||||||
|
if requested_locale != DEFAULT_LOCALE:
|
||||||
|
prompt_layers.append((prompts_root / requested_locale, True))
|
||||||
|
return prompt_layers
|
||||||
|
|
||||||
for prompt_path in _supported_prompt_files(directory, recursive=False):
|
|
||||||
prompt_name = prompt_path.stem
|
def _iter_locale_candidates(requested_locale: str) -> list[str | None]:
|
||||||
if prompt_name in prompt_paths:
|
locale_candidates: list[str | None] = [requested_locale]
|
||||||
raise ValueError(
|
if requested_locale != DEFAULT_LOCALE:
|
||||||
t(
|
locale_candidates.append(DEFAULT_LOCALE)
|
||||||
"prompt.duplicate_template_name",
|
locale_candidates.append(None)
|
||||||
name=prompt_name,
|
return locale_candidates
|
||||||
path_a=prompt_paths[prompt_name].relative_to(prompts_root),
|
|
||||||
path_b=prompt_path.relative_to(prompts_root),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prompt_paths[prompt_name] = prompt_path
|
|
||||||
return prompt_paths
|
|
||||||
|
|
||||||
|
|
||||||
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
|
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
|
||||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||||
requested_locale = normalize_locale(locale or get_locale())
|
requested_locale = normalize_locale(locale or get_locale())
|
||||||
|
|
||||||
prompt_paths = _scan_legacy_prompt_directory(resolved_prompts_root, resolved_prompts_root)
|
prompt_paths: dict[str, Path] = {}
|
||||||
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / DEFAULT_LOCALE, resolved_prompts_root))
|
for directory, recursive in _iter_prompt_template_layers(resolved_prompts_root, requested_locale):
|
||||||
|
prompt_paths.update(_scan_prompt_directory(directory, resolved_prompts_root, recursive=recursive))
|
||||||
if requested_locale != DEFAULT_LOCALE:
|
|
||||||
prompt_paths.update(_scan_prompt_directory(resolved_prompts_root / requested_locale, resolved_prompts_root))
|
|
||||||
|
|
||||||
return prompt_paths
|
return prompt_paths
|
||||||
|
|
||||||
|
|
||||||
def resolve_prompt_path(name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None) -> Path:
|
def resolve_prompt_path(
|
||||||
|
name: str, locale: str | None = None, category: str | None = None, prompts_root: Path | None = None
|
||||||
|
) -> Path:
|
||||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||||
normalized_name = normalize_prompt_name(name)
|
normalized_name = normalize_prompt_name(name)
|
||||||
normalized_category = normalize_prompt_category(category)
|
normalized_category = normalize_prompt_category(category)
|
||||||
requested_locale = normalize_locale(locale or get_locale())
|
requested_locale = normalize_locale(locale or get_locale())
|
||||||
|
|
||||||
locale_candidates: list[str | None] = [requested_locale]
|
|
||||||
if requested_locale != DEFAULT_LOCALE:
|
|
||||||
locale_candidates.append(DEFAULT_LOCALE)
|
|
||||||
locale_candidates.append(None)
|
|
||||||
|
|
||||||
if normalized_category is not None:
|
if normalized_category is not None:
|
||||||
for locale_candidate in locale_candidates:
|
for locale_candidate in _iter_locale_candidates(requested_locale):
|
||||||
base_dir = resolved_prompts_root if locale_candidate is None else resolved_prompts_root / locale_candidate
|
base_dir = resolved_prompts_root if locale_candidate is None else resolved_prompts_root / locale_candidate
|
||||||
for suffix in PROMPT_EXTENSIONS:
|
for suffix in PROMPT_EXTENSIONS:
|
||||||
candidate_paths = [(base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()]
|
candidate_path = (base_dir / normalized_category / f"{normalized_name}{suffix}").resolve()
|
||||||
|
if candidate_path.is_file():
|
||||||
|
return candidate_path
|
||||||
|
|
||||||
# 允许带 category 的调用在旧版平铺目录或未迁移完的 locale 目录中继续工作。
|
# 允许带 category 的调用在旧版平铺目录或未迁移完的 locale 目录中继续工作。
|
||||||
candidate_paths.append((base_dir / f"{normalized_name}{suffix}").resolve())
|
fallback_path = (base_dir / f"{normalized_name}{suffix}").resolve()
|
||||||
for candidate_path in candidate_paths:
|
if fallback_path.is_file():
|
||||||
if candidate_path.is_file():
|
return fallback_path
|
||||||
return candidate_path
|
|
||||||
else:
|
else:
|
||||||
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
|
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
|
||||||
if normalized_name in prompt_paths:
|
if normalized_name in prompt_paths:
|
||||||
@@ -149,6 +158,31 @@ def resolve_prompt_path(name: str, locale: str | None = None, category: str | No
|
|||||||
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
|
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _read_prompt_template(prompt_path: Path) -> str:
|
||||||
|
return prompt_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_prompt_template(name: str, template: str, **kwargs: object) -> str:
|
||||||
|
if not kwargs:
|
||||||
|
return template
|
||||||
|
|
||||||
|
try:
|
||||||
|
return template.format(**kwargs)
|
||||||
|
except KeyError as exc:
|
||||||
|
missing_placeholder = exc.args[0]
|
||||||
|
error = KeyError(t("prompt.missing_placeholder", name=name, placeholder=missing_placeholder))
|
||||||
|
if is_strict_prompt_i18n_mode():
|
||||||
|
raise error from exc
|
||||||
|
logger.error("%s", error)
|
||||||
|
return template
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(t("prompt.format_failed", name=name, error=exc))
|
||||||
|
if is_strict_prompt_i18n_mode():
|
||||||
|
raise
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(
|
def load_prompt(
|
||||||
name: str,
|
name: str,
|
||||||
locale: str | None = None,
|
locale: str | None = None,
|
||||||
@@ -156,40 +190,11 @@ def load_prompt(
|
|||||||
prompts_root: Path | None = None,
|
prompts_root: Path | None = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> str:
|
) -> str:
|
||||||
prompt_path = resolve_prompt_path(name=name, locale=locale, category=category, prompts_root=prompts_root)
|
normalized_name = normalize_prompt_name(name)
|
||||||
with _cache_lock:
|
prompt_path = resolve_prompt_path(name=normalized_name, locale=locale, category=category, prompts_root=prompts_root)
|
||||||
template = _prompt_cache.get(prompt_path)
|
template = _read_prompt_template(prompt_path)
|
||||||
if template is None:
|
return _format_prompt_template(normalized_name, template, **kwargs)
|
||||||
template = prompt_path.read_text(encoding="utf-8")
|
|
||||||
with _cache_lock:
|
|
||||||
_prompt_cache.setdefault(prompt_path, template)
|
|
||||||
template = _prompt_cache[prompt_path]
|
|
||||||
|
|
||||||
if not kwargs:
|
|
||||||
return template
|
|
||||||
|
|
||||||
try:
|
|
||||||
return template.format(**kwargs)
|
|
||||||
except KeyError as exc:
|
|
||||||
missing_placeholder = exc.args[0]
|
|
||||||
error = KeyError(
|
|
||||||
t(
|
|
||||||
"prompt.missing_placeholder",
|
|
||||||
name=normalize_prompt_name(name),
|
|
||||||
placeholder=missing_placeholder,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_strict_prompt_i18n_mode():
|
|
||||||
raise error from exc
|
|
||||||
logger.error("%s", error)
|
|
||||||
return template
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(t("prompt.format_failed", name=normalize_prompt_name(name), error=exc))
|
|
||||||
if is_strict_prompt_i18n_mode():
|
|
||||||
raise
|
|
||||||
return template
|
|
||||||
|
|
||||||
|
|
||||||
def clear_prompt_cache() -> None:
|
def clear_prompt_cache() -> None:
|
||||||
with _cache_lock:
|
_read_prompt_template.cache_clear()
|
||||||
_prompt_cache.clear()
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from collections.abc import Callable, Coroutine
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.prompt_i18n import list_prompt_templates, load_prompt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("Prompt")
|
logger = get_logger("Prompt")
|
||||||
@@ -14,7 +14,7 @@ logger = get_logger("Prompt")
|
|||||||
_LEFT_BRACE = "\ufde9"
|
_LEFT_BRACE = "\ufde9"
|
||||||
_RIGHT_BRACE = "\ufdea"
|
_RIGHT_BRACE = "\ufdea"
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||||
DATA_DIR = PROJECT_ROOT / "data"
|
DATA_DIR = PROJECT_ROOT / "data"
|
||||||
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
|
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
|
||||||
@@ -240,18 +240,23 @@ class PromptManager:
|
|||||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||||
try:
|
try:
|
||||||
prompt_file.unlink()
|
prompt_file.unlink()
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||||
raise e
|
raise
|
||||||
for prompt_name in self._prompt_to_save:
|
for prompt_name in self._prompt_to_save:
|
||||||
prompt = self.prompts[prompt_name]
|
prompt = self.prompts[prompt_name]
|
||||||
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||||
try:
|
try:
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
file_path.write_text(prompt.template, encoding="utf-8")
|
||||||
f.write(prompt.template)
|
except Exception as exc:
|
||||||
except Exception as e:
|
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {exc}")
|
||||||
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}")
|
raise
|
||||||
raise e
|
|
||||||
|
def _load_prompt_template(self, prompt_name: str) -> tuple[str, bool]:
|
||||||
|
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||||
|
if custom_prompt_path.exists():
|
||||||
|
return custom_prompt_path.read_text(encoding="utf-8"), True
|
||||||
|
return load_prompt(prompt_name, prompts_root=PROMPTS_DIR), False
|
||||||
|
|
||||||
def load_prompts(self) -> None:
|
def load_prompts(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -259,34 +264,23 @@ class PromptManager:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
|
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
|
||||||
"""
|
"""
|
||||||
prompt_files = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
||||||
for prompt_name, prompt_file in prompt_files.items():
|
for prompt_name, prompt_file in prompt_templates.items():
|
||||||
try:
|
try:
|
||||||
prompt_to_load = prompt_file
|
template, need_save = self._load_prompt_template(prompt_name)
|
||||||
need_save = False
|
|
||||||
custom_prompt_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
|
||||||
if custom_prompt_path.exists():
|
|
||||||
# 优先加载自定义目录下的 Prompt 文件
|
|
||||||
prompt_to_load = custom_prompt_path
|
|
||||||
need_save = True
|
|
||||||
with open(prompt_to_load, "r", encoding="utf-8") as f:
|
|
||||||
template = f.read()
|
|
||||||
else:
|
|
||||||
template = load_prompt(prompt_name, prompts_root=PROMPTS_DIR)
|
|
||||||
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
|
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||||
raise e
|
raise
|
||||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||||
if prompt_file.stem in prompt_files:
|
if prompt_file.stem in prompt_templates:
|
||||||
continue # 已经加载过了,跳过
|
continue # 已经加载过了,跳过
|
||||||
try:
|
try:
|
||||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
template = prompt_file.read_text(encoding="utf-8")
|
||||||
template = f.read()
|
|
||||||
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
async def _get_function_result(
|
async def _get_function_result(
|
||||||
self,
|
self,
|
||||||
@@ -301,12 +295,12 @@ class PromptManager:
|
|||||||
if isinstance(res, Coroutine):
|
if isinstance(res, Coroutine):
|
||||||
res = await res
|
res = await res
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
if is_prompt_context:
|
if is_prompt_context:
|
||||||
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}")
|
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {exc}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}")
|
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {exc}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
|
||||||
prompt_manager = PromptManager()
|
prompt_manager = PromptManager()
|
||||||
|
|||||||
Reference in New Issue
Block a user